Skip to content

Commit ccbc8b0

Browse files
committed
Simplify the class configurations using enums.
1 parent f8892e7 commit ccbc8b0

File tree

5 files changed

+54
-96
lines changed

5 files changed

+54
-96
lines changed

src/imitation_cli/algorithm_configurations/airl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Config:
2424
reward_net: reward_network.Config = MISSING
2525
demo_batch_size: int = 64
2626
n_disc_updates_per_round: int = 2
27-
disc_opt_cls: optimizer_class.Config = optimizer_class.Adam()
27+
disc_opt_cls: optimizer_class.Config = optimizer_class.Adam
2828
gen_train_timesteps: Optional[int] = None
2929
gen_replay_buffer_capacity: Optional[int] = None
3030
init_tensorboard: bool = False
Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,37 @@
11
"""Classes for configuring activation functions."""
22
import dataclasses
3+
from enum import Enum
34

5+
import torch
46
from hydra.core.config_store import ConfigStore
57

68

7-
@dataclasses.dataclass
8-
class Config:
9-
"""Base class for activation function configs."""
10-
11-
# Note: we don't define _target_ here so in the subclasses it can be defined last.
12-
# This is the same pattern we use as in schedule.py.
13-
pass
14-
15-
16-
@dataclasses.dataclass
17-
class TanH(Config):
18-
"""Config for TanH activation function."""
19-
20-
_target_: str = "imitation_cli.utils.activation_function_class.TanH.make"
21-
22-
@staticmethod
23-
def make() -> type:
24-
import torch
9+
class ActivationFunctionClass(Enum):
10+
"""Enum of activation function classes."""
2511

26-
return torch.nn.Tanh
12+
TanH = torch.nn.Tanh
13+
ReLU = torch.nn.ReLU
14+
LeakyReLU = torch.nn.LeakyReLU
2715

2816

2917
@dataclasses.dataclass
30-
class ReLU(Config):
31-
"""Config for ReLU activation function."""
18+
class Config:
19+
"""Base class for activation function configs."""
3220

33-
_target_: str = "imitation_cli.utils.activation_function_class.ReLU.make"
21+
activation_function_class: ActivationFunctionClass
22+
_target_: str = "imitation_cli.utils.activation_function_class.Config.make"
3423

3524
@staticmethod
36-
def make() -> type:
37-
import torch
38-
39-
return torch.nn.ReLU
40-
25+
def make(activation_function_class: ActivationFunctionClass) -> type:
26+
return activation_function_class.value
4127

42-
@dataclasses.dataclass
43-
class LeakyReLU(Config):
44-
"""Config for LeakyReLU activation function."""
45-
46-
_target_: str = "imitation_cli.utils.activation_function_class.LeakyReLU.make"
47-
48-
@staticmethod
49-
def make() -> type:
50-
import torch
5128

52-
return torch.nn.LeakyReLU
29+
TanH = Config(ActivationFunctionClass.TanH)
30+
ReLU = Config(ActivationFunctionClass.ReLU)
31+
LeakyReLU = Config(ActivationFunctionClass.LeakyReLU)
5332

5433

5534
def register_configs(group: str):
5635
cs = ConfigStore.instance()
57-
cs.store(group=group, name="tanh", node=TanH)
58-
cs.store(group=group, name="relu", node=ReLU)
59-
cs.store(group=group, name="leaky_relu", node=LeakyReLU)
36+
for cls in ActivationFunctionClass:
37+
cs.store(group=group, name=cls.name.lower(), node=Config(cls))
Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,35 @@
11
"""Register Hydra configs for stable_baselines3 feature extractors."""
22
import dataclasses
3+
from enum import Enum
34

5+
import stable_baselines3.common.torch_layers as torch_layers
46
from hydra.core.config_store import ConfigStore
5-
from omegaconf import MISSING
67

78

8-
@dataclasses.dataclass
9-
class Config:
10-
"""Base config for stable_baselines3 feature extractors."""
9+
class FeatureExtractorClass(Enum):
10+
"""Enum of feature extractor classes."""
1111

12-
_target_: str = MISSING
12+
FlattenExtractor = torch_layers.FlattenExtractor
13+
NatureCNN = torch_layers.NatureCNN
1314

1415

1516
@dataclasses.dataclass
16-
class FlattenExtractorConfig(Config):
17-
"""Config for FlattenExtractor."""
17+
class Config:
18+
"""Base config for stable_baselines3 feature extractors."""
1819

19-
_target_: str = (
20-
"imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make"
21-
)
20+
feature_extractor_class: FeatureExtractorClass
21+
_target_: str = "imitation_cli.utils.feature_extractor_class.Config.make"
2222

