learn2learn.optim

A set of utilities to write differentiable optimization algorithms.

LearnableOptimizer

1
LearnableOptimizer(model, transform, lr=1.0)

[Source]

Description

A PyTorch Optimizer with learnable transform, enabling the implementation of meta-descent / hyper-gradient algorithms.

This optimizer takes a Module and a gradient transform. At each step, the gradient of the module is passed through the transforms, and the module differentiably update -- i.e. when the next backward is called, gradients of both the module and the transform are computed. In turn, the transform can be updated via your favorite optmizer.

Arguments

  • model (Module) - Module to be updated.
  • transform (Module) - Transform used to compute updates of the model.
  • lr (float) - Learning rate.

References

  1. Sutton. 1992. “Gain Adaptation Beats Least Squares.”
  2. Schraudolph. 1999. “Local Gain Adaptation in Stochastic Gradient Descent.”
  3. Baydin et al. 2017. “Online Learning Rate Adaptation with Hypergradient Descent.”
  4. Majumder et al. 2019. “Learning the Learning Rate for Gradient Descent by Gradient Descent.”
  5. Jacobsen et al. 2019. “Meta-Descent for Online, Continual Prediction.”

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
linear = nn.Linear(784, 10)
transform = l2l.optim.ModuleTransform(torch.nn.Linear)
metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01)
opt = torch.optim.SGD(metaopt.parameters(), lr=0.001)

metaopt.zero_grad()
opt.zero_grad()
error = loss(linear(X), y)
error.backward()
opt.step()  # update metaopt
metaopt.step()  # update linear

zero_grad

1
LearnableOptimizer.zero_grad()

Only reset target parameters.

ParameterUpdate

1
ParameterUpdate(parameters, transform)

[Source]

Description

Convenience class to implement custom update functions.

Objects instantiated from this class behave similarly to torch.autograd.grad, but return parameter updates as opposed to gradients. Concretely, the gradients are first computed, then fed to their respective transform whose output is finally returned to the user.

Additionally, this class supports parameters that might not require updates by setting the allow_nograd flag to True. In this case, the returned update is None.

Arguments

  • parameters (list) - Parameters of the model to update.
  • transform (callable) - A callable that returns an instantiated transform given a parameter.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
model = torch.nn.Linear()
transform = l2l.optim.KroneckerTransform(l2l.nn.KroneckerLinear)
get_update = ParameterUpdate(model, transform)
opt = torch.optim.SGD(model.parameters() + get_update.parameters())

for iteration in range(10):
    opt.zero_grad()
    error = loss(model(X), y)
    updates = get_update(
        error,
        model.parameters(),
        create_graph=True,
    )
    l2l.update_module(model, updates)
    opt.step()

forward

1
2
3
4
5
6
ParameterUpdate.forward(loss,
                        parameters,
                        create_graph=False,
                        retain_graph=False,
                        allow_unused=False,
                        allow_nograd=False)

Description

Similar to torch.autograd.grad, but passes the gradients through the provided transform.

Arguments

  • loss (Tensor) - The loss to differentiate.
  • parameters (iterable) - Parameters w.r.t. which we want to compute the update.
  • create_graph (bool, optional, default=False) - Same as torch.autograd.grad.
  • retain_graph (bool, optional, default=False) - Same as torch.autograd.grad.
  • allow_unused (bool, optional, default=False) - Same as torch.autograd.grad.
  • allow_nograd (bool, optional, default=False) - Properly handles parameters that do not require gradients. (Their update will be None.)

DifferentiableSGD

1
DifferentiableSGD(lr)

[Source]

Description

A callable object that applies a list of updates to the parameters of a torch.nn.Module in a differentiable manner.

For each parameter and corresponding gradient , calling an instance of this class results in updating parameters:

where is the learning rate.

Note: The module is updated in-place.

Arguments

  • lr (float) - The learning rate used to update the model.

Example

