Small lightweight package for Continual Learning in PyTorch.
For now the package is hosted on TestPyPi. To install it you just need to run:
pip install continual-flame
To use the package you just need to import it inside your project.
import contflame as cf
At the moment the package contains just the dataset module.
This module contains datasets normally used in the continual learning scenario. The main ones are:
- SplitMNIST - MNIST dataset split in classes. It allows to create different subtasks by including custom subsets of classes.
- PermutedMNIST - permuted MNIST dataset. It allows to choose the shape of the applied permutation.
- SplitCIFAR100
- PermutedCIFAR100
SplitMNIST
In the following example the training tasks are five binary classification tasks on subsequent pairs of digit (i.e task 1 (0, 1), task 2 (2, 3), ...)
from cont_flame.dataset import SplitMNIST
valid = []
for i in range(1, 10, 2)
train_dataset = SplitMNIST(classes=[i, i+1], dset='train', valid=0.2)
valid.append(SplitMNIST(classes=[i, i+1], dset='valid', valid=0.2))
for e in epochs:
# train the model on train_dataset
# ...
for v in valid:
# test the model on the current and the previous tasks
# ...
PermutedMNIST
To get a random permutation set tile to (1, 1). The same random permutation, selected by the task id, will be applied to all the data points.
PermutedMNIST(tile=(1, 1), task=1)
PermutedMNIST(tile=(1, 28), task=1)
PermutedMNIST(tile=(8, 8), task=1)
To get the images without any permutation set the tile to (28, 28) (default value).