Aka Continual Learning, Lifelong-Learning, Incremental Learning, etc.
from torch.utils.data import DataLoader
from continuum import ClassIncremental
from continuum.datasets import MNIST
clloader = ClassIncremental(
MNIST("my/data/path", download=True),
increment=1,
initial_increment=5,
train=True # a different loader for test
)
print(f"Number of classes: {clloader.nb_classes}.")
print(f"Number of tasks: {clloader.nb_tasks}.")
for task_id, train_dataset in enumerate(clloader):
train_dataset, val_dataset = split_train_val(train_dataset)
train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)
# Do your cool stuff here