learn2learn.data

A set of utilities for data & tasks loading, preprocessing, and sampling.

MetaDataset

1
MetaDataset(*args, **kwds)

Description

Wraps a classification dataset to enable fast indexing of samples within classes.

This class exposes two attributes specific to the wrapped dataset:

  • labels_to_indices: maps a class label to a list of sample indices with that label.
  • indices_to_labels: maps a sample index to its corresponding class label.

Those dictionary attributes are often used to quickly create few-shot classification tasks. They can be passed as arguments upon instantiation, or automatically built on-the-fly. If the wrapped dataset has an attribute _bookkeeping_path, then the built attributes will be cached on disk and reloaded upon the next instantiation. This caching strategy is useful for large datasets (e.g. ImageNet-1k) where the first instantiation can take several hours.

Note that if only one of labels_to_indices or indices_to_labels is provided, this class builds the other one from it.

Arguments

  • dataset (Dataset) - A torch Dataset.
  • labels_to_indices (dict, optional, default=None) - A dictionary mapping labels to the indices of their samples.
  • indices_to_labels (dict, optional, default=None) - A dictionary mapping sample indices to their corresponding label.

Example

1
2
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
mnist = l2l.data.MetaDataset(mnist)

UnionMetaDataset

1
UnionMetaDataset(*args, **kwds)

Description

Takes multiple MetaDataests and constructs their union.

Note: The labels of all datasets are remapped to be in consecutive order. (i.e. the same label in two datasets will be to two different labels in the union)

Arguments

  • datasets (list of Dataset) - A list of torch Datasets.

Example

1
2
3
4
5
6
7
8
train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
train = l2l.data.MetaDataset(train)
valid = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="validation")
valid = l2l.data.MetaDataset(valid)
test = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="test")
test = l2l.data.MetaDataset(test)
union = UnionMetaDataset([train, valid, test])
assert len(union.labels) == 100

FilteredMetaDataset

1
FilteredMetaDataset(*args, **kwds)

Description

Takes in a MetaDataset and filters it to only include a subset of its labels.

Note: The labels of all datasets are not remapped. (i.e. the labels from the original dataset are retained)

Arguments

  • datasets (Dataset) - A torch Datasets.
  • labels (list of ints) - A list of labels to keep.

Example

1
2
3
4
train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
train = l2l.data.MetaDataset(train)
filtered = FilteredMetaDataset(train, [4, 8, 2, 1, 9])
assert len(filtered.labels) == 5

TaskDataset

1
TaskDataset(dataset, task_transforms=None, num_tasks=-1, task_collate=None)

[Source]

Description

Creates a set of tasks from a given Dataset.

In addition to the Dataset, TaskDataset accepts a list of task transformations (task_transforms) which define the kind of tasks sampled from the dataset.

The tasks are lazily sampled upon indexing (or calling the .sample() method), and their descriptions cached for later use. If num_tasks is -1, the TaskDataset will not cache task descriptions and instead continuously resample new ones. In this case, the length of the TaskDataset is set to 1.

For more information on tasks and task descriptions, please refer to the documentation of task transforms.

Arguments

  • dataset (Dataset) - Dataset of data to compute tasks.
  • task_transforms (list, optional, default=None) - List of task transformations.
  • num_tasks (int, optional, default=-1) - Number of tasks to generate.

Example

1
2
3
4
5
6
7
8
9
dataset = l2l.data.MetaDataset(MyDataset())
transforms = [
    l2l.data.transforms.NWays(dataset, n=5),
    l2l.data.transforms.KShots(dataset, k=1),
    l2l.data.transforms.LoadData(dataset),
]
taskset = TaskDataset(dataset, transforms, num_tasks=20000)
for task in taskset:
    X, y = task

learn2learn.data.transforms

Description

Collection of general task transformations.

A task transformation is an object that implements the callable interface. (Either a function or an object that implements the __call__ special method.) Each transformation is called on a task description, which consists of a list of DataDescription with attributes index and transforms, where index corresponds to the index of single data sample inthe dataset, and transforms is a list of transformations that will be applied to the sample. Each transformation must return a new task description.

At first, the task description contains all samples from the dataset. A task transform takes this task description list and modifies it such that a particular task is created. For example, the NWays task transform filters data samples from the task description such that remaining ones belong to a random subset of all classes available. (The size of the subset is controlled via the class's n argument.) On the other hand, the LoadData task transform simply appends a call to load the actual data from the dataset to the list of transformations of each sample.

To create a task from a task description, the TaskDataset applies each sample's list of transforms in order. Then, all samples are collated via the TaskDataset's collate function.

LoadData

1
LoadData(dataset)

[Source]

Description

Loads a sample from the dataset given its index.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.

NWays

1
NWays(dataset, n=2)

[Source]

Description

Keeps samples from N random labels present in the task description.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • n (int, optional, default=2) - Number of labels to sample from the task description's labels.

KShots

1
KShots(dataset, k=1, replacement=False)

[Source]

Description

Keeps K samples for each present labels.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • k (int, optional, default=1) - The number of samples per label.
  • replacement (bool, optional, default=False) - Whether to sample with replacement.

FilterLabels

1
FilterLabels(dataset, labels)

[Source]

Description

Removes samples that do not belong to the given set of labels.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • labels (list) - The list of labels to include.

FusedNWaysKShots

1
FusedNWaysKShots(dataset, n=2, k=1, replacement=False, filter_labels=None)

[Source]

Description

Efficient implementation of FilterLabels, NWays, and KShots.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.
  • n (int, optional, default=2) - Number of labels to sample from the task description's labels.
  • k (int, optional, default=1) - The number of samples per label.
  • replacement (bool, optional, default=False) - Whether to sample shots with replacement.
  • filter_labels (list, optional, default=None) - The list of labels to include. Defaults to all labels in the dataset.

RemapLabels

1
RemapLabels(dataset, shuffle=True)

[Source]

Description

Given samples from K classes, maps the labels to 0, ..., K.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.

ConsecutiveLabels

1
ConsecutiveLabels(dataset)

[Source]

Description

Re-orders the samples in the task description such that they are sorted in consecutive order.

Note: when used before RemapLabels, the labels will be homogeneously clustered, but in no specific order.

Arguments

  • dataset (Dataset) - The dataset from which to load the sample.