Skip to content
3 changes: 2 additions & 1 deletion sb3_contrib/common/maskable/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def forward(
distribution.apply_masking(action_masks)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
return actions, values, log_prob

def extract_features( # type: ignore[override]
Expand Down Expand Up @@ -304,7 +305,7 @@ def predict(
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
# Convert to numpy
actions = actions.cpu().numpy() # type: ignore[assignment]
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down
3 changes: 2 additions & 1 deletion sb3_contrib/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def forward(
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc])
return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf)

def get_distribution(
Expand Down Expand Up @@ -413,7 +414,7 @@ def predict(
states = (states[0].cpu().numpy(), states[1].cpu().numpy())

# Convert to numpy
actions = actions.cpu().numpy()
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down