Skip to content

Commit ccf4709

Browse files
committed
Swtich from Hydra call to hydra instantiate.
1 parent b720b19 commit ccf4709

File tree

6 files changed

+20
-20
lines changed

6 files changed

+20
-20
lines changed

src/imitation_cli/airl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import hydra
88
import torch as th
99
from hydra.core.config_store import ConfigStore
10-
from hydra.utils import call
10+
from hydra.utils import instantiate
1111
from omegaconf import MISSING
1212

1313
from imitation.policies import serialize
@@ -65,7 +65,7 @@ def run_airl(cfg: RunConfig) -> Dict[str, Any]:
6565
from imitation.data import rollout
6666
from imitation.data.types import TrajectoryWithRew
6767

68-
trainer: airl.AIRL = call(cfg.airl)
68+
trainer: airl.AIRL = instantiate(cfg.airl)
6969

7070
checkpoints_path = pathlib.Path("checkpoints")
7171

src/imitation_cli/utils/environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from stable_baselines3.common.vec_env import VecEnv
1010

1111
from hydra.core.config_store import ConfigStore
12-
from hydra.utils import call
12+
from hydra.utils import instantiate
1313
from omegaconf import MISSING
1414

1515
from imitation_cli.utils import randomness
@@ -40,7 +40,7 @@ def make(log_dir: Optional[str] = None, **kwargs) -> VecEnv:
4040
def make_rollout_venv(environment_config: Config) -> VecEnv:
4141
from imitation.data import wrappers
4242

43-
return call(
43+
return instantiate(
4444
environment_config,
4545
log_dir=None,
4646
post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)],

src/imitation_cli/utils/policy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from stable_baselines3.common.policies import BasePolicy
1212

1313
from hydra.core.config_store import ConfigStore
14-
from hydra.utils import call
14+
from hydra.utils import instantiate
1515
from omegaconf import MISSING
1616

1717
from imitation_cli.utils import activation_function_class as act_fun_class_cfg
@@ -91,9 +91,9 @@ def make_args(
9191
del kwargs["_target_"]
9292
del kwargs["environment"]
9393

94-
kwargs["activation_fn"] = call(activation_fn)
95-
kwargs["features_extractor_class"] = call(features_extractor_class)
96-
kwargs["optimizer_class"] = call(optimizer_class)
94+
kwargs["activation_fn"] = instantiate(activation_fn)
95+
kwargs["features_extractor_class"] = instantiate(features_extractor_class)
96+
kwargs["optimizer_class"] = instantiate(optimizer_class)
9797

9898
return dict(
9999
**kwargs,
@@ -182,7 +182,7 @@ def make(
182182
model = serialize.load_stable_baselines_model(
183183
Loaded.type_to_class(policy_type),
184184
filename,
185-
call(environment),
185+
instantiate(environment),
186186
)
187187
return model.policy
188188

src/imitation_cli/utils/reward_network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from imitation.rewards.reward_nets import RewardNet
1111

1212
from hydra.core.config_store import ConfigStore
13-
from hydra.utils import call
13+
from hydra.utils import instantiate
1414
from omegaconf import MISSING
1515

1616
import imitation_cli.utils.environment as environment_cfg
@@ -99,11 +99,11 @@ def make(
9999
) -> RewardNet:
100100
from imitation.rewards import reward_nets
101101

102-
venv = call(environment)
102+
venv = instantiate(environment)
103103
reward_net = reward_nets.RewardEnsemble(
104104
venv.observation_space,
105105
venv.action_space,
106-
[call(ensemble_member_config) for _ in range(ensemble_size)],
106+
[instantiate(ensemble_member_config) for _ in range(ensemble_size)],
107107
)
108108
if add_std_alpha is not None:
109109
return reward_nets.AddSTDRewardWrapper(

src/imitation_cli/utils/rl_algorithm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import stable_baselines3 as sb3
1111
from stable_baselines3.common.vec_env import VecEnv
1212

13-
from hydra.utils import call, to_absolute_path
13+
from hydra.utils import instantiate, to_absolute_path
1414
from omegaconf import MISSING
1515

1616
from imitation_cli.utils import environment as environment_cfg
@@ -73,9 +73,9 @@ def make(
7373
return sb3.PPO(
7474
policy=sb3.common.policies.ActorCriticPolicy,
7575
policy_kwargs=policy_kwargs,
76-
env=call(environment),
77-
learning_rate=call(learning_rate),
78-
clip_range=call(clip_range),
76+
env=instantiate(environment),
77+
learning_rate=instantiate(learning_rate),
78+
clip_range=instantiate(clip_range),
7979
**kwargs,
8080
)
8181

src/imitation_cli/utils/trajectories.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313

1414
from hydra.core.config_store import ConfigStore
15-
from hydra.utils import call
15+
from hydra.utils import instantiate
1616
from omegaconf import MISSING
1717

1818
from imitation_cli.utils import environment as environment_cfg
@@ -60,9 +60,9 @@ def make(
6060
) -> Sequence[Trajectory]:
6161
from imitation.data import rollout
6262

63-
expert = call(expert_policy)
64-
env = call(expert_policy.environment)
65-
rng = call(rng)
63+
expert = instantiate(expert_policy)
64+
env = instantiate(expert_policy.environment)
65+
rng = instantiate(rng)
6666
return rollout.generate_trajectories(
6767
expert,
6868
env,

0 commit comments

Comments
 (0)