Skip to content

Commit

Permalink
Use has_attr for detecting masking support, fixes several issues (#276
Browse files Browse the repository at this point in the history
)

* Update SB3 dependency

* Use `has_attr` for detecting masking, fixes several issues
  • Loading branch information
araffin authored Feb 4, 2025
1 parent c070fc2 commit e0986a2
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 60 deletions.
66 changes: 18 additions & 48 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@
Changelog
==========

Release 2.5.0 (2025-01-27)

Release 2.6.0a0 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to PyTorch 2.3.0
- Dropped Python 3.8 support
- Upgraded to Stable-Baselines3 >= 2.5.0
- Upgraded to Stable-Baselines3 >= 2.6.0

New Features:
^^^^^^^^^^^^^
- Added Python 3.12 support
- Added Numpy v2.0 support

Bug Fixes:
^^^^^^^^^^
- Fixed issues with ``SubprocVecEnv`` and ``MaskablePPO`` by using ``vec_env.has_attr()`` (pickling issues, mask function not present)

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -29,6 +27,20 @@ Others:
Documentation:
^^^^^^^^^^^^^^

Release 2.5.0 (2025-01-27)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to PyTorch 2.3.0
- Dropped Python 3.8 support
- Upgraded to Stable-Baselines3 >= 2.5.0

New Features:
^^^^^^^^^^^^^
- Added Python 3.12 support
- Added Numpy v2.0 support

Release 2.4.0 (2024-11-18)
--------------------------

Expand All @@ -51,18 +63,12 @@ Bug Fixes:
- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Updated PyTorch version on CI to 2.3.1
- Remove unnecessary SDE noise resampling in PPO/TRPO update
- Switched to uv to download packages on GitHub CI

Documentation:
^^^^^^^^^^^^^^


Release 2.3.0 (2024-03-31)
--------------------------
Expand All @@ -88,13 +94,6 @@ New Features:
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
- Log success rate ``rollout/success_rate`` when available for on policy algorithms


Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
Expand All @@ -121,20 +120,11 @@ New Features:
- Added ``set_options`` for ``AsyncEval``
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)

Documentation:
^^^^^^^^^^^^^^


Release 2.1.0 (2023-08-17)
--------------------------
Expand All @@ -153,14 +143,6 @@ Bug Fixes:
^^^^^^^^^^
- Fixed MaskablePPO ignoring ``stats_window_size`` argument

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 2.0.0 (2023-06-22)
Expand All @@ -179,15 +161,11 @@ Breaking Changes:
- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package (@carlosluis, @arjun-kg, @tlpss)
- Upgraded to Stable-Baselines3 >= 2.0.0

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- Fixed QRDQN update interval for multi envs

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
Expand Down Expand Up @@ -220,11 +198,6 @@ New Features:
^^^^^^^^^^^^^
- Added ``stats_window_size`` argument to control smoothing in rollout logging (@jonasreiher)

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
Expand Down Expand Up @@ -300,9 +273,6 @@ New Features:
^^^^^^^^^^^^^
- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^
- Deprecate parameters ``eval_env``, ``eval_freq`` and ``create_eval_env``
Expand Down
7 changes: 1 addition & 6 deletions sb3_contrib/common/maskable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@ def is_masking_supported(env: GymEnv) -> bool:
"""

if isinstance(env, VecEnv):
try:
# TODO: add VecEnv.has_attr()
env.get_attr(EXPECTED_METHOD_NAME)
return True
except AttributeError:
return False
return env.has_attr(EXPECTED_METHOD_NAME)
else:
try:
env.get_wrapper_attr(EXPECTED_METHOD_NAME)
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0
2.6.0a0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.5.0,<3.0",
"stable_baselines3>=2.6.0a0,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
10 changes: 6 additions & 4 deletions tests/test_invalid_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

from sb3_contrib import MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete, InvalidActionEnvMultiBinary, InvalidActionEnvMultiDiscrete
Expand Down Expand Up @@ -151,18 +152,19 @@ def test_masked_evaluation():
assert masked_avg_rew > unmasked_avg_rew


def test_supports_multi_envs():
@pytest.mark.parametrize("vec_env_cls", [SubprocVecEnv, DummyVecEnv])
def test_supports_multi_envs(vec_env_cls):
"""
Learning and evaluation works with VecEnvs
"""

env = make_vec_env(make_env, n_envs=2)
env = make_vec_env(make_env, n_envs=2, vec_env_cls=vec_env_cls)
assert is_masking_supported(env)
model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1)
model.learn(100)
evaluate_policy(model, env, warn=False)

env = make_vec_env(IdentityEnv, n_envs=2, env_kwargs={"dim": 2})
env = make_vec_env(IdentityEnv, n_envs=2, env_kwargs={"dim": 2}, vec_env_cls=vec_env_cls)
assert not is_masking_supported(env)
model = MaskablePPO("MlpPolicy", env, n_steps=256, gamma=0.4, seed=32, verbose=1)
with pytest.raises(ValueError):
Expand Down Expand Up @@ -224,7 +226,7 @@ def test_discrete_action_space_required():
"""

env = IdentityEnvBox()
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match="The algorithm only supports"):
MaskablePPO("MlpPolicy", env)


Expand Down

0 comments on commit e0986a2

Please sign in to comment.