Skip to content

Commit 26bfee2

Browse files
Add Meta-World example (#144)
* Add MetaWorld wrapper to set done=True at end of horizon * Add MAML-TRPO example on the 'ML' MetaWorld benchmarks * Update CHANGELOG with Meta-World example * Use tanh instead of ReLU (as in the paper) * Add documentation to wrapper * Fix lint issues Co-authored-by: Séb Arnold <[email protected]>
1 parent a8c8985 commit 26bfee2

File tree

4 files changed

+420
-0
lines changed

4 files changed

+420
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Added
1212

13+
* New example: [Meta-World](https://github.com/rlworkgroup/metaworld) example with MAML-TRPO with it's own env wrapper. (@[Kostis-S-Z](https://github.com/Kostis-S-Z))
1314
* Add l2l.vision.benchmarks interface.
1415

1516
### Changed

examples/rl/maml_trpo_metaworld.py

+275
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Trains a 2-layer MLP with MAML-TRPO on the 'ML' benchmarks of metaworld.
5+
For more information related to the benchmark check out https://github.com/rlworkgroup/metaworld
6+
7+
Usage:
8+
9+
python examples/rl/maml_trpo.py
10+
"""
11+
12+
import random
13+
from copy import deepcopy
14+
15+
import cherry as ch
16+
import numpy as np
17+
import torch
18+
from cherry.algorithms import a2c, trpo
19+
from cherry.models.robotics import LinearValue
20+
from torch import autograd
21+
from torch.distributions.kl import kl_divergence
22+
from torch.nn.utils import parameters_to_vector, vector_to_parameters
23+
from tqdm import tqdm
24+
25+
import learn2learn as l2l
26+
27+
from learn2learn.gym.envs.metaworld import MetaWorldML1 as ML1
28+
from learn2learn.gym.envs.metaworld import MetaWorldML10 as ML10
29+
from learn2learn.gym.envs.metaworld import MetaWorldML45 as ML45
30+
from policies import DiagNormalPolicy
31+
32+
33+
def compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states):
34+
# Update baseline
35+
returns = ch.td.discount(gamma, rewards, dones)
36+
37+
baseline.fit(states, returns)
38+
values = baseline(states)
39+
next_values = baseline(next_states)
40+
bootstraps = values * (1.0 - dones) + next_values * dones
41+
next_value = torch.zeros(1, device=values.device)
42+
43+
return ch.pg.generalized_advantage(tau=tau,
44+
gamma=gamma,
45+
rewards=rewards,
46+
dones=dones,
47+
values=bootstraps,
48+
next_value=next_value)
49+
50+
51+
def maml_a2c_loss(train_episodes, learner, baseline, gamma, tau):
52+
# Update policy and baseline
53+
states = train_episodes.state()
54+
actions = train_episodes.action()
55+
rewards = train_episodes.reward()
56+
dones = train_episodes.done()
57+
next_states = train_episodes.next_state()
58+
log_probs = learner.log_prob(states, actions)
59+
60+
advantages = compute_advantages(baseline, tau, gamma, rewards,
61+
dones, states, next_states)
62+
advantages = ch.normalize(advantages).detach()
63+
return a2c.policy_loss(log_probs, advantages)
64+
65+
66+
def fast_adapt_a2c(clone, train_episodes, adapt_lr, baseline, gamma, tau, first_order=False):
67+
second_order = not first_order
68+
loss = maml_a2c_loss(train_episodes, clone, baseline, gamma, tau)
69+
gradients = autograd.grad(loss,
70+
clone.parameters(),
71+
retain_graph=second_order,
72+
create_graph=second_order)
73+
return l2l.algorithms.maml.maml_update(clone, adapt_lr, gradients)
74+
75+
76+
def meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr):
77+
mean_loss = 0.0
78+
mean_kl = 0.0
79+
for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies),
80+
total=len(iteration_replays),
81+
desc='Surrogate Loss',
82+
leave=False):
83+
train_replays = task_replays[:-1]
84+
valid_episodes = task_replays[-1]
85+
new_policy = l2l.clone_module(policy)
86+
87+
# Fast Adapt
88+
for train_episodes in train_replays:
89+
new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr,
90+
baseline, gamma, tau, first_order=False)
91+
92+
# Useful values
93+
states = valid_episodes.state()
94+
actions = valid_episodes.action()
95+
next_states = valid_episodes.next_state()
96+
rewards = valid_episodes.reward()
97+
dones = valid_episodes.done()
98+
99+
# Compute KL
100+
old_densities = old_policy.density(states)
101+
new_densities = new_policy.density(states)
102+
kl = kl_divergence(new_densities, old_densities).mean()
103+
mean_kl += kl
104+
105+
# Compute Surrogate Loss
106+
advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states)
107+
advantages = ch.normalize(advantages).detach()
108+
old_log_probs = old_densities.log_prob(actions).mean(dim=1, keepdim=True).detach()
109+
new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True)
110+
mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages)
111+
mean_kl /= len(iteration_replays)
112+
mean_loss /= len(iteration_replays)
113+
return mean_loss, mean_kl
114+
115+
116+
def make_env(benchmark, seed, num_workers, test=False):
117+
# Set a specific task or left empty to train on all available tasks
118+
task = 'push-v1' if benchmark == ML1 else False # In this case, False corresponds to the sample_all argument
119+
120+
def init_env():
121+
if test:
122+
env = benchmark.get_test_tasks(task)
123+
else:
124+
env = benchmark.get_train_tasks(task)
125+
126+
env = ch.envs.ActionSpaceScaler(env)
127+
return env
128+
129+
env = l2l.gym.AsyncVectorEnv([init_env for _ in range(num_workers)])
130+
env.seed(seed)
131+
env.set_task(env.sample_tasks(1)[0])
132+
env = ch.envs.Torch(env)
133+
return env
134+
135+
136+
def main(
137+
benchmark=ML10, # Choose between ML1, ML10, ML45
138+
adapt_lr=0.1,
139+
meta_lr=0.1,
140+
adapt_steps=1,
141+
num_iterations=1000,
142+
meta_bsz=20,
143+
adapt_bsz=10, # Number of episodes to sample per task
144+
tau=1.00,
145+
gamma=0.99,
146+
seed=42,
147+
num_workers=10, # Currently tasks are distributed evenly so adapt_bsz should be divisible by num_workers
148+
cuda=0):
149+
env = make_env(benchmark, seed, num_workers)
150+
151+
cuda = bool(cuda)
152+
random.seed(seed)
153+
np.random.seed(seed)
154+
torch.manual_seed(seed)
155+
if cuda:
156+
torch.cuda.manual_seed(seed)
157+
158+
policy = DiagNormalPolicy(env.state_size, env.action_size, activation='tanh')
159+
if cuda:
160+
policy.to('cuda')
161+
baseline = LinearValue(env.state_size, env.action_size)
162+
163+
for iteration in range(num_iterations):
164+
iteration_reward = 0.0
165+
iteration_replays = []
166+
iteration_policies = []
167+
168+
for task_config in tqdm(env.sample_tasks(meta_bsz), leave=False, desc='Data'): # Samples a new config
169+
clone = deepcopy(policy)
170+
env.set_task(task_config)
171+
env.reset()
172+
task = ch.envs.Runner(env)
173+
task_replay = []
174+
175+
# Fast Adapt
176+
for step in range(adapt_steps):
177+
train_episodes = task.run(clone, episodes=adapt_bsz)
178+
clone = fast_adapt_a2c(clone, train_episodes, adapt_lr, baseline, gamma, tau, first_order=True)
179+
task_replay.append(train_episodes)
180+
181+
# Compute Validation Loss
182+
valid_episodes = task.run(clone, episodes=adapt_bsz)
183+
task_replay.append(valid_episodes)
184+
185+
iteration_reward += valid_episodes.reward().sum().item() / adapt_bsz
186+
iteration_replays.append(task_replay)
187+
iteration_policies.append(clone)
188+
189+
# Print statistics
190+
print('\nIteration', iteration)
191+
validation_reward = iteration_reward / meta_bsz
192+
print('validation_reward', validation_reward)
193+
194+
# TRPO meta-optimization
195+
backtrack_factor = 0.5
196+
ls_max_steps = 15
197+
max_kl = 0.01
198+
if cuda:
199+
policy.to('cuda', non_blocking=True)
200+
baseline.to('cuda', non_blocking=True)
201+
iteration_replays = [[r.to('cuda', non_blocking=True) for r in task_replays] for task_replays in
202+
iteration_replays]
203+
204+
# Compute CG step direction
205+
old_loss, old_kl = meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma,
206+
adapt_lr)
207+
grad = autograd.grad(old_loss,
208+
policy.parameters(),
209+
retain_graph=True)
210+
grad = parameters_to_vector([g.detach() for g in grad])
211+
Fvp = trpo.hessian_vector_product(old_kl, policy.parameters())
212+
step = trpo.conjugate_gradient(Fvp, grad)
213+
shs = 0.5 * torch.dot(step, Fvp(step))
214+
lagrange_multiplier = torch.sqrt(shs / max_kl)
215+
step = step / lagrange_multiplier
216+
step_ = [torch.zeros_like(p.data) for p in policy.parameters()]
217+
vector_to_parameters(step, step_)
218+
step = step_
219+
del old_kl, Fvp, grad
220+
old_loss.detach_()
221+
222+
# Line-search
223+
for ls_step in range(ls_max_steps):
224+
stepsize = backtrack_factor ** ls_step * meta_lr
225+
clone = deepcopy(policy)
226+
for p, u in zip(clone.parameters(), step):
227+
p.data.add_(-stepsize, u.data)
228+
new_loss, kl = meta_surrogate_loss(iteration_replays, iteration_policies, clone, baseline, tau, gamma,
229+
adapt_lr)
230+
if new_loss < old_loss and kl < max_kl:
231+
for p, u in zip(policy.parameters(), step):
232+
p.data.add_(-stepsize, u.data)
233+
break
234+
235+
# Evaluate on a set of unseen tasks
236+
evaluate(benchmark, policy, baseline, adapt_lr, gamma, tau, num_workers, seed)
237+
238+
239+
def evaluate(benchmark, policy, baseline, adapt_lr, gamma, tau, n_workers, seed):
240+
# Parameters
241+
adapt_steps = 3
242+
adapt_bsz = 10
243+
n_eval_tasks = 10
244+
245+
tasks_reward = 0.
246+
247+
env = make_env(benchmark, seed, n_workers, test=True)
248+
eval_task_list = env.sample_tasks(n_eval_tasks)
249+
250+
for i, task in enumerate(eval_task_list):
251+
clone = deepcopy(policy)
252+
env.set_task(task)
253+
env.reset()
254+
task = ch.envs.Runner(env)
255+
256+
# Adapt
257+
for step in range(adapt_steps):
258+
adapt_episodes = task.run(clone, episodes=adapt_bsz)
259+
clone = fast_adapt_a2c(clone, adapt_episodes, adapt_lr, baseline, gamma, tau, first_order=True)
260+
261+
eval_episodes = task.run(clone, episodes=adapt_bsz)
262+
263+
task_reward = eval_episodes.reward().sum().item() / adapt_bsz
264+
print(f"Reward for task {i} : {task_reward}")
265+
tasks_reward += task_reward
266+
267+
final_eval_reward = tasks_reward / n_eval_tasks
268+
269+
print(f"Average reward over {n_eval_tasks} test tasks: {final_eval_reward}")
270+
271+
return final_eval_reward
272+
273+
274+
if __name__ == '__main__':
275+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/usr/bin/env python3
2+
3+
from .metaworld import MetaWorldML1, MetaWorldML10, MetaWorldML45

0 commit comments

Comments
 (0)