Skip to content

Commit 216d757

Browse files
committed
Remove cast to long
1 parent 9264123 commit 216d757

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

stable_baselines3/a2c/a2c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def train(self) -> None:
144144
for rollout_data in self.rollout_buffer.get(batch_size=None):
145145
actions = rollout_data.actions
146146
if isinstance(self.action_space, spaces.Discrete):
147-
# Convert discrete action from float to long
148-
actions = actions.long().flatten()
147+
# Flatten discrete actions for correct computation of log prob
148+
actions = actions.flatten()
149149

150150
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
151151
values = values.flatten()

stable_baselines3/ppo/ppo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,8 @@ def train(self) -> None:
207207
for rollout_data in self.rollout_buffer.get(self.batch_size):
208208
actions = rollout_data.actions
209209
if isinstance(self.action_space, spaces.Discrete):
210-
# Convert discrete action from float to long
211-
actions = rollout_data.actions.long().flatten()
212-
210+
# Flatten discrete actions for correct computation of log prob
211+
actions = rollout_data.actions.flatten()
213212
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
214213
values = values.flatten()
215214
# Normalize advantage

0 commit comments

Comments
 (0)