Skip to content

Commit 96b094e

Browse files
authored
Merge pull request #17 from openai/backprop
Add backprop support + simple BC example
2 parents f63a391 + b61623e commit 96b094e

File tree

7 files changed

+478
-19
lines changed

7 files changed

+478
-19
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,27 @@ A window should pop up which shows the video frame-by-frame, showing the predict
9696

9797
Note that `run_inverse_dynamics_model.py` is designed to be a demo of the IDM, not code to put it into practice.
9898

99+
# Using behavioural cloning to fine-tune the models
100+
101+
**Disclaimer:** This code is a rough demonstration only and not an exact recreation of what original VPT paper did (but it contains some preprocessing steps you want to be aware of)! As such, do not expect replicate the original experiments with this code. This code has been designed to be run-able on consumer hardware (e.g., 8GB of VRAM).
102+
103+
Setup:
104+
* Install requirements: `pip install -r requirements.txt`
105+
* Download `.weights` and `.model` file for model you want to fine-tune.
106+
* Download contractor data (below) and place the `.mp4` and `.jsonl` files to the same directory (e.g., `data`). With default settings, you need at least 12 recordings.
107+
108+
If you downloaded the "1x Width" models and placed some data under `data` directory, you can perform finetuning with
109+
110+
```
111+
python behavioural_cloning.py --data-dir data --in-model foundation-model-1x.model --in-weights foundation-model-1x.weights --out-weights finetuned-1x.weights
112+
```
113+
114+
You can then use `finetuned-1x.weights` when running the agent. You can change the training settings at the top of `behavioural_cloning.py`.
115+
116+
Major limitations:
117+
- Only trains single step at the time, i.e., errors are not propagated through timesteps.
118+
- Computes gradients one sample at a time to keep memory use low, but also slows down the code.
119+
99120
# Contractor Demonstrations
100121

101122
### Versions

agent.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,11 @@ def reset(self):
139139
self.hidden_state = self.policy.initial_state(1)
140140

141141
def _env_obs_to_agent(self, minerl_obs):
142-
"""Turn observation from MineRL environment into model's observation"""
142+
"""
143+
Turn observation from MineRL environment into model's observation
144+
145+
Returns torch tensors.
146+
"""
143147
agent_input = resize_image(minerl_obs["pov"], AGENT_RESOLUTION)[None]
144148
agent_input = {"img": th.from_numpy(agent_input).to(self.device)}
145149
return agent_input
@@ -149,17 +153,39 @@ def _agent_action_to_env(self, agent_action):
149153
# This is quite important step (for some reason).
150154
# For the sake of your sanity, remember to do this step (manual conversion to numpy)
151155
# before proceeding. Otherwise, your agent might be a little derp.
152-
action = {
153-
"buttons": agent_action["buttons"].cpu().numpy(),
154-
"camera": agent_action["camera"].cpu().numpy()
155-
}
156+
action = agent_action
157+
if isinstance(action["buttons"], th.Tensor):
158+
action = {
159+
"buttons": agent_action["buttons"].cpu().numpy(),
160+
"camera": agent_action["camera"].cpu().numpy()
161+
}
156162
minerl_action = self.action_mapper.to_factored(action)
157163
minerl_action_transformed = self.action_transformer.policy2env(minerl_action)
158164
return minerl_action_transformed
159165

160-
def _env_action_to_agent(self, minerl_action):
161-
"""Turn action from MineRL to model's action"""
162-
raise NotImplementedError()
166+
def _env_action_to_agent(self, minerl_action_transformed, to_torch=False, check_if_null=False):
167+
"""
168+
Turn action from MineRL to model's action.
169+
170+
Note that this will add batch dimensions to the action.
171+
Returns numpy arrays, unless `to_torch` is True, in which case it returns torch tensors.
172+
173+
If `check_if_null` is True, check if the action is null (no action) after the initial
174+
transformation. This matches the behaviour done in OpenAI's VPT work.
175+
If action is null, return "None" instead
176+
"""
177+
minerl_action = self.action_transformer.env2policy(minerl_action_transformed)
178+
if check_if_null:
179+
if np.all(minerl_action["buttons"] == 0) and np.all(minerl_action["camera"] == self.action_transformer.camera_zero_bin):
180+
return None
181+
182+
# Add batch dims if not existant
183+
if minerl_action["camera"].ndim == 1:
184+
minerl_action = {k: v[None] for k, v in minerl_action.items()}
185+
action = self.action_mapper.from_factored(minerl_action)
186+
if to_torch:
187+
action = {k: th.from_numpy(v).to(self.device) for k, v in action.items()}
188+
return action
163189

