Skip to content

Commit 6ff649d

Browse files
authored
Fix clone_module and MAML for RNN modules. (#140)
* Fix clone_module and MAML for RNN modules. * Version bump.
1 parent a5c1ef2 commit 6ff649d

File tree

5 files changed

+75
-2
lines changed

5 files changed

+75
-2
lines changed

CHANGELOG.md

+11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Added
1212

13+
### Changed
14+
15+
### Fixed
16+
17+
18+
## v0.1.1
19+
20+
### Added
21+
1322
* New tutorial: 'Feature Reuse with ANIL'. (@ewinapun)
1423

1524
### Changed
@@ -18,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1827

1928
### Fixed
2029

30+
* `MAML()` and `clone_module` support for RNN modules.
31+
2132

2233
## v0.1.0.1
2334

learn2learn/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.0.1'
1+
__version__ = '0.1.1'

learn2learn/algorithms/maml.py

+5
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def maml_update(model, lr, grads=None):
6161
model._modules[module_key] = maml_update(model._modules[module_key],
6262
lr=lr,
6363
grads=None)
64+
65+
# Finally, rebuild the flattened parameters for RNNs
66+
# See this issue for more details:
67+
# https://github.com/learnables/learn2learn/issues/139
68+
model._apply(lambda x: x)
6469
return model
6570

6671

learn2learn/utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def clone_module(module):
8585
# TODO: This function might require that module.forward()
8686
# was called in order to work properly, if forward() instanciates
8787
# new variables.
88-
# TODO: deepcopy is expensive. We can probably get away with a shallowcopy.
88+
# TODO: We can probably get away with a shallowcopy.
8989
# However, since shallow copy does not recurse, we need to write a
9090
# recursive version of shallow copy.
9191
# NOTE: This can probably be implemented more cleanly with
@@ -119,6 +119,11 @@ def clone_module(module):
119119
if hasattr(clone, '_modules'):
120120
for module_key in clone._modules:
121121
clone._modules[module_key] = clone_module(module._modules[module_key])
122+
123+
# Finally, rebuild the flattened parameters for RNNs
124+
# See this issue for more details:
125+
# https://github.com/learnables/learn2learn/issues/139
126+
clone = clone._apply(lambda x: x)
122127
return clone
123128

124129

tests/unit/utils_test.py

+52
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77

88

99
def ref_clone_module(module):
10+
"""
11+
Note: This implementation does not work for RNNs.
12+
It requires calling learner.rnn._apply(lambda x: x) before
13+
each forward call.
14+
See this issue for more details:
15+
https://github.com/learnables/learn2learn/issues/139
16+
"""
1017
# First, create a copy of the module.
1118
clone = copy.deepcopy(module)
1219

@@ -120,6 +127,51 @@ def test_clone_module_models(self):
120127
for ref_p, l2l_p in zip(ref_model.parameters(), l2l_model.parameters()):
121128
self.assertTrue(torch.equal(ref_p, l2l_p))
122129

130+
def test_rnn_clone(self):
131+
# Tests: https://github.com/learnables/learn2learn/issues/139
132+
# The test is mainly about whether we can clone and adapt RNNs.
133+
# See issue for details.
134+
N_STEPS = 3
135+
for rnn_class in [
136+
torch.nn.RNN,
137+
torch.nn.LSTM,
138+
torch.nn.GRU,
139+
]:
140+
torch.manual_seed(1234)
141+
model = rnn_class(2, 1)
142+
maml = l2l.algorithms.MAML(model, lr=1e-3, allow_unused=False)
143+
optim = torch.optim.SGD(maml.parameters(), lr=0.001)
144+
data = torch.randn(30, 500, 2)
145+
146+
# Adapt and measure loss
147+
learner = maml.clone()
148+
for step in range(N_STEPS):
149+
pred, hidden = learner(data)
150+
loss = pred.norm(p=2)
151+
learner.adapt(loss)
152+
pred, _ = learner(data)
153+
first_loss = pred.norm(p=2)
154+
155+
# Take an optimization step
156+
optim.zero_grad()
157+
first_loss.backward()
158+
optim.step()
159+
first_loss = first_loss.item()
160+
161+
# Adapt a second time
162+
learner = maml.clone()
163+
for step in range(N_STEPS):
164+
pred, hidden = learner(data)
165+
loss = pred.norm(p=2)
166+
learner.adapt(loss)
167+
pred, _ = learner(data)
168+
second_loss = pred.norm(p=2)
169+
second_loss = second_loss.item()
170+
171+
# Ensure we did better
172+
self.assertTrue(first_loss > second_loss)
173+
174+
123175
def test_module_detach(self):
124176
original_output = self.model(self.input)
125177
original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]]))

0 commit comments

Comments
 (0)