learn2learn.data¶
Meta-Datasets¶
MetaDataset (Dataset)
¶
Description
Wraps a classification dataset to enable fast indexing of samples within classes.
This class exposes two attributes specific to the wrapped dataset:
labels_to_indices
: maps a class label to a list of sample indices with that label.indices_to_labels
: maps a sample index to its corresponding class label.
Those dictionary attributes are often used to quickly create few-shot classification tasks.
They can be passed as arguments upon instantiation, or automatically built on-the-fly.
If the wrapped dataset has an attribute _bookkeeping_path
, then the built attributes will be cached on disk and reloaded upon the next instantiation.
This caching strategy is useful for large datasets (e.g. ImageNet-1k) where the first instantiation can take several hours.
Note that if only one of labels_to_indices
or indices_to_labels
is provided, this class builds the other one from it.
Arguments
- dataset (Dataset) - A torch Dataset.
- labels_to_indices (dict, optional, default=None) - A dictionary mapping labels to the indices of their samples.
- indices_to_labels (dict, optional, default=None) - A dictionary mapping sample indices to their corresponding label.
Example
1 2 |
|
UnionMetaDataset (MetaDataset)
¶
Description
Takes multiple MetaDataests and constructs their union.
Note: The labels of all datasets are remapped to be in consecutive order. (i.e. the same label in two datasets will be to two different labels in the union)
Arguments
- datasets (list of Dataset) - A list of torch Datasets.
Example
1 2 3 4 5 6 7 8 |
|
FilteredMetaDataset (MetaDataset)
¶
Description
Takes in a MetaDataset and filters it to only include a subset of its labels.
Note: The labels of all datasets are not remapped. (i.e. the labels from the original dataset are retained)
Arguments
- datasets (Dataset) - A torch Datasets.
- labels (list of ints) - A list of labels to keep.
Example
1 2 3 4 |
|
TaskDataset (CythonTaskDataset)
¶
Description
Creates a set of tasks from a given Dataset.
In addition to the Dataset, TaskDataset accepts a list of task transformations (task_transforms
)
which define the kind of tasks sampled from the dataset.
The tasks are lazily sampled upon indexing (or calling the .sample()
method), and their
descriptions cached for later use.
If num_tasks
is -1, the TaskDataset will not cache task descriptions and instead continuously resample
new ones.
In this case, the length of the TaskDataset is set to 1.
For more information on tasks and task descriptions, please refer to the documentation of task transforms.
Arguments
- dataset (Dataset) - Dataset of data to compute tasks.
- task_transforms (list, optional, default=None) - List of task transformations.
- num_tasks (int, optional, default=-1) - Number of tasks to generate.
Example
1 2 3 4 5 6 7 8 9 |
|
learn2learn.data.transforms¶
Transforms to help automatically generate tasks.
LoadData (TaskTransform)
¶
Description
Loads a sample from the dataset given its index.
Arguments
- dataset (Dataset) - The dataset from which to load the sample.
NWays (CythonNWays)
¶
Description
Keeps samples from N random labels present in the task description.
Arguments
- dataset (Dataset) - The dataset from which to load the sample.
- n (int, optional, default=2) - Number of labels to sample from the task description's labels.
KShots (CythonKShots)
¶
Description
Keeps K samples for each present labels.
Arguments
- dataset (Dataset) - The dataset from which to load the sample.
- k (int, optional, default=1) - The number of samples per label.
- replacement (bool, optional, default=False) - Whether to sample with replacement.
FilterLabels (CythonFilterLabels)
¶
Description
Removes samples that do not belong to the given set of labels.
Arguments
- dataset (Dataset) - The dataset from which to load the sample.
- labels (list) - The list of labels to include.
FusedNWaysKShots (CythonFusedNWaysKShots)
¶
Description
Efficient implementation of FilterLabels, NWays, and KShots.
Arguments
- dataset (Dataset) - The dataset from which to load the sample.
- n (int, optional, default=2) - Number of labels to sample from the task description's labels.
- k (int, optional, default=1) - The number of samples per label.
- replacement (bool, optional, default=False) - Whether to sample shots with replacement.
- filter_labels (list, optional, default=None) - The list of labels to include. Defaults to all labels in the dataset.
RemapLabels (TaskTransform)
¶
Description
Given samples from K classes, maps the labels to 0, ..., K.
Arguments
- dataset (Dataset) - The dataset from which to load the sample.
ConsecutiveLabels (TaskTransform)
¶
Description
Re-orders the samples in the task description such that they are sorted in consecutive order.
Note: when used before RemapLabels
, the labels will be homogeneously clustered,
but in no specific order.
Arguments
- dataset (Dataset) - The dataset from which to load the sample.
learn2learn.data.utils¶
Help functions to work with data and tasks.
OnDeviceDataset (TensorDataset)
¶
Description
Converts an entire dataset into a TensorDataset, and optionally puts it on the desired device.
Useful to accelerate training with relatively small datasets. If the device is cpu and cuda is available, the TensorDataset will live in pinned memory.
Arguments
- dataset (Dataset) - Dataset to put on a device.
- device (torch.device, optional, default=None) - Device of dataset. Defaults to CPU.
- transform (transform, optional, default=None) - Transform to apply on the first variate of the dataset's samples X.
Example
1 2 3 4 5 6 7 8 |
|
InfiniteIterator
¶
Description
Infinitely loops over a given iterator.
Arguments
- dataloader (iterator) - Iterator to loop over.
Example
1 2 3 4 |
|
partition_task(data, labels, shots=1)
¶
Description
Partitions a classification task into support and query sets.
The support set will contain shots
samples per class, the query will take the remaining samples.
Assumes each class in labels
is associated with the same number of samples in data
.
Arguments
- data (Tensor) - Data to be partitioned into support and query.
- labels (Tensor) - Labels of each data sample, used for partitioning.
- shots (int, optional, default=1) - Number of data samples per class in the support set.
Example
1 2 |
|