Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights only param #1902

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/guide/save_format.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The
is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files
are stored under a single .zip archive.

Note that if you use unsafe objects in your torch model, ``torch.load()`` will raise an unpickling error. You can
use the "weights_only" argument to adjust whether or not to load unsafe objects using Pickle, but it will issue
a warning if set to False. (e.g.: if learning_rate_schedule contains the scalar np.pi, it will raise an error without
the "weights_only" argument set to False)

Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded
string in the JSON file, along with some information that was stored in the serialization. This allows
inspecting stored objects without deserializing the object itself.
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger)
- This can be overriden using the ``weights_only`` boolean argument in the ``load()`` method in sb3, which will be passed to ``torch.load()`` (@markscsmith)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -1593,4 +1594,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah
@marekm4 @stagoverflow @rushitnshah @markscsmith
15 changes: 10 additions & 5 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def set_parameters(
load_path_or_dict: Union[str, TensorDict],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
weights_only: bool = True,
) -> None:
"""
Load parameters from a given zip-file or a nested dictionary containing parameters for
Expand All @@ -587,12 +588,15 @@ def set_parameters(
module and each of their parameters, otherwise raises an Exception. If set to False, this
can be used to update only specific parameters.
:param device: Device on which the code should run.
:param weights_only: Set torch weights_only for passthrough into load function.
WARNING: weights_only=True to avoid posisble arbitrary code execution!
See https://pytorch.org/docs/stable/generated/torch.load.html
"""
params = {}
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
_, params, _ = load_from_zip_file(load_path_or_dict, device=device, weights_only=weights_only)

# Keep track which objects were updated.
# `_get_torch_save_params` returns [params, other_pytorch_variables].
Expand Down Expand Up @@ -647,6 +651,7 @@ def load( # noqa: C901
custom_objects: Optional[Dict[str, Any]] = None,
print_system_info: bool = False,
force_reset: bool = True,
weights_only: bool = True,
**kwargs,
) -> SelfBaseAlgorithm:
"""
Expand All @@ -670,6 +675,9 @@ def load( # noqa: C901
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See https://github.com/DLR-RM/stable-baselines3/issues/597
:param weights_only: Set torch weights_only for passthrough into load function.
WARNING: weights_only=True to avoid posisble arbitrary code execution!
See https://pytorch.org/docs/stable/generated/torch.load.html
:param kwargs: extra arguments to change the model when loading
:return: new model instance with loaded parameters
"""
Expand All @@ -678,10 +686,7 @@ def load( # noqa: C901
get_system_info()

data, params, pytorch_variables = load_from_zip_file(
path,
device=device,
custom_objects=custom_objects,
print_system_info=print_system_info,
path, device=device, custom_objects=custom_objects, print_system_info=print_system_info, weights_only=weights_only
)

assert data is not None, "No data found in the saved file"
Expand Down
11 changes: 9 additions & 2 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def load_from_zip_file(
device: Union[th.device, str] = "auto",
verbose: int = 0,
print_system_info: bool = False,
weights_only: bool = True,
) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]:
"""
Load model data from a .zip archive
Expand All @@ -397,6 +398,9 @@ def load_from_zip_file(
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param print_system_info: Whether to print or not the system info
about the saved model.
:param weights_only: Set torch weights_only for passthrough into load function.
WARNING: weights_only=True to avoid posisble arbitrary code execution!
See https://pytorch.org/docs/stable/generated/torch.load.html
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
Expand Down Expand Up @@ -426,7 +430,10 @@ def load_from_zip_file(
"The model was saved with SB3 <= 1.2.0 and thus cannot print system information.",
UserWarning,
)

if weights_only is False:
warnings.warn(
"Unpickling unsafe objects! Loading full state_dict. See pytorch docs on torch.load for more info."
)
if "data" in namelist and load_data:
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
Expand All @@ -447,7 +454,7 @@ def load_from_zip_file(
file_content.seek(0)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
th_object = th.load(file_content, map_location=device, weights_only=True)
th_object = th.load(file_content, map_location=device, weights_only=weights_only)
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
# PyTorch variables (not state_dicts)
Expand Down
47 changes: 47 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import base64
import io
import json
import math
import os
import pathlib
import pickle
import tempfile
import warnings
import zipfile
Expand Down Expand Up @@ -739,6 +741,51 @@ def test_load_invalid_object(tmp_path):
assert len(record) == 0


def test_load_torch_weights_only(tmp_path):
# Test loading only the torch weights
path = str(tmp_path / "ppo_pendulum.zip")
model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0)
model.save(path)
# Load with custom object, no warnings
with warnings.catch_warnings(record=True) as record:
model.load(path, weights_only=False)
assert len(record) == 1

# Load only the weights from a valid model
with warnings.catch_warnings(record=True) as record:
model.load(path, weights_only=True)
assert len(record) == 0

model = PPO(
policy="MlpPolicy",
env="Pendulum-v1",
learning_rate=math.sin(1),
)
model.save(path)
with warnings.catch_warnings(record=True) as record:
model.load(path, weights_only=True)
assert len(record) == 0

# Causes pickle error due to numpy scalars in the learning rate schedule:
# _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`
# will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a
# trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar

model = PPO(
policy="MlpPolicy",
env="Pendulum-v1",
learning_rate=lambda _: np.sin(1),
)
model.save(path)

with pytest.raises(pickle.UnpicklingError) as record:
model.load(path, weights_only=True)

with warnings.catch_warnings(record=True) as record:
model.load(path, weights_only=False)
assert len(record) == 1


def test_dqn_target_update_interval(tmp_path):
# `target_update_interval` should not change when reloading the model. See GH Issue #1373.
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
Expand Down