@@ -48,9 +48,13 @@ def test_save_load(tmp_path, model_class):
4848
4949 env = DummyVecEnv ([lambda : select_env (model_class )])
5050
51+ kwargs = {}
52+ if model_class == PPO :
53+ kwargs = {"n_steps" : 64 , "n_epochs" : 4 }
54+
5155 # create model
52- model = model_class ("MlpPolicy" , env , policy_kwargs = dict (net_arch = [16 ]), verbose = 1 )
53- model .learn (total_timesteps = 500 )
56+ model = model_class ("MlpPolicy" , env , policy_kwargs = dict (net_arch = [16 ]), verbose = 1 , ** kwargs )
57+ model .learn (total_timesteps = 150 )
5458
5559 env .reset ()
5660 observations = np .concatenate ([env .step ([env .action_space .sample ()])[0 ] for _ in range (10 )], axis = 0 )
@@ -159,10 +163,16 @@ def test_save_load(tmp_path, model_class):
159163 assert np .allclose (selected_actions , new_selected_actions , 1e-4 )
160164
161165 # check if learn still works
162- model .learn (total_timesteps = 500 )
166+ model .learn (total_timesteps = 150 )
163167
164168 del model
165169
170+ # Check that loading after compiling works, see GH#2137
171+ model = model_class .load (tmp_path / "test_save.zip" )
172+ model .policy = th .compile (model .policy )
173+ model .save (tmp_path / "test_save.zip" )
174+ model_class .load (tmp_path / "test_save.zip" )
175+
166176 # clear file from os
167177 os .remove (tmp_path / "test_save.zip" )
168178
@@ -284,8 +294,8 @@ def test_exclude_include_saved_params(tmp_path, model_class):
284294
285295
286296def test_save_load_pytorch_var (tmp_path ):
287- model = SAC ("MlpPolicy" , "Pendulum-v1" , seed = 3 , policy_kwargs = dict (net_arch = [64 ], n_critics = 1 ))
288- model .learn (200 )
297+ model = SAC ("MlpPolicy" , "Pendulum-v1" , learning_starts = 10 , seed = 3 , policy_kwargs = dict (net_arch = [64 ], n_critics = 1 ))
298+ model .learn (110 )
289299 save_path = str (tmp_path / "sac_pendulum" )
290300 model .save (save_path )
291301 env = model .get_env ()
@@ -295,14 +305,14 @@ def test_save_load_pytorch_var(tmp_path):
295305
296306 model = SAC .load (save_path , env = env )
297307 assert th .allclose (log_ent_coef_before , model .log_ent_coef )
298- model .learn (200 )
308+ model .learn (50 )
299309 log_ent_coef_after = model .log_ent_coef
300310 # Check that the entropy coefficient is still optimized
301311 assert not th .allclose (log_ent_coef_before , log_ent_coef_after )
302312
303313 # With a fixed entropy coef
304314 model = SAC ("MlpPolicy" , "Pendulum-v1" , seed = 3 , ent_coef = 0.01 , policy_kwargs = dict (net_arch = [64 ], n_critics = 1 ))
305- model .learn (200 )
315+ model .learn (110 )
306316 save_path = str (tmp_path / "sac_pendulum" )
307317 model .save (save_path )
308318 env = model .get_env ()
@@ -313,7 +323,7 @@ def test_save_load_pytorch_var(tmp_path):
313323
314324 model = SAC .load (save_path , env = env )
315325 assert th .allclose (ent_coef_before , model .ent_coef_tensor )
316- model .learn (200 )
326+ model .learn (50 )
317327 ent_coef_after = model .ent_coef_tensor
318328 assert model .log_ent_coef is None
319329 # Check that the entropy coefficient is still the same
@@ -354,9 +364,9 @@ def test_save_load_replay_buffer(tmp_path, model_class):
354364 path = pathlib .Path (tmp_path / "logs/replay_buffer.pkl" )
355365 path .parent .mkdir (exist_ok = True , parents = True ) # to not raise a warning
356366 model = model_class (
357- "MlpPolicy" , select_env (model_class ), buffer_size = 1000 , policy_kwargs = dict (net_arch = [64 ]), learning_starts = 200
367+ "MlpPolicy" , select_env (model_class ), buffer_size = 1000 , policy_kwargs = dict (net_arch = [64 ]), learning_starts = 100
358368 )
359- model .learn (300 )
369+ model .learn (150 )
360370 old_replay_buffer = deepcopy (model .replay_buffer )
361371 model .save_replay_buffer (path )
362372 model .replay_buffer = None
@@ -410,14 +420,14 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
410420 learning_starts = 10 ,
411421 )
412422
413- model .learn (150 )
423+ model .learn (50 )
414424
415- model .learn (150 , reset_num_timesteps = False )
425+ model .learn (50 , reset_num_timesteps = False )
416426
417427 # Check that there is no warning
418428 assert len (recwarn ) == 0
419429
420- model .learn (150 )
430+ model .learn (50 )
421431
422432 if optimize_memory_usage :
423433 assert len (recwarn ) == 1
@@ -439,6 +449,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
439449 """
440450 kwargs = dict (policy_kwargs = dict (net_arch = [16 ]))
441451
452+ if model_class == PPO :
453+ kwargs ["n_steps" ] = 64
454+ kwargs ["n_epochs" ] = 2
455+
442456 # gSDE is only applicable for A2C, PPO and SAC
443457 if use_sde and model_class not in [A2C , PPO , SAC ]:
444458 pytest .skip ()
@@ -461,7 +475,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
461475
462476 # create model
463477 model = model_class (policy_str , env , verbose = 1 , ** kwargs )
464- model .learn (total_timesteps = 300 )
478+ model .learn (total_timesteps = 150 )
465479
466480 env .reset ()
467481 observations = np .concatenate ([env .step ([env .action_space .sample ()])[0 ] for _ in range (10 )], axis = 0 )
@@ -556,7 +570,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
556570
557571 # create model
558572 model = model_class (policy_str , env , verbose = 1 , ** kwargs )
559- model .learn (total_timesteps = 300 )
573+ model .learn (total_timesteps = 150 )
560574
561575 env .reset ()
562576 observations = np .concatenate ([env .step ([env .action_space .sample ()])[0 ] for _ in range (10 )], axis = 0 )
0 commit comments