2323
@staticmethod
24-
def make() -> type:
25-
import stable_baselines3
26-
27-
return stable_baselines3.common.torch_layers.FlattenExtractor
24+
def make(feature_extractor_class: FeatureExtractorClass) -> type:
25+
return feature_extractor_class.value
2826

2927

30-
@dataclasses.dataclass
31-
class NatureCNNConfig(Config):
32-
"""Config for NatureCNN."""
33-
34-
_target_: str = "imitation_cli.utils.feature_extractor_class.NatureCNNConfig.make"
35-
36-
@staticmethod
37-
def make() -> type:
38-
import stable_baselines3
39-
40-
return stable_baselines3.common.torch_layers.NatureCNN
28+
FlattenExtractor = Config(FeatureExtractorClass.FlattenExtractor)
29+
NatureCNN = Config(FeatureExtractorClass.NatureCNN)
4130

4231

4332
def register_configs(group: str):
4433
cs = ConfigStore.instance()
45-
cs.store(group=group, name="flatten", node=FlattenExtractorConfig)
46-
cs.store(group=group, name="nature_cnn", node=NatureCNNConfig)
34+
for cls in FeatureExtractorClass:
35+
cs.store(group=group, name=cls.name.lower(), node=Config(cls))
Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,35 @@
11
"""Register optimizer classes with Hydra."""
22
import dataclasses
3+
from enum import Enum
34

5+
import torch
46
from hydra.core.config_store import ConfigStore
5-
from omegaconf import MISSING
67

78

8-
@dataclasses.dataclass
9-
class Config:
10-
"""Base config for optimizer classes."""
9+
class OptimizerClass(Enum):
10+
"""Enum of optimizer classes."""
1111

12-
_target_: str = MISSING
12+
Adam = torch.optim.Adam
13+
SGD = torch.optim.SGD
1314

1415

1516
@dataclasses.dataclass
16-
class Adam(Config):
17-
"""Config for Adam optimizer class."""
17+
class Config:
18+
"""Base config for optimizer classes."""
1819

19-
_target_: str = "imitation_cli.utils.optimizer_class.Adam.make"
20+
optimizer_class: OptimizerClass
21+
_target_: str = "imitation_cli.utils.optimizer_class.Config.make"
2022

2123
@staticmethod
22-
def make() -> type:
23-
import torch
24-
25-
return torch.optim.Adam
24+
def make(optimizer_class: OptimizerClass) -> type:
25+
return optimizer_class.value
2626

2727

28-
@dataclasses.dataclass
29-
class SGD(Config):
30-
"""Config for SGD optimizer class."""
31-
32-
_target_: str = "imitation_cli.utils.optimizer_class.SGD.make"
33-
34-
@staticmethod
35-
def make() -> type:
36-
import torch
37-
38-
return torch.optim.SGD
28+
Adam = Config(OptimizerClass.Adam)
29+
SGD = Config(OptimizerClass.SGD)
3930

4031

4132
def register_configs(group: str):
4233
cs = ConfigStore.instance()
43-
cs.store(group=group, name="adam", node=Adam)
44-
cs.store(group=group, name="sgd", node=SGD)
34+
for cls in OptimizerClass:
35+
cs.store(group=group, name=cls.name.lower(), node=Config(cls))

src/imitation_cli/utils/policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,20 @@ class ActorCriticPolicy(Config):
6565
_target_: str = "imitation_cli.utils.policy.ActorCriticPolicy.make"
6666
lr_schedule: schedule.Config = schedule.FixedSchedule(3e-4)
6767
net_arch: Optional[Dict[str, List[int]]] = None
68-
activation_fn: act_fun_class_cfg.Config = act_fun_class_cfg.TanH()
68+
activation_fn: act_fun_class_cfg.Config = act_fun_class_cfg.TanH
6969
ortho_init: bool = True
7070
use_sde: bool = False
7171
log_std_init: float = 0.0
7272
full_std: bool = True
7373
use_expln: bool = False
7474
squash_output: bool = False
7575
features_extractor_class: feature_extractor_class_cfg.Config = (
76-
feature_extractor_class_cfg.FlattenExtractorConfig()
76+
feature_extractor_class_cfg.FlattenExtractor
7777
)
7878
features_extractor_kwargs: Optional[Dict[str, Any]] = None
7979
share_features_extractor: bool = True
8080
normalize_images: bool = True
81-
optimizer_class: optimizer_class_cfg.Config = optimizer_class_cfg.Adam()
81+
optimizer_class: optimizer_class_cfg.Config = optimizer_class_cfg.Adam
8282
optimizer_kwargs: Optional[Dict[str, Any]] = None
8383

8484
@staticmethod

0 commit comments

Comments
 (0)