Skip to content

Commit 625aba7

Browse files
author
Anonymous
committed
Separate Implement for Handling Trunction with PPO
1 parent 31a91a2 commit 625aba7

File tree

2 files changed

+388
-35
lines changed

2 files changed

+388
-35
lines changed

cleanrl/ppo_continuous_action.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -184,21 +184,18 @@ def get_action_and_value(self, x, action=None):
184184

185185
# ALGO Logic: Storage setup
186186
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
187-
next_obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
188187
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
189188
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
190189
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
191-
next_dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
192-
next_terminations = torch.zeros((args.num_steps, args.num_envs)).to(device)
190+
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
193191
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
194192

195193
# TRY NOT TO MODIFY: start the game
196194
global_step = 0
197195
start_time = time.time()
198-
next_ob, _ = envs.reset(seed=args.seed)
199-
next_ob = torch.Tensor(next_ob).to(device)
196+
next_obs, _ = envs.reset(seed=args.seed)
197+
next_obs = torch.Tensor(next_obs).to(device)
200198
next_done = torch.zeros(args.num_envs).to(device)
201-
next_termination = torch.zeros(args.num_envs).to(device)
202199

203200
for iteration in range(1, args.num_iterations + 1):
204201
# Annealing the rate if instructed to do so.
@@ -209,53 +206,43 @@ def get_action_and_value(self, x, action=None):
209206

210207
for step in range(0, args.num_steps):
211208
global_step += args.num_envs
209+
obs[step] = next_obs
210+
dones[step] = next_done
212211

213-
ob = next_ob
214212
# ALGO LOGIC: action logic
215213
with torch.no_grad():
216-
action, logprob, _, value = agent.get_action_and_value(ob)
214+
action, logprob, _, value = agent.get_action_and_value(next_obs)
215+
values[step] = value.flatten()
216+
actions[step] = action
217+
logprobs[step] = logprob
217218

218219
# TRY NOT TO MODIFY: execute the game and log data.
219-
next_ob, reward, next_termination, next_truncation, info = envs.step(action.cpu().numpy())
220-
221-
# Correct next obervation (for vec gym)
222-
real_next_ob = next_ob.copy()
223-
for idx, trunc in enumerate(next_truncation):
224-
if trunc:
225-
real_next_ob[idx] = info["final_observation"][idx]
226-
next_ob = torch.Tensor(next_ob).to(device)
227-
228-
# Collect trajectory
229-
obs[step] = torch.Tensor(ob).to(device)
230-
next_obs[step] = torch.Tensor(real_next_ob).to(device)
231-
actions[step] = torch.Tensor(action).to(device)
232-
logprobs[step] = torch.Tensor(logprob).to(device)
233-
values[step] = torch.Tensor(value.flatten()).to(device)
234-
next_terminations[step] = torch.Tensor(next_termination).to(device)
235-
next_dones[step] = torch.Tensor(np.logical_or(next_termination, next_truncation)).to(device)
220+
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
221+
next_done = np.logical_or(terminations, truncations)
236222
rewards[step] = torch.tensor(reward).to(device).view(-1)
223+
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
237224

238-
if "final_info" in info:
239-
for info in info["final_info"]:
225+
if "final_info" in infos:
226+
for info in infos["final_info"]:
240227
if info and "episode" in info:
241228
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
242229
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
243230
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
244231

245232
# bootstrap value if not done
246233
with torch.no_grad():
247-
next_values = torch.zeros_like(values[0]).to(device)
234+
next_value = agent.get_value(next_obs).reshape(1, -1)
248235
advantages = torch.zeros_like(rewards).to(device)
249236
lastgaelam = 0
250237
for t in reversed(range(args.num_steps)):
251238
if t == args.num_steps - 1:
252-
next_values = agent.get_value(next_obs[t]).flatten()
239+
nextnonterminal = 1.0 - next_done
240+
nextvalues = next_value
253241
else:
254-
value_mask = next_dones[t].bool()
255-
next_values[value_mask] = agent.get_value(next_obs[t][value_mask]).flatten()
256-
next_values[~value_mask] = values[t + 1][~value_mask]
257-
delta = rewards[t] + args.gamma * next_values * (1 - next_terminations[t]) - values[t]
258-
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * (1 - next_dones[t]) * lastgaelam
242+
nextnonterminal = 1.0 - dones[t + 1]
243+
nextvalues = values[t + 1]
244+
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
245+
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
259246
returns = advantages + values
260247

261248
# flatten the batch
@@ -363,4 +350,4 @@ def get_action_and_value(self, x, action=None):
363350
push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")
364351

365352
envs.close()
366-
writer.close()
353+
writer.close()

0 commit comments

Comments
 (0)