|
| 1 | +# Basic behavioural cloning |
| 2 | +# Note: this uses gradient accumulation in batches of ones |
| 3 | +# to perform training. |
| 4 | +# This will fit inside even smaller GPUs (tested on 8GB one), |
| 5 | +# but is slow. |
| 6 | +# NOTE: This is _not_ the original code used for VPT! |
| 7 | +# This is merely to illustrate how to fine-tune the models and includes |
| 8 | +# the processing steps used. |
| 9 | + |
| 10 | +# This will likely be much worse than what original VPT did: |
| 11 | +# we are not training on full sequences, but only one step at a time to save VRAM. |
| 12 | + |
| 13 | +from argparse import ArgumentParser |
| 14 | +import pickle |
| 15 | +import time |
| 16 | + |
| 17 | +import gym |
| 18 | +import minerl |
| 19 | +import torch as th |
| 20 | +import numpy as np |
| 21 | + |
| 22 | +from agent import PI_HEAD_KWARGS, MineRLAgent |
| 23 | +from data_loader import DataLoader |
| 24 | +from lib.tree_util import tree_map |
| 25 | + |
| 26 | +EPOCHS = 2 |
| 27 | +# Needs to be <= number of videos |
| 28 | +BATCH_SIZE = 8 |
| 29 | +# Ideally more than batch size to create |
| 30 | +# variation in datasets (otherwise, you will |
| 31 | +# get a bunch of consecutive samples) |
| 32 | +# Decrease this (and batch_size) if you run out of memory |
| 33 | +N_WORKERS = 12 |
| 34 | +DEVICE = "cuda" |
| 35 | + |
| 36 | +LOSS_REPORT_RATE = 100 |
| 37 | + |
| 38 | +LEARNING_RATE = 0.000181 |
| 39 | +WEIGHT_DECAY = 0.039428 |
| 40 | +MAX_GRAD_NORM = 5.0 |
| 41 | + |
| 42 | +def load_model_parameters(path_to_model_file): |
| 43 | + agent_parameters = pickle.load(open(path_to_model_file, "rb")) |
| 44 | + policy_kwargs = agent_parameters["model"]["args"]["net"]["args"] |
| 45 | + pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"] |
| 46 | + pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"]) |
| 47 | + return policy_kwargs, pi_head_kwargs |
| 48 | + |
| 49 | +def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights): |
| 50 | + agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model) |
| 51 | + |
| 52 | + # To create model with the right environment. |
| 53 | + # All basalt environments have the same settings, so any of them works here |
| 54 | + env = gym.make("MineRLBasaltFindCave-v0") |
| 55 | + agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs) |
| 56 | + agent.load_weights(in_weights) |
| 57 | + env.close() |
| 58 | + |
| 59 | + policy = agent.policy |
| 60 | + trainable_parameters = policy.parameters() |
| 61 | + |
| 62 | + # Parameters taken from the OpenAI VPT paper |
| 63 | + optimizer = th.optim.Adam( |
| 64 | + trainable_parameters, |
| 65 | + lr=LEARNING_RATE, |
| 66 | + weight_decay=WEIGHT_DECAY |
| 67 | + ) |
| 68 | + |
| 69 | + data_loader = DataLoader( |
| 70 | + dataset_dir=data_dir, |
| 71 | + n_workers=N_WORKERS, |
| 72 | + batch_size=BATCH_SIZE, |
| 73 | + n_epochs=EPOCHS |
| 74 | + ) |
| 75 | + |
| 76 | + start_time = time.time() |
| 77 | + |
| 78 | + # Keep track of the hidden state per episode/trajectory. |
| 79 | + # DataLoader provides unique id for each episode, which will |
| 80 | + # be different even for the same trajectory when it is loaded |
| 81 | + # up again |
| 82 | + episode_hidden_states = {} |
| 83 | + dummy_first = th.from_numpy(np.array((False,))).to(DEVICE) |
| 84 | + |
| 85 | + loss_sum = 0 |
| 86 | + for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader): |
| 87 | + batch_loss = 0 |
| 88 | + for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id): |
| 89 | + agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True) |
| 90 | + if agent_action is None: |
| 91 | + # Action was null |
| 92 | + continue |
| 93 | + |
| 94 | + agent_obs = agent._env_obs_to_agent({"pov": image}) |
| 95 | + if episode_id not in episode_hidden_states: |
| 96 | + # TODO need to clean up this hidden state after worker is done with the work item. |
| 97 | + # Leaks memory, but not tooooo much at these scales (will be a problem later). |
| 98 | + episode_hidden_states[episode_id] = policy.initial_state(1) |
| 99 | + agent_state = episode_hidden_states[episode_id] |
| 100 | + |
| 101 | + pi_distribution, v_prediction, new_agent_state = policy.get_output_for_observation( |
| 102 | + agent_obs, |
| 103 | + agent_state, |
| 104 | + dummy_first |
| 105 | + ) |
| 106 | + |
| 107 | + log_prob = policy.get_logprob_of_action(pi_distribution, agent_action) |
| 108 | + |
| 109 | + # Make sure we do not try to backprop through sequence |
| 110 | + # (fails with current accumulation) |
| 111 | + new_agent_state = tree_map(lambda x: x.detach(), new_agent_state) |
| 112 | + episode_hidden_states[episode_id] = new_agent_state |
| 113 | + |
| 114 | + # Finally, update the agent to increase the probability of the |
| 115 | + # taken action. |
| 116 | + # Remember to take mean over batch losses |
| 117 | + loss = -log_prob / BATCH_SIZE |
| 118 | + batch_loss += loss.item() |
| 119 | + loss.backward() |
| 120 | + |
| 121 | + th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM) |
| 122 | + optimizer.step() |
| 123 | + optimizer.zero_grad() |
| 124 | + |
| 125 | + loss_sum += batch_loss |
| 126 | + if batch_i % LOSS_REPORT_RATE == 0: |
| 127 | + time_since_start = time.time() - start_time |
| 128 | + print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}") |
| 129 | + loss_sum = 0 |
| 130 | + |
| 131 | + state_dict = policy.state_dict() |
| 132 | + th.save(state_dict, out_weights) |
| 133 | + |
| 134 | + |
| 135 | +if __name__ == "__main__": |
| 136 | + parser = ArgumentParser() |
| 137 | + parser.add_argument("--data-dir", type=str, required=True, help="Path to the directory containing recordings to be trained on") |
| 138 | + parser.add_argument("--in-model", required=True, type=str, help="Path to the .model file to be finetuned") |
| 139 | + parser.add_argument("--in-weights", required=True, type=str, help="Path to the .weights file to be finetuned") |
| 140 | + parser.add_argument("--out-weights", required=True, type=str, help="Path where finetuned weights will be saved") |
| 141 | + |
| 142 | + args = parser.parse_args() |
| 143 | + behavioural_cloning_train(args.data_dir, args.in_model, args.in_weights, args.out_weights) |
0 commit comments