46
46
from torch import nn as nn
47
47
48
48
# Register custom envs
49
- import rl_zoo3 .import_envs # noqa: F401 pytype: disable=import-error
49
+ import rl_zoo3 .import_envs # noqa: F401
50
50
from rl_zoo3 .callbacks import SaveVecNormalizeCallback , TrialEvalCallback
51
51
from rl_zoo3 .hyperparams_opt import HYPERPARAMS_SAMPLER
52
52
from rl_zoo3 .utils import ALGOS , get_callback_list , get_class_by_name , get_latest_run_id , get_wrapper_class , linear_schedule
@@ -116,13 +116,13 @@ def __init__(
116
116
self .n_timesteps = n_timesteps
117
117
self .normalize = False
118
118
self .normalize_kwargs : Dict [str , Any ] = {}
119
- self .env_wrapper = None
119
+ self .env_wrapper : Optional [ Callable ] = None
120
120
self .frame_stack = None
121
121
self .seed = seed
122
122
self .optimization_log_path = optimization_log_path
123
123
124
124
self .vec_env_class = {"dummy" : DummyVecEnv , "subproc" : SubprocVecEnv }[vec_env_type ]
125
- self .vec_env_wrapper = None
125
+ self .vec_env_wrapper : Optional [ Callable ] = None
126
126
127
127
self .vec_env_kwargs : Dict [str , Any ] = {}
128
128
# self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}
@@ -138,7 +138,7 @@ def __init__(
138
138
self .n_eval_envs = n_eval_envs
139
139
140
140
self .n_envs = 1 # it will be updated when reading hyperparams
141
- self .n_actions = None # For DDPG/TD3 action noise objects
141
+ self .n_actions = 0 # For DDPG/TD3 action noise objects
142
142
self ._hyperparams : Dict [str , Any ] = {}
143
143
self .monitor_kwargs : Dict [str , Any ] = {}
144
144
@@ -186,8 +186,10 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
186
186
187
187
:return: the initialized RL model
188
188
"""
189
- hyperparams , saved_hyperparams = self .read_hyperparameters ()
190
- hyperparams , self .env_wrapper , self .callbacks , self .vec_env_wrapper = self ._preprocess_hyperparams (hyperparams )
189
+ unprocessed_hyperparams , saved_hyperparams = self .read_hyperparameters ()
190
+ hyperparams , self .env_wrapper , self .callbacks , self .vec_env_wrapper = self ._preprocess_hyperparams (
191
+ unprocessed_hyperparams
192
+ )
191
193
192
194
self .create_log_folder ()
193
195
self .create_callbacks ()
@@ -221,7 +223,7 @@ def learn(self, model: BaseAlgorithm) -> None:
221
223
"""
222
224
:param model: an initialized RL model
223
225
"""
224
- kwargs = {}
226
+ kwargs : Dict [ str , Any ] = {}
225
227
if self .log_interval > - 1 :
226
228
kwargs = {"log_interval" : self .log_interval }
227
229
@@ -245,6 +247,7 @@ def learn(self, model: BaseAlgorithm) -> None:
245
247
self .callbacks [0 ].on_training_end ()
246
248
# Release resources
247
249
try :
250
+ assert model .env is not None
248
251
model .env .close ()
249
252
except EOFError :
250
253
pass
@@ -265,7 +268,9 @@ def save_trained_model(self, model: BaseAlgorithm) -> None:
265
268
266
269
if self .normalize :
267
270
# Important: save the running average, for testing the agent we need that normalization
268
- model .get_vec_normalize_env ().save (os .path .join (self .params_path , "vecnormalize.pkl" ))
271
+ vec_normalize = model .get_vec_normalize_env ()
272
+ assert vec_normalize is not None
273
+ vec_normalize .save (os .path .join (self .params_path , "vecnormalize.pkl" ))
269
274
270
275
def _save_config (self , saved_hyperparams : Dict [str , Any ]) -> None :
271
276
"""
@@ -293,7 +298,7 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
293
298
with open (self .config ) as f :
294
299
hyperparams_dict = yaml .safe_load (f )
295
300
elif self .config .endswith (".py" ):
296
- global_variables = {}
301
+ global_variables : Dict = {}
297
302
# Load hyperparameters from python file
298
303
exec (Path (self .config ).read_text (), global_variables )
299
304
hyperparams_dict = global_variables ["hyperparams" ]
@@ -452,6 +457,9 @@ def _preprocess_action_noise(
452
457
noise_std = hyperparams ["noise_std" ]
453
458
454
459
# Save for later (hyperparameter optimization)
460
+ assert isinstance (
461
+ env .action_space , spaces .Box
462
+ ), f"Action noise can only be used with Box action space, not { env .action_space } "
455
463
self .n_actions = env .action_space .shape [0 ]
456
464
457
465
if "normal" in noise_type :
@@ -516,7 +524,7 @@ def create_callbacks(self):
516
524
517
525
@staticmethod
518
526
def entry_point (env_id : str ) -> str :
519
- return str (gym .envs .registry [env_id ].entry_point ) # pytype: disable=module-attr
527
+ return str (gym .envs .registry [env_id ].entry_point )
520
528
521
529
@staticmethod
522
530
def is_atari (env_id : str ) -> bool :
@@ -618,7 +626,7 @@ def make_env(**kwargs) -> gym.Env:
618
626
env_kwargs = env_kwargs ,
619
627
monitor_dir = log_dir ,
620
628
wrapper_class = self .env_wrapper ,
621
- vec_env_cls = self .vec_env_class ,
629
+ vec_env_cls = self .vec_env_class , # type: ignore[arg-type]
622
630
vec_env_kwargs = self .vec_env_kwargs ,
623
631
monitor_kwargs = self .monitor_kwargs ,
624
632
)
@@ -645,11 +653,11 @@ def make_env(**kwargs) -> gym.Env:
645
653
# the other channel last); VecTransposeImage will throw an error
646
654
for space in env .observation_space .spaces .values ():
647
655
wrap_with_vectranspose = wrap_with_vectranspose or (
648
- is_image_space (space ) and not is_image_space_channels_first (space )
656
+ is_image_space (space ) and not is_image_space_channels_first (space ) # type: ignore[arg-type]
649
657
)
650
658
else :
651
659
wrap_with_vectranspose = is_image_space (env .observation_space ) and not is_image_space_channels_first (
652
- env .observation_space
660
+ env .observation_space # type: ignore[arg-type]
653
661
)
654
662
655
663
if wrap_with_vectranspose :
@@ -683,13 +691,16 @@ def _load_pretrained_agent(self, hyperparams: Dict[str, Any], env: VecEnv) -> Ba
683
691
if os .path .exists (replay_buffer_path ):
684
692
print ("Loading replay buffer" )
685
693
# `truncate_last_traj` will be taken into account only if we use HER replay buffer
694
+ assert hasattr (
695
+ model , "load_replay_buffer"
696
+ ), "The current model doesn't have a `load_replay_buffer` to load the replay buffer"
686
697
model .load_replay_buffer (replay_buffer_path , truncate_last_traj = self .truncate_last_trajectory )
687
698
return model
688
699
689
700
def _create_sampler (self , sampler_method : str ) -> BaseSampler :
690
701
# n_warmup_steps: Disable pruner until the trial reaches the given number of steps.
691
702
if sampler_method == "random" :
692
- sampler = RandomSampler (seed = self .seed )
703
+ sampler : BaseSampler = RandomSampler (seed = self .seed )
693
704
elif sampler_method == "tpe" :
694
705
sampler = TPESampler (n_startup_trials = self .n_startup_trials , seed = self .seed , multivariate = True )
695
706
elif sampler_method == "skopt" :
@@ -705,7 +716,7 @@ def _create_sampler(self, sampler_method: str) -> BaseSampler:
705
716
706
717
def _create_pruner (self , pruner_method : str ) -> BasePruner :
707
718
if pruner_method == "halving" :
708
- pruner = SuccessiveHalvingPruner (min_resource = 1 , reduction_factor = 4 , min_early_stopping_rate = 0 )
719
+ pruner : BasePruner = SuccessiveHalvingPruner (min_resource = 1 , reduction_factor = 4 , min_early_stopping_rate = 0 )
709
720
elif pruner_method == "median" :
710
721
pruner = MedianPruner (n_startup_trials = self .n_startup_trials , n_warmup_steps = self .n_evaluations // 3 )
711
722
elif pruner_method == "none" :
@@ -718,17 +729,17 @@ def _create_pruner(self, pruner_method: str) -> BasePruner:
718
729
def objective (self , trial : optuna .Trial ) -> float :
719
730
kwargs = self ._hyperparams .copy ()
720
731
721
- # Hack to use DDPG/TD3 noise sampler
722
- trial .n_actions = self .n_actions
723
- # Hack when using HerReplayBuffer
724
- trial .using_her_replay_buffer = kwargs .get ("replay_buffer_class" ) == HerReplayBuffer
725
- if trial .using_her_replay_buffer :
726
- trial .her_kwargs = kwargs .get ("replay_buffer_kwargs" , {})
732
+ n_envs = 1 if self .algo == "ars" else self .n_envs
733
+
734
+ additional_args = {
735
+ "using_her_replay_buffer" : kwargs .get ("replay_buffer_class" ) == HerReplayBuffer ,
736
+ "her_kwargs" : kwargs .get ("replay_buffer_kwargs" , {}),
737
+ }
738
+ # Pass n_actions to initialize DDPG/TD3 noise sampler
727
739
# Sample candidate hyperparameters
728
- sampled_hyperparams = HYPERPARAMS_SAMPLER [self .algo ](trial )
740
+ sampled_hyperparams = HYPERPARAMS_SAMPLER [self .algo ](trial , self . n_actions , n_envs , additional_args )
729
741
kwargs .update (sampled_hyperparams )
730
742
731
- n_envs = 1 if self .algo == "ars" else self .n_envs
732
743
env = self .create_envs (n_envs , no_log = True )
733
744
734
745
# By default, do not activate verbose output to keep
@@ -778,13 +789,15 @@ def objective(self, trial: optuna.Trial) -> float:
778
789
)
779
790
780
791
try :
781
- model .learn (self .n_timesteps , callback = callbacks , ** learn_kwargs )
792
+ model .learn (self .n_timesteps , callback = callbacks , ** learn_kwargs ) # type: ignore[arg-type]
782
793
# Free memory
794
+ assert model .env is not None
783
795
model .env .close ()
784
796
eval_env .close ()
785
797
except (AssertionError , ValueError ) as e :
786
798
# Sometimes, random hyperparams can generate NaN
787
799
# Free memory
800
+ assert model .env is not None
788
801
model .env .close ()
789
802
eval_env .close ()
790
803
# Prune hyperparams that generate NaNs
0 commit comments