Skip to content

Commit 8fccf7f

Browse files
authored
Fix saving jitted model and simplify load (#2205)
* chore: simplify load for newer PyTorch versions * fix: save original model state dict when compiling policy * Faster save/load tests * Add test for saving/loading after JIT * Update changelog and version
1 parent c6ce50f commit 8fccf7f

File tree

5 files changed

+39
-31
lines changed

5 files changed

+39
-31
lines changed

docs/misc/changelog.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 2.8.0a1 (WIP)
7+
Release 2.8.0a2 (WIP)
88
--------------------------
99

1010
Breaking Changes:
@@ -18,6 +18,7 @@ New Features:
1818

1919
Bug Fixes:
2020
^^^^^^^^^^
21+
- Fixed saving and loading of Torch compiled models (using ``th.compile()``) by updating ``get_parameters()``
2122

2223
`SB3-Contrib`_
2324
^^^^^^^^^^^^^^
@@ -35,8 +36,10 @@ Deprecations:
3536
Others:
3637
^^^^^^^
3738
- Updated to Python 3.10+ annotations
38-
- Remove some unused variables (@unexploredtest)
39-
- Improve type hints for distributions
39+
- Removed some unused variables (@unexploredtest)
40+
- Improved type hints for distributions
41+
- Simplified zip file loading by removing Python 3.6 workaround and enabling ``weights_only=True`` (PyTorch 2.x)
42+
- Sped up saving/loading tests
4043

4144
Documentation:
4245
^^^^^^^^^^^^^^

stable_baselines3/common/base_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -812,8 +812,8 @@ def get_parameters(self) -> dict[str, dict]:
812812
params = {}
813813
for name in state_dicts_names:
814814
attr = recursive_getattr(self, name)
815-
# Retrieve state dict
816-
params[name] = attr.state_dict()
815+
# Retrieve state dict, and from the original model if compiled (see GH#2137)
816+
params[name] = getattr(attr, "_orig_mod", attr).state_dict()
817817
return params
818818

819819
def save(

stable_baselines3/common/save_util.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -439,16 +439,7 @@ def load_from_zip_file(
439439
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
440440
for file_path in pth_files:
441441
with archive.open(file_path, mode="r") as param_file:
442-
# File has to be seekable, but param_file is not, so load in BytesIO first
443-
# fixed in python >= 3.7
444-
file_content = io.BytesIO()
445-
file_content.write(param_file.read())
446-
# go to start of file
447-
file_content.seek(0)
448-
# Load the parameters with the right ``map_location``.
449-
# Remove ".pth" ending with splitext
450-
# Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911
451-
th_object = th.load(file_content, map_location=device, weights_only=False)
442+
th_object = th.load(param_file, map_location=device, weights_only=True)
452443
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
453444
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
454445
# PyTorch variables (not state_dicts)

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.8.0a1
1+
2.8.0a2

tests/test_save_load.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,13 @@ def test_save_load(tmp_path, model_class):
4848

4949
env = DummyVecEnv([lambda: select_env(model_class)])
5050

51+
kwargs = {}
52+
if model_class == PPO:
53+
kwargs = {"n_steps": 64, "n_epochs": 4}
54+
5155
# create model
52-
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
53-
model.learn(total_timesteps=500)
56+
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
57+
model.learn(total_timesteps=150)
5458

5559
env.reset()
5660
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@@ -159,10 +163,16 @@ def test_save_load(tmp_path, model_class):
159163
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
160164

161165
# check if learn still works
162-
model.learn(total_timesteps=500)
166+
model.learn(total_timesteps=150)
163167

164168
del model
165169

170+
# Check that loading after compiling works, see GH#2137
171+
model = model_class.load(tmp_path / "test_save.zip")
172+
model.policy = th.compile(model.policy)
173+
model.save(tmp_path / "test_save.zip")
174+
model_class.load(tmp_path / "test_save.zip")
175+
166176
# clear file from os
167177
os.remove(tmp_path / "test_save.zip")
168178

@@ -284,8 +294,8 @@ def test_exclude_include_saved_params(tmp_path, model_class):
284294

285295

286296
def test_save_load_pytorch_var(tmp_path):
287-
model = SAC("MlpPolicy", "Pendulum-v1", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
288-
model.learn(200)
297+
model = SAC("MlpPolicy", "Pendulum-v1", learning_starts=10, seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
298+
model.learn(110)
289299
save_path = str(tmp_path / "sac_pendulum")
290300
model.save(save_path)
291301
env = model.get_env()
@@ -295,14 +305,14 @@ def test_save_load_pytorch_var(tmp_path):
295305

296306
model = SAC.load(save_path, env=env)
297307
assert th.allclose(log_ent_coef_before, model.log_ent_coef)
298-
model.learn(200)
308+
model.learn(50)
299309
log_ent_coef_after = model.log_ent_coef
300310
# Check that the entropy coefficient is still optimized
301311
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)
302312

303313
# With a fixed entropy coef
304314
model = SAC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
305-
model.learn(200)
315+
model.learn(110)
306316
save_path = str(tmp_path / "sac_pendulum")
307317
model.save(save_path)
308318
env = model.get_env()
@@ -313,7 +323,7 @@ def test_save_load_pytorch_var(tmp_path):
313323

314324
model = SAC.load(save_path, env=env)
315325
assert th.allclose(ent_coef_before, model.ent_coef_tensor)
316-
model.learn(200)
326+
model.learn(50)
317327
ent_coef_after = model.ent_coef_tensor
318328
assert model.log_ent_coef is None
319329
# Check that the entropy coefficient is still the same
@@ -354,9 +364,9 @@ def test_save_load_replay_buffer(tmp_path, model_class):
354364
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
355365
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
356366
model = model_class(
357-
"MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=200
367+
"MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=100
358368
)
359-
model.learn(300)
369+
model.learn(150)
360370
old_replay_buffer = deepcopy(model.replay_buffer)
361371
model.save_replay_buffer(path)
362372
model.replay_buffer = None
@@ -410,14 +420,14 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
410420
learning_starts=10,
411421
)
412422

413-
model.learn(150)
423+
model.learn(50)
414424

415-
model.learn(150, reset_num_timesteps=False)
425+
model.learn(50, reset_num_timesteps=False)
416426

417427
# Check that there is no warning
418428
assert len(recwarn) == 0
419429

420-
model.learn(150)
430+
model.learn(50)
421431

422432
if optimize_memory_usage:
423433
assert len(recwarn) == 1
@@ -439,6 +449,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
439449
"""
440450
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
441451

452+
if model_class == PPO:
453+
kwargs["n_steps"] = 64
454+
kwargs["n_epochs"] = 2
455+
442456
# gSDE is only applicable for A2C, PPO and SAC
443457
if use_sde and model_class not in [A2C, PPO, SAC]:
444458
pytest.skip()
@@ -461,7 +475,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
461475

462476
# create model
463477
model = model_class(policy_str, env, verbose=1, **kwargs)
464-
model.learn(total_timesteps=300)
478+
model.learn(total_timesteps=150)
465479

466480
env.reset()
467481
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
@@ -556,7 +570,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
556570

557571
# create model
558572
model = model_class(policy_str, env, verbose=1, **kwargs)
559-
model.learn(total_timesteps=300)
573+
model.learn(total_timesteps=150)
560574

561575
env.reset()
562576
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)

0 commit comments

Comments
 (0)