Skip to content

Commit c2d532c

Browse files
committed
Revert "Remove cast to long"
This reverts commit 216d757.
1 parent 216d757 commit c2d532c

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

stable_baselines3/a2c/a2c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def train(self) -> None:
144144
for rollout_data in self.rollout_buffer.get(batch_size=None):
145145
actions = rollout_data.actions
146146
if isinstance(self.action_space, spaces.Discrete):
147-
# Flatten discrete actions for correct computation of log prob
148-
actions = actions.flatten()
147+
# Convert discrete action from float to long
148+
actions = actions.long().flatten()
149149

150150
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
151151
values = values.flatten()

stable_baselines3/ppo/ppo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ def train(self) -> None:
207207
for rollout_data in self.rollout_buffer.get(self.batch_size):
208208
actions = rollout_data.actions
209209
if isinstance(self.action_space, spaces.Discrete):
210-
# Flatten discrete actions for correct computation of log prob
211-
actions = rollout_data.actions.flatten()
210+
# Convert discrete action from float to long
211+
actions = rollout_data.actions.long().flatten()
212+
212213
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
213214
values = values.flatten()
214215
# Normalize advantage

0 commit comments

Comments
 (0)