It wraps a torch dataset by creating a map of target to indices. This comes in handy when we want to sample elements randomly for a particular label.
For l2l to work its important that the dataset returns a (data, target) tuple.
If your dataset doesn't return that, it should be trivial to wrap your dataset
with another class to do that.
TODO : Add example for wrapping a non standard l2l dataset
- dataset (Dataset) - A torch dataset.
- labels_to_indices (Dict) - A dictionary mapping label to their indices. If not specified then we loop through all the datapoints to understand the mapping. (default: None)
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True) mnist = l2l.data.MetaDataset(mnist)
TaskGenerator(dataset, ways=2, shots=1, classes=None, tasks=None)
A wrapper to generate few-shot classification tasks.
tasks can both indicate predefined tasks, or just the number of tasks to sample.
If specified as an int, a list of size
task would be generated from which we'll sample.
If specified as a list, then that list of tasks would be used to sample always.
The acceptable shape of list would be
n * w, with n the number of tasks to sample and w the number of ways.
Each of the task should have w distinct elements all of which are required to be a subset of ways.
- ways (int, optional, default=2) - Number of labels to sample from.
- shots (int, optional, default=1) - Number of data points per task to sample.
- dataset (MetaDataset or Dataset) - The (meta-) dataset to wrap.
- classes (list, optional, default=None) - List of classes to sample from, if none then sample from all available classes in dataset. (default: None)
- tasks (int or list, optional, default=None) - Tasks to be generated. If none, then all possible permutations are chosen from n classes and w ways.
Returns a dataset and the labels that we have sampled.
The dataset is of length
shots * ways.
The length of labels we have sampled is the same as
shots (int, optional, default=None) - Number of data points to return per class, if None gets self.shots. task (list, optional, default=None) - List of labels you want to sample from.
- Dataset - Containing the sampled task.