diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 4588098..10443f8 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -93,7 +93,6 @@ def __init__( Warning: This is an early proto-type with many limitations: - - Parameters must be in the same order as the model's parameters. - Only linear layers with bias are supported. - Weights and biases are treated separately. - No weight sharing is supported. @@ -119,8 +118,6 @@ def __init__( Raises: ValueError: If the loss function is not supported. - NotImplementedError: If the parameters are not in the same order as the - model's parameters. NotImplementedError: If the model contains bias-free linear layers. NotImplementedError: If any parameter cannot be identified with a supported layer. @@ -130,8 +127,8 @@ def __init__( f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}." ) + self.param_ids = [p.data_ptr() for p in params] self.hooked_modules: List[str] = [] - idx = 0 for name, mod in model_func.named_modules(): if isinstance(mod, self._SUPPORTED_MODULES): # TODO Support bias-free layers @@ -139,18 +136,15 @@ def __init__( raise NotImplementedError( "Bias-free linear layers are not yet supported." ) - # TODO Support arbitrary orders and sub-sets of parameters - if ( - params[idx].data_ptr() != mod.weight.data_ptr() - or params[idx + 1].data_ptr() != mod.bias.data_ptr() - ): - raise NotImplementedError( - "KFAC parameters must be in same order as model parameters " - + "for now." - ) - idx += 2 self.hooked_modules.append(name) - if idx != len(params): + + # check that all parameters are in hooked modules + hooked_param_ids = { + p.data_ptr() + for mod in self.hooked_modules + for p in model_func.get_submodule(mod).parameters() + } + if set(self.param_ids) != hooked_param_ids: raise NotImplementedError( "Could not identify all parameters with supported layers." ) @@ -188,9 +182,12 @@ def _matvec(self, x: ndarray) -> ndarray: x_torch = super()._preprocess(x) assert len(x_torch) % 2 == 0 - for idx in range(len(x_torch) // 2): - idx_weight, idx_bias = 2 * idx, 2 * idx + 1 - weight, bias = self._params[idx_weight], self._params[idx_bias] + for mod_name in self.hooked_modules: + mod = self._model_func.get_submodule(mod_name) + weight, bias = mod.weight, mod.bias + + idx_weight = self.param_ids.index(weight.data_ptr()) + idx_bias = self.param_ids.index(bias.data_ptr()) x_weight, x_bias = x_torch[idx_weight], x_torch[idx_bias] aaT = self._input_covariances[(weight.data_ptr(), bias.data_ptr())] diff --git a/test/test_kfac.py b/test/test_kfac.py index 0f8300e..365ce3e 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -3,8 +3,9 @@ from typing import Iterable, List, Tuple from numpy import eye +from pytest import mark from scipy.linalg import block_diag -from torch import Tensor +from torch import Tensor, randperm from torch.nn import Module, MSELoss, Parameter from curvlinops.examples.utils import report_nonclose @@ -12,13 +13,26 @@ from curvlinops.kfac import KFACLinearOperator +@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) def test_kfac( kfac_expand_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] - ] + ], + shuffle: bool, ): + """Test the KFAC implementation against the exact GGN. + + Args: + kfac_expand_exact_case: A fixture that returns a model, loss function, list of + parameters, and data. + shuffle: Whether to shuffle the parameters before computing the KFAC matrix. + """ model, loss_func, params, data = kfac_expand_exact_case + if shuffle: + permutation = randperm(len(params)) + params = [params[i] for i in permutation] + ggn_blocks = [] # list of per-parameter GGNs for param in params: ggn = GGNLinearOperator(model, loss_func, [param], data)