From 04bf01adb88c27704c1bc265596d8c252279b181 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 27 Oct 2023 11:35:58 -0400 Subject: [PATCH 1/8] [ADD] Prototype for KFAC linear operator --- curvlinops/_base.py | 8 ++ curvlinops/examples/utils.py | 4 +- curvlinops/kfac.py | 211 +++++++++++++++++++++++++++++++++++ test/cases.py | 2 - test/conftest.py | 18 ++- test/kfac_cases.py | 44 ++++++++ test/test_kfac.py | 32 ++++++ 7 files changed, 315 insertions(+), 4 deletions(-) create mode 100644 curvlinops/kfac.py create mode 100644 test/kfac_cases.py create mode 100644 test/test_kfac.py diff --git a/curvlinops/_base.py b/curvlinops/_base.py index 570a3ea..942994f 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -1,6 +1,7 @@ """Contains functionality to analyze Hessian & GGN via matrix-free multiplication.""" from typing import Callable, Iterable, List, Optional, Tuple, Union +from warnings import warn from backpack.utils.convert_parameters import vector_to_parameter_list from numpy import ( @@ -254,6 +255,13 @@ def _preprocess(self, x: ndarray) -> List[Tensor]: Returns: Vector in list format. """ + if x.dtype != self.dtype: + warn( + f"Input vector is {x.dtype}, while linear operator is {self.dtype}. " + + f"Converting to {self.dtype}." + ) + x = x.astype(self.dtype) + x_torch = from_numpy(x).to(self._device) return vector_to_parameter_list(x_torch, self._params) diff --git a/curvlinops/examples/utils.py b/curvlinops/examples/utils.py index 040ecab..aa0cd19 100644 --- a/curvlinops/examples/utils.py +++ b/curvlinops/examples/utils.py @@ -31,5 +31,7 @@ def report_nonclose( else: for a1, a2 in zip(array1.flatten(), array2.flatten()): if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan): - print(f"{a1} ≠ {a2}") + print(f"{a1} ≠ {a2} (ratio {a1 / a2:.5f})") + print(f"Max: {array1.max():.5f}, {array2.max():.5f}") + print(f"Min: {array1.min():.5f}, {array2.min():.5f}") raise ValueError("Compared arrays don't match.") diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py new file mode 100644 index 0000000..01401d9 --- /dev/null +++ b/curvlinops/kfac.py @@ -0,0 +1,211 @@ +"""Linear operator for the Fisher/GGN's Kronecker-factored approximation. + +Kronecker-factored approximate curvature was originally introduced for MLPs in + +- Martens, J., & Grosse, R. (2015). Optimizing neural networks with Kronecker-factored + approximate curvature. International Conference on Machine Learning (ICML). + +and extended to CNNs in + +- Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher matrix for + convolution layers. International Conference on Machine Learning (ICML). +""" + +from __future__ import annotations + +from math import sqrt +from typing import Dict, Iterable, List, Tuple, Union + +from numpy import ndarray +from torch import Generator, Tensor, einsum, randn +from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter +from torch.utils.hooks import RemovableHandle + +from curvlinops._base import _LinearOperator + + +class KFACLinearOperator(_LinearOperator): + """Linear operator to multiply with Fisher/GGN's KFAC approximation.""" + + _SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss) + + def __init__( + self, + model_func: Module, + loss_func: Union[MSELoss, CrossEntropyLoss], + params: List[Parameter], + data: Iterable[Tuple[Tensor, Tensor]], + progressbar: bool = False, + check_deterministic: bool = True, + shape: Union[Tuple[int, int], None] = None, + seed: int = 2147483647, + mc_samples: int = 1, + ): + if not isinstance(loss_func, self._SUPPORTED_LOSSES): + raise ValueError( + f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}." + ) + + # TODO Check for only linear layers + idx = 0 + for mod in model_func.modules(): + if len(list(mod.modules())) == 1 and list(mod.parameters()): + assert isinstance(mod, Linear) + assert mod.bias is not None + assert params[idx].data_ptr() == mod.weight.data_ptr() + assert params[idx + 1].data_ptr() == mod.bias.data_ptr() + idx += 2 + + self._seed = seed + self._generator: Union[None, Generator] = None + self._mc_samples = mc_samples + self._input_covariances: Dict[Tuple[int, ...], Tensor] = {} + self._gradient_covariances: Dict[Tuple[int, ...], Tensor] = {} + + super().__init__( + model_func, + loss_func, + params, + data, + progressbar=progressbar, + check_deterministic=check_deterministic, + shape=shape, + ) + + def _matvec(self, x: ndarray) -> ndarray: + """Loop over all batches in the data and apply the matrix to vector x. + + Create and seed the random number generator. + + Args: + x: Vector for multiplication. + + Returns: + Matrix-multiplication result ``mat @ x``. + """ + if not self._input_covariances and not self._gradient_covariances: + self._compute_kfac() + + x_torch = super()._preprocess(x) + assert len(x_torch) % 2 == 0 + + for idx in range(len(x_torch) // 2): + idx_weight, idx_bias = 2 * idx, 2 * idx + 1 + weight, bias = self._params[idx_weight], self._params[idx_bias] + x_weight, x_bias = x_torch[idx_weight], x_torch[idx_bias] + + aaT = self._input_covariances[(weight.data_ptr(), bias.data_ptr())] + ggT = self._gradient_covariances[(weight.data_ptr(), bias.data_ptr())] + + x_torch[idx_weight] = ggT @ x_weight @ aaT + x_torch[idx_bias] = ggT @ x_bias + + return super()._postprocess(x_torch) + + def _adjoint(self) -> KFACLinearOperator: + """Return the linear operator representing the adjoint. + + The KFAC approximation is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self + + def _compute_kfac(self): + # install forward and backward hooks on layers + hook_handles: List[RemovableHandle] = [] + + modules = [] + for mod in self._model_func.modules(): + if len(list(mod.modules())) == 1 and list(mod.parameters()): + assert isinstance(mod, Linear) + modules.append(mod) + hook_handles.extend( + mod.register_forward_pre_hook(self._hook_accumulate_input_covariance) + for mod in modules + ) + hook_handles.extend( + mod.register_full_backward_hook(self._hook_accumulate_gradient_covariance) + for mod in modules + ) + + # loop over data set + if self._generator is None: + self._generator = Generator(device=self._device) + self._generator.manual_seed(self._seed) + + for X, _ in self._loop_over_data(desc="Computing KFAC matrices"): + output = self._model_func(X) + + for mc in range(self._mc_samples): + y_sampled = self.draw_label(output) + loss = self._loss_func(output, y_sampled) + loss.backward(retain_graph=mc != self._mc_samples - 1) + + # remove hooks + for handle in hook_handles: + handle.remove() + + def draw_label(self, output: Tensor) -> Tensor: + if isinstance(self._loss_func, MSELoss): + std = { + "sum": sqrt(1.0 / 2.0), + "mean": sqrt(output.shape[1:].numel() / 2.0), + }[self._loss_func.reduction] + perturbation = std * randn( + output.shape, + device=output.device, + dtype=output.dtype, + generator=self._generator, + ) + return output.clone().detach() + perturbation + else: + raise NotImplementedError + + def _hook_accumulate_gradient_covariance( + self, module: Module, grad_input: Tuple[Tensor], grad_output: Tuple[Tensor] + ): + assert len(grad_output) == 1 + + if isinstance(module, Linear): + grad_out = grad_output[0].data.detach() + assert grad_out.ndim == 2 + idx = tuple(p.data_ptr() for p in module.parameters()) + + batch_size = grad_output[0].shape[0] + correction = { + "sum": 1.0 / self._mc_samples, + "mean": batch_size**2 / (self._N_data * self._mc_samples), + }[self._loss_func.reduction] + + covariance = einsum("bi,bj->ij", grad_out, grad_out).mul_(correction) + if idx not in self._gradient_covariances: + self._gradient_covariances[idx] = covariance + else: + self._gradient_covariances[idx].add_(covariance) + else: + raise NotImplementedError + + def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor]): + assert len(inputs) == 1 + + if isinstance(module, Linear): + module_input = inputs[0].data.detach() + assert module_input.ndim == 2 + + idx = tuple(p.data_ptr() for p in module.parameters()) + correction = { + "sum": 1.0 / self._N_data, + "mean": 1.0 / self._N_data, + }[self._loss_func.reduction] + + covariance = einsum("bi,bj->ij", module_input, module_input).mul_( + correction + ) + if idx not in self._input_covariances: + self._input_covariances[idx] = covariance + else: + self._input_covariances[idx].add_(covariance) + else: + raise NotImplementedError diff --git a/test/cases.py b/test/cases.py index 071b332..0b92fa2 100644 --- a/test/cases.py +++ b/test/cases.py @@ -17,8 +17,6 @@ DEVICES = get_available_devices() DEVICES_IDS = [f"dev={d}" for d in DEVICES] -LINOPS = [] - # Add test cases here CASES_NO_DEVICE = [ ############################################################################### diff --git a/test/conftest.py b/test/conftest.py index 408074e..ed6043e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,11 +1,13 @@ """Contains pytest fixtures that are visible by other files.""" from test.cases import ADJOINT_CASES, CASES, NON_DETERMINISTIC_CASES +from test.kfac_cases import KFAC_EXPAND_EXACT_CASES from typing import Callable, Dict, Iterable, List, Tuple from numpy import random from pytest import fixture -from torch import Tensor, manual_seed +from torch import Module, Tensor, manual_seed +from torch.nn import MSELoss def initialize_case( @@ -56,3 +58,17 @@ def non_deterministic_case( @fixture(params=ADJOINT_CASES) def adjoint(request) -> bool: return request.param + + +@fixture(params=KFAC_EXPAND_EXACT_CASES) +def kfac_expand_exact_case( + request, +) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]: + """Prepare a test case for which KFAC-expand equals the GGN. + + Yields: + A neural network, the mean-squared error function, a list of parameters, and + a data set. + """ + kfac_case = request.param + yield initialize_case(kfac_case) diff --git a/test/kfac_cases.py b/test/kfac_cases.py new file mode 100644 index 0000000..4b96f3c --- /dev/null +++ b/test/kfac_cases.py @@ -0,0 +1,44 @@ +"""Contains test cases for the KFAC linear operator.""" + +from functools import partial +from test.utils import get_available_devices, regression_targets + +from torch import rand +from torch.nn import Linear, MSELoss, Sequential + +# Add test cases here, devices and loss function with different reductions will be +# added automatically below +KFAC_EXPAND_EXACT_CASES_NO_DEVICE_NO_LOSS_FUNC = [ + ############################################################################### + # REGRESSION # + ############################################################################### + # deep linear network with scalar output + { + "model_func": lambda: Sequential(Linear(6, 3), Linear(3, 1)), + "data": lambda: [ + (rand(2, 6), regression_targets((2, 1))), + (rand(5, 6), regression_targets((5, 1))), + ], + "seed": 0, + }, + # deep linear network with vector output + { + "model_func": lambda: Sequential(Linear(5, 4), Linear(4, 3)), + "data": lambda: [ + (rand(1, 5), regression_targets((1, 3))), + (rand(7, 5), regression_targets((7, 3))), + ], + "seed": 0, + }, +] + +KFAC_EXPAND_EXACT_CASES = [] +for case in KFAC_EXPAND_EXACT_CASES_NO_DEVICE_NO_LOSS_FUNC: + for device in get_available_devices(): + for reduction in ["mean", "sum"]: + case_with_device_and_loss_func = { + **case, + "device": device, + "loss_func": partial(MSELoss, reduction=reduction), + } + KFAC_EXPAND_EXACT_CASES.append(case_with_device_and_loss_func) diff --git a/test/test_kfac.py b/test/test_kfac.py new file mode 100644 index 0000000..42ff459 --- /dev/null +++ b/test/test_kfac.py @@ -0,0 +1,32 @@ +"""Contains tests for ``curvlinops.kfac``.""" + +from typing import Iterable, List, Tuple + +from numpy import eye +from scipy.linalg import block_diag +from torch import Tensor +from torch.nn import Module, MSELoss, Parameter + +from curvlinops.examples.utils import report_nonclose +from curvlinops.ggn import GGNLinearOperator +from curvlinops.kfac import KFACLinearOperator + + +def test_kfac( + kfac_case: Tuple[Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]] +): + model, loss_func, params, data = kfac_case + + ggn_blocks = [] # list of per-parameter GGNs + for param in params: + ggn = GGNLinearOperator(model, loss_func, [param], data) + ggn_blocks.append(ggn @ eye(ggn.shape[1])) + ggn = block_diag(*ggn_blocks) + + kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) + kfac_mat = kfac @ eye(kfac.shape[1]) + + atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction] + rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction] + + report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) From 8b377a788a6806615c8fc9be4c0020162062f65f Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 27 Oct 2023 12:05:31 -0400 Subject: [PATCH 2/8] [DOC] Progress on documentation --- curvlinops/kfac.py | 108 +++++++++++++++++++++++++++++++++------------ 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 01401d9..13403a9 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -148,6 +148,22 @@ def _compute_kfac(self): handle.remove() def draw_label(self, output: Tensor) -> Tensor: + r"""Draw a sample from the model's predictive distribution. + + The model's distribution is implied by the (negative log likelihood) loss + function. For instance, ``MSELoss`` implies a Gaussian distribution with + constant variance, and ``CrossEntropyLoss`` implies a categorical distribution. + + Args: + output: The model's prediction + :math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`. + + Returns: + A sample :math:`\{\mathbf{y}_n\}_{n=1}^N` drawn from the model's predictive + distribution :math:`p(\mathbf{y} \mid \mathbf{x}, \mathbf{\theta})`. Has + the same shape as the labels that would be fed into the loss function + together with ``output``. + """ if isinstance(self._loss_func, MSELoss): std = { "sum": sqrt(1.0 / 2.0), @@ -166,46 +182,82 @@ def draw_label(self, output: Tensor) -> Tensor: def _hook_accumulate_gradient_covariance( self, module: Module, grad_input: Tuple[Tensor], grad_output: Tuple[Tensor] ): - assert len(grad_output) == 1 + """Backward hook that accumulates the output-gradient covariance of a layer. - if isinstance(module, Linear): - grad_out = grad_output[0].data.detach() - assert grad_out.ndim == 2 - idx = tuple(p.data_ptr() for p in module.parameters()) + Updates ``self._gradient_covariances``. - batch_size = grad_output[0].shape[0] + Args: + module: The layer on which the hook is called. + grad_input: The gradient of the loss w.r.t. the layer's inputs. + grad_output: The gradient of the loss w.r.t. the layer's outputs. + + Raises: + ValueError: If ``grad_output`` is not a 1-tuple. + NotImplementedError: If a layer uses weight sharing. + NotImplementedError: If the layer is not supported. + """ + if len(grad_output) != 1: + raise ValueError( + f"Expected grad_output to be a 1-tuple, got {len(grad_output)}." + ) + g = grad_output[0].data.detach() + + if isinstance(module, Linear): + if g.ndim != 2: + # TODO Support weight sharing + raise NotImplementedError( + "Only 2d grad_outputs are supported for linear layers. " + + f"Got {g.ndim}d." + ) + + batch_size = g.shape[0] correction = { "sum": 1.0 / self._mc_samples, "mean": batch_size**2 / (self._N_data * self._mc_samples), }[self._loss_func.reduction] + covariance = einsum("bi,bj->ij", g, g).mul_(correction) + else: + # TODO Support convolutions + raise NotImplementedError(f"Layer of type {type(module)} is unsupported.") - covariance = einsum("bi,bj->ij", grad_out, grad_out).mul_(correction) - if idx not in self._gradient_covariances: - self._gradient_covariances[idx] = covariance - else: - self._gradient_covariances[idx].add_(covariance) + idx = tuple(p.data_ptr() for p in module.parameters()) + if idx not in self._gradient_covariances: + self._gradient_covariances[idx] = covariance else: - raise NotImplementedError + self._gradient_covariances[idx].add_(covariance) def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor]): - assert len(inputs) == 1 + """Pre-forward hook that accumulates the input covariance of a layer. + + Updates ``self._input_covariances``. + + Args: + module: Module on which the hook is called. + inputs: Inputs to the module. + + Raises: + ValueError: If the module has multiple inputs. + NotImplementedError: If a layer uses weight sharing. + NotImplementedError: If a module is not supported. + """ + if len(inputs) != 1: + raise ValueError("Modules with multiple inputs are not supported.") + x = inputs[0].data.detach() if isinstance(module, Linear): - module_input = inputs[0].data.detach() - assert module_input.ndim == 2 + if x.ndim != 2: + # TODO Support weight sharing + raise NotImplementedError( + f"Only 2d inputs are supported for linear layers. Got {x.ndim}d." + ) - idx = tuple(p.data_ptr() for p in module.parameters()) - correction = { - "sum": 1.0 / self._N_data, - "mean": 1.0 / self._N_data, - }[self._loss_func.reduction] + covariance = einsum("bi,bj->ij", x, x).div_(self._N_data) + else: + # TODO Support convolutions + raise NotImplementedError(f"Layer of type {type(module)} is unsupported.") - covariance = einsum("bi,bj->ij", module_input, module_input).mul_( - correction - ) - if idx not in self._input_covariances: - self._input_covariances[idx] = covariance - else: - self._input_covariances[idx].add_(covariance) + idx = tuple(p.data_ptr() for p in module.parameters()) + if idx not in self._input_covariances: + self._input_covariances[idx] = covariance else: - raise NotImplementedError + self._input_covariances[idx].add_(covariance) From 87b044c03192af0cb0d00a38c0722b88694b4a99 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 27 Oct 2023 15:45:12 -0400 Subject: [PATCH 3/8] [DOC] Describe KFAC and its limitations --- curvlinops/__init__.py | 2 + curvlinops/kfac.py | 141 +++++++++++++++++++++++++++++++++-------- docs/rtd/linops.rst | 3 + test/conftest.py | 4 +- 4 files changed, 123 insertions(+), 27 deletions(-) diff --git a/curvlinops/__init__.py b/curvlinops/__init__.py index 0e6e00a..f59ab6c 100644 --- a/curvlinops/__init__.py +++ b/curvlinops/__init__.py @@ -7,6 +7,7 @@ from curvlinops.hessian import HessianLinearOperator from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator +from curvlinops.kfac import KFACLinearOperator from curvlinops.papyan2020traces.spectrum import ( LanczosApproximateLogSpectrumCached, LanczosApproximateSpectrumCached, @@ -22,6 +23,7 @@ "GGNLinearOperator", "EFLinearOperator", "FisherMCLinearOperator", + "KFACLinearOperator", "JacobianLinearOperator", "TransposedJacobianLinearOperator", "CGInverseLinearOperator", diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 13403a9..5b3a18d 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -18,21 +18,61 @@ from numpy import ndarray from torch import Generator, Tensor, einsum, randn -from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter +from torch.nn import Linear, Module, MSELoss, Parameter from torch.utils.hooks import RemovableHandle from curvlinops._base import _LinearOperator class KFACLinearOperator(_LinearOperator): - """Linear operator to multiply with Fisher/GGN's KFAC approximation.""" + r"""Linear operator to multiply with the Fisher/GGN's KFAC approximation. - _SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss) + KFAC approximates the per-layer Fisher/GGN with a Kronecker product: + Consider a weight matrix :math:`\mathbf{W}` and a bias vector :math:`\mathbf{b}` + in a single layer. The layer's Fisher :math:`\mathbf{F}(\mathbf{\theta})` for + + .. math:: + \mathbf{\theta} + = + \begin{pmatrix} + \mathrm{vec}(\mathbf{W}) \\ \mathbf{b} + \end{pmatrix} + + where :math:`\mathrm{vec}` denotes column-stacking is approximated as + + .. math:: + \mathbf{F}(\mathbf{\theta}) + \approx + \mathbf{A}_{(\text{KFAC})} \otimes \mathbf{B}_{(\text{KFAC})} + + (see :class:`curvlinops.FisherMCLinearOperator` for the Fisher's definition). + Loosely speaking, the first Kronecker factor is the un-centered covariance of the + inputs to a layer. The second Kronecker factor is the un-centered covariance of + 'would-be' gradients w.r.t. the layer's output. Those 'would-be' gradients result + from sampling labels from the model's distribution and computing their gradients. + + The basic version of KFAC for MLPs was introduced in + + - Martens, J., & Grosse, R. (2015). Optimizing neural networks with + Kronecker-factored approximate curvature. ICML. + + and later generalized to convolutions in + + - Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher + matrix for convolution layers. ICML. + + Attributes: + _SUPPORTED_LOSSES: Tuple of supported loss functions. + _SUPPORTED_MODULES: Tuple of supported layers. + """ + + _SUPPORTED_LOSSES = (MSELoss,) + _SUPPORTED_MODULES = (Linear,) def __init__( self, model_func: Module, - loss_func: Union[MSELoss, CrossEntropyLoss], + loss_func: MSELoss, params: List[Parameter], data: Iterable[Tuple[Tensor, Tensor]], progressbar: bool = False, @@ -41,20 +81,71 @@ def __init__( seed: int = 2147483647, mc_samples: int = 1, ): + """Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN. + + Warning: + If the model's parameters change, e.g. during training, you need to + create a fresh instance of this object. This is because, for performance + reasons, the Kronecker factors are computed once and cached during the + first matrix-vector product. They will thus become outdated if the model + changes. + + Warning: + This is an early proto-type with many limitations: + + - Parameters must be in the same order as the model's parameters. + - Only linear layers with bias are supported. + - Weights and biases are treated separately. + - No weight sharing is supported. + - Only the Monte-Carlo sampled version is supported. + - Only the ``'expand'`` setting is supported. + + Args: + model_func: The neural network. Must consist of modules. + loss_func: The loss function. + params: The parameters defining the Fisher/GGN that will be approximated + through KFAC. + data: A data loader containing the data of the Fisher/GGN. + progressbar: Whether to show a progress bar when computing the Kronecker + factors. Defaults to ``False``. + check_deterministic: Whether to check that the linear operator is + deterministic. Defaults to ``True``. + shape: The shape of the linear operator. If ``None``, it will be inferred + from the parameters. Defaults to ``None``. + seed: The seed for the random number generator used to draw labels + from the model's predictive distribution. Defaults to ``2147483647``. + mc_samples: The number of Monte-Carlo samples to use per data point. + Defaults to ``1``. + """ if not isinstance(loss_func, self._SUPPORTED_LOSSES): raise ValueError( f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}." ) - # TODO Check for only linear layers + self.hooked_modules: List[str] = [] idx = 0 - for mod in model_func.modules(): - if len(list(mod.modules())) == 1 and list(mod.parameters()): - assert isinstance(mod, Linear) - assert mod.bias is not None - assert params[idx].data_ptr() == mod.weight.data_ptr() - assert params[idx + 1].data_ptr() == mod.bias.data_ptr() + for name, mod in model_func.named_modules(): + if isinstance(mod, self._SUPPORTED_MODULES): + # TODO Support bias-free layers + if mod.bias is None: + raise NotImplementedError( + "Bias-free linear layers are not yet supported." + ) + # TODO Support arbitrary orders and sub-sets of parameters + if ( + params[idx].data_ptr() != mod.weight.data_ptr() + or params[idx + 1].data_ptr() != mod.bias.data_ptr() + ): + raise NotImplementedError( + "KFAC parameters must be in same order as model parameters " + + "for now." + ) idx += 2 + self.hooked_modules.append(name) + if idx != len(params): + raise NotImplementedError( + "Could not identify all parameters with supported layers." + ) self._seed = seed self._generator: Union[None, Generator] = None @@ -113,29 +204,28 @@ def _adjoint(self) -> KFACLinearOperator: return self def _compute_kfac(self): - # install forward and backward hooks on layers + """Compute and cache KFAC's Kronecker factors for future ``matvec``s.""" + # install forward and backward hooks hook_handles: List[RemovableHandle] = [] - - modules = [] - for mod in self._model_func.modules(): - if len(list(mod.modules())) == 1 and list(mod.parameters()): - assert isinstance(mod, Linear) - modules.append(mod) hook_handles.extend( - mod.register_forward_pre_hook(self._hook_accumulate_input_covariance) - for mod in modules + self._model_func.get_submodule(mod).register_forward_pre_hook( + self._hook_accumulate_input_covariance + ) + for mod in self.hooked_modules ) hook_handles.extend( - mod.register_full_backward_hook(self._hook_accumulate_gradient_covariance) - for mod in modules + self._model_func.get_submodule(mod).register_full_backward_hook( + self._hook_accumulate_gradient_covariance + ) + for mod in self.hooked_modules ) - # loop over data set + # loop over data set, computing the Kronecker factors if self._generator is None: self._generator = Generator(device=self._device) self._generator.manual_seed(self._seed) - for X, _ in self._loop_over_data(desc="Computing KFAC matrices"): + for X, _ in self._loop_over_data(desc="KFAC matrices"): output = self._model_func(X) for mc in range(self._mc_samples): @@ -143,7 +233,8 @@ def _compute_kfac(self): loss = self._loss_func(output, y_sampled) loss.backward(retain_graph=mc != self._mc_samples - 1) - # remove hooks + # clean up + self._model_func.zero_grad() for handle in hook_handles: handle.remove() diff --git a/docs/rtd/linops.rst b/docs/rtd/linops.rst index e85fe09..3fc7166 100644 --- a/docs/rtd/linops.rst +++ b/docs/rtd/linops.rst @@ -20,6 +20,9 @@ Fisher (approximate) .. autoclass:: curvlinops.FisherMCLinearOperator :members: __init__ +.. autoclass:: curvlinops.KFACLinearOperator + :members: __init__ + Uncentered gradient covariance (empirical Fisher) ------------------------------------------------- diff --git a/test/conftest.py b/test/conftest.py index ed6043e..b3f4726 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -6,8 +6,8 @@ from numpy import random from pytest import fixture -from torch import Module, Tensor, manual_seed -from torch.nn import MSELoss +from torch import Tensor, manual_seed +from torch.nn import Module, MSELoss def initialize_case( From 11f60ac8662d79064834127eda7e5ce1e0325d26 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 27 Oct 2023 16:14:22 -0400 Subject: [PATCH 4/8] [FIX] Name of fixture --- test/test_kfac.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index 42ff459..0f8300e 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -13,9 +13,11 @@ def test_kfac( - kfac_case: Tuple[Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]] + kfac_expand_exact_case: Tuple[ + Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] + ] ): - model, loss_func, params, data = kfac_case + model, loss_func, params, data = kfac_expand_exact_case ggn_blocks = [] # list of per-parameter GGNs for param in params: From 26234c94205fb684bae9753dde94d01ee9b5c355 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 27 Oct 2023 16:16:23 -0400 Subject: [PATCH 5/8] [FIX] Darglint --- curvlinops/kfac.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 5b3a18d..83b9c21 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -116,6 +116,14 @@ def __init__( from the model's predictive distribution. Defaults to ``2147483647``. mc_samples: The number of Monte-Carlo samples to use per data point. Defaults to ``1``. + + Raises: + ValueError: If the loss function is not supported. + NotImplementedError: If the parameters are not in the same order as the + model's parameters. + NotImplementedError: If the model contains bias-free linear layers. + NotImplementedError: If any parameter cannot be identified with a supported + layer. """ if not isinstance(loss_func, self._SUPPORTED_LOSSES): raise ValueError( From 7161f21e3927329e4b06c23b4cb604aeba169d03 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 27 Oct 2023 16:24:09 -0400 Subject: [PATCH 6/8] [FIX] Darglint See https://github.com/terrencepreilly/darglint/issues/53 --- curvlinops/kfac.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 83b9c21..409b977 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -258,10 +258,14 @@ def draw_label(self, output: Tensor) -> Tensor: :math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`. Returns: - A sample :math:`\{\mathbf{y}_n\}_{n=1}^N` drawn from the model's predictive + A sample + :math:`\{\mathbf{y}_n\}_{n=1}^N` drawn from the model's predictive distribution :math:`p(\mathbf{y} \mid \mathbf{x}, \mathbf{\theta})`. Has the same shape as the labels that would be fed into the loss function together with ``output``. + + Raises: + NotImplementedError: If the loss function is not supported. """ if isinstance(self._loss_func, MSELoss): std = { From 5190c9f0463f715f2448c81ca598f0476c3f9251 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 30 Oct 2023 10:44:48 -0400 Subject: [PATCH 7/8] [DOC] Show supported layers in error message --- curvlinops/kfac.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 409b977..ca65a8c 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -321,7 +321,10 @@ def _hook_accumulate_gradient_covariance( covariance = einsum("bi,bj->ij", g, g).mul_(correction) else: # TODO Support convolutions - raise NotImplementedError(f"Layer of type {type(module)} is unsupported.") + raise NotImplementedError( + f"Layer of type {type(module)} is unsupported. " + + f"Supported layers: {self._SUPPORTED_MODULES}." + ) idx = tuple(p.data_ptr() for p in module.parameters()) if idx not in self._gradient_covariances: From fddfd3b7c4f3ae02a14cb3e90add57bee0330bef Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 30 Oct 2023 11:46:57 -0400 Subject: [PATCH 8/8] [DOC] Improve correctness --- curvlinops/kfac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index ca65a8c..4588098 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -98,7 +98,7 @@ def __init__( - Weights and biases are treated separately. - No weight sharing is supported. - Only the Monte-Carlo sampled version is supported. - - Only the ``'expand'`` setting is supported. + - Only the ``'expand'`` approximation is supported. Args: model_func: The neural network. Must consist of modules.