@@ -209,22 +209,22 @@ def get_action_and_value(self, x, action=None):
209209
210210 for step in range (0 , args .num_steps ):
211211 global_step += args .num_envs
212-
212+
213213 ob = next_ob
214214 # ALGO LOGIC: action logic
215215 with torch .no_grad ():
216216 action , logprob , _ , value = agent .get_action_and_value (ob )
217217
218218 # TRY NOT TO MODIFY: execute the game and log data.
219219 next_ob , reward , next_termination , next_truncation , info = envs .step (action .cpu ().numpy ())
220-
220+
221221 # Correct next obervation (for vec gym)
222222 real_next_ob = next_ob .copy ()
223223 for idx , trunc in enumerate (next_truncation ):
224224 if trunc :
225225 real_next_ob [idx ] = info ["final_observation" ][idx ]
226226 next_ob = torch .Tensor (next_ob ).to (device )
227-
227+
228228 # Collect trajectory
229229 obs [step ] = torch .Tensor (ob ).to (device )
230230 next_obs [step ] = torch .Tensor (real_next_ob ).to (device )
@@ -234,7 +234,7 @@ def get_action_and_value(self, x, action=None):
234234 next_terminations [step ] = torch .Tensor (next_termination ).to (device )
235235 next_dones [step ] = torch .Tensor (np .logical_or (next_termination , next_truncation )).to (device )
236236 rewards [step ] = torch .tensor (reward ).to (device ).view (- 1 )
237-
237+
238238 if "final_info" in info :
239239 for info in info ["final_info" ]:
240240 if info and "episode" in info :
@@ -253,7 +253,7 @@ def get_action_and_value(self, x, action=None):
253253 else :
254254 value_mask = next_dones [t ].bool ()
255255 next_values [value_mask ] = agent .get_value (next_obs [t ][value_mask ]).flatten ()
256- next_values [~ value_mask ] = values [t + 1 ][~ value_mask ]
256+ next_values [~ value_mask ] = values [t + 1 ][~ value_mask ]
257257 delta = rewards [t ] + args .gamma * next_values * (1 - next_terminations [t ]) - values [t ]
258258 advantages [t ] = lastgaelam = delta + args .gamma * args .gae_lambda * (1 - next_dones [t ]) * lastgaelam
259259 returns = advantages + values
0 commit comments