Skip to content

Conversation

@araffin
Copy link
Member

@araffin araffin commented Dec 15, 2025

Description

closes #2137

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes issues with saving and loading jitted (compiled) models and streamlines the loading process by utilizing PyTorch's built-in capabilities. The changes address issue #2137 by properly handling torch.compile wrapped models and simplifying the zip file loading logic.

Key changes:

  • Modified get_parameters() to extract state dicts from the original model when dealing with compiled models
  • Simplified load_from_zip_file() to use PyTorch's direct loading with weights_only=True
  • Added test coverage for saving/loading compiled models and reduced training timesteps across tests

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
stable_baselines3/common/base_class.py Updated get_parameters() to handle compiled models by extracting state dict from _orig_mod
stable_baselines3/common/save_util.py Simplified zip file loading by removing BytesIO workaround and enabling weights_only=True
tests/test_save_load.py Added test for compiled model save/load, reduced timesteps, and added PPO-specific parameters

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +170 to +174
# Check that loading after compiling works, see GH#2137
model = model_class.load(tmp_path / "test_save.zip")
model.policy = th.compile(model.policy)
model.save(tmp_path / "test_save.zip")
model_class.load(tmp_path / "test_save.zip")
Copy link

Copilot AI Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test verifies that loading works after saving a compiled model, but doesn't verify the loaded model's behavior. Consider asserting that the loaded model produces the same predictions or has the same state dict as before saving.

Copilot uses AI. Check for mistakes.
@araffin araffin marked this pull request as ready for review December 17, 2025 14:11
@araffin araffin merged commit 8fccf7f into master Dec 18, 2025
4 checks passed
@araffin araffin deleted the chore/simplify-save branch December 18, 2025 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: PyTorch 2.7 compile results in bad state_dict keys and policy fails to load

5 participants