learn2learn.nn

Additional torch.nn.Modules frequently used for meta-learning.

Lambda

1
Lambda(lmb)

[Source]

Description

Utility class to create a wrapper based on a lambda function.

Arguments

  • lmb (callable) - The function to call in the forward pass.

Example

1
2
3
4
mean23 = Lambda(lambda x: x.mean(dim=[2, 3]))  # mean23 is a Module
x = features(img)
x = mean23(x)
x = x.flatten()

Flatten

1
Flatten()

[Source]

Description

Utility Module to flatten inputs to (batch_size, -1) shape.

Example

1
2
3
4
flatten = Flatten()
x = torch.randn(5, 3, 32, 32)
x = flatten(x)
print(x.shape)  # (5, 3072)

Scale

1
Scale(shape, alpha=1.0)

[Source]

Description

A per-parameter scaling factor with learnable parameter.

Arguments

  • shape (int or tuple) - The shape of the scaling matrix.
  • alpha (float, optional, default=1.0) - Initial value for the scaling factor.

Example

1
2
3
x = torch.ones(3)
scale = Scale(x.shape, alpha=0.5)
print(scale(x))  # [.5, .5, .5]

KroneckerLinear

1
KroneckerLinear(n, m, bias=True, psd=False, device=None)

[Source]

Description

A linear transformation whose parameters are expressed as a Kronecker product.

This Module maps an input vector to such that:

where and are the learnable Kronecker factors. This implementation can reduce the memory requirement for large linear mapping from to , but forces .

The matrix is initialized as the identity, and the bias as a zero vector.

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • psd (bool, optional, default=False) - Forces the matrix to be positive semi-definite if True.
  • device (device, optional, default=None) - The device on which to instantiate the Module.

References

  1. Jose et al. 2018. "Kronecker recurrent units".
  2. Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot".

Example

1
2
3
4
5
m, n = 2, 3
x = torch.randn(6)
kronecker = KroneckerLinear(n, m)
y = kronecker(x)
y.shape  # (6, )

KroneckerRNN

1
KroneckerRNN(n, m, bias=True, sigma=None)

[Source]

Description

Implements a recurrent neural network whose matrices are parameterized via their Kronecker factors. (See KroneckerLinear for details.)

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • sigma (callable, optional, default=None) - The activation function.

References

  1. Jose et al. 2018. "Kronecker recurrent units".

Example

1
2
3
4
5
6
m, n = 2, 3
x = torch.randn(6)
h = torch.randn(6)
kronecker = KroneckerRNN(n, m)
y, new_h = kronecker(x, h)
y.shape  # (6, )

KroneckerLSTM

1
KroneckerLSTM(n, m, bias=True, sigma=None)

[Source]

Description

Implements an LSTM using a factorization similar to the one of KroneckerLinear.

Arguments

  • n (int) - Dimensionality of the left Kronecker factor.
  • m (int) - Dimensionality of the right Kronecker factor.
  • bias (bool, optional, default=True) - Whether to include the bias term.
  • sigma (callable, optional, default=None) - The activation function.

References

  1. Jose et al. 2018. "Kronecker recurrent units".

Example

1
2
3
4
5
6
m, n = 2, 3
x = torch.randn(6)
h = torch.randn(6)
kronecker = KroneckerLSTM(n, m)
y, new_h = kronecker(x, h)
y.shape  # (6, )