Skip to content
6 changes: 4 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.8.0a0 (WIP)
Release 2.8.0a1 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -19,7 +19,8 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Do not call ``forward()`` method directly in ``RecurrentPPO``
- Fix RecurrentPPO and MaskablePPO forward and predict do not reshape action before clip it (@immortal-boy)
- Do not call ``forward()`` method directly in ``RecurrentPPO`` (@immortal-boy)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -657,3 +658,4 @@ Contributors:

@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @danielpalen @corentinlger
@immortal-boy
5 changes: 0 additions & 5 deletions sb3_contrib/common/maskable/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,15 @@ def proba_distribution(
return self

def log_prob(self, actions: th.Tensor) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.log_prob(actions)

def entropy(self) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.entropy()

def sample(self) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.sample()

def mode(self) -> th.Tensor:
assert self.distribution is not None, "Must set distribution parameters"
return th.argmax(self.distribution.probs, dim=1)

def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
Expand All @@ -164,7 +160,6 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.
return actions, log_prob

def apply_masking(self, masks: MaybeMasks) -> None:
assert self.distribution is not None, "Must set distribution parameters"
self.distribution.apply_masking(masks)


Expand Down
3 changes: 2 additions & 1 deletion sb3_contrib/common/maskable/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def forward(
distribution.apply_masking(action_masks)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
return actions, values, log_prob

def extract_features( # type: ignore[override]
Expand Down Expand Up @@ -304,7 +305,7 @@ def predict(
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
# Convert to numpy
actions = actions.cpu().numpy() # type: ignore[assignment]
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment, misc]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down
3 changes: 2 additions & 1 deletion sb3_contrib/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def forward(
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc])
return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf)

def get_distribution(
Expand Down Expand Up @@ -412,7 +413,7 @@ def predict(
states = (states[0].cpu().numpy(), states[1].cpu().numpy())

# Convert to numpy
actions = actions.cpu().numpy()
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.8.0a0
2.8.0a1
10 changes: 5 additions & 5 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ def test_distribution_must_be_initialized(self):

DIMS = 2
dist = MaskableCategoricalDistribution(DIMS)
with pytest.raises(AssertionError):
with pytest.raises(AttributeError):
dist.log_prob(th.randint(DIMS - 1, (1, 3)))

with pytest.raises(AssertionError):
with pytest.raises(AttributeError):
dist.entropy()

with pytest.raises(AssertionError):
with pytest.raises(AttributeError):
dist.sample()

with pytest.raises(AssertionError):
with pytest.raises(AttributeError):
dist.mode()

with pytest.raises(AssertionError):
with pytest.raises(AttributeError):
dist.apply_masking(None)

# But now we can
Expand Down
19 changes: 19 additions & 0 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,22 @@ def make_env():
# In CartPole-v1, a non-recurrent policy can easily get >= 450.
# In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
evaluate_policy(model, env, reward_threshold=450)


class MultiDimensionalActionSpaceEnv(gym.Env):
def __init__(self):
self.observation_space = spaces.Box(low=-1, high=1, shape=(10,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)

def reset(self, seed=None, options=None):
super().reset(seed=seed)
return self.observation_space.sample(), {}

def step(self, action):
return self.observation_space.sample(), 1, np.random.rand() > 0.8, False, {}


def test_ppo_multi_dimensional_action_space():
env = MultiDimensionalActionSpaceEnv()
model = RecurrentPPO("MlpLstmPolicy", env, n_steps=64, n_epochs=2).learn(64)
evaluate_policy(model, model.get_env())