Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.8.0a1 (WIP)
Release 2.8.0a2 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -18,6 +18,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Fixed saving and loading of Torch compiled models (using ``th.compile()``) by updating ``get_parameters()``

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand All @@ -35,8 +36,10 @@ Deprecations:
Others:
^^^^^^^
- Updated to Python 3.10+ annotations
- Remove some unused variables (@unexploredtest)
- Improve type hints for distributions
- Removed some unused variables (@unexploredtest)
- Improved type hints for distributions
- Simplified zip file loading by removing Python 3.6 workaround and enabling ``weights_only=True`` (PyTorch 2.x)
- Sped up saving/loading tests

Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,8 @@ def get_parameters(self) -> dict[str, dict]:
params = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict
params[name] = attr.state_dict()
# Retrieve state dict, and from the original model if compiled (see GH#2137)
params[name] = getattr(attr, "_orig_mod", attr).state_dict()
return params

def save(
Expand Down
11 changes: 1 addition & 10 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,16 +439,7 @@ def load_from_zip_file(
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
for file_path in pth_files:
with archive.open(file_path, mode="r") as param_file:
# File has to be seekable, but param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(param_file.read())
# go to start of file
file_content.seek(0)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
# Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911
th_object = th.load(file_content, map_location=device, weights_only=False)
th_object = th.load(param_file, map_location=device, weights_only=True)
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
# PyTorch variables (not state_dicts)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.8.0a1
2.8.0a2
44 changes: 29 additions & 15 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def test_save_load(tmp_path, model_class):

env = DummyVecEnv([lambda: select_env(model_class)])

kwargs = {}
if model_class == PPO:
kwargs = {"n_steps": 64, "n_epochs": 4}

# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500)
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=150)

env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
Expand Down Expand Up @@ -159,10 +163,16 @@ def test_save_load(tmp_path, model_class):
assert np.allclose(selected_actions, new_selected_actions, 1e-4)

# check if learn still works
model.learn(total_timesteps=500)
model.learn(total_timesteps=150)

del model

# Check that loading after compiling works, see GH#2137
model = model_class.load(tmp_path / "test_save.zip")
model.policy = th.compile(model.policy)
model.save(tmp_path / "test_save.zip")
model_class.load(tmp_path / "test_save.zip")
Comment on lines +170 to +174
Copy link

Copilot AI Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test verifies that loading works after saving a compiled model, but doesn't verify the loaded model's behavior. Consider asserting that the loaded model produces the same predictions or has the same state dict as before saving.

Copilot uses AI. Check for mistakes.

# clear file from os
os.remove(tmp_path / "test_save.zip")

Expand Down Expand Up @@ -284,8 +294,8 @@ def test_exclude_include_saved_params(tmp_path, model_class):


def test_save_load_pytorch_var(tmp_path):
model = SAC("MlpPolicy", "Pendulum-v1", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
model = SAC("MlpPolicy", "Pendulum-v1", learning_starts=10, seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(110)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
Expand All @@ -295,14 +305,14 @@ def test_save_load_pytorch_var(tmp_path):

model = SAC.load(save_path, env=env)
assert th.allclose(log_ent_coef_before, model.log_ent_coef)
model.learn(200)
model.learn(50)
log_ent_coef_after = model.log_ent_coef
# Check that the entropy coefficient is still optimized
assert not th.allclose(log_ent_coef_before, log_ent_coef_after)

# With a fixed entropy coef
model = SAC("MlpPolicy", "Pendulum-v1", seed=3, ent_coef=0.01, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
model.learn(110)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
Expand All @@ -313,7 +323,7 @@ def test_save_load_pytorch_var(tmp_path):

model = SAC.load(save_path, env=env)
assert th.allclose(ent_coef_before, model.ent_coef_tensor)
model.learn(200)
model.learn(50)
ent_coef_after = model.ent_coef_tensor
assert model.log_ent_coef is None
# Check that the entropy coefficient is still the same
Expand Down Expand Up @@ -354,9 +364,9 @@ def test_save_load_replay_buffer(tmp_path, model_class):
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
model = model_class(
"MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=200
"MlpPolicy", select_env(model_class), buffer_size=1000, policy_kwargs=dict(net_arch=[64]), learning_starts=100
)
model.learn(300)
model.learn(150)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(path)
model.replay_buffer = None
Expand Down Expand Up @@ -410,14 +420,14 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
learning_starts=10,
)

model.learn(150)
model.learn(50)

model.learn(150, reset_num_timesteps=False)
model.learn(50, reset_num_timesteps=False)

# Check that there is no warning
assert len(recwarn) == 0

model.learn(150)
model.learn(50)

if optimize_memory_usage:
assert len(recwarn) == 1
Expand All @@ -439,6 +449,10 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))

if model_class == PPO:
kwargs["n_steps"] = 64
kwargs["n_epochs"] = 2

# gSDE is only applicable for A2C, PPO and SAC
if use_sde and model_class not in [A2C, PPO, SAC]:
pytest.skip()
Expand All @@ -461,7 +475,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):

# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
model.learn(total_timesteps=150)

env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
Expand Down Expand Up @@ -556,7 +570,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):

# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
model.learn(total_timesteps=150)

env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
Expand Down