learn2learn.data

Meta-Datasets

MetaDataset (Dataset)

Description

Wraps a classification dataset to enable fast indexing of samples within classes.

This class exposes two attributes specific to the wrapped dataset:

  • labels_to_indices: maps a class label to a list of sample indices with that label.
  • indices_to_labels: maps a sample index to its corresponding class label.

Those dictionary attributes are often used to quickly create few-shot classification tasks. They can be passed as arguments upon instantiation, or automatically built on-the-fly. If the wrapped dataset has an attribute _bookkeeping_path, then the built attributes will be cached on disk and reloaded upon the next instantiation. This caching strategy is useful for large datasets (e.g. ImageNet-1k) where the first instantiation can take several hours.

Note that if only one of labels_to_indices or indices_to_labels is provided, this class builds the other one from it.

Arguments

  • dataset (Dataset) - A torch Dataset.
  • labels_to_indices (dict, optional, default=None) - A dictionary mapping labels to the indices of their samples.
  • indices_to_labels (dict, optional, default=None) - A dictionary mapping sample indices to their corresponding label.

Example

1
2
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
mnist = l2l.data.MetaDataset(mnist)

UnionMetaDataset (MetaDataset)

Description

Takes multiple MetaDataests and constructs their union.

Note: The labels of all datasets are remapped to be in consecutive order. (i.e. the same label in two datasets will be to two different labels in the union)

Arguments

  • datasets (list of Dataset) - A list of torch Datasets.

Example

1
2
3
4
5
6
7
8
train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
train = l2l.data.MetaDataset(train)
valid = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="validation")
valid = l2l.data.MetaDataset(valid)
test = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="test")
test = l2l.data.MetaDataset(test)
union = UnionMetaDataset([train, valid, test])
assert len(union.labels) == 100

FilteredMetaDataset (MetaDataset)

Description

Takes in a MetaDataset and filters it to only include a subset of its labels.

Note: The labels of all datasets are not remapped. (i.e. the labels from the original dataset are retained)

Arguments

  • datasets (Dataset) - A torch Datasets.
  • labels (list of ints) - A list of labels to keep.

Example

1
2
3
4
train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
train = l2l.data.MetaDataset(train)
filtered = FilteredMetaDataset(train, [4, 8, 2, 1, 9])
assert len(filtered.labels) == 5

TaskDataset (CythonTaskDataset)

[Source]

Description

Creates a set of tasks from a given Dataset.

In addition to the Dataset, TaskDataset accepts a list of task transformations (task_transforms) which define the kind of tasks sampled from the dataset.

The tasks are lazily sampled upon indexing (or calling the .sample() method), and their descriptions cached for later use. If num_tasks is -1, the TaskDataset will not cache task descriptions and instead continuously resample new ones. In this case, the length of the TaskDataset is set to 1.

For more information on tasks and task descriptions, please refer to the documentation of task transforms.

Arguments

  • dataset (Dataset) - Dataset of data to compute tasks.
  • task_transforms (list, optional, default=None) - List of task transformations.
  • num_tasks (int, optional, default=-1) - Number of tasks to generate.

Example

1
2
3
4
5
6
7
8
9
dataset = l2l.data.MetaDataset(MyDataset())
transforms = [
    l2l.data.transforms.NWays(dataset, n=5),
    l2l.data.transforms.KShots(dataset, k=1),
    l2l.data.transforms.LoadData(dataset),
]
taskset = TaskDataset(dataset, transforms, num_tasks=20000)
for task in taskset:
    X, y = task

learn2learn.data.transforms

Transforms to help automatically generate tasks.

LoadData (TaskTransform)

[Source]

Description

Loads a sample from the dataset given its index.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.

NWays (CythonNWays)

[Source]

Description

Keeps samples from N random labels present in the task description.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • n (int, optional, default=2) - Number of labels to sample from the task description's labels.

KShots (CythonKShots)

[Source]

Description

Keeps K samples for each present labels.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • k (int, optional, default=1) - The number of samples per label.
  • replacement (bool, optional, default=False) - Whether to sample with replacement.

FilterLabels (CythonFilterLabels)

[Source]

Description

Removes samples that do not belong to the given set of labels.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • labels (list) - The list of labels to include.

FusedNWaysKShots (CythonFusedNWaysKShots)

[Source]

Description

Efficient implementation of FilterLabels, NWays, and KShots.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • n (int, optional, default=2) - Number of labels to sample from the task description's labels.
  • k (int, optional, default=1) - The number of samples per label.
  • replacement (bool, optional, default=False) - Whether to sample shots with replacement.
  • filter_labels (list, optional, default=None) - The list of labels to include. Defaults to all labels in the dataset.

RemapLabels (TaskTransform)

[Source]

Description

Given samples from K classes, maps the labels to 0, ..., K.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.

ConsecutiveLabels (TaskTransform)

[Source]

Description

Re-orders the samples in the task description such that they are sorted in consecutive order.

Note: when used before RemapLabels, the labels will be homogeneously clustered, but in no specific order.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.

learn2learn.data.utils

Help functions to work with data and tasks.

OnDeviceDataset (TensorDataset)

[Source]

Description

Converts an entire dataset into a TensorDataset, and optionally puts it on the desired device.

Useful to accelerate training with relatively small datasets. If the device is cpu and cuda is available, the TensorDataset will live in pinned memory.

Arguments

  • dataset (Dataset) - Dataset to put on a device.
  • device (torch.device, optional, default=None) - Device of dataset. Defaults to CPU.
  • transform (transform, optional, default=None) - Transform to apply on the first variate of the dataset's samples X.

Example

1
2
3
4
5
6
7
8
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    lambda x: x.view(1, 28, 28),
])
mnist = MNIST('~/data')
mnist_ondevice = OnDeviceDataset(mnist, device='cuda', transform=transforms)
mnist_meta = MetaDataset(mnist_ondevice)

InfiniteIterator

[Source]

Description

Infinitely loops over a given iterator.

Arguments

  • dataloader (iterator) - Iterator to loop over.

Example

1
2
3
4
dataloader = DataLoader(dataset, shuffle=True, batch_size=32)
inf_dataloader = InfiniteIterator(dataloader)
for iteration in range(10000):  # guaranteed to reach 10,000 regardless of len(dataloader)
    X, y = next(inf_dataloader)

partition_task(data, labels, shots=1)

[Source]

Description

Partitions a classification task into support and query sets.

The support set will contain shots samples per class, the query will take the remaining samples.

Assumes each class in labels is associated with the same number of samples in data.

Arguments

  • data (Tensor) - Data to be partitioned into support and query.
  • labels (Tensor) - Labels of each data sample, used for partitioning.
  • shots (int, optional, default=1) - Number of data samples per class in the support set.

Example

1
2
X, y = taskset.sample()
(X_support, y_support), (X_query, y_query) = partition_task(X, y, shots=5)