It wraps a torch dataset by creating a map of target to indices. This comes in handy when we want to sample elements randomly for a particular label.

Notes: For l2l to work its important that the dataset returns a (data, target) tuple. If your dataset doesn't return that, it should be trivial to wrap your dataset with another class to do that. TODO : Add example for wrapping a non standard l2l dataset



mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
mnist =


TaskGenerator(dataset, classes=None, ways=2, tasks=1, shots=1)



A wrapper to generate few-shot classification tasks.

tasks can both indicate predefined tasks, or just the number of tasks to sample. If specified as an int, a list of size task would be generated from which we'll sample. If specified as a list, then that list of tasks would be used to sample always.

The acceptable shape of list would be n * w, with n the number of tasks to sample and w the number of ways.

Each of the task should have w distinct elements all of which are required to be a subset of ways.



TaskGenerator.sample(shots=None, task=None)


Returns a dataset and the labels that we have sampled.

The dataset is of length shots * ways. The length of labels we have sampled is the same as shots.


shots (int, optional, default=None) - Number of data points to return per class, if None gets self.shots. task (list, optional, default=None) - List of labels you want to sample from.