1+ """Configuration for Gym environments."""
12from __future__ import annotations
3+
24import dataclasses
35import typing
4- from typing import Optional
6+ from typing import Optional , Union , cast
57
68if typing .TYPE_CHECKING :
79 from stable_baselines3 .common .vec_env import VecEnv
1517
1618@dataclasses .dataclass
1719class Config :
20+ """Configuration for Gym environments."""
21+
1822 _target_ : str = "imitation_cli.utils.environment.Config.make"
1923 env_name : str = MISSING # The environment to train on
2024 n_envs : int = 8 # number of environments in VecEnv
21- parallel : bool = False # Use SubprocVecEnv rather than DummyVecEnv TODO: when setting this to true this is really slow for some reason
25+ # TODO: when setting this to true this is really slow for some reason
26+ parallel : bool = False # Use SubprocVecEnv rather than DummyVecEnv
2227 max_episode_steps : int = MISSING # Set to positive int to limit episode horizons
2328 env_make_kwargs : dict = dataclasses .field (
24- default_factory = dict
29+ default_factory = dict ,
2530 ) # The kwargs passed to `spec.make`.
26- rng : randomness .Config = randomness . Config ()
31+ rng : randomness .Config = MISSING
2732
2833 @staticmethod
29- def make (log_dir : Optional [str ]= None , ** kwargs ) -> VecEnv :
34+ def make (log_dir : Optional [str ] = None , ** kwargs ) -> VecEnv :
3035 from imitation .util import util
3136
3237 return util .make_vec_env (log_dir = log_dir , ** kwargs )
@@ -38,13 +43,24 @@ def make_rollout_venv(environment_config: Config) -> VecEnv:
3843 return call (
3944 environment_config ,
4045 log_dir = None ,
41- post_wrappers = [lambda env , i : wrappers .RolloutInfoWrapper (env )]
46+ post_wrappers = [lambda env , i : wrappers .RolloutInfoWrapper (env )],
4247 )
4348
4449
45- def register_configs (group : str ):
50+ def register_configs (
51+ group : str ,
52+ default_rng : Union [randomness .Config , str ] = MISSING ,
53+ ):
54+ default_rng = cast (randomness .Config , default_rng )
4655 cs = ConfigStore .instance ()
47- cs .store (group = group , name = "gym_env" , node = Config )
48- cs .store (group = group , name = "cartpole" , node = Config (env_name = "CartPole-v0" , max_episode_steps = 500 ))
49- cs .store (group = group , name = "pendulum" , node = Config (env_name = "Pendulum-v1" , max_episode_steps = 500 ))
50-
56+ cs .store (group = group , name = "gym_env" , node = Config (rng = default_rng ))
57+ cs .store (
58+ group = group ,
59+ name = "cartpole" ,
60+ node = Config (env_name = "CartPole-v0" , max_episode_steps = 500 , rng = default_rng ),
61+ )
62+ cs .store (
63+ group = group ,
64+ name = "pendulum" ,
65+ node = Config (env_name = "Pendulum-v1" , max_episode_steps = 500 , rng = default_rng ),
66+ )
0 commit comments