learn2learn.nn¶
Classifiers¶
PrototypicalClassifier (Module)
¶
Description
A module for the differentiable nearest neighbour classifier of Prototypical Networks.
Arguments
- support (Tensor, optional, default=None) - Tensor of support features.
- labels (Tensor, optional, default=None) - Labels corresponding to the support features.
- distance (str, optional, default='euclidean') - Distance metric between samples. ['euclidean', 'cosine']
- normalize (bool, optional, default=False) - Whether to normalize the inputs. Defaults to True when distance='cosine'.
References
- Snell et al. 2017. "Prototypical Networks for Few-shot Learning"
Example
1 2 3 4 5 |
|
fit_(self, support, labels)
¶
Description
Computes and updates the prototypes given support embeddings and corresponding labels.
SVClassifier (Module)
¶
Description
A module for the differentiable SVM classifier of MetaOptNet.
Arguments
- support (Tensor, optional, default=None) - Tensor of support features.
- labels (Tensor, optional, default=None) - Labels corresponding to the support features.
- ways (str, optional, default=None) - Number of classes in the task.
- normalize (bool, optional, default=False) - Whether to normalize the inputs.
- C_reg (float, optional, default=0.1) - Regularization weight for SVM.
- max_iters (int, optional, default=15) - Maximum number of iterations for SVM convergence.
References
- Lee et al. 2019. "Prototypical Networks for Few-shot Learning"
Example
1 2 3 4 5 |
|
fit_(self, support, labels, ways=None, C_reg=None, max_iters=None)
¶
MetaLayers¶
MetaModule (Module)
¶
Description¶
Takes a module and recursively replaces its submodules with others.
The substitution is passed based on a dictionary (substitutions
) which maps module classes to substitution functions.
For example, to append a second Linear module after all Linear submodules:
1 2 3 4 |
|
Optionally, the original module parameters can be frozen (requires_grad = False
) by setting freeze_module = True
.
This is helpful when only the substitution modules need to be updated.
Arguments
- module (Module) - The model to wrap.
- substitutions (dict) - Map of class -> construction substitutions.
- freeze_module (bool, optional, default=True) - Whether to freeze the original
module
parameters.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
|
Output:
1 2 3 4 5 6 7 8 9 |
|
module(self)
¶
Description
Returns the original module
.
Example
(continued from above)
1 |
|
ParameterTransform (Module)
¶
Description
Calls module
after have transformed its parameters via transform
.
After the forward pass, the parameters of module
are reverted to their original values.
Useful to implement learnable (and constrained) updates of module weights (e.g., LoRA).
Best used in conjunction with MetaModule
.
Arguments
- module (Module) - The model to wrap.
- transform (callable) - Function to be called on all parameters of
module
before its forward pass. Possibly a module itself, which is learnable.
Example
Where we only learn to a scalar factor of the original weights.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
|
Kroneckers¶
KroneckerLinear (Module)
¶
Description
A linear transformation whose parameters are expressed as a Kronecker product.
This Module maps an input vector to such that:
where and are the learnable Kronecker factors. This implementation can reduce the memory requirement for large linear mapping from to , but forces .
The matrix is initialized as the identity, and the bias as a zero vector.
Arguments
- n (int) - Dimensionality of the left Kronecker factor.
- m (int) - Dimensionality of the right Kronecker factor.
- bias (bool, optional, default=True) - Whether to include the bias term.
- psd (bool, optional, default=False) - Forces the matrix to be positive semi-definite if True.
- device (device, optional, default=None) - The device on which to instantiate the Module.
References
- Jose et al. 2018. "Kronecker recurrent units".
- Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot".
Example
1 2 3 4 5 |
|
KroneckerRNN (Module)
¶
Description
Implements a recurrent neural network whose matrices are parameterized via their Kronecker factors.
(See KroneckerLinear
for details.)
Arguments
- n (int) - Dimensionality of the left Kronecker factor.
- m (int) - Dimensionality of the right Kronecker factor.
- bias (bool, optional, default=True) - Whether to include the bias term.
- sigma (callable, optional, default=None) - The activation function.
References
- Jose et al. 2018. "Kronecker recurrent units".
Example
1 2 3 4 5 6 |
|
KroneckerLSTM (Module)
¶
Description
Implements an LSTM using a factorization similar to the one of
KroneckerLinear
.
Arguments
- n (int) - Dimensionality of the left Kronecker factor.
- m (int) - Dimensionality of the right Kronecker factor.
- bias (bool, optional, default=True) - Whether to include the bias term.
- sigma (callable, optional, default=None) - The activation function.
References
- Jose et al. 2018. "Kronecker recurrent units".
Example
1 2 3 4 5 6 |
|
Misc¶
Lambda (Module)
¶
Description
Utility class to create a wrapper based on a lambda function.
Arguments
- lmb (callable) - The function to call in the forward pass.
Example
1 2 3 4 |
|
Scale (Module)
¶
Description
A per-parameter scaling factor with learnable parameter.
Arguments
- shape (int or tuple, optional, default=1) - The shape of the scaling matrix.
- alpha (float, optional, default=1.0) - Initial value for the scaling factor.
Example
1 2 3 |
|
Flatten (Module)
¶
Description
Utility Module to flatten inputs to (batch_size, -1)
shape.
Example
1 2 3 4 |
|
freeze(module)
¶
Description
Prevents all parameters in module
to get gradients.
Note: the module is modified in-place.
Arguments
- module (Module) - The module to freeze.
Example
1 2 |
|