Skip to content

Commit

Permalink
[ADD] Support modules with inplace activations in KFAC
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 8, 2023
1 parent 2eb654f commit 6106ad6
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 16 deletions.
46 changes: 32 additions & 14 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from __future__ import annotations

from functools import partial
from math import sqrt
from typing import Dict, Iterable, List, Set, Tuple, Union

Expand Down Expand Up @@ -237,8 +238,8 @@ def _compute_kfac(self):

# gradient covariance required for weights and biases
hook_handles.append(
module.register_full_backward_hook(
self._hook_accumulate_gradient_covariance
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_gradient_covariance
)
)

Expand Down Expand Up @@ -327,28 +328,45 @@ def draw_label(self, output: Tensor) -> Tensor:
else:
raise NotImplementedError

def _hook_accumulate_gradient_covariance(
self, module: Module, grad_input: Tuple[Tensor], grad_output: Tuple[Tensor]
def _register_tensor_hook_on_output_to_accumulate_gradient_covariance(
self, module: Module, inputs: Tuple[Tensor], output: Tensor
):
"""Backward hook that accumulates the output-gradient covariance of a layer.
"""Register tensor hook on layer's output to accumulate the grad. covariance.
Note:
The easier way to compute the gradient covariance would be via a full
backward hook on the module itself which performs the computation.
However, this approach breaks down if the output of a layer feeds into an
activation with `inplace=True` (see
https://github.com/pytorch/pytorch/issues/61519). Hence we use the
workaround
https://github.com/pytorch/pytorch/issues/61519#issuecomment-883524237, and
install a module hook which installs a tensor hook on the module's output
tensor, which performs the accumulation of the gradient covariance.
Args:
module: Layer onto whose output a tensor hook to accumulate the gradient
covariance will be installed.
inputs: The layer's input tensors.
output: The layer's output tensor.
"""
tensor_hook = partial(self._accumulate_gradient_covariance, module)
output.register_hook(tensor_hook)

def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
"""Accumulate the gradient covariance for a layer's output.
Updates ``self._gradient_covariances``.
Args:
module: The layer on which the hook is called.
grad_input: The gradient of the loss w.r.t. the layer's inputs.
grad_output: The gradient of the loss w.r.t. the layer's outputs.
module: The layer whose output's gradient covariance will be accumulated.
grad_output: The gradient w.r.t. the output.
Raises:
ValueError: If ``grad_output`` is not a 1-tuple.
NotImplementedError: If a layer uses weight sharing.
NotImplementedError: If the layer is not supported.
"""
if len(grad_output) != 1:
raise ValueError(
f"Expected grad_output to be a 1-tuple, got {len(grad_output)}."
)
g = grad_output[0].data.detach()
g = grad_output.data.detach()

if isinstance(module, Linear):
if g.ndim != 2:
Expand Down
52 changes: 50 additions & 2 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Contains tests for ``curvlinops.kfac``."""

from test.cases import DEVICES, DEVICES_IDS
from test.utils import regression_targets
from typing import Iterable, List, Tuple

from numpy import eye
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor, randperm
from torch.nn import Module, MSELoss, Parameter
from torch import Tensor, device, manual_seed, rand, randperm
from torch.nn import Linear, Module, MSELoss, Parameter, ReLU, Sequential

from curvlinops.examples.utils import report_nonclose
from curvlinops.ggn import GGNLinearOperator
Expand Down Expand Up @@ -103,3 +105,49 @@ def test_kfac_ef_one_datum(
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ef, kfac_mat)


@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_kfac_inplace_activations(dev: device):
"""Test that KFAC works if the network has in-place activations.
We use a test case with a single datum as KFAC becomes exact as the number of
MC samples increases.
Args:
dev: The device to run the test on.
"""
manual_seed(0)
model = Sequential(Linear(6, 3), ReLU(inplace=True), Linear(3, 2))
loss_func = MSELoss()
batch_size = 1
data = [(rand(batch_size, 6), regression_targets((batch_size, 2)))]
params = list(model.parameters())

# 1) compare KFAC and GGN
ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 2e-3}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)

# 2) Compare GGN (inplace=True) and GGN (inplace=False)
for mod in model.modules():
if hasattr(mod, "inplace"):
mod.inplace = False

ggn2_blocks = [] # list of per-parameter GGNs
for param in params:
ggn2 = GGNLinearOperator(model, loss_func, [param], data)
ggn2_blocks.append(ggn2 @ eye(ggn2.shape[1]))
ggn2 = block_diag(*ggn2_blocks)

report_nonclose(ggn, ggn2)

0 comments on commit 6106ad6

Please sign in to comment.