forked from ikostrikov/pytorch-a2c-ppo-acktr-gail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_functions.py
83 lines (64 loc) · 3.77 KB
/
train_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from a2c_ppo_acktr.storage import RolloutStorage
from a2c_ppo_acktr.envs import make_vec_envs
from a2c_ppo_acktr import algo, utils
import torch
from collections import deque
'''
First 4 parameters are for shared recurrent layer. Can freeze these by setting
requires_grad = False
'''
def populate_rollouts(model, envs, rollouts, num_steps):
for step in range(num_steps):
#Generate rollouts for num_steps batch
with torch.no_grad():
outputs = model.act(rollouts.obs[step], rollouts.recurrent_hidden_states[step],
rollouts.masks[step])
action = outputs['action']
value = outputs['value']
action_log_prob = outputs['action_log_probs']
recurrent_hidden_states = outputs['rnn_hxs']
# auxiliary_preds = outputs['auxiliary_preds']
obs, reward, done, infos = envs.step(action)
masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
bad_masks = torch.FloatTensor([[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos])
rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks)
def update_model(agent, rollouts, use_gae=False, gamma=0.99, gae_lambda=0.95,
after_update=True):
#Compute last value to be used for the update
with torch.no_grad():
next_value = agent.actor_critic.get_value(rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
rollouts.masks[-1]).detach()
rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda)
value_loss, action_loss, dist_entropy, approx_kl, clipfracs, auxiliary_loss = agent.update(rollouts)
if after_update:
rollouts.after_update()
return value_loss, action_loss, dist_entropy, approx_kl, clipfracs, auxiliary_loss
def initialize_ppo_training(model, obs_rms=None, env_name='NavEnv-v0', env_kwargs={},
num_steps=10, num_processes=1, seed=0, ppo_epoch=4, clip_param=0.5,
num_mini_batch=1, value_loss_coef=0.5, entropy_coef=0.01,
auxiliary_loss_coef=0.3, gamma=0.99, lr=7e-4, eps=1e-5, max_grad_norm=0.5,
log_dir='/tmp/gym/', device=torch.device('cpu'),
capture_video=False):
#Wrap model with an agent algorithm object
agent = algo.PPO(model, clip_param, ppo_epoch, num_mini_batch,
value_loss_coef, entropy_coef, auxiliary_loss_coef, lr=lr,
eps=eps, max_grad_norm=max_grad_norm)
#Initialize vectorized environments
envs = make_vec_envs(env_name, seed, num_processes, gamma, log_dir, device, False,
capture_video=capture_video, env_kwargs=env_kwargs)
#If loading a previously trained model, pass an obs_rms object to set the vec envs to use
vec_norm = utils.get_vec_normalize(envs)
if vec_norm is not None and obs_rms is not None:
vec_norm.obs_rms = obs_rms
#Initialize storage
rollouts = RolloutStorage(num_steps, num_processes, envs.observation_space.shape, envs.action_space,
model.recurrent_hidden_state_size, model.auxiliary_output_size)
#Storage objects initializes a bunch of empty tensors to store information, e.g.
#obs has shape (num_steps+1, num_processes, obs_shape)
#rewards has shape (num_steps, num_processes, 1)
#obs, recurrent_hidden_states, value_preds, returns all have batch size num_steps+1
#rewards, action_log_probs, actions, masks, auxiliary_preds, auxiliary_truths all have batch size num_steps
obs = envs.reset()
rollouts.obs[0].copy_(obs)
rollouts.to(device)
return agent, envs, rollouts