Skip to content

Commit d8a1f77

Browse files
authored
Merge branch 'master' into hybrid_PPO
2 parents b419142 + 124f167 commit d8a1f77

File tree

8 files changed

+32
-10
lines changed

8 files changed

+32
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
22

3-
[![CI](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/workflows/CI/badge.svg)](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions) [![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
3+
[![CI](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions/workflows/ci.yml/badge.svg)](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/actions/workflows/ci.yml) [![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
44

55
# Stable-Baselines3 - Contrib (SB3-Contrib)
66

docs/misc/changelog.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,28 @@
33
Changelog
44
==========
55

6+
Release 2.7.1a3 (WIP)
7+
--------------------------
8+
9+
Breaking Changes:
10+
^^^^^^^^^^^^^^^^^
11+
12+
New Features:
13+
^^^^^^^^^^^^^
14+
15+
Bug Fixes:
16+
^^^^^^^^^^
17+
- Fix tensorboard log name for ``MaskablePPO``
18+
19+
Deprecations:
20+
^^^^^^^^^^^^^
21+
22+
Others:
23+
^^^^^^^
24+
25+
Documentation:
26+
^^^^^^^^^^^^^^
27+
628
Release 2.7.0 (2025-07-25)
729
--------------------------
830

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ exclude = """(?x)(
3535

3636
[tool.pytest.ini_options]
3737
# Deterministic ordering for tests; useful for pytest-xdist.
38-
env = ["PYTHONHASHSEED=0"]
38+
# env = ["PYTHONHASHSEED=0"]
3939

4040
filterwarnings = [
4141
# Tensorboard warnings

sb3_contrib/common/envs/invalid_actions_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
dim = 1
2121
assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}"
2222

23-
space = spaces.Discrete(dim)
23+
space = spaces.Discrete(dim) # type: ignore[var-annotated]
2424
self.n_invalid_actions = n_invalid_actions
2525
self.possible_actions = np.arange(space.n, dtype=int)
2626
self.invalid_actions: list[int] = []

sb3_contrib/common/recurrent/policies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def predict_values(
296296
features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor)
297297

298298
if self.lstm_critic is not None:
299-
latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
299+
latent_vf, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic)
300300
elif self.shared_lstm:
301301
# Use LSTM from the actor
302302
latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor)

sb3_contrib/ppo_mask/ppo_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def collect_rollouts(
221221
while n_steps < n_rollout_steps:
222222
with th.no_grad():
223223
# Convert to pytorch tensor or to TensorDict
224-
obs_tensor = obs_as_tensor(self._last_obs, self.device)
224+
obs_tensor = obs_as_tensor(self._last_obs, self.device) # type: ignore[arg-type]
225225

226226
# This is the only change related to invalid action masking
227227
if use_masking:
@@ -431,7 +431,7 @@ def learn( # type: ignore[override]
431431
total_timesteps: int,
432432
callback: MaybeCallback = None,
433433
log_interval: int = 1,
434-
tb_log_name: str = "PPO",
434+
tb_log_name: str = "MaskablePPO",
435435
reset_num_timesteps: bool = True,
436436
use_masking: bool = True,
437437
progress_bar: bool = False,

sb3_contrib/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.7.0
1+
2.7.1a3

tests/wrappers/test_action_masker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class IdentityEnvDiscrete(IdentityEnv):
9-
def __init__(self, dim: int = 1, ep_length: int = 100):
9+
def __init__(self, dim=1, ep_length=100):
1010
"""
1111
Identity environment for testing purposes
1212
@@ -17,12 +17,12 @@ def __init__(self, dim: int = 1, ep_length: int = 100):
1717
self.useless_property = 1
1818
super().__init__(ep_length=ep_length, space=space)
1919

20-
def _action_masks(self) -> list[int]:
20+
def _action_masks(self): # -> list[bool]
2121
assert isinstance(self.action_space, spaces.Discrete)
2222
return [i == self.state for i in range(self.action_space.n)]
2323

2424

25-
def action_mask_fn(env: IdentityEnvDiscrete) -> list[int]:
25+
def action_mask_fn(env): # -> list[int]
2626
assert isinstance(env.action_space, spaces.Discrete)
2727
return [i == env.state for i in range(env.action_space.n)]
2828

0 commit comments

Comments
 (0)