Skip to content

Commit b720b19

Browse files
committed
Formatting, typing and documentation fixes. Also the implicit seed dependency was pulled out of the utils and made explicit in airl.py
1 parent e459cd7 commit b720b19

16 files changed

+305
-103
lines changed

src/imitation_cli/airl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from imitation_cli.algorithm_configurations import airl as airl_cfg
1515
from imitation_cli.utils import environment as environment_cfg
1616
from imitation_cli.utils import (
17-
policy,
1817
policy_evaluation,
18+
randomness,
1919
reward_network,
2020
rl_algorithm,
2121
trajectories,
@@ -26,7 +26,7 @@
2626
class RunConfig:
2727
"""Config for running AIRL."""
2828

29-
seed: int = 0
29+
rng: randomness.Config = randomness.Config(seed=0)
3030
total_timesteps: int = int(1e6)
3131
checkpoint_interval: int = 0
3232

@@ -39,11 +39,11 @@ class RunConfig:
3939

4040

4141
cs = ConfigStore.instance()
42-
environment_cfg.register_configs("environment")
43-
trajectories.register_configs("airl/demonstrations", "${environment}")
44-
rl_algorithm.register_configs("airl/gen_algo", "${environment}")
42+
environment_cfg.register_configs("environment", "${rng}")
43+
trajectories.register_configs("airl/demonstrations", "${environment}", "${rng}")
44+
rl_algorithm.register_configs("airl/gen_algo", "${environment}", "${rng.seed}")
4545
reward_network.register_configs("airl/reward_net", "${environment}")
46-
policy_evaluation.register_configs("evaluation", "${environment}")
46+
policy_evaluation.register_configs("evaluation", "${environment}", "${rng}")
4747

4848
cs.store(
4949
name="airl_run_base",

src/imitation_cli/config/airl_optuna.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ hydra:
2222
sweeper:
2323
params:
2424
environment: cartpole,pendulum
25-
airl/reward_net: basic,shaped,ensemble
25+
airl/reward_net: basic,shaped,small_ensemble

src/imitation_cli/config/airl_run.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ checkpoint_interval: 1
1515
airl:
1616
demo_batch_size: 128
1717
demonstrations:
18-
total_timesteps: 10
18+
total_timesteps: 10

src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ hydra:
2121
sweeper:
2222
params:
2323
environment: cartpole,pendulum
24-
airl/reward_net: basic,shaped,ensemble
24+
airl/reward_net: basic,shaped,small_ensemble
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Configurations to be used as ingredient to algorithm configurations."""

src/imitation_cli/utils/activation_function_class.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1+
"""Classes for configuring activation functions."""
12
import dataclasses
23

34
from hydra.core.config_store import ConfigStore
45

56

67
@dataclasses.dataclass
78
class Config:
9+
"""Base class for activation function configs."""
10+
811
# Note: we don't define _target_ here so in the subclasses it can be defined last.
912
# This is the same pattern we use as in schedule.py.
1013
pass
1114

1215

1316
@dataclasses.dataclass
1417
class TanH(Config):
18+
"""Config for TanH activation function."""
19+
1520
_target_: str = "imitation_cli.utils.activation_function_class.TanH.make"
1621

1722
@staticmethod
@@ -23,6 +28,8 @@ def make() -> type:
2328

2429
@dataclasses.dataclass
2530
class ReLU(Config):
31+
"""Config for ReLU activation function."""
32+
2633
_target_: str = "imitation_cli.utils.activation_function_class.ReLU.make"
2734

2835
@staticmethod
@@ -34,10 +41,12 @@ def make() -> type:
3441

3542
@dataclasses.dataclass
3643
class LeakyReLU(Config):
44+
"""Config for LeakyReLU activation function."""
45+
3746
_target_: str = "imitation_cli.utils.activation_function_class.LeakyReLU.make"
3847

3948
@staticmethod
40-
def make() -> type:
49+
def make() -> type:
4150
import torch
4251

4352
return torch.nn.LeakyReLU
Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
"""Configuration for Gym environments."""
12
from __future__ import annotations
3+
24
import dataclasses
35
import typing
4-
from typing import Optional
6+
from typing import Optional, Union, cast
57

68
if typing.TYPE_CHECKING:
79
from stable_baselines3.common.vec_env import VecEnv
@@ -15,18 +17,21 @@
1517

1618
@dataclasses.dataclass
1719
class Config:
20+
"""Configuration for Gym environments."""
21+
1822
_target_: str = "imitation_cli.utils.environment.Config.make"
1923
env_name: str = MISSING # The environment to train on
2024
n_envs: int = 8 # number of environments in VecEnv
21-
parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv TODO: when setting this to true this is really slow for some reason
25+
# TODO: when setting this to true this is really slow for some reason
26+
parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv
2227
max_episode_steps: int = MISSING # Set to positive int to limit episode horizons
2328
env_make_kwargs: dict = dataclasses.field(
24-
default_factory=dict
29+
default_factory=dict,
2530
) # The kwargs passed to `spec.make`.
26-
rng: randomness.Config = randomness.Config()
31+
rng: randomness.Config = MISSING
2732

2833
@staticmethod
29-
def make(log_dir: Optional[str]=None, **kwargs) -> VecEnv:
34+
def make(log_dir: Optional[str] = None, **kwargs) -> VecEnv:
3035
from imitation.util import util
3136

3237
return util.make_vec_env(log_dir=log_dir, **kwargs)
@@ -38,13 +43,24 @@ def make_rollout_venv(environment_config: Config) -> VecEnv:
3843
return call(
3944
environment_config,
4045
log_dir=None,
41-
post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)]
46+
post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)],
4247
)
4348

4449

45-
def register_configs(group: str):
50+
def register_configs(
51+
group: str,
52+
default_rng: Union[randomness.Config, str] = MISSING,
53+
):
54+
default_rng = cast(randomness.Config, default_rng)
4655
cs = ConfigStore.instance()
47-
cs.store(group=group, name="gym_env", node=Config)
48-
cs.store(group=group, name="cartpole", node=Config(env_name="CartPole-v0", max_episode_steps=500))
49-
cs.store(group=group, name="pendulum", node=Config(env_name="Pendulum-v1", max_episode_steps=500))
50-
56+
cs.store(group=group, name="gym_env", node=Config(rng=default_rng))
57+
cs.store(
58+
group=group,
59+
name="cartpole",
60+
node=Config(env_name="CartPole-v0", max_episode_steps=500, rng=default_rng),
61+
)
62+
cs.store(
63+
group=group,
64+
name="pendulum",
65+
node=Config(env_name="Pendulum-v1", max_episode_steps=500, rng=default_rng),
66+
)

src/imitation_cli/utils/feature_extractor_class.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Register Hydra configs for stable_baselines3 feature extractors."""
12
import dataclasses
23

34
from hydra.core.config_store import ConfigStore
@@ -6,12 +7,18 @@
67

78
@dataclasses.dataclass
89
class Config:
10+
"""Base config for stable_baselines3 feature extractors."""
11+
912
_target_: str = MISSING
1013

1114

1215
@dataclasses.dataclass
1316
class FlattenExtractorConfig(Config):
14-
_target_: str = "imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make"
17+
"""Config for FlattenExtractor."""
18+
19+
_target_: str = (
20+
"imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make"
21+
)
1522

1623
@staticmethod
1724
def make() -> type:
@@ -22,6 +29,8 @@ def make() -> type:
2229

2330
@dataclasses.dataclass
2431
class NatureCNNConfig(Config):
32+
"""Config for NatureCNN."""
33+
2534
_target_: str = "imitation_cli.utils.feature_extractor_class.NatureCNNConfig.make"
2635

2736
@staticmethod

src/imitation_cli/utils/optimizer_class.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Register optimizer classes with Hydra."""
12
import dataclasses
23

34
from hydra.core.config_store import ConfigStore
@@ -6,11 +7,15 @@
67

78
@dataclasses.dataclass
89
class Config:
10+
"""Base config for optimizer classes."""
11+
912
_target_: str = MISSING
1013

1114

1215
@dataclasses.dataclass
1316
class Adam(Config):
17+
"""Config for Adam optimizer class."""
18+
1419
_target_: str = "imitation_cli.utils.optimizer_class.Adam.make"
1520

1621
@staticmethod
@@ -22,6 +27,8 @@ def make() -> type:
2227

2328
@dataclasses.dataclass
2429
class SGD(Config):
30+
"""Config for SGD optimizer class."""
31+
2532
_target_: str = "imitation_cli.utils.optimizer_class.SGD.make"
2633

2734
@staticmethod

0 commit comments

Comments
 (0)