Skip to content

Commit

Permalink
Style improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Oct 19, 2023
1 parent 6d1b7f8 commit 3740e30
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 43 deletions.
1 change: 0 additions & 1 deletion singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def _accumulate_H_terms(
grad_input: Gradients w.r.t. the input.
grad_output: Gradients w.r.t. the output.
"""
del grad_input
T = self._get_param_group_entry(module, "T")
if self.steps % T != 0:
return
Expand Down
65 changes: 36 additions & 29 deletions test/optim/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@
from test.optim.utils import jacobians_naive
from typing import List, Tuple

import torch
from einops import rearrange, reduce
from pytest import mark
from torch import Tensor, device
from torch import (
Tensor,
allclose,
block_diag,
cat,
device,
float64,
kron,
manual_seed,
randn,
)
from torch.cuda import is_available
from torch.nn import AdaptiveAvgPool2d, Conv2d, Flatten, Linear, Module, Sequential
from torch.utils.hooks import RemovableHandle

Expand All @@ -22,7 +32,7 @@
H_in = W_in = 16
K = 4
# Use double dtype to avoid numerical issues.
DTYPE = torch.float64
DTYPE = float64


@mark.parametrize("setting", ["expand", "reduce"])
Expand All @@ -45,12 +55,12 @@ def test_kfac_single_linear_module(
Raises:
AssertionError: If the KFAC approximation is not exact.
"""
if not torch.cuda.is_available() and device.type == "cuda":
if not is_available() and device.type == "cuda":
return
# Fix random seed.
torch.manual_seed(711)
manual_seed(711)
# Set up inputs x.
x = torch.randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device)
x = randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device)
n_loss_terms = N_SAMPLES * REP_DIM if setting == "expand" else N_SAMPLES

