Skip to content

Commit e03aa1e

Browse files
Add support for multi-dimensional action spaces to PPO variants (#318)
* Fix RecurrentPPO and MaskablePPO forward and predict do not reshape action before clip it (#317) * test: add unit test for multi-dimensional action spaces to PPO variants * reformat tests/test_lstm.py * make type * update changelog * Update test and version * Asserts are not needed anymore * Update changelog.rst * Fix tests --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 06c6e62 commit e03aa1e

File tree

7 files changed

+33
-15
lines changed

7 files changed

+33
-15
lines changed

docs/misc/changelog.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.8.0a0 (WIP)
6+
Release 2.8.0a1 (WIP)
77
--------------------------
88

99
Breaking Changes:
@@ -19,7 +19,8 @@ New Features:
1919

2020
Bug Fixes:
2121
^^^^^^^^^^
22-
- Do not call ``forward()`` method directly in ``RecurrentPPO``
22+
- Fix RecurrentPPO and MaskablePPO forward and predict do not reshape action before clip it (@immortal-boy)
23+
- Do not call ``forward()`` method directly in ``RecurrentPPO`` (@immortal-boy)
2324

2425
Deprecations:
2526
^^^^^^^^^^^^^
@@ -657,3 +658,4 @@ Contributors:
657658

658659
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
659660
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @danielpalen @corentinlger
661+
@immortal-boy

sb3_contrib/common/maskable/distributions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,15 @@ def proba_distribution(
138138
return self
139139

140140
def log_prob(self, actions: th.Tensor) -> th.Tensor:
141-
assert self.distribution is not None, "Must set distribution parameters"
142141
return self.distribution.log_prob(actions)
143142

144143
def entropy(self) -> th.Tensor:
145-
assert self.distribution is not None, "Must set distribution parameters"
146144
return self.distribution.entropy()
147145

148146
def sample(self) -> th.Tensor:
149-
assert self.distribution is not None, "Must set distribution parameters"
150147
return self.distribution.sample()
151148

152149
def mode(self) -> th.Tensor:
153-
assert self.distribution is not None, "Must set distribution parameters"
154150
return th.argmax(self.distribution.probs, dim=1)
155151

156152
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
@@ -164,7 +160,6 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.
164160
return actions, log_prob
165161

166162
def apply_masking(self, masks: MaybeMasks) -> None:
167-
assert self.distribution is not None, "Must set distribution parameters"
168163
self.distribution.apply_masking(masks)
169164

170165

sb3_contrib/common/maskable/policies.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def forward(
139139
distribution.apply_masking(action_masks)
140140
actions = distribution.get_actions(deterministic=deterministic)
141141
log_prob = distribution.log_prob(actions)
142+
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
142143
return actions, values, log_prob
143144

144145
def extract_features( # type: ignore[override]
@@ -304,7 +305,7 @@ def predict(
304305
with th.no_grad():
305306
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
306307
# Convert to numpy
307-
actions = actions.cpu().numpy() # type: ignore[assignment]
308+
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment, misc]
308309

309310
if isinstance(self.action_space, spaces.Box):
310311
if self.squash_output:

sb3_contrib/common/recurrent/policies.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def forward(
253253
distribution = self._get_action_dist_from_latent(latent_pi)
254254
actions = distribution.get_actions(deterministic=deterministic)
255255
log_prob = distribution.log_prob(actions)
256+
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc])
256257
return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf)
257258

258259
def get_distribution(
@@ -412,7 +413,7 @@ def predict(
412413
states = (states[0].cpu().numpy(), states[1].cpu().numpy())
413414

414415
# Convert to numpy
415-
actions = actions.cpu().numpy()
416+
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment]
416417

417418
if isinstance(self.action_space, spaces.Box):
418419
if self.squash_output:

sb3_contrib/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.8.0a0
1+
2.8.0a1

tests/test_distributions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,19 @@ def test_distribution_must_be_initialized(self):
8282

8383
DIMS = 2
8484
dist = MaskableCategoricalDistribution(DIMS)
85-
with pytest.raises(AssertionError):
85+
with pytest.raises(AttributeError):
8686
dist.log_prob(th.randint(DIMS - 1, (1, 3)))
8787

88-
with pytest.raises(AssertionError):
88+
with pytest.raises(AttributeError):
8989
dist.entropy()
9090

91-
with pytest.raises(AssertionError):
91+
with pytest.raises(AttributeError):
9292
dist.sample()
9393

94-
with pytest.raises(AssertionError):
94+
with pytest.raises(AttributeError):
9595
dist.mode()
9696

97-
with pytest.raises(AssertionError):
97+
with pytest.raises(AttributeError):
9898
dist.apply_masking(None)
9999

100100
# But now we can

tests/test_lstm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,22 @@ def make_env():
244244
# In CartPole-v1, a non-recurrent policy can easily get >= 450.
245245
# In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
246246
evaluate_policy(model, env, reward_threshold=450)
247+
248+
249+
class MultiDimensionalActionSpaceEnv(gym.Env):
250+
def __init__(self):
251+
self.observation_space = spaces.Box(low=-1, high=1, shape=(10,), dtype=np.float32)
252+
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
253+
254+
def reset(self, seed=None, options=None):
255+
super().reset(seed=seed)
256+
return self.observation_space.sample(), {}
257+
258+
def step(self, action):
259+
return self.observation_space.sample(), 1, np.random.rand() > 0.8, False, {}
260+
261+
262+
def test_ppo_multi_dimensional_action_space():
263+
env = MultiDimensionalActionSpaceEnv()
264+
model = RecurrentPPO("MlpLstmPolicy", env, n_steps=64, n_epochs=2).learn(64)
265+
evaluate_policy(model, model.get_env())

0 commit comments

Comments
 (0)