learn2learn.nn

Classifiers

PrototypicalClassifier (Module)

[Source]

Description

A module for the differentiable nearest neighbour classifier of Prototypical Networks.

Arguments

  • support (Tensor, optional, default=None) - Tensor of support features.
  • labels (Tensor, optional, default=None) - Labels corresponding to the support features.
  • distance (str, optional, default='euclidean') - Distance metric between samples. ['euclidean', 'cosine']
  • normalize (bool, optional, default=False) - Whether to normalize the inputs. Defaults to True when distance='cosine'.

References

  1. Snell et al. 2017. "Prototypical Networks for Few-shot Learning"

Example

1
2
3
4
5
classifier = PrototypicalClassifier()
support = features(support_data)
classifier.fit_(support, labels)
query = features(query_data)
preds = classifier(query)

fit_(self, support, labels)

Description

Computes and updates the prototypes given support embeddings and corresponding labels.

SVClassifier (Module)

[Source]

Description

A module for the differentiable SVM classifier of MetaOptNet.

Arguments

  • support (Tensor, optional, default=None) - Tensor of support features.
  • labels (Tensor, optional, default=None) - Labels corresponding to the support features.
  • ways (str, optional, default=None) - Number of classes in the task.
  • normalize (bool, optional, default=False) - Whether to normalize the inputs.
  • C_reg (float, optional, default=0.1) - Regularization weight for SVM.
  • max_iters (int, optional, default=15) - Maximum number of iterations for SVM convergence.

References

  1. Lee et al. 2019. "Prototypical Networks for Few-shot Learning"

Example

1
2
3
4
5
classifier = SVMClassifier()
support = features(support_data)
classifier.fit_(support, labels)
query = features(query_data)
preds = classifier(query)

fit_(self, support, labels, ways=None, C_reg=None, max_iters=None)

MetaLayers

MetaModule (Module)

[Source]

Description

Takes a module and recursively replaces its submodules with others.

The substitution is passed based on a dictionary (substitutions) which maps module classes to substitution functions. For example, to append a second Linear module after all Linear submodules:

1
2
3
4
substitutions[torch.nn.Linear] = lambda linear: torch.nn.Sequential(
    linear,
    torch.nn.Linear(linear.out_features, linear.out_features),
)

Optionally, the original module parameters can be frozen (requires_grad = False) by setting freeze_module = True. This is helpful when only the substitution modules need to be updated.

Arguments

  • module (Module) - The model to wrap.
  • substitutions (dict) - Map of class -> construction substitutions.
  • freeze_module (bool, optional, default=True) - Whether to freeze the original module parameters.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import learn2learn.nn.metalayers as ml

single_layer = torch.nn.Sequential(
    torch.nn.Linear(768, 10),
    torch.nn.ReLU(),
)

double_layers = ml.MetaModule(
    module=single_layer,
    substitutions={
        torch.nn.Linear: lambda linear: torch.nn.Sequential(
            linear,
            torch.nn.Linear(linear.out_features, linear.out_features),
        )
    },
    freeze_module=True,
)
print(double_layers)

Output:

1
2
3
4
5
6
7
8
9
MetaModule(
  (wrapped_module): Sequential(
    (0): Sequential(
      (0): Linear(in_features=768, out_features=10, bias=True)
      (1): Linear(in_features=10, out_features=10, bias=True)
    )
    (1): ReLU()
  )
)

module(self)

Description

Returns the original module.

Example

(continued from above)

1
single_layer = double_layers.module()

ParameterTransform (Module)

[Source]

Description

Calls module after have transformed its parameters via transform.

After the forward pass, the parameters of module are reverted to their original values.

Useful to implement learnable (and constrained) updates of module weights (e.g., LoRA). Best used in conjunction with MetaModule.

Arguments

  • module (Module) - The model to wrap.
  • transform (callable) - Function to be called on all parameters of module before its forward pass. Possibly a module itself, which is learnable.

Example

