learn2learn

clone_module(module, memo=None)

[Source]

Description

Creates a copy of a module, whose parameters/buffers/submodules are created using PyTorch's torch.clone().

This implies that the computational graph is kept, and you can compute the derivatives of the new modules' parameters w.r.t the original parameters.

Arguments

  • module (Module) - Module to be cloned.

Return

  • (Module) - The cloned module.

Example

1
2
3
4
net = nn.Sequential(Linear(20, 10), nn.ReLU(), nn.Linear(10, 2))
clone = clone_module(net)
error = loss(clone(X), y)
error.backward()  # Gradients are back-propagate all the way to net.

detach_module(module, keep_requires_grad=False)

[Source]

Description

Detaches all parameters/buffers of a previously cloned module from its computational graph.

Note: detach works in-place, so it does not return a copy.

Arguments

  • module (Module) - Module to be detached.
  • keep_requires_grad (bool) - By default, all parameters of the detached module will have requires_grad set to False. If this flag is set to True, then the requires_grad field will be the same as the pre-detached module.

Example

1
2
3
4
5
net = nn.Sequential(nn.Linear(20, 10), nn.ReLU(), nn.Linear(10, 2))
clone = clone_module(net)
detach_module(clone, keep_requires_grad=True)
error = loss(clone(X), y)
error.backward()  # Gradients are back-propagate on clone, not net.

update_module(module, updates=None, memo=None)

[Source]

Description

Updates the parameters of a module in-place, in a way that preserves differentiability.

The parameters of the module are swapped with their update values, according to: where is the parameter, and is its corresponding update.

Arguments

  • module (Module) - The module to update.
  • updates (list, optional, default=None) - A list of gradients for each parameter of the model. If None, will use the tensors in .update attributes.

Example

1
2
3
4
5
6
7
8
error = loss(model(X), y)
grads = torch.autograd.grad(
    error,
    model.parameters(),
    create_graph=True,
)
updates = [-lr * g for g in grads]
l2l.update_module(model, updates=updates)

magic_box(x)

[Source]

Description

The magic box operator, which evaluates to 1 but whose gradient is :

where is the stop-gradient (or detach) operator.

This operator is useful when computing higher-order derivatives of stochastic graphs. For more informations, please refer to the DiCE paper. (Reference 1)

References

  1. Foerster et al. 2018. "DiCE: The Infinitely Differentiable Monte-Carlo Estimator." arXiv.

Arguments

  • x (Variable) - Variable to transform.

Return

  • (Variable) - Tensor of 1, but it's gradient is the gradient of x.

Example

1
2
loss = (magic_box(cum_log_probs) * advantages).mean()  # loss is the mean advantage
loss.backward()

clone_distribution(dist)

detach_distribution(dist)