From 3740e30bc41f2a031b1accc6fa4f438ffa445281 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 19 Oct 2023 17:49:17 +0200 Subject: [PATCH] Style improvements --- singd/optim/optimizer.py | 1 - test/optim/test_kfac.py | 65 ++++++++++++++++++++++------------------ test/optim/utils.py | 36 ++++++++++++++-------- 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/singd/optim/optimizer.py b/singd/optim/optimizer.py index 7b3c280..f8498ce 100644 --- a/singd/optim/optimizer.py +++ b/singd/optim/optimizer.py @@ -504,7 +504,6 @@ def _accumulate_H_terms( grad_input: Gradients w.r.t. the input. grad_output: Gradients w.r.t. the output. """ - del grad_input T = self._get_param_group_entry(module, "T") if self.steps % T != 0: return diff --git a/test/optim/test_kfac.py b/test/optim/test_kfac.py index 4ba7dde..9e08a15 100644 --- a/test/optim/test_kfac.py +++ b/test/optim/test_kfac.py @@ -3,10 +3,20 @@ from test.optim.utils import jacobians_naive from typing import List, Tuple -import torch from einops import rearrange, reduce from pytest import mark -from torch import Tensor, device +from torch import ( + Tensor, + allclose, + block_diag, + cat, + device, + float64, + kron, + manual_seed, + randn, +) +from torch.cuda import is_available from torch.nn import AdaptiveAvgPool2d, Conv2d, Flatten, Linear, Module, Sequential from torch.utils.hooks import RemovableHandle @@ -22,7 +32,7 @@ H_in = W_in = 16 K = 4 # Use double dtype to avoid numerical issues. -DTYPE = torch.float64 +DTYPE = float64 @mark.parametrize("setting", ["expand", "reduce"]) @@ -45,12 +55,12 @@ def test_kfac_single_linear_module( Raises: AssertionError: If the KFAC approximation is not exact. """ - if not torch.cuda.is_available() and device.type == "cuda": + if not is_available() and device.type == "cuda": return # Fix random seed. - torch.manual_seed(711) + manual_seed(711) # Set up inputs x. - x = torch.randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device) + x = randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device) n_loss_terms = N_SAMPLES * REP_DIM if setting == "expand" else N_SAMPLES # Set up one-layer linear network for inputs with additional REP_DIM. @@ -62,8 +72,7 @@ def test_kfac_single_linear_module( # Jacobians. Js, f = jacobians_naive(model, x, setting) assert f.shape == (n_loss_terms, OUT_DIM) - assert Js.shape == (n_loss_terms, OUT_DIM, num_params) - Js = Js.flatten(end_dim=-2) + assert Js.shape == (n_loss_terms * OUT_DIM, num_params) # Exact Fisher/GGN. exact_F = Js.T @ Js # regression @@ -78,8 +87,8 @@ def test_kfac_single_linear_module( assert F.shape == (num_params, num_params) # Compare true Fisher/GGN against K-FAC Fisher/GGN (should be exact). - assert torch.allclose(F.diag(), exact_F.diag()) # diagonal comparison - assert torch.allclose(F, exact_F) # full comparison + assert allclose(F.diag(), exact_F.diag()) # diagonal comparison + assert allclose(F, exact_F) # full comparison @mark.parametrize("setting", ["expand", "reduce"]) @@ -103,12 +112,12 @@ def test_kfac_deep_linear( AssertionError: If the KFAC approximation is not exact for the block diagonal. """ - if not torch.cuda.is_available() and device.type == "cuda": + if not is_available() and device.type == "cuda": return # Fix random seed. - torch.manual_seed(711) + manual_seed(711) # Set up inputs x. - x = torch.randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device) + x = randn((N_SAMPLES, REP_DIM, IN_DIM), dtype=DTYPE, device=device) n_loss_terms = N_SAMPLES * REP_DIM if setting == "expand" else N_SAMPLES # Set up two-layer linear network for inputs with additional REP_DIM. @@ -122,8 +131,7 @@ def test_kfac_deep_linear( # Jacobians. Js, f = jacobians_naive(model, x, setting) assert f.shape == (n_loss_terms, OUT_DIM) - assert Js.shape == (n_loss_terms, OUT_DIM, num_params) - Js = Js.flatten(end_dim=-2) + assert Js.shape == (n_loss_terms * OUT_DIM, num_params) # Exact Fisher/GGN. exact_F = Js.T @ Js # regression @@ -138,12 +146,12 @@ def test_kfac_deep_linear( assert F.shape == (num_params, num_params) # Compare true Fisher/GGN against K-FAC Fisher/GGN block diagonal (should be exact). - assert torch.allclose(F.diag(), exact_F.diag()) # diagonal comparison - assert torch.allclose( + assert allclose(F.diag(), exact_F.diag()) # diagonal comparison + assert allclose( F[:num_params_layer1, :num_params_layer1], exact_F[:num_params_layer1, :num_params_layer1], ) # full comparison layer 1. - assert torch.allclose( + assert allclose( F[num_params_layer1:, num_params_layer1:], exact_F[num_params_layer1:, num_params_layer1:], ) # full comparison layer 2. @@ -170,12 +178,12 @@ def test_kfac_conv2d_module( AssertionError: If the KFAC-reduce approximation is not exact for the diagonal or the Conv2d layer or if it is exact for KFAC-expand. """ - if not torch.cuda.is_available() and device.type == "cuda": + if not is_available() and device.type == "cuda": return # Fix random seed. - torch.manual_seed(711) + manual_seed(711) # Set up inputs x. - x = torch.randn((N_SAMPLES, C_in, H_in, W_in), dtype=DTYPE, device=device) + x = randn((N_SAMPLES, C_in, H_in, W_in), dtype=DTYPE, device=device) n_loss_terms = N_SAMPLES # Only reduce setting. # Set up model with conv layer, average pooling, and linear output layer. @@ -191,8 +199,7 @@ def test_kfac_conv2d_module( # Jacobians. Js, f = jacobians_naive(model, x, setting) assert f.shape == (n_loss_terms, OUT_DIM) - assert Js.shape == (n_loss_terms, OUT_DIM, num_params) - Js = Js.flatten(end_dim=-2) + assert Js.shape == (n_loss_terms * OUT_DIM, num_params) # Exact Fisher/GGN. exact_F = Js.T @ Js # regression @@ -209,16 +216,16 @@ def test_kfac_conv2d_module( if setting == "reduce": # KFAC-reduce should be exact for this setting. # Compare true Fisher/GGN against K-FAC Fisher/GGN diagonal. - assert torch.allclose(F.diag(), exact_F.diag()) + assert allclose(F.diag(), exact_F.diag()) # Compare true Fisher/GGN against K-FAC Fisher/GGN for the Conv2d layer. - assert torch.allclose( + assert allclose( F[:num_conv_params, :num_conv_params], exact_F[:num_conv_params, :num_conv_params], ) else: # KFAC-expand should not be exact for this setting. # Compare true Fisher/GGN against K-FAC Fisher/GGN diagonal. - assert not torch.allclose(F.diag(), exact_F.diag()) + assert not allclose(F.diag(), exact_F.diag()) class WeightShareModel(Sequential): @@ -322,9 +329,9 @@ def get_kfac_blocks(self) -> List[Tensor]: raise ValueError("forward_and_backward() has to be called first.") # Get Kronecker factor ingredients stored as module attributes. a: Tensor = module.kfac_a - g: Tensor = torch.cat(module.kfac_g) + g: Tensor = cat(module.kfac_g) # Compute Kronecker product of both factors. - block = torch.kron(g.T @ g, a.T @ a) + block = kron(g.T @ g, a.T @ a) # When a bias is used we have to reorder the rows and columns of the # block to match the order of the parameters in the naive Jacobian # implementation. @@ -352,7 +359,7 @@ def get_full_kfac_matrix(self) -> Tensor: The full block-diagonal KFAC approximation matrix. """ blocks = self.get_kfac_blocks() - return torch.block_diag(*blocks) + return block_diag(*blocks) def _install_hooks(self) -> List[RemovableHandle]: """Installs forward and backward hooks to the model. diff --git a/test/optim/utils.py b/test/optim/utils.py index d26ef92..d7c6b72 100644 --- a/test/optim/utils.py +++ b/test/optim/utils.py @@ -2,8 +2,8 @@ from typing import Tuple -import torch -from torch import Tensor +from torch import Tensor, cat, dtype, stack +from torch.autograd import grad from torch.nn import Module from singd.optim.optimizer import SINGD @@ -41,8 +41,8 @@ def check_preconditioner_dtypes(optim: SINGD): """ for module, name in optim.module_names.items(): dtype_K, dtype_C = optim._get_param_group_entry(module, "preconditioner_dtype") - dtype_K = dtype_K if isinstance(dtype_K, torch.dtype) else module.weight.dtype - dtype_C = dtype_C if isinstance(dtype_C, torch.dtype) else module.weight.dtype + dtype_K = dtype_K if isinstance(dtype_K, dtype) else module.weight.dtype + dtype_C = dtype_C if isinstance(dtype_C, dtype) else module.weight.dtype verify_dtype(optim.Ks[name], dtype_K) verify_dtype(optim.m_Ks[name], dtype_K) @@ -50,7 +50,7 @@ def check_preconditioner_dtypes(optim: SINGD): verify_dtype(optim.m_Cs[name], dtype_C) -def verify_dtype(mat: StructuredMatrix, dtype: torch.dtype): +def verify_dtype(mat: StructuredMatrix, dtype: dtype): """Check whether a structured matrix's tensors are of the specified type. Args: @@ -82,19 +82,29 @@ def verify_dtype(mat: StructuredMatrix, dtype: torch.dtype): def jacobians_naive(model: Module, data: Tensor, setting: str) -> Tuple[Tensor, Tensor]: - num_params = sum(p.numel() for p in model.parameters()) + """Compute the Jacobians of a model's output w.r.t. its parameters. + + Args: + model: The model. + data: The input data. + setting: The setting to use for the forward pass of the model + (if appropriate). Possible values are `"expand"` and `"reduce"`. + + Returns: + A tuple of the Jacobians of the model's output w.r.t. its parameters and + the model's output, with shapes `(n_loss_terms * out_dim, num_params)` + and `(n_loss_terms, ..., out_dim)` respectively. + """ try: f: Tensor = model(data, setting) except TypeError: f: Tensor = model(data) - # f: (batch_size/n_loss_terms, ..., out_dim) - out_dim = f.size(-1) + # f: (n_loss_terms, ..., out_dim) last_f_dim = f.numel() - 1 jacs = [] for i, f_i in enumerate(f.flatten()): - rg = i != last_f_dim - jac = torch.autograd.grad(f_i, model.parameters(), retain_graph=rg) - jacs.append(torch.cat([j.flatten() for j in jac])) - # jacs: (n_loss_terms, out_dim, num_params) - jacs = torch.stack(jacs).view(-1, out_dim, num_params) + jac = grad(f_i, model.parameters(), retain_graph=i != last_f_dim) + jacs.append(cat([j.flatten() for j in jac])) + # jacs: (n_loss_terms * out_dim, num_params) + jacs = stack(jacs).flatten(end_dim=-2) return jacs.detach(), f.detach()