learn2learn.nn

Additional torch.nn.Modules frequently used for meta-learning.

Lambda

1
Lambda(lmb)

[Source]

Description

Utility class to create a wrapper based on a lambda function.

Arguments

  • lmb (callable) - The function to call in the forward pass.

Example

1
2
3
4
mean23 = Lambda(lambda x: x.mean(dim=[2, 3]))  # mean23 is a Module
x = features(img)
x = mean23(x)
x = x.flatten()

Flatten

1
Flatten()

[Source]

Description

Utility Module to flatten inputs to (batch_size, -1) shape.

Example

1
2
3
4
flatten = Flatten()
x = torch.randn(5, 3, 32, 32)
x = flatten(x)
print(x.shape)  # (5, 3072)

Scale

1
Scale(shape, alpha=1.0)

[Source]

Description

A per-parameter scaling factor with learnable parameter.

Arguments

  • shape (int or tuple) - The shape of the scaling matrix.
  • alpha (float, optional, default=1.0) - Initial value for the scaling factor.

Example

1
2
3
x = torch.ones(3)
scale = Scale(x.shape, alpha=0.5)
print(scale(x))  # [.5, .5, .5]

PrototypicalClassifier

1
2
3
4
PrototypicalClassifier(support=None,
                       labels=None,
                       distance='euclidean',
                       normalize=False)

[Source]

Description

A module for the differentiable nearest neighbour classifier of Prototypical Networks.

Arguments

  • support (Tensor, optional, default=None) - Tensor of support features.
  • labels (Tensor, optional, default=None) - Labels corresponding to the support features.
  • distance (str, optional, default='euclidean') - Distance metric between samples. ['euclidean', 'cosine']
  • normalize (bool, optional, default=False) - Whether to normalize the inputs. Defaults to True when distance='cosine'.

References

  1. Snell et al. 2017. "Prototypical Networks for Few-shot Learning"

Example

1
2
3
4
5
classifier = PrototypicalClassifier()
support = features(support_data)
classifier.fit_(support, labels)
query = features(query_data)
preds = classifier(query)

SVClassifier

1
2
3
4
5
6
SVClassifier(support=None,
             labels=None,
             ways=None,
             normalize=False,
             C_reg=0.1,
             max_iters=15)

[Source]

Description

A module for the differentiable SVM classifier of MetaOptNet.

Arguments

  • support (Tensor, optional, default=None) - Tensor of support features.
  • labels (Tensor, optional, default=None) - Labels corresponding to the support features.
  • ways (str, optional, default=None) - Number of classes in the task.
  • normalize (bool, optional, default=False) - Whether to normalize the inputs.
  • C_reg (float, optional, default=0.1) - Regularization weight for SVM.
  • max_iters (int, optional, default=15) - Maximum number of iterations for SVM convergence.

References

  1. Lee et al. 2019. "Prototypical Networks for Few-shot Learning"

Example

1
2
3
4
5
classifier = SVMClassifier()
support = features(support_data)
classifier.fit_(support, labels)
query = features(query_data)
preds = classifier(query)

KroneckerLinear

1
KroneckerLinear(n, m, bias=True, psd=False, device=None)

[Source]

Description

A linear transformation whose parameters are expressed as a Kronecker product.

This Module maps an input vector to such that:

where and are the learnable Kronecker factors. This implementation can reduce the memory requirement for large linear mapping from to , but forces .

The matrix is initialized as the identity, and the bias as a zero vector.

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • psd (bool, optional, default=False) - Forces the matrix to be positive semi-definite if True.
  • device (device, optional, default=None) - The device on which to instantiate the Module.

References

  1. Jose et al. 2018. "Kronecker recurrent units".
  2. Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot".

Example

1
2
3
4
5
m, n = 2, 3
x = torch.randn(6)
kronecker = KroneckerLinear(n, m)
y = kronecker(x)
y.shape  # (6, )

KroneckerRNN

1
KroneckerRNN(n, m, bias=True, sigma=None)

[Source]

Description

Implements a recurrent neural network whose matrices are parameterized via their Kronecker factors. (See KroneckerLinear for details.)

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • sigma (callable, optional, default=None) - The activation function.

References

  1. Jose et al. 2018. "Kronecker recurrent units".

Example

1
2
3
4
5
6
m, n = 2, 3
x = torch.randn(6)
h = torch.randn(6)
kronecker = KroneckerRNN(n, m)
y, new_h = kronecker(x, h)
y.shape  # (6, )

KroneckerLSTM

1
KroneckerLSTM(n, m, bias=True, sigma=None)

[Source]

Description

Implements an LSTM using a factorization similar to the one of KroneckerLinear.

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • sigma (callable, optional, default=None) - The activation function.

References

  1. Jose et al. 2018. "Kronecker recurrent units".

Example

1
2
3
4
5
6
m, n = 2, 3
x = torch.randn(6)
h = torch.randn(6)
kronecker = KroneckerLSTM(n, m)
y, new_h = kronecker(x, h)
y.shape  # (6, )