File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed
Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -144,8 +144,8 @@ def train(self) -> None:
144144 for rollout_data in self .rollout_buffer .get (batch_size = None ):
145145 actions = rollout_data .actions
146146 if isinstance (self .action_space , spaces .Discrete ):
147- # Flatten discrete actions for correct computation of log prob
148- actions = actions .flatten ()
147+ # Convert discrete action from float to long
148+ actions = actions .long (). flatten ()
149149
150150 values , log_prob , entropy = self .policy .evaluate_actions (rollout_data .observations , actions )
151151 values = values .flatten ()
Original file line number Diff line number Diff line change @@ -207,8 +207,9 @@ def train(self) -> None:
207207 for rollout_data in self .rollout_buffer .get (self .batch_size ):
208208 actions = rollout_data .actions
209209 if isinstance (self .action_space , spaces .Discrete ):
210- # Flatten discrete actions for correct computation of log prob
211- actions = rollout_data .actions .flatten ()
210+ # Convert discrete action from float to long
211+ actions = rollout_data .actions .long ().flatten ()
212+
212213 values , log_prob , entropy = self .policy .evaluate_actions (rollout_data .observations , actions )
213214 values = values .flatten ()
214215 # Normalize advantage
You can’t perform that action at this time.
0 commit comments