@@ -184,21 +184,18 @@ def get_action_and_value(self, x, action=None):
184184
185185 # ALGO Logic: Storage setup
186186 obs = torch .zeros ((args .num_steps , args .num_envs ) + envs .single_observation_space .shape ).to (device )
187- next_obs = torch .zeros ((args .num_steps , args .num_envs ) + envs .single_observation_space .shape ).to (device )
188187 actions = torch .zeros ((args .num_steps , args .num_envs ) + envs .single_action_space .shape ).to (device )
189188 logprobs = torch .zeros ((args .num_steps , args .num_envs )).to (device )
190189 rewards = torch .zeros ((args .num_steps , args .num_envs )).to (device )
191- next_dones = torch .zeros ((args .num_steps , args .num_envs )).to (device )
192- next_terminations = torch .zeros ((args .num_steps , args .num_envs )).to (device )
190+ dones = torch .zeros ((args .num_steps , args .num_envs )).to (device )
193191 values = torch .zeros ((args .num_steps , args .num_envs )).to (device )
194192
195193 # TRY NOT TO MODIFY: start the game
196194 global_step = 0
197195 start_time = time .time ()
198- next_ob , _ = envs .reset (seed = args .seed )
199- next_ob = torch .Tensor (next_ob ).to (device )
196+ next_obs , _ = envs .reset (seed = args .seed )
197+ next_obs = torch .Tensor (next_obs ).to (device )
200198 next_done = torch .zeros (args .num_envs ).to (device )
201- next_termination = torch .zeros (args .num_envs ).to (device )
202199
203200 for iteration in range (1 , args .num_iterations + 1 ):
204201 # Annealing the rate if instructed to do so.
@@ -209,53 +206,43 @@ def get_action_and_value(self, x, action=None):
209206
210207 for step in range (0 , args .num_steps ):
211208 global_step += args .num_envs
209+ obs [step ] = next_obs
210+ dones [step ] = next_done
212211
213- ob = next_ob
214212 # ALGO LOGIC: action logic
215213 with torch .no_grad ():
216- action , logprob , _ , value = agent .get_action_and_value (ob )
214+ action , logprob , _ , value = agent .get_action_and_value (next_obs )
215+ values [step ] = value .flatten ()
216+ actions [step ] = action
217+ logprobs [step ] = logprob
217218
218219 # TRY NOT TO MODIFY: execute the game and log data.
219- next_ob , reward , next_termination , next_truncation , info = envs .step (action .cpu ().numpy ())
220-
221- # Correct next obervation (for vec gym)
222- real_next_ob = next_ob .copy ()
223- for idx , trunc in enumerate (next_truncation ):
224- if trunc :
225- real_next_ob [idx ] = info ["final_observation" ][idx ]
226- next_ob = torch .Tensor (next_ob ).to (device )
227-
228- # Collect trajectory
229- obs [step ] = torch .Tensor (ob ).to (device )
230- next_obs [step ] = torch .Tensor (real_next_ob ).to (device )
231- actions [step ] = torch .Tensor (action ).to (device )
232- logprobs [step ] = torch .Tensor (logprob ).to (device )
233- values [step ] = torch .Tensor (value .flatten ()).to (device )
234- next_terminations [step ] = torch .Tensor (next_termination ).to (device )
235- next_dones [step ] = torch .Tensor (np .logical_or (next_termination , next_truncation )).to (device )
220+ next_obs , reward , terminations , truncations , infos = envs .step (action .cpu ().numpy ())
221+ next_done = np .logical_or (terminations , truncations )
236222 rewards [step ] = torch .tensor (reward ).to (device ).view (- 1 )
223+ next_obs , next_done = torch .Tensor (next_obs ).to (device ), torch .Tensor (next_done ).to (device )
237224
238- if "final_info" in info :
239- for info in info ["final_info" ]:
225+ if "final_info" in infos :
226+ for info in infos ["final_info" ]:
240227 if info and "episode" in info :
241228 print (f"global_step={ global_step } , episodic_return={ info ['episode' ]['r' ]} " )
242229 writer .add_scalar ("charts/episodic_return" , info ["episode" ]["r" ], global_step )
243230 writer .add_scalar ("charts/episodic_length" , info ["episode" ]["l" ], global_step )
244231
245232 # bootstrap value if not done
246233 with torch .no_grad ():
247- next_values = torch . zeros_like ( values [ 0 ]). to ( device )
234+ next_value = agent . get_value ( next_obs ). reshape ( 1 , - 1 )
248235 advantages = torch .zeros_like (rewards ).to (device )
249236 lastgaelam = 0
250237 for t in reversed (range (args .num_steps )):
251238 if t == args .num_steps - 1 :
252- next_values = agent .get_value (next_obs [t ]).flatten ()
239+ nextnonterminal = 1.0 - next_done
240+ nextvalues = next_value
253241 else :
254- value_mask = next_dones [t ].bool ()
255- next_values [value_mask ] = agent .get_value (next_obs [t ][value_mask ]).flatten ()
256- next_values [~ value_mask ] = values [t + 1 ][~ value_mask ]
257- delta = rewards [t ] + args .gamma * next_values * (1 - next_terminations [t ]) - values [t ]
258- advantages [t ] = lastgaelam = delta + args .gamma * args .gae_lambda * (1 - next_dones [t ]) * lastgaelam
242+ nextnonterminal = 1.0 - dones [t + 1 ]
243+ nextvalues = values [t + 1 ]
244+ delta = rewards [t ] + args .gamma * nextvalues * nextnonterminal - values [t ]
245+ advantages [t ] = lastgaelam = delta + args .gamma * args .gae_lambda * nextnonterminal * lastgaelam
259246 returns = advantages + values
260247
261248 # flatten the batch
@@ -363,4 +350,4 @@ def get_action_and_value(self, x, action=None):
363350 push_to_hub (args , episodic_returns , repo_id , "PPO" , f"runs/{ run_name } " , f"videos/{ run_name } -eval" )
364351
365352 envs .close ()
366- writer .close ()
353+ writer .close ()
0 commit comments