1
2
3
4
5
6
sgd = DifferentiableSGD(0.1)
gradients = torch.autograd.grad(
    loss,
    model.parameters(),
    create_gaph=True)
sgd(model, gradients)  # model is updated in-place

forward

1
DifferentiableSGD.forward(module, gradients=None)

Arguments

  • module (Module) - The module to update.
  • gradients (list, optional, default=None) - A list of gradients for each parameter of the module. If None, will use the gradients in .grad attributes.

learn2learn.optim.transforms

Optimization transforms are special modules that take gradients as inputs and output model updates. Transforms are usually parameterized, and those parameters can be learned by gradient descent, allow you to learn optimization functions from data.

ModuleTransform

1
ModuleTransform(module_cls)

[Source]

Description

The ModuleTransform creates a an optimization transform based on any nn.Module.

ModuleTransform automatically instanciates a module from its class, based on a given parameter. The input and output shapes are of the module are set to (1, param.numel()).

When optimizing large layers, this type of transform can quickly run out of memory. See KroneckerTransform for a scalable alternative.

Arguments

  • module_cls (callable) - A callable that instantiates the module used to transform gradients.

Example

1
2
3
4
5
6
classifier = torch.nn.Linear(784, 10, bias=False)
linear_transform = ModuleTransform(torch.nn.Linear)
linear_update = linear_transform(classifier.weight)  # maps gradients to updates, both of shape (1, 7840)
loss(classifier(X), y).backward()
update = linear_update(classifier.weight.grad)
classifier.weight.data.add_(-lr, update)  # Not a differentiable update. See l2l.optim.DifferentiableSGD.

KroneckerTransform

1
KroneckerTransform(kronecker_cls, bias=False, psd=True)

[Source]

Description

The KroneckerTransform creates a an optimization transform based on nn.Module's that admit a Kronecker factorization. (see l2l.nn.Kronecker*)

Akin to the ModuleTransform, this class of transform instanciates a module from its class, based on a given parameter. But, instead of reshaping the gradients to shape (1, param.numel()), this class assumes a Kronecker factorization of the weights for memory and computational efficiency.

The specific dimension of the Kronecker factorization depends on the the parameter's shape. For a weight of shape (n, m), a KroneckerLinear transform consists of two weights with shapes (n, n) and (m, m) rather than a single weight of shape (nm, nm). Refer to Arnold et al., 2019 for more details.

Arguments

  • kronecker_cls (callable) - A callable that instantiates the Kronecker module used to transform gradients.

References

  1. Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot".

Example

1
2
3
4
5
6
classifier = torch.nn.Linear(784, 10, bias=False)
kronecker_transform = KroneckerTransform(l2l.nn.KroneckerLinear)
kronecker_update = kronecker_transform(classifier.weight)
loss(classifier(X), y).backward()
update = kronecker_update(classifier.weight.grad)
classifier.weight.data.add_(-lr, update)  # Not a differentiable update. See l2l.optim.DifferentiableSGD.

MetaCurvatureTransform

1
MetaCurvatureTransform(param, lr=1.0)

[Source]

Description

Implements the Meta-Curvature transform of Park and Oliva, 2019.

Unlike ModuleTranform and KroneckerTransform, this class does not wrap other Modules but is directly called on a weight to instantiate the transform.

Arguments

  • param (Tensor) - The weight whose gradients will be transformed.
  • lr (float, optional, default=1.0) - Scaling factor of the udpate. (non-learnable)

References

  1. Park & Oliva. 2019. Meta-curvature.

Example

1
2
3
4
5
classifier = torch.nn.Linear(784, 10, bias=False)
metacurvature_update = MetaCurvatureTransform(classifier.weight)
loss(classifier(X), y).backward()
update = metacurvature_update(classifier.weight.grad)
classifier.weight.data.add_(-lr, update)  # Not a differentiable update. See l2l.optim.DifferentiableSGD.