From f873494741d2a542ee2a47e73be32d08e5d752d3 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 May 2024 23:04:38 +0100 Subject: [PATCH] Add state dict functionality to (inverse) KFAC linear operator --- curvlinops/inverse.py | 69 ++++++++++++++++++- curvlinops/kfac.py | 154 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 220 insertions(+), 3 deletions(-) diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 9ad4e43..3a357bf 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -1,7 +1,7 @@ """Implements linear operator inverses.""" from math import sqrt -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn from einops import einsum, rearrange @@ -695,3 +695,70 @@ def _matmat(self, M: ndarray) -> ndarray: M_torch = self._A._preprocess(M) M_torch = self.torch_matmat(M_torch) return self._A._postprocess(M_torch) + + def state_dict(self) -> Dict[str, Any]: + """Return the state of the inverse KFAC linear operator. + + Returns: + State dictionary. + """ + return { + "A": self._A.state_dict(), + # Attributes + "damping": self._damping, + "use_heuristic_damping": self._use_heuristic_damping, + "min_damping": self._min_damping, + "use_exact_damping": self._use_exact_damping, + "cache": self._cache, + "retry_double_precision": self._retry_double_precision, + # Inverse Kronecker factors (if computed and cached) + "inverse_input_covariances": self._inverse_input_covariances, + "inverse_gradient_covariances": self._inverse_gradient_covariances, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the state of the inverse KFAC linear operator. + + Args: + state_dict: State dictionary. + """ + self._A.load_state_dict(state_dict["A"]) + + # Set attributes + self._damping = state_dict["damping"] + self._use_heuristic_damping = state_dict["use_heuristic_damping"] + self._min_damping = state_dict["min_damping"] + self._use_exact_damping = state_dict["use_exact_damping"] + self._cache = state_dict["cache"] + self._retry_double_precision = state_dict["retry_double_precision"] + + # Set inverse Kronecker factors (if computed and cached) + self._inverse_input_covariances = state_dict["inverse_input_covariances"] + self._inverse_gradient_covariances = state_dict["inverse_gradient_covariances"] + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, Any], + A: KFACLinearOperator, + ) -> "KFACInverseLinearOperator": + """Load an inverse KFAC linear operator from a state dictionary. + + Args: + state_dict: State dictionary. + A: ``KFACLinearOperator`` whose inverse is formed. + + Returns: + Linear operator of inverse KFAC approximation. + """ + inv_kfac = cls( + A, + damping=state_dict["damping"], + use_heuristic_damping=state_dict["use_heuristic_damping"], + min_damping=state_dict["min_damping"], + use_exact_damping=state_dict["use_exact_damping"], + cache=state_dict["cache"], + retry_double_precision=state_dict["retry_double_precision"], + ) + inv_kfac.load_state_dict(state_dict) + return inv_kfac diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 4909d8e..a4185cb 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -21,7 +21,7 @@ from collections.abc import MutableMapping from functools import partial from math import sqrt -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from einops import einsum, rearrange, reduce from numpy import ndarray @@ -111,7 +111,7 @@ class KFACLinearOperator(_LinearOperator): def __init__( # noqa: C901 self, model_func: Module, - loss_func: MSELoss, + loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], params: List[Parameter], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], progressbar: bool = False, @@ -1070,3 +1070,153 @@ def frobenius_norm(self) -> Tensor: ) self._frobenius_norm.sqrt_() return self._frobenius_norm + + def state_dict(self) -> Dict[str, Any]: + """Return the state of the KFAC linear operator. + + Returns: + State dictionary. + """ + loss_type = { + MSELoss: "MSELoss", + CrossEntropyLoss: "CrossEntropyLoss", + BCEWithLogitsLoss: "BCEWithLogitsLoss", + }[type(self._loss_func)] + return { + # Model and loss function + "model_func_state_dict": self._model_func.state_dict(), + "loss_type": loss_type, + "loss_reduction": self._loss_func.reduction, + # Attributes + "progressbar": self._progressbar, + "shape": self._shape, + "seed": self._seed, + "fisher_type": self._fisher_type, + "mc_samples": self._mc_samples, + "kfac_approx": self._kfac_approx, + "loss_average": self._loss_average, + "num_per_example_loss_terms": self._num_per_example_loss_terms, + "separate_weight_and_bias": self._separate_weight_and_bias, + "num_data": self._N_data, + # Kronecker factors (if computed) + "input_covariances": self._input_covariances, + "gradient_covariances": self._gradient_covariances, + # Properties (not necessarily computed) + "trace": self._trace, + "det": self._det, + "logdet": self._logdet, + "frobenius_norm": self._frobenius_norm, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the state of the KFAC linear operator. + + Args: + state_dict: State dictionary. + + Raises: + ValueError: If the loss function does not match the state dict. + ValueError: If the loss function reduction does not match the state dict. + """ + self._model_func.load_state_dict(state_dict["model_func_state_dict"]) + # Verify that the loss function and its reduction match the state dict + loss_func_type = { + "MSELoss": MSELoss, + "CrossEntropyLoss": CrossEntropyLoss, + "BCEWithLogitsLoss": BCEWithLogitsLoss, + }[state_dict["loss_type"]] + if not isinstance(self._loss_func, loss_func_type): + raise ValueError( + f"Loss function mismatch: {loss_func_type} != {type(self._loss_func)}." + ) + if state_dict["loss_reduction"] != self._loss_func.reduction: + raise ValueError( + "Loss function reduction mismatch: " + f"{state_dict['loss_reduction']} != {self._loss_func.reduction}." + ) + + # Set attributes + self._progressbar = state_dict["progressbar"] + self._shape = state_dict["shape"] + self._seed = state_dict["seed"] + self._fisher_type = state_dict["fisher_type"] + self._mc_samples = state_dict["mc_samples"] + self._kfac_approx = state_dict["kfac_approx"] + self._loss_average = state_dict["loss_average"] + self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"] + self._separate_weight_and_bias = state_dict["separate_weight_and_bias"] + self._N_data = state_dict["num_data"] + + # Set Kronecker factors (if computed) + self._input_covariances = state_dict["input_covariances"] + self._gradient_covariances = state_dict["gradient_covariances"] + + # Set properties (not necessarily computed) + self._trace = state_dict["trace"] + self._det = state_dict["det"] + self._logdet = state_dict["logdet"] + self._frobenius_norm = state_dict["frobenius_norm"] + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, Any], + model_func: Module, + params: List[Parameter], + data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], + check_deterministic: bool = True, + batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, + ) -> KFACLinearOperator: + """Load a KFAC linear operator from a state dictionary. + + Args: + state_dict: State dictionary. + model_func: The model function. + params: The model's parameters that KFAC is computed for. + data: A data loader containing the data of the Fisher/GGN. + check_deterministic: Whether to check that the linear operator is + deterministic. Defaults to ``True``. + batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this + needs to be specified. The intended behavior is to consume the first + entry of the iterates from ``data`` and return their batch size. + + Returns: + Linear operator of KFAC approximation. + """ + loss_func = { + "MSELoss": MSELoss, + "CrossEntropyLoss": CrossEntropyLoss, + "BCEWithLogitsLoss": BCEWithLogitsLoss, + }[state_dict["loss_type"]](reduction=state_dict["loss_reduction"]) + kfac = cls( + model_func, + loss_func, + params, + data, + batch_size_fn=batch_size_fn, + check_deterministic=False, + progressbar=state_dict["progressbar"], + shape=state_dict["shape"], + seed=state_dict["seed"], + fisher_type=state_dict["fisher_type"], + mc_samples=state_dict["mc_samples"], + kfac_approx=state_dict["kfac_approx"], + loss_average=state_dict["loss_average"], + num_per_example_loss_terms=state_dict["num_per_example_loss_terms"], + separate_weight_and_bias=state_dict["separate_weight_and_bias"], + num_data=state_dict["num_data"], + ) + kfac.load_state_dict(state_dict) + + # Potentially call `check_deterministic` after the state dict is loaded + if check_deterministic: + old_device = kfac._device + kfac.to_device(device("cpu")) + try: + kfac._check_deterministic() + except RuntimeError as e: + raise e + finally: + kfac.to_device(old_device) + + return kfac