# Set up one-layer linear network for inputs with additional REP_DIM.
Expand All @@ -62,8 +72,7 @@ def test_kfac_single_linear_module(
# Jacobians.
Js, f = jacobians_naive(model, x, setting)
assert f.shape == (n_loss_terms, OUT_DIM)
assert Js.shape == (n_loss_terms, OUT_DIM, num_params)
Js = Js.flatten(end_dim=-2)
assert Js.shape == (n_loss_terms * OUT_DIM, num_params)

# Exact Fisher/GGN.
exact_F = Js.T @ Js # regression
Expand All @@ -78,8 +87,8 @@ def test_kfac_single_linear_module(
assert F.shape == (num_params, num_params)

# Compare true Fisher/GGN against K-FAC Fisher/GGN (should be exact).
assert torch.allclose(F.diag(), exact_F.diag()) # diagonal comparison
assert torch.allclose(F, exact_F) # full comparison
assert allclose(F.diag(), exact_F.diag()) # diagonal comparison
assert allclose(F, exact_F) # full comparison


@mark.parametrize("setting", ["expand", "reduce"])
Expand All @@ -103,12 +112,12 @@ def test_kfac_deep_linear(
AssertionError: If the KFAC approximation is not exact for the block
diagonal.
"""
if not torch.cuda.is_available() and device.type == "cuda":
if not is_available() and device.type == "cuda":
return
# Fix random seed.
torch.manual_seed(711)
manual_seed(711)
# Set up inputs x.
x = torch.randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device)
x = randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device)
n_loss_terms = N_SAMPLES * REP_DIM if setting == "expand" else N_SAMPLES

# Set up two-layer linear network for inputs with additional REP_DIM.
Expand All @@ -122,8 +131,7 @@ def test_kfac_deep_linear(
# Jacobians.
Js, f = jacobians_naive(model, x, setting)
assert f.shape == (n_loss_terms, OUT_DIM)
assert Js.shape == (n_loss_terms, OUT_DIM, num_params)
Js = Js.flatten(end_dim=-2)
assert Js.shape == (n_loss_terms * OUT_DIM, num_params)

# Exact Fisher/GGN.
exact_F = Js.T @ Js # regression
Expand All @@ -138,12 +146,12 @@ def test_kfac_deep_linear(
assert F.shape == (num_params, num_params)

# Compare true Fisher/GGN against K-FAC Fisher/GGN block diagonal (should be exact).
assert torch.allclose(F.diag(), exact_F.diag()) # diagonal comparison
assert torch.allclose(
assert allclose(F.diag(), exact_F.diag()) # diagonal comparison
assert allclose(
F[:num_params_layer1, :num_params_layer1],
exact_F[:num_params_layer1, :num_params_layer1],
) # full comparison layer 1.
assert torch.allclose(
assert allclose(
F[num_params_layer1:, num_params_layer1:],
exact_F[num_params_layer1:, num_params_layer1:],
) # full comparison layer 2.
Expand All @@ -170,12 +178,12 @@ def test_kfac_conv2d_module(
AssertionError: If the KFAC-reduce approximation is not exact for the
diagonal or the Conv2d layer or if it is exact for KFAC-expand.
"""
if not torch.cuda.is_available() and device.type == "cuda":
if not is_available() and device.type == "cuda":
return
# Fix random seed.
torch.manual_seed(711)
manual_seed(711)
# Set up inputs x.
x = torch.randn((N_SAMPLES, C_in, H_in, W_in), dtype=DTYPE, device=device)
x = randn((N_SAMPLES, C_in, H_in, W_in), dtype=DTYPE, device=device)
n_loss_terms = N_SAMPLES # Only reduce setting.

# Set up model with conv layer, average pooling, and linear output layer.
Expand All @@ -191,8 +199,7 @@ def test_kfac_conv2d_module(
# Jacobians.
Js, f = jacobians_naive(model, x, setting)
assert f.shape == (n_loss_terms, OUT_DIM)
assert Js.shape == (n_loss_terms, OUT_DIM, num_params)
Js = Js.flatten(end_dim=-2)
assert Js.shape == (n_loss_terms * OUT_DIM, num_params)

# Exact Fisher/GGN.
exact_F = Js.T @ Js # regression
Expand All @@ -209,16 +216,16 @@ def test_kfac_conv2d_module(
if setting == "reduce":
# KFAC-reduce should be exact for this setting.
# Compare true Fisher/GGN against K-FAC Fisher/GGN diagonal.
assert torch.allclose(F.diag(), exact_F.diag())
assert allclose(F.diag(), exact_F.diag())
# Compare true Fisher/GGN against K-FAC Fisher/GGN for the Conv2d layer.
assert torch.allclose(
assert allclose(
F[:num_conv_params, :num_conv_params],
exact_F[:num_conv_params, :num_conv_params],
)
else:
# KFAC-expand should not be exact for this setting.
# Compare true Fisher/GGN against K-FAC Fisher/GGN diagonal.
assert not torch.allclose(F.diag(), exact_F.diag())
assert not allclose(F.diag(), exact_F.diag())


class WeightShareModel(Sequential):
Expand Down Expand Up @@ -322,9 +329,9 @@ def get_kfac_blocks(self) -> List[Tensor]:
raise ValueError("forward_and_backward() has to be called first.")
# Get Kronecker factor ingredients stored as module attributes.
a: Tensor = module.kfac_a
g: Tensor = torch.cat(module.kfac_g)
g: Tensor = cat(module.kfac_g)
# Compute Kronecker product of both factors.
block = torch.kron(g.T @ g, a.T @ a)
block = kron(g.T @ g, a.T @ a)
# When a bias is used we have to reorder the rows and columns of the
# block to match the order of the parameters in the naive Jacobian
# implementation.
Expand Down Expand Up @@ -352,7 +359,7 @@ def get_full_kfac_matrix(self) -> Tensor:
The full block-diagonal KFAC approximation matrix.
"""
blocks = self.get_kfac_blocks()
return torch.block_diag(*blocks)
return block_diag(*blocks)

def _install_hooks(self) -> List[RemovableHandle]:
"""Installs forward and backward hooks to the model.
Expand Down
36 changes: 23 additions & 13 deletions test/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from typing import Tuple

import torch
from torch import Tensor
from torch import Tensor, cat, dtype, stack
from torch.autograd import grad
from torch.nn import Module

from singd.optim.optimizer import SINGD
Expand Down Expand Up @@ -41,16 +41,16 @@ def check_preconditioner_dtypes(optim: SINGD):
"""
for module, name in optim.module_names.items():
dtype_K, dtype_C = optim._get_param_group_entry(module, "preconditioner_dtype")
dtype_K = dtype_K if isinstance(dtype_K, torch.dtype) else module.weight.dtype
dtype_C = dtype_C if isinstance(dtype_C, torch.dtype) else module.weight.dtype
dtype_K = dtype_K if isinstance(dtype_K, dtype) else module.weight.dtype
dtype_C = dtype_C if isinstance(dtype_C, dtype) else module.weight.dtype

verify_dtype(optim.Ks[name], dtype_K)
verify_dtype(optim.m_Ks[name], dtype_K)
verify_dtype(optim.Cs[name], dtype_C)
verify_dtype(optim.m_Cs[name], dtype_C)


def verify_dtype(mat: StructuredMatrix, dtype: torch.dtype):
def verify_dtype(mat: StructuredMatrix, dtype: dtype):
"""Check whether a structured matrix's tensors are of the specified type.
Args:
Expand Down Expand Up @@ -82,19 +82,29 @@ def verify_dtype(mat: StructuredMatrix, dtype: torch.dtype):


def jacobians_naive(model: Module, data: Tensor, setting: str) -> Tuple[Tensor, Tensor]:
num_params = sum(p.numel() for p in model.parameters())
"""Compute the Jacobians of a model's output w.r.t. its parameters.
Args:
model: The model.
data: The input data.
setting: The setting to use for the forward pass of the model
(if appropriate). Possible values are `"expand"` and `"reduce"`.
Returns:
A tuple of the Jacobians of the model's output w.r.t. its parameters and
the model's output, with shapes `(n_loss_terms * out_dim, num_params)`
and `(n_loss_terms, ..., out_dim)` respectively.
"""
try:
f: Tensor = model(data, setting)
except TypeError:
f: Tensor = model(data)
# f: (batch_size/n_loss_terms, ..., out_dim)
out_dim = f.size(-1)
# f: (n_loss_terms, ..., out_dim)
last_f_dim = f.numel() - 1
jacs = []
for i, f_i in enumerate(f.flatten()):
rg = i != last_f_dim
jac = torch.autograd.grad(f_i, model.parameters(), retain_graph=rg)
jacs.append(torch.cat([j.flatten() for j in jac]))
# jacs: (n_loss_terms, out_dim, num_params)
jacs = torch.stack(jacs).view(-1, out_dim, num_params)
jac = grad(f_i, model.parameters(), retain_graph=i != last_f_dim)
jacs.append(cat([j.flatten() for j in jac]))
# jacs: (n_loss_terms * out_dim, num_params)
jacs = stack(jacs).flatten(end_dim=-2)
return jacs.detach(), f.detach()

0 comments on commit 3740e30

Please sign in to comment.