Where we only learn to a scalar factor of the original weights.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import learn2learn.nn.metalayers as ml

model = torch.nn.Sequential(
    torch.nn.Linear(768, 512),
    torch.nn.ReLU(),
    torch.nn.Linear(512, 10),
)
meta_model = ml.MetaModule(
    module=model,
    substitutions={
        torch.nn.Linear: lambda linear: ml.ParameterTransform(
            module=linear,
            transform=lambda param: l2l.nn.Scale(),
        ),
    },
    freeze_module=True,
)

Kroneckers

KroneckerLinear (Module)

[Source]

Description

A linear transformation whose parameters are expressed as a Kronecker product.

This Module maps an input vector to such that:

where and are the learnable Kronecker factors. This implementation can reduce the memory requirement for large linear mapping from to , but forces .

The matrix is initialized as the identity, and the bias as a zero vector.

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • psd (bool, optional, default=False) - Forces the matrix to be positive semi-definite if True.
  • device (device, optional, default=None) - The device on which to instantiate the Module.

References

  1. Jose et al. 2018. "Kronecker recurrent units".
  2. Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot".

Example

1
2
3
4
5
m, n = 2, 3
x = torch.randn(6)
kronecker = KroneckerLinear(n, m)
y = kronecker(x)
y.shape  # (6, )

KroneckerRNN (Module)

[Source]

Description

Implements a recurrent neural network whose matrices are parameterized via their Kronecker factors. (See KroneckerLinear for details.)

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • sigma (callable, optional, default=None) - The activation function.

References

  1. Jose et al. 2018. "Kronecker recurrent units".

Example

1
2
3
4
5
6
m, n = 2, 3
x = torch.randn(6)
h = torch.randn(6)
kronecker = KroneckerRNN(n, m)
y, new_h = kronecker(x, h)
y.shape  # (6, )

KroneckerLSTM (Module)

[Source]

Description

Implements an LSTM using a factorization similar to the one of KroneckerLinear.

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • sigma (callable, optional, default=None) - The activation function.

References

  1. Jose et al. 2018. "Kronecker recurrent units".

Example

1
2
3
4
5
6
m, n = 2, 3
x = torch.randn(6)
h = torch.randn(6)
kronecker = KroneckerLSTM(n, m)
y, new_h = kronecker(x, h)
y.shape  # (6, )

Misc

Lambda (Module)

[Source]

Description

Utility class to create a wrapper based on a lambda function.

Arguments

  • lmb (callable) - The function to call in the forward pass.

Example

1
2
3
4
mean23 = Lambda(lambda x: x.mean(dim=[2, 3]))  # mean23 is a Module
x = features(img)
x = mean23(x)
x = x.flatten()

Scale (Module)

[Source]

Description

A per-parameter scaling factor with learnable parameter.

Arguments

  • shape (int or tuple, optional, default=1) - The shape of the scaling matrix.
  • alpha (float, optional, default=1.0) - Initial value for the scaling factor.

Example

1
2
3
x = torch.ones(3)
scale = Scale(x.shape, alpha=0.5)
print(scale(x))  # [.5, .5, .5]

Flatten (Module)

[Source]

Description

Utility Module to flatten inputs to (batch_size, -1) shape.

Example

1
2
3
4
flatten = Flatten()
x = torch.randn(5, 3, 32, 32)
x = flatten(x)
print(x.shape)  # (5, 3072)

freeze(module)

[Source]

Description

Prevents all parameters in module to get gradients.

Note: the module is modified in-place.

Arguments

  • module (Module) - The module to freeze.

Example

1
2
linear = torch.nn.Linear(128, 4)
l2l.nn.freeze(linear)

unfreeze(module)

[Source]

Description

Enables all parameters in module to compute gradients.

Note: the module is modified in-place.

Arguments

  • module (Module) - The module to unfreeze.

Example

1
2
3
linear = torch.nn.Linear(128, 4)
l2l.nn.freeze(linear)
l2l.nn.unfreeze(linear)