Skip to content

Commit 632317f

Browse files
authored
Add TasksetSampler (compatible with non-classification tasks) (#255)
* Add l2l.utils.warn_once * Replace TaskDataset by Taskset everywhere. * Add TasksetSampler implementation. * Update docs and changelog. * Bump to 0.2.0.
1 parent 7f77b4b commit 632317f

28 files changed

+251
-112
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
* Add `l2l.nn.MetaModule` and `l2l.nn.ParameterTransform` for parameter-efficient finetuning.
1616
* Add `l2l.nn.freeze`and `l2l.nn.unfreeze`.
1717
* Add Adapters and LoRA examples.
18+
* Add TasksetSampler, compatible with PyTorch's Dataloaders.
1819

1920
### Changed
2021

2122
* Documentation: uses `mkdocstrings` instead of `pydoc-markdown`.
2223
* Remove `text/news_topic_classification.py` example.
24+
* Rename TaskDataset to Taskset.
2325

2426
### Fixed
2527

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ To learn more, see our whitepaper: [arXiv:2008.12284](https://arxiv.org/abs/2008
1919

2020
**Overview**
2121

22-
* [`learn2learn.data`](http://learn2learn.net/docs/learn2learn.data/): `TaskDataset` and transforms to create few-shot tasks from any PyTorch dataset.
22+
* [`learn2learn.data`](http://learn2learn.net/docs/learn2learn.data/): `Taskset` and transforms to create few-shot tasks from any PyTorch dataset.
2323
* [`learn2learn.vision`](http://learn2learn.net/docs/learn2learn.vision/): Models, datasets, and benchmarks for computer vision and few-shot learning.
2424
* [`learn2learn.gym`](http://learn2learn.net/docs/learn2learn.gym/): Environment and utilities for meta-reinforcement learning.
2525
* [`learn2learn.algorithms`](http://learn2learn.net/docs/learn2learn.algorithms/): High-level wrappers for existing meta-learning algorithms.
@@ -101,7 +101,7 @@ transforms = [ # Easy to define your own transform
101101
l2l.data.transforms.KShots(dataset, k=1),
102102
l2l.data.transforms.LoadData(dataset),
103103
]
104-
taskset = TaskDataset(dataset, transforms, num_tasks=20000)
104+
taskset = Taskset(dataset, transforms, num_tasks=20000)
105105
for task in taskset:
106106
X, y = task
107107
# Meta-train on the task

docs/docs/learn2learn.data.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
- __init__
2222
- __getitem__
2323

24-
::: learn2learn.data.TaskDataset
24+
::: learn2learn.data.Taskset
2525
selection:
2626
members:
2727
- __init__

docs/tutorials/anil_tutorial/ANIL_tutorial.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ train_transforms = [
9393
RemapLabels(train_dataset),
9494
ConsecutiveLabels(train_dataset),
9595
]
96-
train_tasks = l2l.data.TaskDataset(train_dataset,
96+
train_tasks = l2l.data.Taskset(train_dataset,
9797
task_transforms=train_transforms,
9898
num_tasks=20000)
9999
~~~
100100

101-
`l2l.data.TaskDataset` creates a set of tasks from the MetaDataset using a list of task transformations:
101+
`l2l.data.Taskset` creates a set of tasks from the MetaDataset using a list of task transformations:
102102

103103
* `FusedNWaysKShots(dataset, n=ways, k=2*shots)`: efficient implementation to keep \(k\) data samples from \(n\) randomly sampled labels.
104104

docs/tutorials/task_transform_tutorial/transform_tutorial.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ In this tutorial, we will explore in depth one of the core utilities [learn2lear
1111
* We will first discuss the motivation behind generating tasks. *(Those familiar with meta-learning can skip this section.)*
1212
* Next, we will have a high-level overview of the overall pipeline used for generating tasks using `learn2learn`.
1313
* `MetaDataset` is used fast indexing, and accelerates the process of generating few-shot learning tasks. `UnionMetaDataset` and `FilteredMetaDataset` are extensions of `MetaDataset` that can further provide customised utility. `UnionMetaDataset` builds up on `MetaDataset` to construct a union of multiple input datasets, and `FilteredMetaDataset` takes in a `MetaDataset` and filters it to include only the required labels.
14-
* `TaskDataset` is the core module that generates tasks from input dataset. Tasks are lazily sampled upon indexing or calling `.sample()` method.
14+
* `Taskset` is the core module that generates tasks from input dataset. Tasks are lazily sampled upon indexing or calling `.sample()` method.
1515
* Lastly, we study different `task transforms` defined in `learn2learn` that modifies the input data such that a customised `task` is generated.
1616

1717
## Motivation for generating tasks
@@ -109,7 +109,7 @@ transforms = [
109109
]
110110

111111
# 5. Generate set of tasks using the dataset, and transforms
112-
taskset = l2l.data.TaskDataset(dataset=omniglot, task_transforms=transforms, num_tasks=10) # Creates sets of tasks from the dataset
112+
taskset = l2l.data.Taskset(dataset=omniglot, task_transforms=transforms, num_tasks=10) # Creates sets of tasks from the dataset
113113

114114
# Now sample a task from the taskset
115115
X, y = taskset.sample()
@@ -270,15 +270,15 @@ print('Original Labels:', len(toy_omniglot.labels))
270270
print('Filtered Labels:', len(filtered.labels))
271271
~~~
272272

273-
## TaskDataset - Core module
273+
## Taskset - Core module
274274

275275
#### Introduction
276276

277277
This is one of the core module of `learn2learn` that is used to generate a task from a given input dataset. It takes `dataset`, and list of `task transformations` as arguments. The task transformation basically define the kind of tasks that will be generated from the dataset. (For example, `KShots` transform limits the number of samples per class in a task to `K` samples per class.)
278278

279279
> If there are no task transforms, then the task consists of all the samples in the entire dataset.
280280
281-
Another argument that `TaskDataset` takes as input is `num_tasks` *(an integer value)*. The value is set depending on how many tasks the user wants to generate. By default, it is kept as `-1`, meaning infinite number of tasks will be generated, and a new task is generated on sampling. In the former case, the descriptions of the task will be cached in a dictionary such that if a given task is called again, the description can be loaded instantly rather than generating it once again.
281+
Another argument that `Taskset` takes as input is `num_tasks` *(an integer value)*. The value is set depending on how many tasks the user wants to generate. By default, it is kept as `-1`, meaning infinite number of tasks will be generated, and a new task is generated on sampling. In the former case, the descriptions of the task will be cached in a dictionary such that if a given task is called again, the description can be loaded instantly rather than generating it once again.
282282

283283
#### What is a task description?
284284

@@ -627,7 +627,7 @@ toy_transforms = [
627627
ConsecutiveLabels(omniglot), # Re-orders samples s.t. they are sorted in consecutive order
628628
RandomClassRotation(omniglot, [0, 90, 180, 270]) # Randomly rotate sample over x degrees (only for vision tasks)
629629
]
630-
toy_taskset = l2l.data.TaskDataset(omniglot, toy_transforms, num_tasks=20000)
630+
toy_taskset = l2l.data.Taskset(omniglot, toy_transforms, num_tasks=20000)
631631
try:
632632
print(len(toy_taskset.sample())) # Expected error as RemapLabels is used before LoadData
633633
except TypeError:
@@ -639,9 +639,9 @@ except TypeError:
639639
Traceback (most recent call last):
640640
File "<ipython-input-27-4c0558e6745b>", line 13, in <module>
641641
print(len(toy_taskset.sample())) # Expected error as RemapLabels is used before LoadData
642-
File "learn2learn/data/task_dataset.pyx", line 158, in learn2learn.data.task_dataset.CythonTaskDataset.sample
643-
File "learn2learn/data/task_dataset.pyx", line 173, in learn2learn.data.task_dataset.CythonTaskDataset.__getitem__
644-
File "learn2learn/data/task_dataset.pyx", line 142, in learn2learn.data.task_dataset.CythonTaskDataset.get_task
642+
File "learn2learn/data/task_dataset.pyx", line 158, in learn2learn.data.task_dataset.CythonTaskset.sample
643+
File "learn2learn/data/task_dataset.pyx", line 173, in learn2learn.data.task_dataset.CythonTaskset.__getitem__
644+
File "learn2learn/data/task_dataset.pyx", line 142, in learn2learn.data.task_dataset.CythonTaskset.get_task
645645
File "learn2learn/data/transforms.pyx", line 201, in learn2learn.data.transforms.RemapLabels.remap
646646
TypeError: 'int' object is not iterable
647647
~~~

examples/vision/anil_fc100.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -96,29 +96,35 @@ def main(
9696
RemapLabels(train_dataset),
9797
ConsecutiveLabels(train_dataset),
9898
]
99-
train_tasks = l2l.data.TaskDataset(train_dataset,
100-
task_transforms=train_transforms,
101-
num_tasks=20000)
99+
train_tasks = l2l.data.Taskset(
100+
train_dataset,
101+
task_transforms=train_transforms,
102+
num_tasks=20000,
103+
)
102104

103105
valid_transforms = [
104106
FusedNWaysKShots(valid_dataset, n=ways, k=2*shots),
105107
LoadData(valid_dataset),
106108
ConsecutiveLabels(valid_dataset),
107109
RemapLabels(valid_dataset),
108110
]
109-
valid_tasks = l2l.data.TaskDataset(valid_dataset,
110-
task_transforms=valid_transforms,
111-
num_tasks=600)
111+
valid_tasks = l2l.data.Taskset(
112+
valid_dataset,
113+
task_transforms=valid_transforms,
114+
num_tasks=600,
115+
)
112116

113117
test_transforms = [
114118
FusedNWaysKShots(test_dataset, n=ways, k=2*shots),
115119
LoadData(test_dataset),
116120
RemapLabels(test_dataset),
117121
ConsecutiveLabels(test_dataset),
118122
]
119-
test_tasks = l2l.data.TaskDataset(test_dataset,
120-
task_transforms=test_transforms,
121-
num_tasks=600)
123+
test_tasks = l2l.data.Taskset(
124+
test_dataset,
125+
task_transforms=test_transforms,
126+
num_tasks=600,
127+
)
122128

123129

124130
# Create model

examples/vision/meta_mnist.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,17 @@ def main(lr=0.005, maml_lr=0.01, iterations=1000, ways=5, shots=1, tps=32, fas=5
5353
download=True,
5454
transform=transformations))
5555

56-
train_tasks = l2l.data.TaskDataset(mnist_train,
57-
task_transforms=[
58-
l2l.data.transforms.NWays(mnist_train, ways),
59-
l2l.data.transforms.KShots(mnist_train, 2*shots),
60-
l2l.data.transforms.LoadData(mnist_train),
61-
l2l.data.transforms.RemapLabels(mnist_train),
62-
l2l.data.transforms.ConsecutiveLabels(mnist_train),
63-
],
64-
num_tasks=1000)
56+
train_tasks = l2l.data.Taskset(
57+
mnist_train,
58+
task_transforms=[
59+
l2l.data.transforms.NWays(mnist_train, ways),
60+
l2l.data.transforms.KShots(mnist_train, 2*shots),
61+
l2l.data.transforms.LoadData(mnist_train),
62+
l2l.data.transforms.RemapLabels(mnist_train),
63+
l2l.data.transforms.ConsecutiveLabels(mnist_train),
64+
],
65+
num_tasks=1000,
66+
)
6567

6668
model = Net(ways)
6769
model.to(device)

examples/vision/protonet_miniimagenet.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
113113
LoadData(train_dataset),
114114
RemapLabels(train_dataset),
115115
]
116-
train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms)
116+
train_tasks = l2l.data.Taskset(train_dataset, task_transforms=train_transforms)
117117
train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)
118118

119119
valid_dataset = l2l.data.MetaDataset(valid_dataset)
@@ -123,9 +123,11 @@ def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
123123
LoadData(valid_dataset),
124124
RemapLabels(valid_dataset),
125125
]
126-
valid_tasks = l2l.data.TaskDataset(valid_dataset,
127-
task_transforms=valid_transforms,
128-
num_tasks=200)
126+
valid_tasks = l2l.data.Taskset(
127+
valid_dataset,
128+
task_transforms=valid_transforms,
129+
num_tasks=200,
130+
)
129131
valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True)
130132

131133
test_dataset = l2l.data.MetaDataset(test_dataset)
@@ -135,9 +137,11 @@ def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
135137
LoadData(test_dataset),
136138
RemapLabels(test_dataset),
137139
]
138-
test_tasks = l2l.data.TaskDataset(test_dataset,
139-
task_transforms=test_transforms,
140-
num_tasks=2000)
140+
test_tasks = l2l.data.Taskset(
141+
test_dataset,
142+
task_transforms=test_transforms,
143+
num_tasks=2000,
144+
)
141145
test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)
142146

143147
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

learn2learn/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.7'
1+
__version__ = '0.2.0'

learn2learn/data/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66

77
from . import transforms
88
from .meta_dataset import MetaDataset, UnionMetaDataset, FilteredMetaDataset
9-
from .task_dataset import TaskDataset, DataDescription
9+
from .task_dataset import TaskDataset, Taskset, DataDescription
1010
from .utils import OnDeviceDataset, partition_task, InfiniteIterator

learn2learn/data/meta_dataset.pyx

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from collections import defaultdict
1111
import numpy as np
1212
import torch
1313
from torch.utils.data import Dataset
14+
import learn2learn as l2l
1415

1516

1617
class MetaDataset(Dataset):

learn2learn/data/samplers.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python
2+
3+
import random
4+
import torch
5+
import learn2learn as l2l
6+
7+
8+
class TasksetSampler(torch.utils.data.Sampler):
9+
10+
def __init__(self, taskset, shuffle=True):
11+
self.taskset = taskset
12+
self.shuffle = shuffle
13+
14+
def description2indices(self, task_description):
15+
return [dd.index for dd in task_description]
16+
17+
def __iter__(self):
18+
if self.taskset.num_tasks == -1: # loop infinitely
19+
while True:
20+
yield self.description2indices(
21+
self.taskset.sample_task_description()
22+
)
23+
else: # loop over the range of tasks
24+
task_indices = list(range(self.taskset.num_tasks))
25+
if self.shuffle:
26+
random.shuffle(task_indices)
27+
for i in task_indices:
28+
if i not in self.taskset.sampled_descriptions:
29+
self.taskset.sampled_descriptions[i] = self.taskset.sample_task_description()
30+
yield self.description2indices(
31+
self.taskset.sampled_descriptions[i]
32+
)
33+
34+
35+
if __name__ == "__main__":
36+
NUM_TASKS = 10
37+
NUM_DATA = 128
38+
X_SHAPE = 16
39+
Y_SHAPE = 10
40+
EPSILON = 1e-6
41+
SUBSET_SIZE = 5
42+
WORKERS = 4
43+
META_BSZ = 16
44+
data = torch.randn(NUM_DATA, X_SHAPE)
45+
labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
46+
dataset = torch.utils.data.TensorDataset(data, labels)
47+
dataset = l2l.data.MetaDataset(dataset)
48+
taskset = l2l.data.Taskset(
49+
dataset,
50+
task_transforms=[
51+
l2l.data.transforms.FusedNWaysKShots(dataset, n=2, k=1),
52+
l2l.data.transforms.LoadData(dataset),
53+
l2l.data.transforms.RemapLabels(dataset),
54+
l2l.data.transforms.ConsecutiveLabels(dataset),
55+
],
56+
num_tasks=NUM_TASKS,
57+
)
58+
59+
sampler = TasksetSampler(taskset)
60+
dataloader = torch.utils.data.DataLoader(
61+
dataset=dataset,
62+
batch_sampler=sampler,
63+
)
64+
for task in dataloader:
65+
print(task)
66+
67+
__import__('pdb').set_trace()

learn2learn/data/task_dataset.pyx

+12-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ cdef class DataDescription:
4747
self.transforms = []
4848

4949

50-
class TaskDataset(CythonTaskDataset):
50+
class Taskset(CythonTaskDataset):
5151

5252
"""
5353
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/data/task_dataset.py)
@@ -89,14 +89,24 @@ class TaskDataset(CythonTaskDataset):
8989
"""
9090

9191
def __init__(self, dataset, task_transforms=None, num_tasks=-1, task_collate=None):
92-
super(TaskDataset, self).__init__(
92+
super(Taskset, self).__init__(
9393
dataset=dataset,
9494
task_transforms=task_transforms,
9595
num_tasks=num_tasks,
9696
task_collate=task_collate,
9797
)
9898

9999

100+
class TaskDataset(Taskset):
101+
102+
def __init__(self, *args, **kwargs):
103+
super(TaskDataset, self).__init__(*args, **kwargs)
104+
l2l.utils.warn_once(
105+
message='TaskDataset is deprecated, use Taskset instead.',
106+
severity='deprecation',
107+
)
108+
109+
100110
cdef class CythonTaskDataset:
101111

102112
cdef public:

learn2learn/utils/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import argparse
66
import dataclasses
7+
import warnings
78

89

910
def magic_box(x):
@@ -373,3 +374,27 @@ def __getattr__(self, *args, **kwargs):
373374

374375
def __call__(self, *args, **kwargs):
375376
self.raise_import()
377+
378+
379+
class _SingleWarning(object):
380+
381+
def __init__(self):
382+
self.warned_messages = []
383+
self.warning_categories = {
384+
'default': UserWarning,
385+
'deprecation': DeprecationWarning,
386+
}
387+
388+
def __call__(self, message, severity=None):
389+
if message not in self.warned_messages:
390+
if severity is None:
391+
severity = 'default'
392+
if severity == 'error':
393+
raise RuntimeError(message)
394+
elif isinstance(severity, str):
395+
severity = self.warning_categories[severity]
396+
warnings.warn(message, category=severity)
397+
self.warned_messages.append(message)
398+
399+
400+
warn_once = _SingleWarning()

0 commit comments

Comments
 (0)