Skip to content

Commit 63ff92e

Browse files
authored
Add differentiable optimization module. (Meta-Descent, KFO, Meta-Curvature) (#151)
* Ported hypergrad example. * Add meta-curvature example with GBML wrapper. * GBML support for nograd, unused, first_order and tests. * Add ANIL+KFO low-level example. * Add misc nn layers. * Update maml_update. * Change download path for mini-imagenet tests. * Add docs for differentiable sgd. * Update docs, incl. for MetaWorld. * KroneckerTranform docs. * Docs for meta-curvature. * Add docs for l2l.nn.misc. * Add docs for kroneckers. * Fix lint, add more docs. * Add docs for GBML. * Completes GBML docs. * Rename meta_update -> update_module, and write docs. * Fix lint, add docs for ParameterUpdate. * Add docs for LearnableOptimizer * Update changelog * Update to readme, part 1 * Update README, part 2. * Fix readme links * Version bump.
1 parent 26bfee2 commit 63ff92e

36 files changed

+2079
-111
lines changed

CHANGELOG.md

+14-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Added
1212

13+
### Changed
14+
15+
### Fixed
16+
17+
18+
## v0.1.2
19+
20+
### Added
21+
1322
* New example: [Meta-World](https://github.com/rlworkgroup/metaworld) example with MAML-TRPO with it's own env wrapper. (@[Kostis-S-Z](https://github.com/Kostis-S-Z))
14-
* Add l2l.vision.benchmarks interface.
23+
* `l2l.vision.benchmarks` interface.
24+
* Differentiable optimization utilities in `l2l.optim`. (including `l2l.optim.LearnableOptimizer` for meta-descent)
25+
* General gradient-based meta-learning wrapper in `l2l.algorithms.GBML`.
26+
* Various `nn.Modules` in `l2l.nn`.
27+
* `l2l.update_module` as a more general alternative to `l2l.algorithms.maml_update`.
1528

1629
### Changed
1730

README.md

+117-56
Original file line numberDiff line numberDiff line change
@@ -4,81 +4,142 @@
44

55
[![Build Status](https://travis-ci.com/learnables/learn2learn.svg?branch=master)](https://travis-ci.com/learnables/learn2learn)
66

7-
learn2learn is a PyTorch library for meta-learning implementations.
7+
learn2learn is a software library for meta-learning research.
88

9-
The goal of meta-learning is to enable agents to *learn how to learn*.
10-
That is, we would like our agents to become better learners as they solve more and more tasks.
11-
For example, the animation below shows an agent that learns to run after a only one parameter update.
9+
learn2learn builds on top of PyTorch to accelerate two aspects of the meta-learning research cycle:
1210

13-
<p align="center"><img src="http://learn2learn.net/assets/img/halfcheetah.gif" height="250px" /></p>
11+
* *fast prototyping*, essential in letting researchers quickly try new ideas, and
12+
* *correct reproducibility*, ensuring that these ideas are evaluated fairly.
1413

15-
**Features**
14+
learn2learn provides low-level utilities and unified interface to create new algorithms and domains, together with high-quality implementations of existing algorithms and standardized benchmarks.
15+
It retains compatibility with [torchvision](https://pytorch.org/vision/), [torchaudio](https://pytorch.org/audio/), [torchtext](https://pytorch.org/text/), [cherry](http://cherry-rl.net/), and any other PyTorch-based library you might be using.
1616

17-
learn2learn provides high- and low-level utilities for meta-learning.
18-
The high-level utilities allow arbitrary users to take advantage of exisiting meta-learning algorithms.
19-
The low-level utilities enable researchers to develop new and better meta-learning algorithms.
17+
**Overview**
2018

21-
Some features of learn2learn include:
19+
* [`learn2learn.data`](http://learn2learn.net/docs/learn2learn.data/): `TaskDataset` and transforms to create few-shot tasks from any PyTorch dataset.
20+
* [`learn2learn.vision`](http://learn2learn.net/docs/learn2learn.vision/): Models, datasets, and benchmarks for computer vision and few-shot learning.
21+
* [`learn2learn.gym`](http://learn2learn.net/docs/learn2learn.gym/): Environment and utilities for meta-reinforcement learning.
22+
* [`learn2learn.algorithms`](http://learn2learn.net/docs/learn2learn.algorithms/): High-level wrappers for existing meta-learning algorithms.
23+
* [`learn2learn.optim`](http://learn2learn.net/docs/learn2learn.optim/): Utilities and algorithms for differentiable optimization and meta-descent.
2224

23-
* Modular API: implement your own training loops with our low-level utilities.
24-
* Provides various meta-learning algorithms (e.g. MAML, FOMAML, MetaSGD, ProtoNets, DiCE)
25-
* Task generator with unified API, compatible with torchvision, torchtext, torchaudio, and cherry.
26-
* Provides standardized meta-learning tasks for vision (Omniglot, mini-ImageNet), reinforcement learning (Particles, Mujoco), and even text (news classification).
27-
* 100% compatible with PyTorch -- use your own modules, datasets, or libraries!
25+
**Resources**
26+
27+
* Website: [http://learn2learn.net/](http://learn2learn.net/)
28+
* Documentation: [http://learn2learn.net/docs/](http://learn2learn.net/docs/)
29+
* Tutorials: [http://learn2learn.net/tutorials/getting_started/](http://learn2learn.net/tutorials/getting_started/)
30+
* Examples: [https://github.com/learnables/learn2learn/tree/master/examples](https://github.com/learnables/learn2learn/tree/master/examples)
31+
* GitHub: [https://github.com/learnables/learn2learn/](https://github.com/learnables/learn2learn/)
32+
* Slack: [http://slack.learn2learn.net/](http://slack.learn2learn.net/)
2833

2934
## Installation
3035

3136
~~~bash
3237
pip install learn2learn
3338
~~~
3439

35-
## API Demo
40+
## Snippets & Examples
41+
42+
The following snippets provide a sneak peek at the functionalities of learn2learn.
43+
44+
### High-level Wrappers
3645

37-
The following is an example of using the high-level MAML implementation on MNIST.
38-
For more algorithms and lower-level utilities, please refer to the [documentation](http://learn2learn.net/docs/learn2learn/) or the [examples](https://github.com/learnables/learn2learn/tree/master/examples).
46+
**Few-Shot Learning with MAML**
3947

48+
For more algorithms (ProtoNets, ANIL, Meta-SGD, Reptile, Meta-Curvature, KFO) refer to the [examples](https://github.com/learnables/learn2learn/tree/master/examples/vision) folder.
49+
Most of them can be implemented with with the `GBML` wrapper. ([documentation](http://learn2learn.net/docs/learn2learn.algorithms/#gbml)).
4050
~~~python
41-
import learn2learn as l2l
42-
43-
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
44-
45-
mnist = l2l.data.MetaDataset(mnist)
46-
train_tasks = l2l.data.TaskDataset(mnist,
47-
task_transforms=[
48-
NWays(mnist, n=3),
49-
KShots(mnist, k=1),
50-
LoadData(mnist),
51-
],
52-
num_tasks=10)
53-
model = Net()
54-
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
55-
opt = optim.Adam(maml.parameters(), lr=4e-3)
56-
57-
for iteration in range(num_iterations):
58-
learner = maml.clone() # Creates a clone of model
59-
for task in train_tasks:
60-
# Split task in adaptation_task and evalutation_task
61-
# Fast adapt
62-
for step in range(adaptation_steps):
63-
error = compute_loss(adaptation_task)
64-
learner.adapt(error)
65-
66-
# Compute evaluation loss
67-
evaluation_error = compute_loss(evaluation_task)
68-
69-
# Meta-update the model parameters
70-
opt.zero_grad()
71-
evaluation_error.backward()
72-
opt.step()
51+
maml = l2l.algorithms.MAML(model, lr=0.1)
52+
opt = torch.optim.SGD(maml.parameters(), lr=0.001)
53+
for iteration in range(10):
54+
opt.zero_grad()
55+
task_model = maml.clone() # torch.clone() for nn.Modules
56+
adaptation_loss = compute_loss(task_model)
57+
task_model.adapt(adaptation_loss) # computes gradient, update task_model in-place
58+
evaluation_loss = compute_loss(task_model)
59+
evaluation_loss.backward() # gradients w.r.t. maml.parameters()
60+
opt.step()
7361
~~~
7462

75-
## Changelog
63+
**Meta-Descent with Hypergradient**
7664

77-
A human-readable changelog is available in the [CHANGELOG.md](CHANGELOG.md) file.
65+
Learn any kind of optimization algorithm with the `LearnableOptimizer`. ([example](https://github.com/learnables/learn2learn/tree/master/examples/optimization) and [documentation](http://learn2learn.net/docs/learn2learn.optim/#learnableoptimizer))
66+
~~~python
67+
linear = nn.Linear(784, 10)
68+
transform = l2l.optim.ModuleTransform(l2l.nn.Scale)
69+
metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01) # metaopt has .step()
70+
opt = torch.optim.SGD(metaopt.parameters(), lr=0.001) # metaopt also has .parameters()
71+
72+
metaopt.zero_grad()
73+
opt.zero_grad()
74+
error = loss(linear(X), y)
75+
error.backward()
76+
opt.step() # update metaopt
77+
metaopt.step() # update linear
78+
~~~
79+
80+
### Learning Domains
81+
82+
**Custom Few-Shot Dataset**
83+
84+
Many standardized datasets (Omniglot, mini-/tiered-ImageNet, FC100, CIFAR-FS) are readily available in `learn2learn.vision.datasets`.
85+
([documentation](http://learn2learn.net/docs/learn2learn.vision/#learn2learnvisiondatasets))
86+
~~~python
87+
dataset = l2l.data.MetaDataset(MyDataset()) # any PyTorch dataset
88+
transforms = [ # Easy to define your own transform
89+
l2l.data.transforms.NWays(dataset, n=5),
90+
l2l.data.transforms.KShots(dataset, k=1),
91+
l2l.data.transforms.LoadData(dataset),
92+
]
93+
taskset = TaskDataset(dataset, transforms, num_tasks=20000)
94+
for task in taskset:
95+
X, y = task
96+
# Meta-train on the task
97+
~~~
7898

79-
## Documentation
99+
**Environments and Utilities for Meta-RL**
80100

81-
Documentation and tutorials are available on learn2learn’s website: [http://learn2learn.net](http://learn2learn.net).
101+
Parallelize your own meta-environments with `AsyncVectorEnv`, or use the standardized ones.
102+
([documentation](http://learn2learn.net/docs/learn2learn.gym/#metaenv))
103+
~~~python
104+
def make_env():
105+
env = l2l.gym.HalfCheetahForwardBackwardEnv()
106+
env = cherry.envs.ActionSpaceScaler(env)
107+
return env
108+
109+
env = l2l.gym.AsyncVectorEnv([make_env for _ in range(16)]) # uses 16 threads
110+
for task_config in env.sample_tasks(20):
111+
env.set_task(task) # all threads receive the same task
112+
state = env.reset() # use standard Gym API
113+
action = my_policy(env)
114+
env.step(action)
115+
~~~
116+
117+
### Low-Level Utilities
118+
119+
**Differentiable Optimization**
120+
121+
Learn and differentiate through updates of PyTorch Modules.
122+
([documentation](http://learn2learn.net/docs/learn2learn.optim/#parameterupdate))
123+
~~~python
124+
125+
model = MyModel()
126+
transform = l2l.optim.KroneckerTransform(l2l.nn.KroneckerLinear)
127+
learned_update = l2l.optim.ParameterUpdate( # learnable update function
128+
model.parameters(), transform)
129+
clone = l2l.clone_module(model) # torch.clone() for nn.Modules
130+
error = loss(clone(X), y)
131+
updates = learned_update( # similar API as torch.autograd.grad
132+
error,
133+
clone.parameters(),
134+
create_graph=True,
135+
)
136+
l2l.update_module(clone, updates=updates)
137+
loss(clone(X), y).backward() # Gradients w.r.t model.parameters() and learned_update.parameters()
138+
~~~
139+
140+
## Changelog
141+
142+
A human-readable changelog is available in the [CHANGELOG.md](CHANGELOG.md) file.
82143

83144
## Citation
84145

@@ -101,5 +162,5 @@ You can also use the following Bibtex entry.
101162
### Acknowledgements & Friends
102163

103164
1. The RL environments are adapted from Tristan Deleu's [implementations](https://github.com/tristandeleu/pytorch-maml-rl) and from the ProMP [repository](https://github.com/jonasrothfuss/ProMP/). Both shared with permission, under the MIT License.
104-
2. [TorchMeta](https://github.com/tristandeleu/pytorch-meta) is similar library, with a focus on supervised meta-learning. If learn2learn were missing a particular functionality, we would go check if TorchMeta has it. But we would also open an issue ;)
105-
3. [higher](https://github.com/facebookresearch/higher) is a PyTorch library that also enables differentiating through optimization inner-loops. Their approach is different from learn2learn in that they monkey-patch nn.Module to be stateless. For more information, refer to [their ArXiv paper](https://arxiv.org/abs/1910.01727).
165+
2. [TorchMeta](https://github.com/tristandeleu/pytorch-meta) is similar library, with a focus on datasets for supervised meta-learning.
166+
3. [higher](https://github.com/facebookresearch/higher) is a PyTorch library that enables differentiating through optimization inner-loops. While they monkey-patch `nn.Module` to be stateless, learn2learn retains the stateful PyTorch look-and-feel. For more information, refer to [their ArXiv paper](https://arxiv.org/abs/1910.01727).

docs/pydocmd.yml

+29-5
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ site_name: "learn2learn"
66
# documented. Higher indentation leads to smaller header size.
77
generate:
88
- docs/learn2learn.md:
9-
- learn2learn.utils:
9+
- learn2learn:
1010
- learn2learn.clone_module
1111
- learn2learn.detach_module
12+
- learn2learn.update_module
1213
- learn2learn.magic_box
1314
- docs/learn2learn.data.md:
1415
- learn2learn.data:
@@ -25,9 +26,8 @@ generate:
2526
- docs/learn2learn.algorithms.md:
2627
- learn2learn.algorithms:
2728
- learn2learn.algorithms.MAML++
28-
- learn2learn.algorithms.maml_update
2929
- learn2learn.algorithms.MetaSGD++
30-
- learn2learn.algorithms.meta_sgd_update
30+
- learn2learn.algorithms.GBML++
3131
- docs/learn2learn.gym.md:
3232
- learn2learn.gym++:
3333
- learn2learn.gym.MetaEnv
@@ -40,6 +40,27 @@ generate:
4040
- learn2learn.gym.envs.mujoco.HumanoidDirectionEnv
4141
- learn2learn.gym.envs.particles:
4242
- learn2learn.gym.envs.particles.Particles2DEnv
43+
- learn2learn.gym.envs.metaworld:
44+
- learn2learn.gym.envs.metaworld.MetaWorldML1++
45+
- learn2learn.gym.envs.metaworld.MetaWorldML10++
46+
- learn2learn.gym.envs.metaworld.MetaWorldML45++
47+
- docs/learn2learn.optim.md:
48+
- learn2learn.optim++:
49+
- learn2learn.optim.LearnableOptimizer++
50+
- learn2learn.optim.ParameterUpdate++
51+
- learn2learn.optim.DifferentiableSGD++
52+
- learn2learn.optim.transforms:
53+
- learn2learn.optim.transforms.ModuleTransform++
54+
- learn2learn.optim.transforms.KroneckerTransform++
55+
- learn2learn.optim.transforms.MetaCurvatureTransform++
56+
- docs/learn2learn.nn.md:
57+
- learn2learn.nn++:
58+
- learn2learn.nn.Lambda
59+
- learn2learn.nn.Flatten
60+
- learn2learn.nn.Scale
61+
- learn2learn.nn.KroneckerLinear
62+
- learn2learn.nn.KroneckerRNN
63+
- learn2learn.nn.KroneckerLSTM
4364
- docs/learn2learn.vision.md:
4465
- learn2learn.vision++:
4566
- learn2learn.vision.models:
@@ -73,13 +94,16 @@ pages:
7394
- Feature Reuse with ANIL: tutorials/anil_tutorial/ANIL_tutorial.md
7495
- Documentation:
7596
- learn2learn: docs/learn2learn.md
76-
- learn2learn.algorithms: docs/learn2learn.algorithms.md
7797
- learn2learn.data: docs/learn2learn.data.md
78-
- learn2learn.gym: docs/learn2learn.gym.md
98+
- learn2learn.algorithms: docs/learn2learn.algorithms.md
99+
- learn2learn.optim: docs/learn2learn.optim.md
100+
- learn2learn.nn: docs/learn2learn.nn.md
79101
- learn2learn.vision: docs/learn2learn.vision.md
102+
- learn2learn.gym: docs/learn2learn.gym.md
80103
- Examples:
81104
- Computer Vision: examples.vision.md << ../examples/vision/README.md
82105
- Reinforcement Learning: examples.rl.md << ../examples/rl/README.md
106+
- Optimization: examples.optim.md << ../examples/optimization/README.md
83107
- Changelog: changelog.md << ../CHANGELOG.md
84108
- GitHub: https://github.com/learnables/learn2learn/
85109

examples/optimization/README.md

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Meta-Optimization
2+
3+
This directory contains examples of using learn2learn for meta-optimization or meta-descent.
4+
5+
# Hypergradient
6+
7+
The script `hypergrad_mnist.py` demonstrates how to implement a slightly modified version of "[Online Learning Rate Adaptation with Hypergradient Descent](https://arxiv.org/abs/1703.04782)".
8+
The implementation departs from the algorithm presented in the paper in two ways.
9+
10+
1. We forgo the analytical formulation of the learning rate's gradient to demonstrate the capability of the `LearnableOptimizer` class.
11+
2. We adapt per-parameter learning rates instead of updating a single learning rate shared by all parameters.
12+
13+
**Usage**
14+
15+
!!! warning
16+
The parameters for this script were not carefully tuned.
17+
18+
Manually edit the script and run:
19+
20+
~~~shell
21+
python examples/optimization/hypergrad_mnist.py
22+
~~~

0 commit comments

Comments
 (0)