From fd812f7fd32adcc9412c333e86871d168ec76ed5 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 7 Aug 2024 10:48:20 -0400 Subject: [PATCH 01/15] Add compute_eigendecomposition --- curvlinops/kfac.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index fbee73b..6cb6a6b 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -27,6 +27,7 @@ from einops import einsum, rearrange, reduce from numpy import ndarray from torch import Generator, Tensor, cat, device, eye, randn, stack +from torch.linalg import eigh from torch.autograd import grad from torch.nn import ( BCEWithLogitsLoss, @@ -144,6 +145,7 @@ def __init__( fisher_type: str = FisherType.MC, mc_samples: int = 1, kfac_approx: str = KFACType.EXPAND, + correct_eigenvalues: bool = False, num_per_example_loss_terms: Optional[int] = None, separate_weight_and_bias: bool = True, num_data: Optional[int] = None, @@ -204,6 +206,11 @@ def __init__( See `Eschenhagen et al., 2023 `_ for an explanation of the two approximations. Defaults to ``KFACType.EXPAND``. + correct_eigenvalues: Whether to correct the eigenvalues in the KFAC + eigenbasis, as proposed in + `George et al., 2018 `_. If true, + will only store the eigendecomposition of the KFAC approximation. + Defaults to ``False``. num_per_example_loss_terms: Number of per-example loss terms, e.g., the number of tokens in a sequence. The model outputs will have ``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is @@ -254,10 +261,20 @@ def __init__( self._fisher_type = fisher_type self._mc_samples = mc_samples self._kfac_approx = kfac_approx + self._correct_eigenvalues = correct_eigenvalues self._input_covariances: Dict[str, Tensor] = {} self._gradient_covariances: Dict[str, Tensor] = {} self._mapping = self.compute_parameter_mapping(params, model_func) + # Initialize the eigenvectors and eigenvalues of the Kronecker factors + self._input_covariances_eigenvectors: Dict[str, Tensor] = {} + self._input_covariances_eigenvalues: Dict[str, Tensor] = {} + self._gradient_covariances_eigenvectors: Dict[str, Tensor] = {} + self._gradient_covariances_eigenvalues: Dict[str, Tensor] = {} + + # Initialize the corrected eigenvalues for EKFAC + self._corrected_eigenvalues: Dict[str, Tensor] = {} + # Properties of the full matrix KFAC approximation are initialized to `None` self._reset_matrix_properties() @@ -893,6 +910,35 @@ def compute_parameter_mapping( return positions + def compute_eigendecomposition(self, keep_kronecker_factors: bool = False) -> None: + """Compute the eigendecomposition of the KFAC approximation. + + Args: + keep_kronecker_factors: Whether to keep the Kronecker factors. If ``False``, + will free the memory used by the Kronecker factors. + Defaults to ``False``. + """ + if not self._input_covariances and not self._gradient_covariances: + self._compute_kfac() + + for mod_name in self._mapping.keys(): + aaT = self._input_covariances[mod_name] + ggT = self._gradient_covariances[mod_name] + if not keep_kronecker_factors: + del self._input_covariances[mod_name] + del self._gradient_covariances[mod_name] + + # Compute eigendecomposition of the Kronecker factors + aaT_eigvals, aaT_eigvecs = eigh(aaT) + self._input_covariances_eigenvectors[mod_name] = aaT_eigvecs + self._input_covariances_eigenvalues[mod_name] = aaT_eigvals + del aaT + + ggT_eigvals, ggT_eigvecs = eigh(ggT) + self._gradient_covariances_eigenvectors[mod_name] = ggT_eigvecs + self._gradient_covariances_eigenvalues[mod_name] = ggT_eigvals + del ggT + @property def trace(self) -> Tensor: r"""Trace of the KFAC approximation. From 290fd1b56f71255620c1fca631c80ddb1d1f0e70 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 16 Sep 2024 23:12:39 -0400 Subject: [PATCH 02/15] Add EKFAC test coverage --- test/test_kfac.py | 316 ++++++++++++++++++++++++++++++++++++---------- test/utils.py | 30 +++-- 2 files changed, 267 insertions(+), 79 deletions(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index e7f8d61..e56d8dc 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,15 +1,16 @@ """Contains tests for ``curvlinops.kfac``.""" import os +from contextlib import nullcontext from test.cases import DEVICES, DEVICES_IDS from test.utils import ( Conv2dModel, UnetModel, WeightShareModel, binary_classification_targets, + block_diagonal, classification_targets, compare_state_dicts, - ggn_block_diagonal, regression_targets, ) from typing import Dict, Iterable, List, Tuple, Union @@ -19,7 +20,6 @@ from numpy import eye from numpy.linalg import det, norm, slogdet from pytest import mark, raises, skip -from scipy.linalg import block_diag from torch import Tensor, allclose, cat, cuda, device from torch import eye as torch_eye from torch import isinf, isnan, load, manual_seed, rand, rand_like, randperm, save @@ -35,6 +35,7 @@ Sequential, ) +from curvlinops import EFLinearOperator, GGNLinearOperator from curvlinops.examples.utils import report_nonclose from curvlinops.gradient_moments import EFLinearOperator from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType @@ -47,6 +48,9 @@ "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_type2( kfac_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] @@ -54,6 +58,7 @@ def test_kfac_type2( shuffle: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, ): """Test the KFAC implementation against the exact GGN. @@ -65,6 +70,7 @@ def test_kfac_type2( or ``None``. separate_weight_and_bias: Whether to treat weight and bias as separate blocks in the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ assert exclude in [None, "weight", "bias"] model, loss_func, params, data, batch_size_fn = kfac_exact_case @@ -77,7 +83,8 @@ def test_kfac_type2( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -93,10 +100,11 @@ def test_kfac_type2( batch_size_fn=batch_size_fn, fisher_type=FisherType.TYPE2, separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) - report_nonclose(ggn, kfac_mat) + report_nonclose(ggn, kfac_mat, atol=1e-6) # Check that input covariances were not computed if exclude == "weight": @@ -111,6 +119,9 @@ def test_kfac_type2( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_type2_weight_sharing( kfac_weight_sharing_exact_case: Tuple[ Union[WeightShareModel, Conv2dModel], @@ -122,6 +133,7 @@ def test_kfac_type2_weight_sharing( shuffle: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, ): """Test KFAC for linear weight-sharing layers against the exact GGN. @@ -135,6 +147,7 @@ def test_kfac_type2_weight_sharing( or ``None``. separate_weight_and_bias: Whether to treat weight and bias as separate blocks in the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ assert exclude in [None, "weight", "bias"] model, loss_func, params, data, batch_size_fn = kfac_weight_sharing_exact_case @@ -152,7 +165,8 @@ def test_kfac_type2_weight_sharing( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -169,6 +183,7 @@ def test_kfac_type2_weight_sharing( fisher_type=FisherType.TYPE2, kfac_approx=setting, # choose KFAC approximation consistent with setting separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -180,11 +195,15 @@ def test_kfac_type2_weight_sharing( @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_mc( kfac_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ], shuffle: bool, + correct_eigenvalues: bool, ): """Test the KFAC implementation using MC samples against the exact GGN. @@ -192,6 +211,7 @@ def test_kfac_mc( kfac_exact_case: A fixture that returns a model, loss function, list of parameters, and data. shuffle: Whether to shuffle the parameters before computing the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ model, loss_func, params, data, batch_size_fn = kfac_exact_case @@ -199,8 +219,8 @@ def test_kfac_mc( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -208,7 +228,9 @@ def test_kfac_mc( params, data, batch_size_fn=batch_size_fn, - mc_samples=2_000, + fisher_type=FisherType.MC, + mc_samples=3_000, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -220,6 +242,9 @@ def test_kfac_mc( @mark.parametrize("setting", [KFACType.EXPAND, KFACType.REDUCE]) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_mc_weight_sharing( kfac_weight_sharing_exact_case: Tuple[ Union[WeightShareModel, Conv2dModel], @@ -229,6 +254,7 @@ def test_kfac_mc_weight_sharing( ], setting: str, shuffle: bool, + correct_eigenvalues: bool, ): """Test KFAC-MC for linear layers with weight sharing against the exact GGN. @@ -238,6 +264,7 @@ def test_kfac_mc_weight_sharing( setting: The weight-sharing setting to use. Can be ``KFACType.EXPAND`` or ``KFACType.REDUCE``. shuffle: Whether to shuffle the parameters before computing the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ model, loss_func, params, data, batch_size_fn = kfac_weight_sharing_exact_case model.setting = setting @@ -250,8 +277,8 @@ def test_kfac_mc_weight_sharing( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -260,8 +287,9 @@ def test_kfac_mc_weight_sharing( data, batch_size_fn=batch_size_fn, fisher_type=FisherType.MC, - mc_samples=2_000, + mc_samples=4_000, kfac_approx=setting, # choose KFAC approximation consistent with setting + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -271,6 +299,9 @@ def test_kfac_mc_weight_sharing( report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_one_datum( kfac_exact_one_datum_case: Tuple[ Module, @@ -278,11 +309,12 @@ def test_kfac_one_datum( List[Parameter], Iterable[Tuple[Tensor, Tensor]], ], + correct_eigenvalues: bool, ): model, loss_func, params, data, batch_size_fn = kfac_exact_one_datum_case - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -291,12 +323,16 @@ def test_kfac_one_datum( data, batch_size_fn=batch_size_fn, fisher_type=FisherType.TYPE2, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) report_nonclose(ggn, kfac_mat) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_mc_one_datum( kfac_exact_one_datum_case: Tuple[ Module, @@ -304,11 +340,12 @@ def test_kfac_mc_one_datum( List[Parameter], Iterable[Tuple[Tensor, Tensor]], ], + correct_eigenvalues: bool, ): model, loss_func, params, data, batch_size_fn = kfac_exact_one_datum_case - ggn = ggn_block_diagonal( - model, loss_func, params, data, batch_size_fn=batch_size_fn + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, data, batch_size_fn=batch_size_fn ) kfac = KFACLinearOperator( model, @@ -316,7 +353,9 @@ def test_kfac_mc_one_datum( params, data, batch_size_fn=batch_size_fn, + fisher_type=FisherType.MC, mc_samples=11_000, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -326,6 +365,15 @@ def test_kfac_mc_one_datum( report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) +@mark.parametrize( + "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] +) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_kfac_ef_one_datum( kfac_exact_one_datum_case: Tuple[ Module, @@ -333,16 +381,25 @@ def test_kfac_ef_one_datum( List[Parameter], Iterable[Tuple[Tensor, Tensor]], ], + separate_weight_and_bias: bool, + exclude: str, + correct_eigenvalues: bool, ): model, loss_func, params, data, batch_size_fn = kfac_exact_one_datum_case - ef_blocks = [] # list of per-parameter EFs - for param in params: - ef = EFLinearOperator( - model, loss_func, [param], data, batch_size_fn=batch_size_fn - ) - ef_blocks.append(ef @ eye(ef.shape[1])) - ef = block_diag(*ef_blocks) + if exclude is not None: + names = {p.data_ptr(): name for name, p in model.named_parameters()} + params = [p for p in params if exclude not in names[p.data_ptr()]] + + ef = block_diagonal( + EFLinearOperator, + model, + loss_func, + params, + data, + batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + ) kfac = KFACLinearOperator( model, @@ -350,7 +407,9 @@ def test_kfac_ef_one_datum( params, data, batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, fisher_type=FisherType.EMPIRICAL, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -375,7 +434,7 @@ def test_kfac_inplace_activations(dev: device): params = list(model.parameters()) # 1) compare KFAC and GGN - ggn = ggn_block_diagonal(model, loss_func, params, data) + ggn = block_diagonal(GGNLinearOperator, model, loss_func, params, data) kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -389,7 +448,7 @@ def test_kfac_inplace_activations(dev: device): for mod in model.modules(): if hasattr(mod, "inplace"): mod.inplace = False - ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data) + ggn_no_inplace = block_diagonal(GGNLinearOperator, model, loss_func, params, data) report_nonclose(ggn, ggn_no_inplace) @@ -400,11 +459,15 @@ def test_kfac_inplace_activations(dev: device): ) @mark.parametrize("reduction", ["mean", "sum"]) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_multi_dim_output( fisher_type: str, loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], reduction: str, dev: device, + correct_eigenvalues: bool, ): """Test the KFAC implementation for >2d outputs (using a 3d and 4d output). @@ -413,6 +476,7 @@ def test_multi_dim_output( loss: The loss function to use. reduction: The reduction to use for the loss function. dev: The device to run the test on. + correct_eigenvalues: Whether EKFAC should be used. """ manual_seed(0) # set up loss function, data, and model @@ -454,6 +518,7 @@ def test_multi_dim_output( params, data, fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, ) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -479,6 +544,7 @@ def test_multi_dim_output( params_flat, data_flat, fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, ) kfac_flat_mat = kfac_flat @ eye(kfac_flat.shape[1]) @@ -489,11 +555,15 @@ def test_multi_dim_output( @mark.parametrize( "loss", [MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], ids=["mse", "ce", "bce"] ) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) def test_expand_setting_scaling( fisher_type: str, loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], dev: device, + correct_eigenvalues: bool, ): """Test KFAC for correct scaling for expand setting with mean reduction loss. @@ -503,6 +573,7 @@ def test_expand_setting_scaling( fisher_type: The type of Fisher matrix to use. loss: The loss function to use. dev: The device to run the test on. + correct_eigenvalues: Whether EKFAC should be used. """ manual_seed(0) @@ -535,6 +606,7 @@ def test_expand_setting_scaling( params, data, fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, ) # FOOF does not scale the gradient covariances, even when using a mean reduction if fisher_type != FisherType.FORWARD_ONLY: @@ -544,8 +616,17 @@ def test_expand_setting_scaling( output_random_variable_size = 3 # MSE loss averages over number of output channels loss_term_factor *= output_random_variable_size - for ggT in kfac_sum._gradient_covariances.values(): - ggT /= kfac_sum._N_data * loss_term_factor + correction = kfac_sum._N_data * loss_term_factor + if correct_eigenvalues: + for eigenvalues in kfac_sum._corrected_eigenvalues.values(): + if isinstance(eigenvalues, dict): + for eigenvals in eigenvalues.values(): + eigenvals /= correction + else: + eigenvalues /= correction + else: + for ggT in kfac_sum._gradient_covariances.values(): + ggT /= correction kfac_simulated_mean_mat = kfac_sum @ eye(kfac_sum.shape[1]) # KFAC with mean reduction @@ -556,10 +637,11 @@ def test_expand_setting_scaling( params, data, fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, ) kfac_mean_mat = kfac_mean @ eye(kfac_mean.shape[1]) - report_nonclose(kfac_simulated_mean_mat, kfac_mean_mat) + report_nonclose(kfac_simulated_mean_mat, kfac_mean_mat, atol=1e-7) def test_bug_device_change_invalidates_parameter_mapping(): @@ -595,16 +677,31 @@ def test_bug_device_change_invalidates_parameter_mapping(): report_nonclose(kfac_x_gpu, kfac_x_cpu) -def test_torch_matmat(case): +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) +@mark.parametrize( + "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] +) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matmat(case, separate_weight_and_bias, exclude, correct_eigenvalues): """Test that the torch_matmat method of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case + if exclude is not None: + names = {p.data_ptr(): name for name, p in model.named_parameters()} + params = [p for p in params if exclude not in names[p.data_ptr()]] + kfac = KFACLinearOperator( model, loss_func, params, data, batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) device = kfac._device # KFAC.dtype is a numpy data type @@ -635,16 +732,31 @@ def test_torch_matmat(case): report_nonclose(kfac_x, kfac_x_numpy, rtol=1e-4) -def test_torch_matvec(case): +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) +@mark.parametrize( + "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] +) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matvec(case, separate_weight_and_bias, exclude, correct_eigenvalues): """Test that the torch_matvec method of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case + if exclude is not None: + names = {p.data_ptr(): name for name, p in model.named_parameters()} + params = [p for p in params if exclude not in names[p.data_ptr()]] + kfac = KFACLinearOperator( model, loss_func, params, data, batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + correct_eigenvalues=correct_eigenvalues, ) device = kfac._device # KFAC.dtype is a numpy data type @@ -683,7 +795,10 @@ def test_torch_matvec(case): report_nonclose(kfac_x, kfac_x_numpy) -def test_torch_matvec_list_output_shapes(cnn_case): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matvec_list_output_shapes(cnn_case, correct_eigenvalues): """Test output shapes with list input format (issue #124).""" model, loss_func, params, data, batch_size_fn = cnn_case kfac = KFACLinearOperator( @@ -692,6 +807,7 @@ def test_torch_matvec_list_output_shapes(cnn_case): params, data, batch_size_fn=batch_size_fn, + correct_eigenvalues=correct_eigenvalues, ) vec = [rand_like(p) for p in kfac._params] out_list = kfac.torch_matvec(vec) @@ -711,7 +827,12 @@ def test_torch_matvec_list_output_shapes(cnn_case): @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_trace(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_trace( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the trace property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -727,6 +848,7 @@ def test_trace(case, exclude, separate_weight_and_bias, check_deterministic): batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # Check for equivalence of trace property and naive trace computation @@ -751,7 +873,12 @@ def test_trace(case, exclude, separate_weight_and_bias, check_deterministic): @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_frobenius_norm(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_frobenius_norm( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the Frobenius norm property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -767,6 +894,7 @@ def test_frobenius_norm(case, exclude, separate_weight_and_bias, check_determini batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # Check for equivalence of frobenius_norm property and the naive computation @@ -791,7 +919,12 @@ def test_frobenius_norm(case, exclude, separate_weight_and_bias, check_determini @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_det(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_det( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the determinant property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -807,21 +940,32 @@ def test_det(case, exclude, separate_weight_and_bias, check_deterministic): batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # add damping manually to avoid singular matrices if not check_deterministic: kfac._compute_kfac() - assert kfac._input_covariances or kfac._gradient_covariances + delta = 1.0 # requires much larger damping value compared to ``logdet`` - for aaT in kfac._input_covariances.values(): - aaT.add_( - torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta - ) - for ggT in kfac._gradient_covariances.values(): - ggT.add_( - torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta - ) + if correct_eigenvalues: + assert kfac._corrected_eigenvalues + for eigenvalues in kfac._corrected_eigenvalues.values(): + if isinstance(eigenvalues, dict): + for eigenvals in eigenvalues.values(): + eigenvals.add_(delta) + else: + eigenvalues.add_(delta) + else: + assert kfac._input_covariances or kfac._gradient_covariances + for aaT in kfac._input_covariances.values(): + aaT.add_( + torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta + ) + for ggT in kfac._gradient_covariances.values(): + ggT.add_( + torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta + ) # Check for equivalence of the det property and naive determinant computation determinant = kfac.det @@ -847,7 +991,12 @@ def test_det(case, exclude, separate_weight_and_bias, check_deterministic): @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) -def test_logdet(case, exclude, separate_weight_and_bias, check_deterministic): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_logdet( + case, exclude, separate_weight_and_bias, check_deterministic, correct_eigenvalues +): """Test that the log determinant property of KFACLinearOperator works.""" model, loss_func, params, data, batch_size_fn = case @@ -863,21 +1012,32 @@ def test_logdet(case, exclude, separate_weight_and_bias, check_deterministic): batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=check_deterministic, + correct_eigenvalues=correct_eigenvalues, ) # add damping manually to avoid singular matrices if not check_deterministic: kfac._compute_kfac() - assert kfac._input_covariances or kfac._gradient_covariances + delta = 1e-3 # only requires much smaller damping value compared to ``det`` - for aaT in kfac._input_covariances.values(): - aaT.add_( - torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta - ) - for ggT in kfac._gradient_covariances.values(): - ggT.add_( - torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta - ) + if correct_eigenvalues: + assert kfac._corrected_eigenvalues + for eigenvalues in kfac._corrected_eigenvalues.values(): + if isinstance(eigenvalues, dict): + for eigenvals in eigenvalues.values(): + eigenvals.add_(delta) + else: + eigenvalues.add_(delta) + else: + assert kfac._input_covariances or kfac._gradient_covariances + for aaT in kfac._input_covariances.values(): + aaT.add_( + torch_eye(aaT.shape[0], dtype=aaT.dtype, device=aaT.device), alpha=delta + ) + for ggT in kfac._gradient_covariances.values(): + ggT.add_( + torch_eye(ggT.shape[0], dtype=ggT.dtype, device=ggT.device), alpha=delta + ) # Check for equivalence of the logdet property and naive log determinant computation log_det = kfac.logdet @@ -900,11 +1060,15 @@ def test_logdet(case, exclude, separate_weight_and_bias, check_deterministic): "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_forward_only_fisher_type( case: Tuple[Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]], shuffle: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, ): """Test the KFAC with forward-only Fisher (used for FOOF) implementation. @@ -916,6 +1080,7 @@ def test_forward_only_fisher_type( or ``None``. separate_weight_and_bias: Whether to treat weight and bias as separate blocks in the KFAC matrix. + correct_eigenvalues: Whether EKFAC should be used. """ assert exclude in [None, "weight", "bias"] model, loss_func, params, data, batch_size_fn = case @@ -946,16 +1111,25 @@ def test_forward_only_fisher_type( ) simulated_foof_mat = foof_simulated @ eye(foof_simulated.shape[1]) - # Compute KFAC with `fisher_type=FisherType.FORWARD_ONLY` - foof = KFACLinearOperator( - model, - loss_func, - params, - data, - batch_size_fn=batch_size_fn, - separate_weight_and_bias=separate_weight_and_bias, - fisher_type=FisherType.FORWARD_ONLY, - ) + # Compute KFAC with `fisher_type=FisherType.FORWARD_ONLY + context = ( + raises(ValueError, match="eigenvalues") + if correct_eigenvalues + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + foof = KFACLinearOperator( + model, + loss_func, + params, + data, + batch_size_fn=batch_size_fn, + separate_weight_and_bias=separate_weight_and_bias, + fisher_type=FisherType.FORWARD_ONLY, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues: + return foof_mat = foof @ eye(foof.shape[1]) # Check for equivalence @@ -1014,7 +1188,8 @@ def test_forward_only_fisher_type_exact_case( params = [params[i] for i in permutation] # Compute exact block-diagonal GGN - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -1118,7 +1293,8 @@ def test_forward_only_fisher_type_exact_weight_sharing_case( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn = ggn_block_diagonal( + ggn = block_diagonal( + GGNLinearOperator, model, loss_func, params, @@ -1189,7 +1365,10 @@ def test_kfac_does_affect_grad(): assert allclose(grad_before, p.grad) -def test_save_and_load_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_save_and_load_state_dict(correct_eigenvalues): """Test that KFACLinearOperator can be saved and loaded from state dict.""" manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -1204,6 +1383,7 @@ def test_save_and_load_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # save state dict @@ -1260,7 +1440,10 @@ def test_save_and_load_state_dict(): report_nonclose(kfac @ test_vec, kfac_new @ test_vec) -def test_from_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_from_state_dict(correct_eigenvalues): """Test that KFACLinearOperator can be created from state dict.""" manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -1275,6 +1458,7 @@ def test_from_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # save state dict diff --git a/test/utils.py b/test/utils.py index dcc656b..0975345 100644 --- a/test/utils.py +++ b/test/utils.py @@ -33,7 +33,7 @@ Upsample, ) -from curvlinops import GGNLinearOperator +from curvlinops._base import _LinearOperator def get_available_devices() -> List[device]: @@ -87,7 +87,8 @@ def regression_targets(size: Tuple[int]) -> Tensor: return rand(*size) -def ggn_block_diagonal( +def block_diagonal( + linear_operator: _LinearOperator, model: Module, loss_func: Module, params: List[Parameter], @@ -95,26 +96,29 @@ def ggn_block_diagonal( batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, separate_weight_and_bias: bool = True, ) -> ndarray: - """Compute the block-diagonal GGN. + """Compute the block-diagonal of the matrix induced by a linear operator. Args: + linear_operator: The linear operator. model: The neural network. loss_func: The loss function. - params: The parameters w.r.t. which the GGN block-diagonals will be computed. + params: The parameters w.r.t. which the block-diagonal will be computed for. data: A data loader. batch_size_fn: A function that returns the batch size given a dict-like ``X``. separate_weight_and_bias: Whether to treat weight and bias of a layer as - separate blocks in the block-diagonal GGN. Default: ``True``. + separate blocks in the block-diagonal. Default: ``True``. Returns: - The block-diagonal GGN. + The block-diagonal matrix. """ - # compute the full GGN then zero out the off-diagonal blocks - ggn = GGNLinearOperator(model, loss_func, params, data, batch_size_fn=batch_size_fn) - ggn = from_numpy(ggn @ eye(ggn.shape[1])) + # compute the full matrix then zero out the off-diagonal blocks + linop = linear_operator(model, loss_func, params, data, batch_size_fn=batch_size_fn) + linop = from_numpy(linop @ eye(linop.shape[1])) sizes = [p.numel() for p in params] - # ggn_blocks[i, j] corresponds to the block of (params[i], params[j]) - ggn_blocks = [list(block.split(sizes, dim=1)) for block in ggn.split(sizes, dim=0)] + # matrix_blocks[i, j] corresponds to the block of (params[i], params[j]) + matrix_blocks = [ + list(block.split(sizes, dim=1)) for block in linop.split(sizes, dim=0) + ] # find out which blocks to keep num_params = len(params) @@ -142,10 +146,10 @@ def ggn_block_diagonal( for i, j in product(range(num_params), range(num_params)): if (i, j) not in keep: - ggn_blocks[i][j].zero_() + matrix_blocks[i][j].zero_() # concatenate all blocks - return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy() + return cat([cat(row_blocks, dim=1) for row_blocks in matrix_blocks], dim=0).numpy() class WeightShareModel(Sequential): From 7c72d97fbdea5a9de01ace1de0e518541536e810 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 16 Sep 2024 23:25:13 -0400 Subject: [PATCH 03/15] Implement EKFAC --- curvlinops/kfac.py | 582 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 441 insertions(+), 141 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index c333a0a..0d1d1e7 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -27,8 +27,8 @@ from einops import einsum, rearrange, reduce from numpy import ndarray from torch import Generator, Tensor, cat, device, eye, randn, stack -from torch.linalg import eigh from torch.autograd import grad +from torch.linalg import eigh from torch.nn import ( BCEWithLogitsLoss, Conv2d, @@ -51,6 +51,9 @@ # shape as the parameters, or a single matrix/vector of shape `[D, D]`/`[D]` where `D` # is the number of parameters. ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor]) +KFACType = TypeVar( + "KFACType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] +) class MetaEnum(EnumMeta): @@ -249,6 +252,10 @@ def __init__( f"Invalid mc_samples: {mc_samples}. " "Only mc_samples=1 is supported for `fisher_type != FisherType.MC`." ) + if fisher_type == FisherType.FORWARD_ONLY and correct_eigenvalues: + raise ValueError( + "Correcting eigenvalues is not supported for FisherType.FORWARD_ONLY." + ) if kfac_approx not in self._SUPPORTED_KFAC_APPROX: raise ValueError( f"Invalid kfac_approx: {kfac_approx}. " @@ -262,6 +269,7 @@ def __init__( self._mc_samples = mc_samples self._kfac_approx = kfac_approx self._correct_eigenvalues = correct_eigenvalues + self._compute_eigenvalue_correction_flag = False self._input_covariances: Dict[str, Tensor] = {} self._gradient_covariances: Dict[str, Tensor] = {} self._mapping = self.compute_parameter_mapping(params, model_func) @@ -272,6 +280,8 @@ def __init__( self._gradient_covariances_eigenvectors: Dict[str, Tensor] = {} self._gradient_covariances_eigenvalues: Dict[str, Tensor] = {} + # Initialize the cache for activations + self._cached_activations: Dict[str, Tensor] = {} # Initialize the corrected eigenvalues for EKFAC self._corrected_eigenvalues: Dict[str, Tensor] = {} @@ -425,6 +435,86 @@ def _check_input_type_and_preprocess( M_torch = self._torch_preprocess(M_torch) return return_tensor, M_torch + @staticmethod + def _left_and_right_multiply( + M_joint: Tensor, + aaT: KFACType, + ggT: KFACType, + eigenvalues: Optional[Tensor], + ) -> Tensor: + """Left and right multiply matrix with Kronecker factors. + + Args: + M_joint: Matrix for multiplication. + aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for + biases. + ggT: Gradient covariance Kronecker factor or its eigenvectors. + eigenvalues: Corrected eigenvalues for the EKFAC approximation. + + Returns: + Matrix-multiplication result ``KFAC @ M_joint``. + """ + if eigenvalues is None: + M_joint = einsum(ggT, M_joint, aaT, "i j, m j k, k l -> m i l") + else: + # Perform preconditioning in KFE, e.g. see equation (21) in + # https://arxiv.org/abs/2308.03296. + aaT_eigvecs = aaT + ggT_eigvecs = ggT + # Transform in eigenbasis. + M_joint = einsum( + ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l" + ) + # Multiply by eigenvalues. + M_joint.mul_(eigenvalues) + # Transform back to standard basis. + M_joint = einsum( + ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m j k, l k -> m i l" + ) + return M_joint + + @staticmethod + def _separate_left_and_right_multiply( + M_torch: Tensor, + param_pos: Dict[str, int], + aaT: KFACType, + ggT: KFACType, + eigenvalues: Optional[Tensor], + ) -> Tensor: + """Multiply matrix with Kronecker factors for separated weight and bias. + + Args: + M_torch: Matrix for multiplication. + param_pos: Dictionary with positions of the weight and bias parameters. + aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for + biases. + ggT: Gradient covariance Kronecker factor or its eigenvectors. + eigenvalues: Corrected eigenvalues for the EKFAC approximation. + + Returns: + Matrix-multiplication result ``KFAC @ M_torch``. + """ + for p_name, pos in param_pos.items(): + # for weights we need to multiply from the right with aaT + # for weights and biases we need to multiply from the left with ggT + if p_name == "weight": + M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)") + # If `eigenvalues` is not `None`, we transform to eigenbasis here + M_torch[pos] = einsum(M_w, aaT, "m i j, j k -> m i k") + + dims = "m j ... -> m i ..." if eigenvalues is None else "m i ... -> m j ..." + # If `eigenvalues` is not `None`, we transform to eigenbasis here + M_torch[pos] = einsum(ggT, M_torch[pos], f"i j, {dims}") + + if eigenvalues is not None: + # Multiply by eigenvalues and transform back to standard basis + M_torch[pos].mul_(eigenvalues[pos]) + if p_name == "weight": + M_torch[pos] = einsum(M_torch[pos], aaT, "m i j, k j -> m i k") + M_torch[pos] = einsum(ggT, M_torch[pos], "i j, m j ... -> m i ...") + + return M_torch + def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply KFAC to a matrix (multiple vectors) in PyTorch. @@ -446,7 +536,12 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: ``[D, K]`` with some ``K``. """ return_tensor, M_torch = self._check_input_type_and_preprocess(M_torch) - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._input_covariances_eigenvectors + and not self._gradient_covariances_eigenvectors + ): self._compute_kfac() for mod_name, param_pos in self._mapping.items(): @@ -454,6 +549,16 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: if "weight" in param_pos: weight_shape = M_torch[param_pos["weight"]].shape + # get the Kronecker factors for the current module + if self._correct_eigenvalues: + aaT = self._input_covariances_eigenvectors.get(mod_name) + ggT = self._gradient_covariances_eigenvectors.get(mod_name) + eigenvalues = self._corrected_eigenvalues[mod_name] + else: + aaT = self._input_covariances.get(mod_name) + ggT = self._gradient_covariances.get(mod_name) + eigenvalues = None + # bias and weights are treated jointly if ( not self._separate_weight_and_bias @@ -461,33 +566,15 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: and "bias" in param_pos.keys() ): w_pos, b_pos = param_pos["weight"], param_pos["bias"] - # v denotes the free dimension for treating multiple vectors in parallel - M_w = rearrange(M_torch[w_pos], "v c_out ... -> v c_out (...)") - M_joint = cat([M_w, M_torch[b_pos].unsqueeze(-1)], dim=2) - aaT = self._input_covariances[mod_name] - ggT = self._gradient_covariances[mod_name] - M_joint = einsum(ggT, M_joint, aaT, "i j,v j k,k l -> v i l") - + M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)") + M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2) + M_joint = self._left_and_right_multiply(M_joint, aaT, ggT, eigenvalues) w_cols = M_w.shape[2] M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2) - - # for weights we need to multiply from the right with aaT - # for weights and biases we need to multiply from the left with ggT else: - for p_name, pos in param_pos.items(): - if p_name == "weight": - M_w = rearrange(M_torch[pos], "v c_out ... -> v c_out (...)") - M_torch[pos] = einsum( - M_w, - self._input_covariances[mod_name], - "v c_out j,j k -> v c_out k", - ) - - M_torch[pos] = einsum( - self._gradient_covariances[mod_name], - M_torch[pos], - "j k,v k ... -> v j ...", - ) + M_torch = self._separate_left_and_right_multiply( + M_torch, param_pos, aaT, ggT, eigenvalues + ) # restore original shapes if "weight" in param_pos: @@ -593,6 +680,24 @@ def _compute_kfac(self): output = self._model_func(X) self._compute_loss_and_backward(output, y) + if self._correct_eigenvalues: + # Compute the eigenvalue decomposition of the KFAC approximation + if not ( + self._input_covariances_eigenvalues + or self._gradient_covariances_eigenvalues + ): + self._compute_eigendecomposition() + + # Compute the corrected eigenvalues for the EKFAC approximation + self._compute_eigenvalue_correction_flag = True + for X, y in self._loop_over_data(desc="Eigenvalue correction"): + output = self._model_func(X) + self._compute_loss_and_backward(output, y) + self._compute_eigenvalue_correction_flag = False + + # Delete the cached activations + self._cached_activations.clear() + # clean up for handle in hook_handles: handle.remove() @@ -815,12 +920,96 @@ def _accumulate_gradient_covariance( / (self._N_data * self._mc_samples * self._num_per_example_loss_terms), }[self._loss_func.reduction] - covariance = einsum(g, g, "b i,b j->i j").mul_(correction) + if self._compute_eigenvalue_correction_flag: + # Compute the eigenvalue correction for the EKFAC approximation + self._compute_eigenvalue_correction(module_name, g, correction) + else: + # Compute and accumulate the gradient covariance + covariance = einsum(g, g, "b i, b j -> i j").mul_(correction) + self._gradient_covariances = self._set_or_add_( + self._gradient_covariances, module_name, covariance + ) + + def _compute_eigenvalue_correction( + self, module_name: str, g: Tensor, correction: int + ): + """Compute the corrected eigenvalues for the EKFAC approximation. + + The corrected eigenvalues are computed as + :math:`\lambda_{\text{corrected}} = (Q_g^T G Q_a)^2`, where + :math:`Q_a` and :math:`Q_g` are the eigenvectors of the input and gradient + covariances, respectively, and ``G`` is the gradient matrix. The corrected + eigenvalues are used to correct the eigenvalues of the KFAC approximation + (EKFAC). + + Args: + module_name: Name of the module in the neural network. + g: The gradient w.r.t. the layer output. + correction: Correction factor for the eigenvalues. + """ + param_pos = self._mapping[module_name] + aaT_eigenvectors = self._input_covariances_eigenvectors.get(module_name) + ggT_eigenvectors = self._gradient_covariances_eigenvectors.get(module_name) - if module_name not in self._gradient_covariances: - self._gradient_covariances[module_name] = covariance + # Compute corrected eigenvalues for EKFAC. + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + # Compute per-example gradient using the cached activations + per_example_gradient = einsum( + g, + self._cached_activations[module_name], + "shared d_out, shared d_in -> shared d_out d_in", + ) + # Transform the per-example gradient to the eigenbasis and square it + self._corrected_eigenvalues = self._set_or_add_( + self._corrected_eigenvalues, + module_name, + einsum( + ggT_eigenvectors, + per_example_gradient, + aaT_eigenvectors, + "d_out1 d_out2, ... d_out1 d_in1, d_in1 d_in2 -> ... d_out2 d_in2", + ) + .square_() + .sum(dim=0) + .mul_(correction), + ) else: - self._gradient_covariances[module_name].add_(covariance) + if module_name not in self._corrected_eigenvalues: + self._corrected_eigenvalues[module_name] = {} + for p_name, pos in param_pos.items(): + # Compute per-example gradient using the cached activations + per_example_gradient = ( + einsum( + g, + self._cached_activations[module_name], + "shared d_out, shared d_in -> shared d_out d_in", + ) + if p_name == "weight" + else g + ) + # Transform the per-example gradient to the eigenbasis and square it + if p_name == "weight": + per_example_gradient = einsum( + per_example_gradient, + aaT_eigenvectors, + "batch d_out d_in1, d_in1 d_in2 -> batch d_out d_in2", + ) + self._corrected_eigenvalues[module_name] = self._set_or_add_( + self._corrected_eigenvalues[module_name], + pos, + einsum( + ggT_eigenvectors, + per_example_gradient, + "d_out1 d_out2, batch d_out1 ... -> batch d_out2 ...", + ) + .square_() + .sum(dim=0) + .mul_(correction), + ) def _hook_accumulate_input_covariance( self, module: Module, inputs: Tuple[Tensor], module_name: str @@ -857,7 +1046,7 @@ def _hook_accumulate_input_covariance( if self._kfac_approx == KFACType.EXPAND: # KFAC-expand approximation - scale = x.shape[1:-1].numel() # sequence length + scale = x.shape[1:-1].numel() # weight sharing dimensions size x = rearrange(x, "batch ... d_in -> (batch ...) d_in") else: # KFAC-reduce approximation @@ -872,12 +1061,36 @@ def _hook_accumulate_input_covariance( ): x = cat([x, x.new_ones(x.shape[0], 1)], dim=1) - covariance = einsum(x, x, "b i,b j -> i j").div_(self._N_data * scale) + if self._compute_eigenvalue_correction_flag: + self._cached_activations[module_name] = x + else: + # Compute and accumulate the input covariance + covariance = einsum(x, x, "b i, b j -> i j").div_(self._N_data * scale) + self._input_covariances = self._set_or_add_( + self._input_covariances, module_name, covariance + ) + + @staticmethod + def _set_or_add_( + dictionary: Dict[str, Tensor], key: str, value: Tensor + ) -> Dict[str, Tensor]: + """Set or add a value to a dictionary entry. - if module_name not in self._input_covariances: - self._input_covariances[module_name] = covariance + Args: + dictionary: The dictionary to update. + key: The key to update. + value: The value to add. + + Returns: + The updated dictionary. + """ + if key not in dictionary: + dictionary[key] = value + elif isinstance(dictionary[key], Tensor) and isinstance(value, Tensor): + dictionary[key].add_(value) else: - self._input_covariances[module_name].add_(covariance) + raise ValueError("Incompatible types for addition.") + return dictionary @classmethod def compute_parameter_mapping( @@ -920,34 +1133,27 @@ def compute_parameter_mapping( return positions - def compute_eigendecomposition(self, keep_kronecker_factors: bool = False) -> None: - """Compute the eigendecomposition of the KFAC approximation. - - Args: - keep_kronecker_factors: Whether to keep the Kronecker factors. If ``False``, - will free the memory used by the Kronecker factors. - Defaults to ``False``. - """ + def _compute_eigendecomposition(self) -> None: + """Compute the eigendecomposition of the KFAC approximation.""" if not self._input_covariances and not self._gradient_covariances: self._compute_kfac() for mod_name in self._mapping.keys(): - aaT = self._input_covariances[mod_name] - ggT = self._gradient_covariances[mod_name] - if not keep_kronecker_factors: - del self._input_covariances[mod_name] - del self._gradient_covariances[mod_name] + # Free up memory by deleting the Kronecker factors + aaT = self._input_covariances.pop(mod_name, None) + ggT = self._gradient_covariances.pop(mod_name, None) # Compute eigendecomposition of the Kronecker factors - aaT_eigvals, aaT_eigvecs = eigh(aaT) - self._input_covariances_eigenvectors[mod_name] = aaT_eigvecs - self._input_covariances_eigenvalues[mod_name] = aaT_eigvals - del aaT - - ggT_eigvals, ggT_eigvecs = eigh(ggT) - self._gradient_covariances_eigenvectors[mod_name] = ggT_eigvecs - self._gradient_covariances_eigenvalues[mod_name] = ggT_eigvals - del ggT + if aaT is not None: + aaT_eigvals, aaT_eigvecs = eigh(aaT) + self._input_covariances_eigenvectors[mod_name] = aaT_eigvecs + self._input_covariances_eigenvalues[mod_name] = aaT_eigvals + del aaT + if ggT is not None: + ggT_eigvals, ggT_eigvecs = eigh(ggT) + self._gradient_covariances_eigenvectors[mod_name] = ggT_eigvecs + self._gradient_covariances_eigenvalues[mod_name] = ggT_eigvals + del ggT @property def trace(self) -> Tensor: @@ -964,25 +1170,41 @@ def trace(self) -> Tensor: if self._trace is not None: return self._trace - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the trace self._trace = 0.0 - for mod_name, param_pos in self._mapping.items(): - tr_ggT = self._gradient_covariances[mod_name].trace() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - self._trace += self._input_covariances[mod_name].trace() * tr_ggT - else: - for p_name in param_pos.keys(): - self._trace += tr_ggT * ( - self._input_covariances[mod_name].trace() - if p_name == "weight" - else 1 - ) + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._trace += val.sum() + else: + self._trace += corrected_eigenvalues.sum() + else: + # TODO: Also support the trace for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + tr_ggT = self._gradient_covariances[mod_name].trace() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + self._trace += self._input_covariances[mod_name].trace() * tr_ggT + else: + for p_name in param_pos.keys(): + self._trace += tr_ggT * ( + self._input_covariances[mod_name].trace() + if p_name == "weight" + else 1 + ) + return self._trace @property @@ -1001,33 +1223,49 @@ def det(self) -> Tensor: if self._det is not None: return self._det - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the determinant self._det = 1.0 - for mod_name, param_pos in self._mapping.items(): - m = self._gradient_covariances[mod_name].shape[0] - det_ggT = self._gradient_covariances[mod_name].det() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - n = self._input_covariances[mod_name].shape[0] - det_aaT = self._input_covariances[mod_name].det() - self._det *= det_aaT.pow(m) * det_ggT.pow(n) - else: - for p_name in param_pos.keys(): - n = ( - self._input_covariances[mod_name].shape[0] - if p_name == "weight" - else 1 - ) - self._det *= det_ggT.pow(n) * ( - self._input_covariances[mod_name].det().pow(m) - if p_name == "weight" - else 1 - ) + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._det *= val.prod() + else: + self._det *= corrected_eigenvalues.prod() + else: + # TODO: Also support the det for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + m = self._gradient_covariances[mod_name].shape[0] + det_ggT = self._gradient_covariances[mod_name].det() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + n = self._input_covariances[mod_name].shape[0] + det_aaT = self._input_covariances[mod_name].det() + self._det *= det_aaT.pow(m) * det_ggT.pow(n) + else: + for p_name in param_pos.keys(): + n = ( + self._input_covariances[mod_name].shape[0] + if p_name == "weight" + else 1 + ) + self._det *= det_ggT.pow(n) * ( + self._input_covariances[mod_name].det().pow(m) + if p_name == "weight" + else 1 + ) + return self._det @property @@ -1047,33 +1285,49 @@ def logdet(self) -> Tensor: if self._logdet is not None: return self._logdet - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the log determinant self._logdet = 0.0 - for mod_name, param_pos in self._mapping.items(): - m = self._gradient_covariances[mod_name].shape[0] - logdet_ggT = self._gradient_covariances[mod_name].logdet() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - n = self._input_covariances[mod_name].shape[0] - logdet_aaT = self._input_covariances[mod_name].logdet() - self._logdet += m * logdet_aaT + n * logdet_ggT - else: - for p_name in param_pos.keys(): - n = ( - self._input_covariances[mod_name].shape[0] - if p_name == "weight" - else 1 - ) - self._logdet += n * logdet_ggT + ( - m * self._input_covariances[mod_name].logdet() - if p_name == "weight" - else 0 - ) + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._logdet += val.log().sum() + else: + self._logdet += corrected_eigenvalues.log().sum() + else: + # TODO: Also support the log det for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + m = self._gradient_covariances[mod_name].shape[0] + logdet_ggT = self._gradient_covariances[mod_name].logdet() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + n = self._input_covariances[mod_name].shape[0] + logdet_aaT = self._input_covariances[mod_name].logdet() + self._logdet += m * logdet_aaT + n * logdet_ggT + else: + for p_name in param_pos.keys(): + n = ( + self._input_covariances[mod_name].shape[0] + if p_name == "weight" + else 1 + ) + self._logdet += n * logdet_ggT + ( + m * self._input_covariances[mod_name].logdet() + if p_name == "weight" + else 0 + ) + return self._logdet @property @@ -1090,28 +1344,43 @@ def frobenius_norm(self) -> Tensor: if self._frobenius_norm is not None: return self._frobenius_norm - if not self._input_covariances and not self._gradient_covariances: + if ( + not self._input_covariances + and not self._gradient_covariances + and not self._corrected_eigenvalues + ): self._compute_kfac() + # Initialize the Frobenius norm self._frobenius_norm = 0.0 - for mod_name, param_pos in self._mapping.items(): - squared_frob_ggT = self._gradient_covariances[mod_name].square().sum() - if ( - not self._separate_weight_and_bias - and "weight" in param_pos.keys() - and "bias" in param_pos.keys() - ): - squared_frob_aaT = self._input_covariances[mod_name].square().sum() - self._frobenius_norm += squared_frob_aaT * squared_frob_ggT - else: - for p_name in param_pos.keys(): - self._frobenius_norm += squared_frob_ggT * ( - self._input_covariances[mod_name].square().sum() - if p_name == "weight" - else 1 - ) - self._frobenius_norm.sqrt_() - return self._frobenius_norm + + if self._correct_eigenvalues: + for corrected_eigenvalues in self._corrected_eigenvalues.values(): + if isinstance(corrected_eigenvalues, dict): + for val in corrected_eigenvalues.values(): + self._frobenius_norm += val.square().sum() + else: + self._frobenius_norm += corrected_eigenvalues.square().sum() + else: + # TODO: Also support the Frobenius norm for eigendecomposition of KFAC + for mod_name, param_pos in self._mapping.items(): + squared_frob_ggT = self._gradient_covariances[mod_name].square().sum() + if ( + not self._separate_weight_and_bias + and "weight" in param_pos.keys() + and "bias" in param_pos.keys() + ): + squared_frob_aaT = self._input_covariances[mod_name].square().sum() + self._frobenius_norm += squared_frob_aaT * squared_frob_ggT + else: + for p_name in param_pos.keys(): + self._frobenius_norm += squared_frob_ggT * ( + self._input_covariances[mod_name].square().sum() + if p_name == "weight" + else 1 + ) + + return self._frobenius_norm.sqrt_() def state_dict(self) -> Dict[str, Any]: """Return the state of the KFAC linear operator. @@ -1136,13 +1405,22 @@ def state_dict(self) -> Dict[str, Any]: "fisher_type": self._fisher_type, "mc_samples": self._mc_samples, "kfac_approx": self._kfac_approx, + "correct_eigenvalues": self._correct_eigenvalues, "num_per_example_loss_terms": self._num_per_example_loss_terms, "separate_weight_and_bias": self._separate_weight_and_bias, "num_data": self._N_data, # Kronecker factors (if computed) "input_covariances": self._input_covariances, "gradient_covariances": self._gradient_covariances, - # Properties (not necessarily computed) + # Kronecker factors eigendecomposition (if computed) + "input_covariances_eigenvectors": self._input_covariances_eigenvectors, + "input_covariances_eigenvalues": self._input_covariances_eigenvalues, + "gradient_covariances_eigenvectors": self._gradient_covariances_eigenvectors, + "gradient_covariances_eigenvalues": self._gradient_covariances_eigenvalues, + # Quantities for eigenvalue correction (if computed) + "cached_activations": self._cached_activations, + "corrected_eigenvalues": self._corrected_eigenvalues, + # Properties (if computed) "trace": self._trace, "det": self._det, "logdet": self._logdet, @@ -1183,6 +1461,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self._fisher_type = state_dict["fisher_type"] self._mc_samples = state_dict["mc_samples"] self._kfac_approx = state_dict["kfac_approx"] + self._correct_eigenvalues = state_dict["correct_eigenvalues"] self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"] self._separate_weight_and_bias = state_dict["separate_weight_and_bias"] self._N_data = state_dict["num_data"] @@ -1208,6 +1487,26 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self._input_covariances = state_dict["input_covariances"] self._gradient_covariances = state_dict["gradient_covariances"] + # Set Kronecker factors eigendecomposition (if computed) + # TODO: should we check if the keys match the mapping keys? + self._input_covariances_eigenvectors = state_dict[ + "input_covariances_eigenvectors" + ] + self._input_covariances_eigenvalues = state_dict[ + "input_covariances_eigenvalues" + ] + self._gradient_covariances_eigenvectors = state_dict[ + "gradient_covariances_eigenvectors" + ] + self._gradient_covariances_eigenvalues = state_dict[ + "gradient_covariances_eigenvalues" + ] + + # Set quantities for eigenvalue correction (if computed) + # TODO: should we check if the keys match the mapping keys? + self._cached_activations = state_dict["cached_activations"] + self._corrected_eigenvalues = state_dict["corrected_eigenvalues"] + # Set properties (not necessarily computed) self._trace = state_dict["trace"] self._det = state_dict["det"] @@ -1261,6 +1560,7 @@ def from_state_dict( fisher_type=state_dict["fisher_type"], mc_samples=state_dict["mc_samples"], kfac_approx=state_dict["kfac_approx"], + correct_eigenvalues=state_dict["correct_eigenvalues"], num_per_example_loss_terms=state_dict["num_per_example_loss_terms"], separate_weight_and_bias=state_dict["separate_weight_and_bias"], num_data=state_dict["num_data"], From 18062efbdae4fd91b95b512bf52f7c75accfabd1 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 16 Sep 2024 23:29:05 -0400 Subject: [PATCH 04/15] Add inverse EKFAC test coverage --- test/test_inverse.py | 67 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/test/test_inverse.py b/test/test_inverse.py index f064ccc..150abc8 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -440,6 +440,9 @@ def test_KFAC_inverse_heuristically_damped_matmat( # noqa: C901 @mark.parametrize( "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] ) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_KFAC_inverse_exactly_damped_matmat( case: Tuple[ Module, @@ -450,6 +453,7 @@ def test_KFAC_inverse_exactly_damped_matmat( cache: bool, exclude: str, separate_weight_and_bias: bool, + correct_eigenvalues: bool, delta: float = 1e-2, ): """Test matrix-matrix multiplication by an inverse (exactly) damped KFAC approximation.""" @@ -479,6 +483,7 @@ def test_KFAC_inverse_exactly_damped_matmat( batch_size_fn=batch_size_fn, separate_weight_and_bias=separate_weight_and_bias, check_deterministic=False, + correct_eigenvalues=correct_eigenvalues, ) KFAC.dtype = float64 @@ -511,7 +516,7 @@ def test_KFAC_inverse_exactly_damped_matmat( report_nonclose(inv_KFAC @ X, inv_KFAC_naive @ X) assert inv_KFAC._cache == cache - if cache: + if cache and not correct_eigenvalues: # test that the cache is not empty assert len(inv_KFAC._inverse_input_covariances) > 0 assert len(inv_KFAC._inverse_gradient_covariances) > 0 @@ -521,6 +526,9 @@ def test_KFAC_inverse_exactly_damped_matmat( assert len(inv_KFAC._inverse_gradient_covariances) == 0 +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_KFAC_inverse_damped_torch_matmat( case: Tuple[ Module, @@ -528,6 +536,7 @@ def test_KFAC_inverse_damped_torch_matmat( List[Parameter], Iterable[Tuple[torch.Tensor, torch.Tensor]], ], + correct_eigenvalues: bool, delta: float = 1e-2, ): """Test torch matrix-matrix multiplication by an inverse damped KFAC approximation.""" @@ -552,9 +561,14 @@ def test_KFAC_inverse_damped_torch_matmat( data, batch_size_fn=batch_size_fn, check_deterministic=False, + correct_eigenvalues=correct_eigenvalues, ) KFAC.dtype = float64 - inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta)) + inv_KFAC = KFACInverseLinearOperator( + KFAC, + damping=delta if correct_eigenvalues else (delta, delta), + use_exact_damping=True if correct_eigenvalues else False, + ) device = KFAC._device num_vectors = 2 @@ -584,6 +598,9 @@ def test_KFAC_inverse_damped_torch_matmat( report_nonclose(inv_KFAC_X, kfac_x_numpy) +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) def test_KFAC_inverse_damped_torch_matvec( case: Tuple[ Module, @@ -591,6 +608,7 @@ def test_KFAC_inverse_damped_torch_matvec( List[Parameter], Iterable[Tuple[torch.Tensor, torch.Tensor]], ], + correct_eigenvalues: bool, delta: float = 1e-2, ): """Test torch matrix-vector multiplication by an inverse damped KFAC approximation.""" @@ -615,9 +633,14 @@ def test_KFAC_inverse_damped_torch_matvec( data, batch_size_fn=batch_size_fn, check_deterministic=False, + correct_eigenvalues=correct_eigenvalues, ) KFAC.dtype = float64 - inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta)) + inv_KFAC = KFACInverseLinearOperator( + KFAC, + damping=delta if correct_eigenvalues else (delta, delta), + use_exact_damping=True if correct_eigenvalues else False, + ) device = KFAC._device x = torch.rand(KFAC.shape[1], dtype=dtype, device=device) @@ -647,7 +670,10 @@ def test_KFAC_inverse_damped_torch_matvec( report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy()) -def test_KFAC_inverse_save_and_load_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_KFAC_inverse_save_and_load_state_dict(correct_eigenvalues): """Test that KFACInverseLinearOperator can be saved and loaded from state dict.""" torch.manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -662,11 +688,16 @@ def test_KFAC_inverse_save_and_load_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # create inverse KFAC inv_kfac = KFACInverseLinearOperator( - kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + kfac, + damping=1e-2, + use_exact_damping=True if correct_eigenvalues else False, + use_heuristic_damping=False if correct_eigenvalues else True, + retry_double_precision=False, ) _ = inv_kfac @ eye(kfac.shape[1]) # to trigger inverse computation @@ -681,7 +712,9 @@ def test_KFAC_inverse_save_and_load_state_dict(): inv_kfac_wrong.load_state_dict(torch.load("inv_kfac_state_dict.pt")) # create new inverse KFAC and load state dict - inv_kfac_new = KFACInverseLinearOperator(kfac) + inv_kfac_new = KFACInverseLinearOperator( + kfac, use_exact_damping=True if correct_eigenvalues else False + ) inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt")) # clean up os.remove("inv_kfac_state_dict.pt") @@ -692,7 +725,10 @@ def test_KFAC_inverse_save_and_load_state_dict(): report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) -def test_KFAC_inverse_from_state_dict(): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_KFAC_inverse_from_state_dict(correct_eigenvalues): """Test that KFACInverseLinearOperator can be created from state dict.""" torch.manual_seed(0) batch_size, D_in, D_out = 4, 3, 2 @@ -707,11 +743,16 @@ def test_KFAC_inverse_from_state_dict(): MSELoss(reduction="sum"), params, [(X, y)], + correct_eigenvalues=correct_eigenvalues, ) # create inverse KFAC and save state dict inv_kfac = KFACInverseLinearOperator( - kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + kfac, + damping=1e-2, + use_exact_damping=True if correct_eigenvalues else False, + use_heuristic_damping=False if correct_eigenvalues else True, + retry_double_precision=False, ) state_dict = inv_kfac.state_dict() @@ -724,7 +765,10 @@ def test_KFAC_inverse_from_state_dict(): report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) -def test_torch_matvec_list_output_shapes(cnn_case): +@mark.parametrize( + "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] +) +def test_torch_matvec_list_output_shapes(cnn_case, correct_eigenvalues): """Test output shapes with list input format (issue #124).""" model, loss_func, params, data, batch_size_fn = cnn_case kfac = KFACLinearOperator( @@ -733,8 +777,11 @@ def test_torch_matvec_list_output_shapes(cnn_case): params, data, batch_size_fn=batch_size_fn, + correct_eigenvalues=correct_eigenvalues, + ) + inv_kfac = KFACInverseLinearOperator( + kfac, damping=1e-2, use_exact_damping=True if correct_eigenvalues else False ) - inv_kfac = KFACInverseLinearOperator(kfac, damping=1e-2) vec = [torch.rand_like(p) for p in kfac._params] out_list = inv_kfac.torch_matvec(vec) assert len(out_list) == len(kfac._params) From b09672650b69d4494f87be19eb5219cb6bf5cc07 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 16 Sep 2024 23:30:13 -0400 Subject: [PATCH 05/15] Add inverse EKFAC support --- curvlinops/inverse.py | 195 ++++++++++++++++-------------------------- 1 file changed, 72 insertions(+), 123 deletions(-) diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index cd5482b..4b563a5 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -10,11 +10,7 @@ from torch import Tensor, cat, cholesky_inverse, eye, float64, outer from torch.linalg import cholesky, eigh -from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType - -KFACInvType = TypeVar( - "KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] -) +from curvlinops.kfac import KFACLinearOperator, KFACType, ParameterMatrixType class _InverseLinearOperator(LinearOperator): @@ -355,6 +351,8 @@ def __init__( raise ValueError( "Heuristic and exact damping require a single damping value." ) + if self._A._correct_eigenvalues and not use_exact_damping: + raise ValueError("Only exact damping is supported for EKFAC.") self._damping = damping self._use_heuristic_damping = use_heuristic_damping @@ -362,8 +360,8 @@ def __init__( self._use_exact_damping = use_exact_damping self._cache = cache self._retry_double_precision = retry_double_precision - self._inverse_input_covariances: Dict[str, KFACInvType] = {} - self._inverse_gradient_covariances: Dict[str, KFACInvType] = {} + self._inverse_input_covariances: Dict[str, KFACType] = {} + self._inverse_gradient_covariances: Dict[str, KFACType] = {} def _compute_damping( self, aaT: Optional[Tensor], ggT: Optional[Tensor] @@ -408,18 +406,20 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor: ) def _compute_inverse_factors( - self, aaT: Optional[Tensor], ggT: Optional[Tensor] - ) -> Tuple[KFACInvType, KFACInvType]: + self, aaT: Optional[Tensor], ggT: Optional[Tensor], name: str + ) -> Tuple[KFACType, KFACType, Optional[Tensor]]: """Compute the inverses of the Kronecker factors for a given layer. Args: aaT: Input covariance matrix. ``None`` for biases. ggT: Gradient covariance matrix. + name: Name of the layer for which to invert Kronecker factors. Returns: Tuple of inverses (or eigendecompositions) of the input and gradient - covariance Kronecker factors. Can be ``None`` if the input or gradient - covariance is ``None`` (e.g. the input covariances for biases). + covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if + the input or gradient covariance is ``None`` (e.g. the input covariances for + biases). Raises: RuntimeError: If a Cholesky decomposition (and optionally the retry in @@ -430,7 +430,27 @@ def _compute_inverse_factors( # Kronecker-factored eigenbasis (KFE). aaT_eigvals, aaT_eigvecs = (None, None) if aaT is None else eigh(aaT) ggT_eigvals, ggT_eigvecs = (None, None) if ggT is None else eigh(ggT) - return (aaT_eigvecs, aaT_eigvals), (ggT_eigvecs, ggT_eigvals) + param_pos = self._A._mapping[name] + if ( + not self._A._separate_weight_and_bias + and "weight" in param_pos + and "bias" in param_pos + ): + inv_damped_eigenvalues = ( + outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) + ) + else: + inv_damped_eigenvalues = {} + for p_name, pos in param_pos.items(): + if p_name == "weight": + inv_damped_eigenvalues[pos] = ( + outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) + ) + else: + inv_damped_eigenvalues[pos] = ggT_eigvals.add( + self._damping + ).pow_(-1) + return aaT_eigvecs, ggT_eigvecs, inv_damped_eigenvalues else: damping_aaT, damping_ggT = self._compute_damping(aaT, ggT) @@ -476,11 +496,11 @@ def _compute_inverse_factors( raise error ggT_inv = None if ggT_chol is None else cholesky_inverse(ggT_chol) - return aaT_inv, ggT_inv + return aaT_inv, ggT_inv, None def _compute_or_get_cached_inverse( self, name: str - ) -> Tuple[KFACInvType, KFACInvType]: + ) -> Tuple[KFACType, KFACType, Optional[Tensor]]: """Invert the Kronecker factors of the KFACLinearOperator or retrieve them. Args: @@ -488,117 +508,37 @@ def _compute_or_get_cached_inverse( Returns: Tuple of inverses (or eigendecompositions) of the input and gradient - covariance Kronecker factors. Can be ``None`` if the input or gradient - covariance is ``None`` (e.g. the input covariances for biases). + covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if + the input or gradient covariance is ``None`` (e.g. the input covariances for + biases). """ if name in self._inverse_input_covariances: aaT_inv = self._inverse_input_covariances.get(name) ggT_inv = self._inverse_gradient_covariances.get(name) - return aaT_inv, ggT_inv - - aaT = self._A._input_covariances.get(name) - ggT = self._A._gradient_covariances.get(name) - aaT_inv, ggT_inv = self._compute_inverse_factors(aaT, ggT) - - if self._cache: - self._inverse_input_covariances[name] = aaT_inv - self._inverse_gradient_covariances[name] = ggT_inv - - return aaT_inv, ggT_inv - - def _left_and_right_multiply( - self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType - ) -> Tensor: - """Left and right multiply matrix with inverse Kronecker factors. - - Args: - M_joint: Matrix for multiplication. - aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for - biases. - ggT_inv: Inverse of the gradient covariance Kronecker factor. - - Returns: - Matrix-multiplication result ``KFAC⁻¹ @ M_joint``. - """ - if self._use_exact_damping: - # Perform damped preconditioning in KFE, e.g. see equation (21) in - # https://arxiv.org/abs/2308.03296. - aaT_eigvecs, aaT_eigvals = aaT_inv - ggT_eigvecs, ggT_eigvals = ggT_inv - # Transform in eigenbasis. - M_joint = einsum( - ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l" - ) - # Divide by damped eigenvalues to perform the inversion. - M_joint.div_(outer(ggT_eigvals, aaT_eigvals).add_(self._damping)) - # Transform back to standard basis. - M_joint = einsum( - ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m j k, l k -> m i l" - ) + return aaT_inv, ggT_inv, None + + if self._A._correct_eigenvalues: + aaT_inv = self._A._input_covariances_eigenvectors.get(name) + ggT_inv = self._A._gradient_covariances_eigenvectors.get(name) + eigenvalues = self._A._corrected_eigenvalues.get(name) + if isinstance(eigenvalues, dict): + inv_damped_eigenvalues = {} + for key, val in eigenvalues.items(): + inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1) + elif isinstance(eigenvalues, Tensor): + inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1) else: - M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j, m j k, k l -> m i l") - return M_joint - - def _separate_left_and_right_multiply( - self, - M_torch: Tensor, - param_pos: Dict[str, int], - aaT_inv: KFACInvType, - ggT_inv: KFACInvType, - ) -> Tensor: - """Multiply matrix with inverse Kronecker factors for separated weight and bias. - - Args: - M_torch: Matrix for multiplication. - param_pos: Dictionary with positions of the weight and bias parameters. - aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for - biases. - ggT_inv: Inverse of the gradient covariance Kronecker factor. - - Returns: - Matrix-multiplication result ``KFAC⁻¹ @ M_torch``. - """ - if self._use_exact_damping: - # Perform damped preconditioning in KFE, e.g. see equation (21) in - # https://arxiv.org/abs/2308.03296. - aaT_eigvecs, aaT_eigvals = aaT_inv - ggT_eigvecs, ggT_eigvals = ggT_inv - - for p_name, pos in param_pos.items(): - # for weights we need to multiply from the right with aaT - # for weights and biases we need to multiply from the left with ggT - if p_name == "weight": - M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)") - aaT_fac = aaT_eigvecs if self._use_exact_damping else aaT_inv - # If `use_exact_damping` is `True`, we transform to eigenbasis - M_torch[pos] = einsum(M_w, aaT_fac, "m i j, j k -> m i k") - - ggT_fac = ggT_eigvecs if self._use_exact_damping else ggT_inv - dims = ( - "m i ... -> m j ..." - if self._use_exact_damping - else " m j ... -> m i ..." + aaT = self._A._input_covariances.get(name) + ggT = self._A._gradient_covariances.get(name) + aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors( + aaT, ggT, name ) - # If `use_exact_damping` is `True`, we transform to eigenbasis - M_torch[pos] = einsum(ggT_fac, M_torch[pos], f"i j, {dims}") - - if self._use_exact_damping: - # Divide by damped eigenvalues to perform the inversion and transform - # back to standard basis. - if p_name == "weight": - M_torch[pos].div_( - outer(ggT_eigvals, aaT_eigvals).add_(self._damping) - ) - M_torch[pos] = einsum( - M_torch[pos], aaT_eigvecs, "m i j, k j -> m i k" - ) - else: - M_torch[pos].div_(ggT_eigvals.add_(self._damping)) - M_torch[pos] = einsum( - ggT_eigvecs, M_torch[pos], "i j, m j ... -> m i ..." - ) - return M_torch + if self._cache: + self._inverse_input_covariances[name] = aaT_inv + self._inverse_gradient_covariances[name] = ggT_inv + + return aaT_inv, ggT_inv, inv_damped_eigenvalues def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: """Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch. @@ -621,12 +561,19 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: ``[D, K]`` with some ``K``. """ return_tensor, M_torch = self._A._check_input_type_and_preprocess(M_torch) - if not self._A._input_covariances and not self._A._gradient_covariances: + if ( + not self._A._input_covariances + and not self._A._gradient_covariances + and not self._A._input_covariances_eigenvectors + and not self._A._gradient_covariances_eigenvectors + ): self._A._compute_kfac() for mod_name, param_pos in self._A._mapping.items(): # retrieve the inverses of the Kronecker factors from cache or invert them - aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name) + aaT_inv, ggT_inv, inv_damped_eigenvalues = ( + self._compute_or_get_cached_inverse(mod_name) + ) # cache the weight shape to ensure correct shapes are returned if "weight" in param_pos: weight_shape = M_torch[param_pos["weight"]].shape @@ -640,12 +587,14 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType: w_pos, b_pos = param_pos["weight"], param_pos["bias"] M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)") M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2) - M_joint = self._left_and_right_multiply(M_joint, aaT_inv, ggT_inv) + M_joint = self._A._left_and_right_multiply( + M_joint, aaT_inv, ggT_inv, inv_damped_eigenvalues + ) w_cols = M_w.shape[2] M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2) else: - M_torch = self._separate_left_and_right_multiply( - M_torch, param_pos, aaT_inv, ggT_inv + M_torch = self._A._separate_left_and_right_multiply( + M_torch, param_pos, aaT_inv, ggT_inv, inv_damped_eigenvalues ) # restore original shapes From a2dec74ecd8dc1d0e4c907edc2bb871970380c1c Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 16 Sep 2024 23:46:50 -0400 Subject: [PATCH 06/15] Fix flake8 --- curvlinops/inverse.py | 72 +++++++++++++++++++++++++++---------------- curvlinops/kfac.py | 18 ++++++----- test/test_kfac.py | 1 - 3 files changed, 56 insertions(+), 35 deletions(-) diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 4b563a5..6473a8e 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -1,16 +1,16 @@ """Implements linear operator inverses.""" from math import sqrt -from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union from warnings import warn -from einops import einsum, rearrange +from einops import rearrange from numpy import allclose, column_stack, ndarray from scipy.sparse.linalg import LinearOperator, cg, lsmr from torch import Tensor, cat, cholesky_inverse, eye, float64, outer from torch.linalg import cholesky, eigh -from curvlinops.kfac import KFACLinearOperator, KFACType, ParameterMatrixType +from curvlinops.kfac import FactorType, KFACLinearOperator, ParameterMatrixType class _InverseLinearOperator(LinearOperator): @@ -360,8 +360,8 @@ def __init__( self._use_exact_damping = use_exact_damping self._cache = cache self._retry_double_precision = retry_double_precision - self._inverse_input_covariances: Dict[str, KFACType] = {} - self._inverse_gradient_covariances: Dict[str, KFACType] = {} + self._inverse_input_covariances: Dict[str, FactorType] = {} + self._inverse_gradient_covariances: Dict[str, FactorType] = {} def _compute_damping( self, aaT: Optional[Tensor], ggT: Optional[Tensor] @@ -405,9 +405,44 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor: M.add(eye(M.shape[0], dtype=M.dtype, device=M.device), alpha=damping) ) + def _compute_inv_damped_eigenvalues( + self, aaT_eigvals: Tensor, ggT_eigvals: Tensor, name: str + ) -> Union[Tensor, Dict[str, Tensor]]: + """Compute the inverses of the damped eigenvalues for a given layer. + + Args: + aaT_eigvals: Eigenvalues of the input covariance matrix. + ggT_eigvals: Eigenvalues of the gradient covariance matrix. + name: Name of the layer for which to damp and invert eigenvalues. + + Returns: + Inverses of the damped eigenvalues. + """ + param_pos = self._A._mapping[name] + if ( + not self._A._separate_weight_and_bias + and "weight" in param_pos + and "bias" in param_pos + ): + inv_damped_eigenvalues = ( + outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) + ) + else: + inv_damped_eigenvalues = {} + for p_name, pos in param_pos.items(): + if p_name == "weight": + inv_damped_eigenvalues[pos] = ( + outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) + ) + else: + inv_damped_eigenvalues[pos] = ggT_eigvals.add(self._damping).pow_( + -1 + ) + return inv_damped_eigenvalues + def _compute_inverse_factors( self, aaT: Optional[Tensor], ggT: Optional[Tensor], name: str - ) -> Tuple[KFACType, KFACType, Optional[Tensor]]: + ) -> Tuple[FactorType, FactorType, Optional[Tensor]]: """Compute the inverses of the Kronecker factors for a given layer. Args: @@ -430,26 +465,9 @@ def _compute_inverse_factors( # Kronecker-factored eigenbasis (KFE). aaT_eigvals, aaT_eigvecs = (None, None) if aaT is None else eigh(aaT) ggT_eigvals, ggT_eigvecs = (None, None) if ggT is None else eigh(ggT) - param_pos = self._A._mapping[name] - if ( - not self._A._separate_weight_and_bias - and "weight" in param_pos - and "bias" in param_pos - ): - inv_damped_eigenvalues = ( - outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) - ) - else: - inv_damped_eigenvalues = {} - for p_name, pos in param_pos.items(): - if p_name == "weight": - inv_damped_eigenvalues[pos] = ( - outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) - ) - else: - inv_damped_eigenvalues[pos] = ggT_eigvals.add( - self._damping - ).pow_(-1) + inv_damped_eigenvalues = self._compute_inv_damped_eigenvalues( + aaT_eigvals, ggT_eigvals, name + ) return aaT_eigvecs, ggT_eigvecs, inv_damped_eigenvalues else: damping_aaT, damping_ggT = self._compute_damping(aaT, ggT) @@ -500,7 +518,7 @@ def _compute_inverse_factors( def _compute_or_get_cached_inverse( self, name: str - ) -> Tuple[KFACType, KFACType, Optional[Tensor]]: + ) -> Tuple[FactorType, FactorType, Optional[Tensor]]: """Invert the Kronecker factors of the KFACLinearOperator or retrieve them. Args: diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 0d1d1e7..a598c5d 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -51,8 +51,8 @@ # shape as the parameters, or a single matrix/vector of shape `[D, D]`/`[D]` where `D` # is the number of parameters. ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor]) -KFACType = TypeVar( - "KFACType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] +FactorType = TypeVar( + "FactorType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]] ) @@ -438,8 +438,8 @@ def _check_input_type_and_preprocess( @staticmethod def _left_and_right_multiply( M_joint: Tensor, - aaT: KFACType, - ggT: KFACType, + aaT: FactorType, + ggT: FactorType, eigenvalues: Optional[Tensor], ) -> Tensor: """Left and right multiply matrix with Kronecker factors. @@ -477,8 +477,8 @@ def _left_and_right_multiply( def _separate_left_and_right_multiply( M_torch: Tensor, param_pos: Dict[str, int], - aaT: KFACType, - ggT: KFACType, + aaT: FactorType, + ggT: FactorType, eigenvalues: Optional[Tensor], ) -> Tensor: """Multiply matrix with Kronecker factors for separated weight and bias. @@ -933,7 +933,7 @@ def _accumulate_gradient_covariance( def _compute_eigenvalue_correction( self, module_name: str, g: Tensor, correction: int ): - """Compute the corrected eigenvalues for the EKFAC approximation. + r"""Compute the corrected eigenvalues for the EKFAC approximation. The corrected eigenvalues are computed as :math:`\lambda_{\text{corrected}} = (Q_g^T G Q_a)^2`, where @@ -1083,6 +1083,10 @@ def _set_or_add_( Returns: The updated dictionary. + + Raises: + ValueError: If the types of the value and the dictionary entry are + incompatible. """ if key not in dictionary: dictionary[key] = value diff --git a/test/test_kfac.py b/test/test_kfac.py index e56d8dc..13d8f63 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -37,7 +37,6 @@ from curvlinops import EFLinearOperator, GGNLinearOperator from curvlinops.examples.utils import report_nonclose -from curvlinops.gradient_moments import EFLinearOperator from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType From d60c87630f7e0a7e3caec99a5080dc78a15b339c Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 16 Sep 2024 23:51:24 -0400 Subject: [PATCH 07/15] Fix docstring and lower test numerical threshold --- curvlinops/kfac.py | 1 + test/test_kfac.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index a598c5d..42c5bf9 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -60,6 +60,7 @@ class MetaEnum(EnumMeta): """Metaclass for the Enum class for desired behavior of the `in` operator.""" def __contains__(cls, item): + """Check if an item is a valid member of the Enum.""" try: cls(item) except ValueError: diff --git a/test/test_kfac.py b/test/test_kfac.py index 13d8f63..98426d6 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -412,7 +412,7 @@ def test_kfac_ef_one_datum( ) kfac_mat = kfac @ eye(kfac.shape[1]) - report_nonclose(ef, kfac_mat) + report_nonclose(ef, kfac_mat, atol=1e-7) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) From b30930043541c762dc0c3207061fcb7fbc97b02c Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 17 Sep 2024 00:09:34 -0400 Subject: [PATCH 08/15] Fix MetaEnum docstring --- curvlinops/kfac.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 42c5bf9..48ecda2 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -59,8 +59,15 @@ class MetaEnum(EnumMeta): """Metaclass for the Enum class for desired behavior of the `in` operator.""" - def __contains__(cls, item): - """Check if an item is a valid member of the Enum.""" + def __contains__(cls, item: str) -> bool: + """Check if an item is a valid member of the Enum. + + Args: + item: The item to check. + + Returns: + ``True`` if the item is a valid member of the Enum, ``False`` otherwise. + """ try: cls(item) except ValueError: From 725c413ace484ae1aae1620a82f85748e7e97fc6 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 17 Sep 2024 00:09:57 -0400 Subject: [PATCH 09/15] Fix tests for FOOF+eigenvalue correction --- test/test_kfac.py | 51 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index 98426d6..02861b5 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -511,14 +511,22 @@ def test_multi_dim_output( # KFAC for deep linear network with 4d input and output params = list(model.parameters()) - kfac = KFACLinearOperator( - model, - loss_func, - params, - data, - fisher_type=fisher_type, - correct_eigenvalues=correct_eigenvalues, - ) + context = ( + raises(ValueError, match="eigenvalues") + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + kfac = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY: + return kfac_mat = kfac @ eye(kfac.shape[1]) # KFAC for deep linear network with 4d input and equivalent 2d output @@ -598,15 +606,26 @@ def test_expand_setting_scaling( params = list(model.parameters()) # KFAC with sum reduction + params = list(model.parameters()) loss_func = loss(reduction="sum").to(dev) - kfac_sum = KFACLinearOperator( - model, - loss_func, - params, - data, - fisher_type=fisher_type, - correct_eigenvalues=correct_eigenvalues, - ) + + context = ( + raises(ValueError, match="eigenvalues") + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + kfac_sum = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY: + return + # FOOF does not scale the gradient covariances, even when using a mean reduction if fisher_type != FisherType.FORWARD_ONLY: # Simulate a mean reduction by manually scaling the gradient covariances From 1e4769dff75f4dd827bbc3e77dadec004541e114 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 17 Sep 2024 00:12:55 -0400 Subject: [PATCH 10/15] Fix black --- curvlinops/kfac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 48ecda2..a747b10 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -61,7 +61,7 @@ class MetaEnum(EnumMeta): def __contains__(cls, item: str) -> bool: """Check if an item is a valid member of the Enum. - + Args: item: The item to check. From 31cab8ab3d92477d47e4f4fb71b7cfa1adfd2773 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 17 Sep 2024 00:20:44 -0400 Subject: [PATCH 11/15] Ignore flake8 too complex error --- test/test_kfac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index 02861b5..8d72972 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -566,7 +566,7 @@ def test_multi_dim_output( "correct_eigenvalues", [False, True], ids=["", "eigenvalue_corrected"] ) @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) -def test_expand_setting_scaling( +def test_expand_setting_scaling( # noqa: C901 fisher_type: str, loss: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], dev: device, From 89c814f6acc6f2229cfe469fe01a7017f81a8c69 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 21 Sep 2024 14:34:21 -0400 Subject: [PATCH 12/15] Refactor inverse --- curvlinops/inverse.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 6473a8e..1ccacb4 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -430,14 +430,12 @@ def _compute_inv_damped_eigenvalues( else: inv_damped_eigenvalues = {} for p_name, pos in param_pos.items(): - if p_name == "weight": - inv_damped_eigenvalues[pos] = ( - outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1) - ) - else: - inv_damped_eigenvalues[pos] = ggT_eigvals.add(self._damping).pow_( - -1 - ) + inv_damped_eigenvalues[pos] = ( + outer(ggT_eigvals, aaT_eigvals) + if p_name == "weight" + else ggT_eigvals + ) + inv_damped_eigenvalues[pos].add_(self._damping).pow_(-1) return inv_damped_eigenvalues def _compute_inverse_factors( @@ -536,8 +534,8 @@ def _compute_or_get_cached_inverse( return aaT_inv, ggT_inv, None if self._A._correct_eigenvalues: - aaT_inv = self._A._input_covariances_eigenvectors.get(name) - ggT_inv = self._A._gradient_covariances_eigenvectors.get(name) + aaT_eigenvecs = self._A._input_covariances_eigenvectors.get(name) + ggT_eigenvecs = self._A._gradient_covariances_eigenvectors.get(name) eigenvalues = self._A._corrected_eigenvalues.get(name) if isinstance(eigenvalues, dict): inv_damped_eigenvalues = {} @@ -545,16 +543,17 @@ def _compute_or_get_cached_inverse( inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1) elif isinstance(eigenvalues, Tensor): inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1) - else: - aaT = self._A._input_covariances.get(name) - ggT = self._A._gradient_covariances.get(name) - aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors( - aaT, ggT, name - ) + return aaT_eigenvecs, ggT_eigenvecs, inv_damped_eigenvalues + + aaT = self._A._input_covariances.get(name) + ggT = self._A._gradient_covariances.get(name) + aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors( + aaT, ggT, name + ) - if self._cache: - self._inverse_input_covariances[name] = aaT_inv - self._inverse_gradient_covariances[name] = ggT_inv + if self._cache: + self._inverse_input_covariances[name] = aaT_inv + self._inverse_gradient_covariances[name] = ggT_inv return aaT_inv, ggT_inv, inv_damped_eigenvalues From cf45ef5676247ad8af5bffb4d0eb55d05db638a3 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 21 Sep 2024 14:52:21 -0400 Subject: [PATCH 13/15] Address KFAC refactor suggestions --- curvlinops/kfac.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index a747b10..d41ee0e 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -261,7 +261,7 @@ def __init__( "Only mc_samples=1 is supported for `fisher_type != FisherType.MC`." ) if fisher_type == FisherType.FORWARD_ONLY and correct_eigenvalues: - raise ValueError( + raise NotImplementedError( "Correcting eigenvalues is not supported for FisherType.FORWARD_ONLY." ) if kfac_approx not in self._SUPPORTED_KFAC_APPROX: @@ -277,6 +277,8 @@ def __init__( self._mc_samples = mc_samples self._kfac_approx = kfac_approx self._correct_eigenvalues = correct_eigenvalues + # Initialize flag which determines whether to compute the KFAC factors or the + # eigenvalue correction in the forward-backward pass(es) self._compute_eigenvalue_correction_flag = False self._input_covariances: Dict[str, Tensor] = {} self._gradient_covariances: Dict[str, Tensor] = {} @@ -457,7 +459,9 @@ def _left_and_right_multiply( aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for biases. ggT: Gradient covariance Kronecker factor or its eigenvectors. - eigenvalues: Corrected eigenvalues for the EKFAC approximation. + eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with + the eigendecomposition of the KFAC approximation. ``None`` for the + non-decomposed KFAC approximation. Returns: Matrix-multiplication result ``KFAC @ M_joint``. @@ -473,7 +477,7 @@ def _left_and_right_multiply( M_joint = einsum( ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l" ) - # Multiply by eigenvalues. + # Multiply (broadcasted) by eigenvalues. M_joint.mul_(eigenvalues) # Transform back to standard basis. M_joint = einsum( @@ -497,7 +501,9 @@ def _separate_left_and_right_multiply( aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for biases. ggT: Gradient covariance Kronecker factor or its eigenvectors. - eigenvalues: Corrected eigenvalues for the EKFAC approximation. + eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with + the eigendecomposition of the KFAC approximation. ``None`` for the + non-decomposed KFAC approximation. Returns: Matrix-multiplication result ``KFAC @ M_torch``. @@ -510,12 +516,13 @@ def _separate_left_and_right_multiply( # If `eigenvalues` is not `None`, we transform to eigenbasis here M_torch[pos] = einsum(M_w, aaT, "m i j, j k -> m i k") - dims = "m j ... -> m i ..." if eigenvalues is None else "m i ... -> m j ..." - # If `eigenvalues` is not `None`, we transform to eigenbasis here - M_torch[pos] = einsum(ggT, M_torch[pos], f"i j, {dims}") + # If `eigenvalues` is not `None`, we convert to eigenbasis here + M_torch[pos] = einsum( + ggT.T if eigenvalues else ggT, M_torch[pos], "i j, m j ... -> m i ..." + ) if eigenvalues is not None: - # Multiply by eigenvalues and transform back to standard basis + # Multiply (broadcasted) by eigenvalues, convert back to original basis M_torch[pos].mul_(eigenvalues[pos]) if p_name == "weight": M_torch[pos] = einsum(M_torch[pos], aaT, "m i j, k j -> m i k") @@ -1101,7 +1108,10 @@ def _set_or_add_( elif isinstance(dictionary[key], Tensor) and isinstance(value, Tensor): dictionary[key].add_(value) else: - raise ValueError("Incompatible types for addition.") + raise ValueError( + "Incompatible types for addition: dictionary value of type " + f"{type(dictionary[key])} and value to be added of type {type(value)}." + ) return dictionary @classmethod From 7422b0044749857d066d1897c8ca54f1a51187f7 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 21 Sep 2024 15:06:51 -0400 Subject: [PATCH 14/15] Fix docstring and error catching in test --- curvlinops/kfac.py | 4 ++++ test/test_kfac.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index d41ee0e..33d0d74 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -243,7 +243,11 @@ def __init__( Raises: RuntimeError: If the check for deterministic behavior fails. ValueError: If the loss function is not supported. + ValueError: If the Fisher type is not supported. + ValueError: If the KFAC approximation type is not supported. ValueError: If ``fisher_type != FisherType.MC`` and ``mc_samples != 1``. + NotImplementedError: If ``correct_eigenvalues`` and ``fisher_type == + FisherType.FORWARD_ONLY``. ValueError: If ``X`` is not a tensor and ``batch_size_fn`` is not specified. """ if not isinstance(loss_func, self._SUPPORTED_LOSSES): diff --git a/test/test_kfac.py b/test/test_kfac.py index 8d72972..33d25dc 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -512,7 +512,7 @@ def test_multi_dim_output( # KFAC for deep linear network with 4d input and output params = list(model.parameters()) context = ( - raises(ValueError, match="eigenvalues") + raises(NotImplementedError, match="eigenvalues") if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY else nullcontext() ) # EKFAC for FOOF is currently not supported From 358307020c10e0698bc5fd83112dbf71772072a3 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 21 Sep 2024 15:26:32 -0400 Subject: [PATCH 15/15] Fix test --- test/test_kfac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index 33d25dc..99f8b88 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -610,7 +610,7 @@ def test_expand_setting_scaling( # noqa: C901 loss_func = loss(reduction="sum").to(dev) context = ( - raises(ValueError, match="eigenvalues") + raises(NotImplementedError, match="eigenvalues") if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY else nullcontext() ) # EKFAC for FOOF is currently not supported @@ -1131,7 +1131,7 @@ def test_forward_only_fisher_type( # Compute KFAC with `fisher_type=FisherType.FORWARD_ONLY context = ( - raises(ValueError, match="eigenvalues") + raises(NotImplementedError, match="eigenvalues") if correct_eigenvalues else nullcontext() ) # EKFAC for FOOF is currently not supported