164190
def get_action(self, minerl_obs):
165191
"""
@@ -177,4 +203,4 @@ def get_action(self, minerl_obs):
177203
stochastic=True
178204
)
179205
minerl_action = self._agent_action_to_env(agent_action)
180-
return minerl_action
206+
return minerl_action

behavioural_cloning.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Basic behavioural cloning
2+
# Note: this uses gradient accumulation in batches of ones
3+
# to perform training.
4+
# This will fit inside even smaller GPUs (tested on 8GB one),
5+
# but is slow.
6+
# NOTE: This is _not_ the original code used for VPT!
7+
# This is merely to illustrate how to fine-tune the models and includes
8+
# the processing steps used.
9+
10+
# This will likely be much worse than what original VPT did:
11+
# we are not training on full sequences, but only one step at a time to save VRAM.
12+
13+
from argparse import ArgumentParser
14+
import pickle
15+
import time
16+
17+
import gym
18+
import minerl
19+
import torch as th
20+
import numpy as np
21+
22+
from agent import PI_HEAD_KWARGS, MineRLAgent
23+
from data_loader import DataLoader
24+
from lib.tree_util import tree_map
25+
26+
EPOCHS = 2
27+
# Needs to be <= number of videos
28+
BATCH_SIZE = 8
29+
# Ideally more than batch size to create
30+
# variation in datasets (otherwise, you will
31+
# get a bunch of consecutive samples)
32+
# Decrease this (and batch_size) if you run out of memory
33+
N_WORKERS = 12
34+
DEVICE = "cuda"
35+
36+
LOSS_REPORT_RATE = 100
37+
38+
LEARNING_RATE = 0.000181
39+
WEIGHT_DECAY = 0.039428
40+
MAX_GRAD_NORM = 5.0
41+
42+
def load_model_parameters(path_to_model_file):
43+
agent_parameters = pickle.load(open(path_to_model_file, "rb"))
44+
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
45+
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
46+
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
47+
return policy_kwargs, pi_head_kwargs
48+
49+
def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
50+
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)
51+
52+
# To create model with the right environment.
53+
# All basalt environments have the same settings, so any of them works here
54+
env = gym.make("MineRLBasaltFindCave-v0")
55+
agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
56+
agent.load_weights(in_weights)
57+
env.close()
58+
59+
policy = agent.policy
60+
trainable_parameters = policy.parameters()
61+
62+
# Parameters taken from the OpenAI VPT paper
63+
optimizer = th.optim.Adam(
64+
trainable_parameters,
65+
lr=LEARNING_RATE,
66+
weight_decay=WEIGHT_DECAY
67+
)
68+
69+
data_loader = DataLoader(
70+
dataset_dir=data_dir,
71+
n_workers=N_WORKERS,
72+
batch_size=BATCH_SIZE,
73+
n_epochs=EPOCHS
74+
)
75+
76+
start_time = time.time()
77+
78+
# Keep track of the hidden state per episode/trajectory.
79+
# DataLoader provides unique id for each episode, which will
80+
# be different even for the same trajectory when it is loaded
81+
# up again
82+
episode_hidden_states = {}
83+
dummy_first = th.from_numpy(np.array((False,))).to(DEVICE)
84+
85+
loss_sum = 0
86+
for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader):
87+
batch_loss = 0
88+
for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
89+
agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True)
90+
if agent_action is None:
91+
# Action was null
92+
continue
93+
94+
agent_obs = agent._env_obs_to_agent({"pov": image})
95+
if episode_id not in episode_hidden_states:
96+
# TODO need to clean up this hidden state after worker is done with the work item.
97+
# Leaks memory, but not tooooo much at these scales (will be a problem later).
98+
episode_hidden_states[episode_id] = policy.initial_state(1)
99+
agent_state = episode_hidden_states[episode_id]
100+
101+
pi_distribution, v_prediction, new_agent_state = policy.get_output_for_observation(
102+
agent_obs,
103+
agent_state,
104+
dummy_first
105+
)
106+
107+
log_prob = policy.get_logprob_of_action(pi_distribution, agent_action)
108+
109+
# Make sure we do not try to backprop through sequence
110+
# (fails with current accumulation)
111+
new_agent_state = tree_map(lambda x: x.detach(), new_agent_state)
112+
episode_hidden_states[episode_id] = new_agent_state
113+
114+
# Finally, update the agent to increase the probability of the
115+
# taken action.
116+
# Remember to take mean over batch losses
117+
loss = -log_prob / BATCH_SIZE
118+
batch_loss += loss.item()
119+
loss.backward()
120+
121+
th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM)
122+
optimizer.step()
123+
optimizer.zero_grad()
124+
125+
loss_sum += batch_loss
126+
if batch_i % LOSS_REPORT_RATE == 0:
127+
time_since_start = time.time() - start_time
128+
print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")
129+
loss_sum = 0
130+
131+
state_dict = policy.state_dict()
132+
th.save(state_dict, out_weights)
133+
134+
135+
if __name__ == "__main__":
136+
parser = ArgumentParser()
137+
parser.add_argument("--data-dir", type=str, required=True, help="Path to the directory containing recordings to be trained on")
138+
parser.add_argument("--in-model", required=True, type=str, help="Path to the .model file to be finetuned")
139+
parser.add_argument("--in-weights", required=True, type=str, help="Path to the .weights file to be finetuned")
140+
parser.add_argument("--out-weights", required=True, type=str, help="Path where finetuned weights will be saved")
141+
142+
args = parser.parse_args()
143+
behavioural_cloning_train(args.data_dir, args.in_model, args.in_weights, args.out_weights)
1.82 KB
Loading

0 commit comments

Comments
 (0)