Skip to content

Commit

Permalink
[REF] Replace full backward hook with tensor hook on module output (#63)
Browse files Browse the repository at this point in the history
* [ADD] Test to reproduce #56

* [REF] Replace module full backward hook with tensor hook on output

See
pytorch/pytorch#61519 (comment)
for details

* [FIX] Incorporate suggestions
  • Loading branch information
f-dangel authored Nov 8, 2023
1 parent 298ea86 commit 5857af8
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 11 deletions.
47 changes: 36 additions & 11 deletions singd/optim/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Implements structured inverse-free KFAC."""

from functools import partial
from math import sqrt
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, Union
from warnings import simplefilter, warn
Expand Down Expand Up @@ -553,25 +554,47 @@ def _update_preconditioner(self, module: Module):
self.Ks[module_name].add_(K @ new_m_K, alpha=-beta1_K)
self.Cs[module_name].add_(C @ new_m_C, alpha=-beta1_C)

def _accumulate_H_terms(
self, module: Module, grad_input: Tuple[Tensor], grad_output: Tuple[Tensor]
def _register_tensor_hook_on_output_to_accumulate_H_terms(
self, module: Module, inputs: Tuple[Tensor], output: Tensor
):
"""Register a tensor hook on the module's output that accumulates the H terms.
This function can be used as a `forward_hook`.
Only installs the hook for steps matching the specified update frequency.
Note:
The easier way to compute `H_K` and `H_C` would be via a full backward hook
on the module itself which performs the computation. However, this approach
breaks down if the output of a layer feeds into an activation with
`inplace=True` (see https://github.com/pytorch/pytorch/issues/61519). Hence
we use the workaround
https://github.com/pytorch/pytorch/issues/61519#issuecomment-883524237, and
install a module hook which installs a tensor hook on the module's output
tensor, which performs the accumulation of `H_K` and `H_C`.
Args:
module: Layer onto whose output a tensor hook to compute `H_K` and `H_C`
will be installed.
inputs: The layer's input tensors.
output: The layer's output tensor.
"""
T = self._get_param_group_entry(module, "T")
if self.steps % T == 0:
tensor_hook = partial(self._accumulate_H_terms, module)
output.register_hook(tensor_hook)

def _accumulate_H_terms(self, module: Module, grad_output: Tensor):
"""Accumulate the current mini-batch's contribution to `H_K, H_C` for a layer.
Updates the `H_K, H_C` buffers for the module.
Only updates for steps matched by the specified update frequency.
Requires that the layer inputs have been stored in `self.inputs`.
Args:
module: Layer whose pre-conditioner is updated.
grad_input: Gradients w.r.t. the input.
grad_output: Gradients w.r.t. the output.
grad_output: The gradient w.r.t. the output.
"""
T = self._get_param_group_entry(module, "T")
if self.steps % T != 0:
return

loss_average = self._get_param_group_entry(module, "loss_average")
kfac_approx = self._get_param_group_entry(module, "kfac_approx")
module_name = self.module_names[module]
Expand All @@ -582,7 +605,7 @@ def _accumulate_H_terms(
# For convolutions, unfold the input, for modules with bias terms, append a 1
a = process_input(a, module, kfac_approx)

g = grad_output[0].data
g = grad_output.data
# Process into matrix according to kfac_approx, add scaling from batch average
g = process_grad_output(g, module, loss_average, kfac_approx)

Expand Down Expand Up @@ -642,7 +665,9 @@ def _install_hooks(
handles.extend(
(
module.register_forward_pre_hook(self._save_input),
module.register_full_backward_hook(self._accumulate_H_terms),
module.register_forward_hook(
self._register_tensor_hook_on_output_to_accumulate_H_terms
),
)
)
return module_names, handles
Expand Down
110 changes: 110 additions & 0 deletions test/optim/test_inplace_activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""SINGD with a model that uses in-place activations."""

from copy import deepcopy
from test.utils import REDUCTION_IDS, REDUCTIONS, compare_optimizers

from pytest import mark, skip
from torch import manual_seed, rand
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from singd.optim.optimizer import SINGD


@mark.parametrize("inplace", [True, False], ids=["inplace=True", "inplace=False"])
def test_hooks_support_inplace_activations(inplace: bool):
"""Test that SINGD's hooks support in in-place activations.
See https://github.com/f-dangel/singd/issues/56.
Args:
inplace: Whether to use in-place activations.
"""
manual_seed(0)

X = rand(2, 1, 5, 5)
model = Sequential(Conv2d(1, 1, 3), ReLU(inplace=inplace))
SINGD(model) # install hooks

model(X)


@mark.parametrize("reduction", REDUCTIONS, ids=REDUCTION_IDS)
def test_compare_training_using_inplace_activations(reduction: str):
"""Compare training w/o in-place activations.
Args:
reduction: Reduction used for the loss function.
"""
if reduction == "sum":
skip("Need to fix https://github.com/f-dangel/singd/issues/43 first.")

manual_seed(0)
MAX_STEPS = 100
batch_size = 32

train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True
)

# _inplace indicates that the trained net has in-place activations

# NOTE All parameters of this net are supported by SINGD
model = Sequential(
Conv2d(1, 3, kernel_size=5, stride=2),
ReLU(),
Flatten(),
Linear(432, 50),
ReLU(),
Linear(50, 10),
)
model_inplace = deepcopy(model)
# activate in-place option
for mod in model_inplace.modules():
if hasattr(mod, "inplace"):
mod.inplace = True

loss_func = CrossEntropyLoss(reduction=reduction)
loss_func_inplace = deepcopy(loss_func)

loss_average = {"mean": "batch", "sum": None}[reduction]
optim_hyperparams = {
"lr": 5e-4,
"damping": 1e-4,
"momentum": 0.9,
"weight_decay": 1e-2,
"lr_cov": 1e-2,
"loss_average": loss_average,
"T": 1,
"alpha1": 0.5,
"structures": ("dense", "dense"),
}

optim = SINGD(model, **optim_hyperparams)
optim_inplace = SINGD(model_inplace, **optim_hyperparams)

model.train()
model_inplace.train()

# Loop over each batch from the training set
for batch_idx, (inputs, target) in enumerate(train_loader):
print(f"Step {optim.steps}")

# Zero gradient buffers
optim.zero_grad()
optim_inplace.zero_grad()

# Take a step
loss_func(model(inputs), target).backward()
optim.step()

loss_func_inplace(model_inplace(inputs), target).backward()
optim_inplace.step()

compare_optimizers(optim, optim_inplace, rtol=1e-5, atol=1e-7)

if batch_idx >= MAX_STEPS:
break

0 comments on commit 5857af8

Please sign in to comment.