File tree Expand file tree Collapse file tree 2 files changed +4
-5
lines changed
Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -144,8 +144,8 @@ def train(self) -> None:
144144 for rollout_data in self .rollout_buffer .get (batch_size = None ):
145145 actions = rollout_data .actions
146146 if isinstance (self .action_space , spaces .Discrete ):
147- # Convert discrete action from float to long
148- actions = actions .long (). flatten ()
147+ # Flatten discrete actions for correct computation of log prob
148+ actions = actions .flatten ()
149149
150150 values , log_prob , entropy = self .policy .evaluate_actions (rollout_data .observations , actions )
151151 values = values .flatten ()
Original file line number Diff line number Diff line change @@ -207,9 +207,8 @@ def train(self) -> None:
207207 for rollout_data in self .rollout_buffer .get (self .batch_size ):
208208 actions = rollout_data .actions
209209 if isinstance (self .action_space , spaces .Discrete ):
210- # Convert discrete action from float to long
211- actions = rollout_data .actions .long ().flatten ()
212-
210+ # Flatten discrete actions for correct computation of log prob
211+ actions = rollout_data .actions .flatten ()
213212 values , log_prob , entropy = self .policy .evaluate_actions (rollout_data .observations , actions )
214213 values = values .flatten ()
215214 # Normalize advantage
You can’t perform that action at this time.
0 commit comments