Skip to content

Commit 512160b

Browse files
committed
pre-commit run check
1 parent f9ad20a commit 512160b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

cleanrl/ppo_continuous_action.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,22 +209,22 @@ def get_action_and_value(self, x, action=None):
209209

210210
for step in range(0, args.num_steps):
211211
global_step += args.num_envs
212-
212+
213213
ob = next_ob
214214
# ALGO LOGIC: action logic
215215
with torch.no_grad():
216216
action, logprob, _, value = agent.get_action_and_value(ob)
217217

218218
# TRY NOT TO MODIFY: execute the game and log data.
219219
next_ob, reward, next_termination, next_truncation, info = envs.step(action.cpu().numpy())
220-
220+
221221
# Correct next obervation (for vec gym)
222222
real_next_ob = next_ob.copy()
223223
for idx, trunc in enumerate(next_truncation):
224224
if trunc:
225225
real_next_ob[idx] = info["final_observation"][idx]
226226
next_ob = torch.Tensor(next_ob).to(device)
227-
227+
228228
# Collect trajectory
229229
obs[step] = torch.Tensor(ob).to(device)
230230
next_obs[step] = torch.Tensor(real_next_ob).to(device)
@@ -234,7 +234,7 @@ def get_action_and_value(self, x, action=None):
234234
next_terminations[step] = torch.Tensor(next_termination).to(device)
235235
next_dones[step] = torch.Tensor(np.logical_or(next_termination, next_truncation)).to(device)
236236
rewards[step] = torch.tensor(reward).to(device).view(-1)
237-
237+
238238
if "final_info" in info:
239239
for info in info["final_info"]:
240240
if info and "episode" in info:
@@ -253,7 +253,7 @@ def get_action_and_value(self, x, action=None):
253253
else:
254254
value_mask = next_dones[t].bool()
255255
next_values[value_mask] = agent.get_value(next_obs[t][value_mask]).flatten()
256-
next_values[~value_mask] = values[t+1][~value_mask]
256+
next_values[~value_mask] = values[t + 1][~value_mask]
257257
delta = rewards[t] + args.gamma * next_values * (1 - next_terminations[t]) - values[t]
258258
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * (1 - next_dones[t]) * lastgaelam
259259
returns = advantages + values

0 commit comments

Comments
 (0)