learn2learn.nn¶
Classifiers¶
PrototypicalClassifier (Module)
¶
Description
A module for the differentiable nearest neighbour classifier of Prototypical Networks.
Arguments
- support (Tensor, optional, default=None) - Tensor of support features.
- labels (Tensor, optional, default=None) - Labels corresponding to the support features.
- distance (str, optional, default='euclidean') - Distance metric between samples. ['euclidean', 'cosine']
- normalize (bool, optional, default=False) - Whether to normalize the inputs. Defaults to True when distance='cosine'.
References
- Snell et al. 2017. "Prototypical Networks for Few-shot Learning"
Example
1 2 3 4 5 |
|
fit_(self, support, labels)
¶
Description
Computes and updates the prototypes given support embeddings and corresponding labels.
SVClassifier (Module)
¶
Description
A module for the differentiable SVM classifier of MetaOptNet.
Arguments
- support (Tensor, optional, default=None) - Tensor of support features.
- labels (Tensor, optional, default=None) - Labels corresponding to the support features.
- ways (str, optional, default=None) - Number of classes in the task.
- normalize (bool, optional, default=False) - Whether to normalize the inputs.
- C_reg (float, optional, default=0.1) - Regularization weight for SVM.
- max_iters (int, optional, default=15) - Maximum number of iterations for SVM convergence.
References
- Lee et al. 2019. "Prototypical Networks for Few-shot Learning"
Example
1 2 3 4 5 |
|
fit_(self, support, labels, ways=None, C_reg=None, max_iters=None)
¶
Kroneckers¶
KroneckerLinear (Module)
¶
Description
A linear transformation whose parameters are expressed as a Kronecker product.
This Module maps an input vector to such that:
where and are the learnable Kronecker factors. This implementation can reduce the memory requirement for large linear mapping from to , but forces .
The matrix is initialized as the identity, and the bias as a zero vector.
Arguments
- n (int) - Dimensionality of the left Kronecker factor.
- m (int) - Dimensionality of the right Kronecker factor.
- bias (bool, optional, default=True) - Whether to include the bias term.
- psd (bool, optional, default=False) - Forces the matrix to be positive semi-definite if True.
- device (device, optional, default=None) - The device on which to instantiate the Module.
References
- Jose et al. 2018. "Kronecker recurrent units".
- Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot".
Example
1 2 3 4 5 |
|
KroneckerRNN (Module)
¶
Description
Implements a recurrent neural network whose matrices are parameterized via their Kronecker factors.
(See KroneckerLinear
for details.)
Arguments
- n (int) - Dimensionality of the left Kronecker factor.
- m (int) - Dimensionality of the right Kronecker factor.
- bias (bool, optional, default=True) - Whether to include the bias term.
- sigma (callable, optional, default=None) - The activation function.
References
- Jose et al. 2018. "Kronecker recurrent units".
Example
1 2 3 4 5 6 |
|
KroneckerLSTM (Module)
¶
Description
Implements an LSTM using a factorization similar to the one of
KroneckerLinear
.
Arguments
- n (int) - Dimensionality of the left Kronecker factor.
- m (int) - Dimensionality of the right Kronecker factor.
- bias (bool, optional, default=True) - Whether to include the bias term.
- sigma (callable, optional, default=None) - The activation function.
References
- Jose et al. 2018. "Kronecker recurrent units".
Example
1 2 3 4 5 6 |
|
Misc¶
Lambda (Module)
¶
Description
Utility class to create a wrapper based on a lambda function.
Arguments
- lmb (callable) - The function to call in the forward pass.
Example
1 2 3 4 |
|
Scale (Module)
¶
Description
A per-parameter scaling factor with learnable parameter.
Arguments
- shape (int or tuple, optional, default=1) - The shape of the scaling matrix.
- alpha (float, optional, default=1.0) - Initial value for the scaling factor.
Example
1 2 3 |
|