From 04a6e8097c0d474a4671d794d6e197957e6f1b35 Mon Sep 17 00:00:00 2001 From: mscs Date: Wed, 17 Apr 2024 22:48:47 -0400 Subject: [PATCH 1/9] add weights_only param for passthrough. --- stable_baselines3/common/base_class.py | 15 ++++++++++----- stable_baselines3/common/save_util.py | 10 +++++++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e6c7d3cfc..ad250f5c4 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -575,6 +575,7 @@ def set_parameters( load_path_or_dict: Union[str, TensorDict], exact_match: bool = True, device: Union[th.device, str] = "auto", + weights_only: bool = True, ) -> None: """ Load parameters from a given zip-file or a nested dictionary containing parameters for @@ -587,12 +588,15 @@ def set_parameters( module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters. :param device: Device on which the code should run. + :param weights_only: Set torch weights_only for passthrough into load function. + WARNING: weights_only=True to avoid posisble arbitrary code execution! + See https://pytorch.org/docs/stable/generated/torch.load.html """ params = {} if isinstance(load_path_or_dict, dict): params = load_path_or_dict else: - _, params, _ = load_from_zip_file(load_path_or_dict, device=device) + _, params, _ = load_from_zip_file(load_path_or_dict, device=device, weights_only=weights_only) # Keep track which objects were updated. # `_get_torch_save_params` returns [params, other_pytorch_variables]. @@ -647,6 +651,7 @@ def load( # noqa: C901 custom_objects: Optional[Dict[str, Any]] = None, print_system_info: bool = False, force_reset: bool = True, + weights_only: bool = True, **kwargs, ) -> SelfBaseAlgorithm: """ @@ -670,6 +675,9 @@ def load( # noqa: C901 :param force_reset: Force call to ``reset()`` before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597 + :param weights_only: Set torch weights_only for passthrough into load function. + WARNING: weights_only=True to avoid posisble arbitrary code execution! + See https://pytorch.org/docs/stable/generated/torch.load.html :param kwargs: extra arguments to change the model when loading :return: new model instance with loaded parameters """ @@ -678,10 +686,7 @@ def load( # noqa: C901 get_system_info() data, params, pytorch_variables = load_from_zip_file( - path, - device=device, - custom_objects=custom_objects, - print_system_info=print_system_info, + path, device=device, custom_objects=custom_objects, print_system_info=print_system_info, weights_only=weights_only ) assert data is not None, "No data found in the saved file" diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 2d8652006..20eb81366 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -380,6 +380,7 @@ def load_from_zip_file( device: Union[th.device, str] = "auto", verbose: int = 0, print_system_info: bool = False, + weights_only: bool = True, ) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]: """ Load model data from a .zip archive @@ -397,6 +398,9 @@ def load_from_zip_file( :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages :param print_system_info: Whether to print or not the system info about the saved model. + :param weights_only: Set torch weights_only for passthrough into load function. + WARNING: weights_only=True to avoid posisble arbitrary code execution! + See https://pytorch.org/docs/stable/generated/torch.load.html :return: Class parameters, model state_dicts (aka "params", dict of state_dict) and dict of pytorch variables """ @@ -447,7 +451,11 @@ def load_from_zip_file( file_content.seek(0) # Load the parameters with the right ``map_location``. # Remove ".pth" ending with splitext - th_object = th.load(file_content, map_location=device, weights_only=True) + if weights_only is False: + warnings.warn( + f"Unpickling unsafe objects! Loading full state_dict from {file_path}. See pytorch docs on torch.load for more info." + ) + th_object = th.load(file_content, map_location=device, weights_only=weights_only) # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) From e528cc910acba3d1c3f6a61f3cf11f1b6d1d9271 Mon Sep 17 00:00:00 2001 From: mscs Date: Wed, 17 Apr 2024 23:20:48 -0400 Subject: [PATCH 2/9] Add tests for warning for unsafe object loading in load_from_zip_file function, fix duplicate warning discovered by test --- stable_baselines3/common/save_util.py | 9 ++++----- tests/test_save_load.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 20eb81366..5432b4d41 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -430,7 +430,10 @@ def load_from_zip_file( "The model was saved with SB3 <= 1.2.0 and thus cannot print system information.", UserWarning, ) - + if weights_only is False: + warnings.warn( + "Unpickling unsafe objects! Loading full state_dict. See pytorch docs on torch.load for more info." + ) if "data" in namelist and load_data: # Load class parameters that are stored # with either JSON or pickle (not PyTorch variables). @@ -451,10 +454,6 @@ def load_from_zip_file( file_content.seek(0) # Load the parameters with the right ``map_location``. # Remove ".pth" ending with splitext - if weights_only is False: - warnings.warn( - f"Unpickling unsafe objects! Loading full state_dict from {file_path}. See pytorch docs on torch.load for more info." - ) th_object = th.load(file_content, map_location=device, weights_only=weights_only) # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": diff --git a/tests/test_save_load.py b/tests/test_save_load.py index e7123e984..cbe241919 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -738,6 +738,16 @@ def test_load_invalid_object(tmp_path): PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0)) assert len(record) == 0 +def test_load_torch_weights_only(tmp_path): + # Test loading only the torch weights + path = str(tmp_path / "ppo_pendulum.zip") + model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0) + model.learn(1) + model.save(path) + # Load with custom object, no warnings + with warnings.catch_warnings(record=True) as record: + model.load(path, custom_objects=dict(learning_rate=lambda _: 1.0), weights_only=False) + assert len(record) == 1 def test_dqn_target_update_interval(tmp_path): # `target_update_interval` should not change when reloading the model. See GH Issue #1373. From 62d26b19a2a4dff82c0928bc30457fa6d22edfd9 Mon Sep 17 00:00:00 2001 From: mscs Date: Thu, 18 Apr 2024 00:43:03 -0400 Subject: [PATCH 3/9] Add gymnasium.spaces import in test_save_load.py --- tests/test_save_load.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index cbe241919..747cbe590 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -10,6 +10,7 @@ from copy import deepcopy import gymnasium as gym +from gymnasium.spaces import Box, Discrete import numpy as np import pytest import torch as th @@ -742,13 +743,20 @@ def test_load_torch_weights_only(tmp_path): # Test loading only the torch weights path = str(tmp_path / "ppo_pendulum.zip") model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0) - model.learn(1) + model.learn(10) model.save(path) # Load with custom object, no warnings with warnings.catch_warnings(record=True) as record: model.load(path, custom_objects=dict(learning_rate=lambda _: 1.0), weights_only=False) assert len(record) == 1 + # Load only the weights from a valid model + with warnings.catch_warnings(record=True) as record: + model.load(path, weights_only=True) + assert len(record) == 0 + + # TODO: Negative test case. I can cause this to fail with a saved model. Need to understand how / why. + def test_dqn_target_update_interval(tmp_path): # `target_update_interval` should not change when reloading the model. See GH Issue #1373. env = make_vec_env(env_id="CartPole-v1", n_envs=2) From 2d041331d4fd1677cf15cb2dd8b7823d0cc7206d Mon Sep 17 00:00:00 2001 From: Mark S C Smith Date: Thu, 18 Apr 2024 11:01:57 -0400 Subject: [PATCH 4/9] Reproduce issue loading model with weights_only=True when using a learning_rate_schedule with numpy scalars in the function --- tests/test_save_load.py | 44 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 747cbe590..de014ca2c 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -2,10 +2,13 @@ import io import json import os +import math +import pickle import pathlib import tempfile import warnings import zipfile + from collections import OrderedDict from copy import deepcopy @@ -739,6 +742,7 @@ def test_load_invalid_object(tmp_path): PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0)) assert len(record) == 0 + def test_load_torch_weights_only(tmp_path): # Test loading only the torch weights path = str(tmp_path / "ppo_pendulum.zip") @@ -755,7 +759,45 @@ def test_load_torch_weights_only(tmp_path): model.load(path, weights_only=True) assert len(record) == 0 - # TODO: Negative test case. I can cause this to fail with a saved model. Need to understand how / why. + # No pickle error. + def learning_rate_schedule(progress): + rate = 0.0003 + variation = 0.2 * rate * progress + new_rate = rate + variation * math.sin(progress * math.pi * 20) # positive and negative adjustments + return new_rate + + model = PPO( + policy="MlpPolicy", + env="Pendulum-v1", + learning_rate=learning_rate_schedule, + ) + model.save(path) + with warnings.catch_warnings(record=True) as record: + model.load(path, weights_only=True) + assert len(record) == 0 + + # Causes pickle error due to numpy scalars in the learning rate schedule + # _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar + def learning_rate_schedule(progress): + rate = 0.0003 + variation = 0.2 * rate * progress + new_rate = rate + variation * np.sin(progress * np.pi * 20) + return new_rate + + model = PPO( + policy="MlpPolicy", + env="Pendulum-v1", + learning_rate=learning_rate_schedule, + ) + model.save(path) + + with pytest.raises(pickle.UnpicklingError) as record: + model.load(path, weights_only=True) + + with warnings.catch_warnings(record=True) as record: + model.load(path, weights_only=False) + assert len(record) == 1 + def test_dqn_target_update_interval(tmp_path): # `target_update_interval` should not change when reloading the model. See GH Issue #1373. From 1bc204d7f940d1b61146f4f574baad5538916b39 Mon Sep 17 00:00:00 2001 From: Mark S C Smith Date: Thu, 18 Apr 2024 11:15:54 -0400 Subject: [PATCH 5/9] formatting and commit-checks cleanup --- stable_baselines3/common/save_util.py | 6 +++--- tests/test_save_load.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 5432b4d41..2601c358f 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -431,9 +431,9 @@ def load_from_zip_file( UserWarning, ) if weights_only is False: - warnings.warn( - "Unpickling unsafe objects! Loading full state_dict. See pytorch docs on torch.load for more info." - ) + warnings.warn( + "Unpickling unsafe objects! Loading full state_dict. See pytorch docs on torch.load for more info." + ) if "data" in namelist and load_data: # Load class parameters that are stored # with either JSON or pickle (not PyTorch variables). diff --git a/tests/test_save_load.py b/tests/test_save_load.py index de014ca2c..764144615 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,19 +1,17 @@ import base64 import io import json -import os import math -import pickle +import os import pathlib +import pickle import tempfile import warnings import zipfile - from collections import OrderedDict from copy import deepcopy import gymnasium as gym -from gymnasium.spaces import Box, Discrete import numpy as np import pytest import torch as th @@ -776,8 +774,10 @@ def learning_rate_schedule(progress): model.load(path, weights_only=True) assert len(record) == 0 - # Causes pickle error due to numpy scalars in the learning rate schedule - # _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar + # Causes pickle error due to numpy scalars in the learning rate schedule: + # _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` + # will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a + # trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar def learning_rate_schedule(progress): rate = 0.0003 variation = 0.2 * rate * progress From 88907ceb9496b5564ecb3a0bc971ac905cee3dd5 Mon Sep 17 00:00:00 2001 From: Mark S C Smith Date: Thu, 18 Apr 2024 11:28:53 -0400 Subject: [PATCH 6/9] Update documentation to reflect new parameter --- docs/guide/save_format.rst | 5 +++++ docs/misc/changelog.rst | 1 + 2 files changed, 6 insertions(+) diff --git a/docs/guide/save_format.rst b/docs/guide/save_format.rst index 8bd9aa8ea..0917a1505 100644 --- a/docs/guide/save_format.rst +++ b/docs/guide/save_format.rst @@ -24,6 +24,11 @@ A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files are stored under a single .zip archive. +Note that if you use unsafe objects in your torch model, ``torch.load()`` will raise an unpickling error. You can +use the "weights_only" argument to adjust whether or not to load unsafe objects using Pickle, but it will issue +a warning if set to False. (e.g.: if learning_rate_schedule contains the scalar np.pi, it will raise an error without +the "weights_only" argument set to False) + Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded string in the JSON file, along with some information that was stored in the serialization. This allows inspecting stored objects without deserializing the object itself. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c1560201c..498922a28 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -38,6 +38,7 @@ Breaking Changes: - For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors, policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work + This can be overriden using the ``weights_only`` boolean argument in the ``load()`` method in sb3, which will be passed to ``torch.load()`` - When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub, as ``pickle.load`` is not safe. From c4fc775e69153ca4d63e765f554758a648cb4af3 Mon Sep 17 00:00:00 2001 From: mscs Date: Thu, 18 Apr 2024 19:44:55 -0400 Subject: [PATCH 7/9] Simplify testcase --- tests/test_save_load.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 764144615..d39b67561 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -745,11 +745,10 @@ def test_load_torch_weights_only(tmp_path): # Test loading only the torch weights path = str(tmp_path / "ppo_pendulum.zip") model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0) - model.learn(10) model.save(path) # Load with custom object, no warnings with warnings.catch_warnings(record=True) as record: - model.load(path, custom_objects=dict(learning_rate=lambda _: 1.0), weights_only=False) + model.load(path, weights_only=False) assert len(record) == 1 # Load only the weights from a valid model @@ -757,17 +756,10 @@ def test_load_torch_weights_only(tmp_path): model.load(path, weights_only=True) assert len(record) == 0 - # No pickle error. - def learning_rate_schedule(progress): - rate = 0.0003 - variation = 0.2 * rate * progress - new_rate = rate + variation * math.sin(progress * math.pi * 20) # positive and negative adjustments - return new_rate - model = PPO( policy="MlpPolicy", env="Pendulum-v1", - learning_rate=learning_rate_schedule, + learning_rate=math.sin(1), ) model.save(path) with warnings.catch_warnings(record=True) as record: @@ -778,16 +770,12 @@ def learning_rate_schedule(progress): # _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` # will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a # trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar - def learning_rate_schedule(progress): - rate = 0.0003 - variation = 0.2 * rate * progress - new_rate = rate + variation * np.sin(progress * np.pi * 20) - return new_rate + model = PPO( policy="MlpPolicy", env="Pendulum-v1", - learning_rate=learning_rate_schedule, + learning_rate=lambda _: np.sin(1), ) model.save(path) From b7d82b63091751c2a5a0f2b1b3f07f7e77188763 Mon Sep 17 00:00:00 2001 From: mscs Date: Thu, 18 Apr 2024 21:21:50 -0400 Subject: [PATCH 8/9] formatting fix --- tests/test_save_load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index d39b67561..4ca5ed563 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -771,7 +771,6 @@ def test_load_torch_weights_only(tmp_path): # will likely succeed, but it can result in arbitrary code execution.Do it only if you get the file from a # trusted source. WeightsUnpickler error: Unsupported class numpy.core.multiarray.scalar - model = PPO( policy="MlpPolicy", env="Pendulum-v1", From 25a38fc1f3d7f23788c5dce380509f4e190d308e Mon Sep 17 00:00:00 2001 From: mscs Date: Thu, 18 Apr 2024 21:25:19 -0400 Subject: [PATCH 9/9] Update changelog with weights_only details --- docs/misc/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 498922a28..1d069b3ae 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -38,13 +38,13 @@ Breaking Changes: - For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors, policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work - This can be overriden using the ``weights_only`` boolean argument in the ``load()`` method in sb3, which will be passed to ``torch.load()`` - When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub, as ``pickle.load`` is not safe. New Features: ^^^^^^^^^^^^^ - Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger) +- This can be overriden using the ``weights_only`` boolean argument in the ``load()`` method in sb3, which will be passed to ``torch.load()`` (@markscsmith) Bug Fixes: ^^^^^^^^^^ @@ -1594,4 +1594,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah +@marekm4 @stagoverflow @rushitnshah @markscsmith