From 07ffeb22d1693ae6c3b649873d3fc6dc52f6870f Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 May 2024 22:54:04 +0100 Subject: [PATCH 1/7] Add tests for state dict functionality --- test/test_inverse.py | 108 ++++++++++++++++++++++++++++++++++++++ test/test_kfac.py | 121 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+) diff --git a/test/test_inverse.py b/test/test_inverse.py index 136a65d..8762bc8 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -654,3 +654,111 @@ def test_KFAC_inverse_damped_torch_matvec( # Test against _matmat report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy()) + + +def test_KFAC_inverse_save_and_load_state_dict(): + """Test that KFACInverseLinearOperator can be saved and loaded from state dict.""" + torch.manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = torch.rand(batch_size, D_in) + y = torch.rand(batch_size, D_out) + model = torch.nn.Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # create inverse KFAC + inv_kfac = KFACInverseLinearOperator( + kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + ) + _ = inv_kfac @ eye(kfac.shape[1]) # to trigger inverse computation + + # save state dict + state_dict = inv_kfac.state_dict() + + with raises(ValueError, match="mismatch"): + # create new inverse KFAC with different linop input and try to load state dict + wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)]) + inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac) + inv_kfac_wrong.load_state_dict(state_dict) + + # create new inverse KFAC and load state dict + inv_kfac_new = KFACInverseLinearOperator(kfac) + inv_kfac_new.load_state_dict(state_dict) + + # check that the two inverse KFACs are equal + def compare_state_dicts(state_dict: dict, state_dict_new: dict): + assert len(state_dict) == len(state_dict_new) + for value, value_new in zip(state_dict.values(), state_dict_new.values()): + if isinstance(value, torch.Tensor): + assert torch.allclose(value, value_new) + elif isinstance(value, dict): + compare_state_dicts(value, value_new) + elif isinstance(value, tuple): + assert all( + torch.allclose(torch.as_tensor(v), torch.as_tensor(v2)) + for v, v2 in zip(value, value_new) + ) + else: + assert value == value_new + + compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) + + test_mat = torch.rand(inv_kfac.shape[1]) + report_nonclose(inv_kfac @ test_mat, inv_kfac_new @ test_mat) + + +def test_KFAC_inverse_from_state_dict(): + """Test that KFACInverseLinearOperator can be created from state dict.""" + torch.manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = torch.rand(batch_size, D_in) + y = torch.rand(batch_size, D_out) + model = torch.nn.Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # create inverse KFAC and save state dict + inv_kfac = KFACInverseLinearOperator( + kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False + ) + state_dict = inv_kfac.state_dict() + + # create new KFAC from state dict + inv_kfac_new = KFACInverseLinearOperator.from_state_dict(state_dict, kfac) + + # check that the two inverse KFACs are equal + def compare_state_dicts(state_dict: dict, state_dict_new: dict): + assert len(state_dict) == len(state_dict_new) + for value, value_new in zip(state_dict.values(), state_dict_new.values()): + if isinstance(value, torch.Tensor): + assert torch.allclose(value, value_new) + elif isinstance(value, dict): + compare_state_dicts(value, value_new) + elif isinstance(value, tuple): + assert all( + torch.allclose(torch.as_tensor(v), torch.as_tensor(v2)) + for v, v2 in zip(value, value_new) + ) + else: + assert value == value_new + + compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) + + test_mat = torch.rand(kfac.shape[1]) + report_nonclose(inv_kfac @ test_mat, inv_kfac_new @ test_mat) diff --git a/test/test_kfac.py b/test/test_kfac.py index 06394dd..3d08ba4 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1220,3 +1220,124 @@ def test_kfac_does_affect_grad(): # make sure gradients are unchanged for grad_before, p in zip(grads_before, params): assert allclose(grad_before, p.grad) + + +def test_save_and_load_state_dict(): + """Test that KFACLinearOperator can be saved and loaded from state dict.""" + manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = rand(batch_size, D_in) + y = rand(batch_size, D_out) + model = Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # save state dict + state_dict = kfac.state_dict() + + with raises(ValueError, match="loss"): + # create new KFAC with different loss function and try to load state dict + kfac_new = KFACLinearOperator( + model, + CrossEntropyLoss(), + params, + [(X, y)], + ) + kfac_new.load_state_dict(state_dict) + + with raises(ValueError, match="reduction"): + # create new KFAC with different loss reduction and try to load state dict + kfac_new = KFACLinearOperator( + model, + MSELoss(), + params, + [(X, y)], + ) + kfac_new.load_state_dict(state_dict) + + with raises(RuntimeError, match="loading state_dict"): + # create new KFAC with different model and try to load state dict + wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out)) + wrong_params = list(wrong_model.parameters()) + kfac_new = KFACLinearOperator( + wrong_model, + MSELoss(reduction="sum"), + wrong_params, + [(X, y)], + loss_average=None, + ) + kfac_new.load_state_dict(state_dict) + + # create new KFAC and load state dict + kfac_new = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # check that the two KFACs are equal + assert len(kfac.state_dict()) == len(kfac_new.state_dict()) + for value, value_new in zip( + kfac.state_dict().values(), kfac_new.state_dict().values() + ): + if isinstance(value, Tensor): + assert allclose(value, value_new) + elif isinstance(value, dict): + for key, val in value.items(): + assert allclose(val, value_new[key]) + else: + assert value == value_new + + test_mat = rand(kfac.shape[1]) + report_nonclose(kfac @ test_mat, kfac_new @ test_mat) + + +def test_from_state_dict(): + """Test that KFACLinearOperator can be created from state dict.""" + manual_seed(0) + batch_size, D_in, D_out = 4, 3, 2 + X = rand(batch_size, D_in) + y = rand(batch_size, D_out) + model = Linear(D_in, D_out) + + params = list(model.parameters()) + # create and compute KFAC + kfac = KFACLinearOperator( + model, + MSELoss(reduction="sum"), + params, + [(X, y)], + loss_average=None, + ) + + # save state dict + state_dict = kfac.state_dict() + + # create new KFAC from state dict + kfac_new = KFACLinearOperator.from_state_dict(state_dict, model, params, [(X, y)]) + + # check that the two KFACs are equal + assert len(kfac.state_dict()) == len(kfac_new.state_dict()) + for value, value_new in zip( + kfac.state_dict().values(), kfac_new.state_dict().values() + ): + if isinstance(value, Tensor): + assert allclose(value, value_new) + elif isinstance(value, dict): + for key, val in value.items(): + assert allclose(val, value_new[key]) + else: + assert value == value_new + + test_mat = rand(kfac.shape[1]) + report_nonclose(kfac @ test_mat, kfac_new @ test_mat) From f873494741d2a542ee2a47e73be32d08e5d752d3 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 May 2024 23:04:38 +0100 Subject: [PATCH 2/7] Add state dict functionality to (inverse) KFAC linear operator --- curvlinops/inverse.py | 69 ++++++++++++++++++- curvlinops/kfac.py | 154 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 220 insertions(+), 3 deletions(-) diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 9ad4e43..3a357bf 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -1,7 +1,7 @@ """Implements linear operator inverses.""" from math import sqrt -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn from einops import einsum, rearrange @@ -695,3 +695,70 @@ def _matmat(self, M: ndarray) -> ndarray: M_torch = self._A._preprocess(M) M_torch = self.torch_matmat(M_torch) return self._A._postprocess(M_torch) + + def state_dict(self) -> Dict[str, Any]: + """Return the state of the inverse KFAC linear operator. + + Returns: + State dictionary. + """ + return { + "A": self._A.state_dict(), + # Attributes + "damping": self._damping, + "use_heuristic_damping": self._use_heuristic_damping, + "min_damping": self._min_damping, + "use_exact_damping": self._use_exact_damping, + "cache": self._cache, + "retry_double_precision": self._retry_double_precision, + # Inverse Kronecker factors (if computed and cached) + "inverse_input_covariances": self._inverse_input_covariances, + "inverse_gradient_covariances": self._inverse_gradient_covariances, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the state of the inverse KFAC linear operator. + + Args: + state_dict: State dictionary. + """ + self._A.load_state_dict(state_dict["A"]) + + # Set attributes + self._damping = state_dict["damping"] + self._use_heuristic_damping = state_dict["use_heuristic_damping"] + self._min_damping = state_dict["min_damping"] + self._use_exact_damping = state_dict["use_exact_damping"] + self._cache = state_dict["cache"] + self._retry_double_precision = state_dict["retry_double_precision"] + + # Set inverse Kronecker factors (if computed and cached) + self._inverse_input_covariances = state_dict["inverse_input_covariances"] + self._inverse_gradient_covariances = state_dict["inverse_gradient_covariances"] + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, Any], + A: KFACLinearOperator, + ) -> "KFACInverseLinearOperator": + """Load an inverse KFAC linear operator from a state dictionary. + + Args: + state_dict: State dictionary. + A: ``KFACLinearOperator`` whose inverse is formed. + + Returns: + Linear operator of inverse KFAC approximation. + """ + inv_kfac = cls( + A, + damping=state_dict["damping"], + use_heuristic_damping=state_dict["use_heuristic_damping"], + min_damping=state_dict["min_damping"], + use_exact_damping=state_dict["use_exact_damping"], + cache=state_dict["cache"], + retry_double_precision=state_dict["retry_double_precision"], + ) + inv_kfac.load_state_dict(state_dict) + return inv_kfac diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 4909d8e..a4185cb 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -21,7 +21,7 @@ from collections.abc import MutableMapping from functools import partial from math import sqrt -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from einops import einsum, rearrange, reduce from numpy import ndarray @@ -111,7 +111,7 @@ class KFACLinearOperator(_LinearOperator): def __init__( # noqa: C901 self, model_func: Module, - loss_func: MSELoss, + loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], params: List[Parameter], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], progressbar: bool = False, @@ -1070,3 +1070,153 @@ def frobenius_norm(self) -> Tensor: ) self._frobenius_norm.sqrt_() return self._frobenius_norm + + def state_dict(self) -> Dict[str, Any]: + """Return the state of the KFAC linear operator. + + Returns: + State dictionary. + """ + loss_type = { + MSELoss: "MSELoss", + CrossEntropyLoss: "CrossEntropyLoss", + BCEWithLogitsLoss: "BCEWithLogitsLoss", + }[type(self._loss_func)] + return { + # Model and loss function + "model_func_state_dict": self._model_func.state_dict(), + "loss_type": loss_type, + "loss_reduction": self._loss_func.reduction, + # Attributes + "progressbar": self._progressbar, + "shape": self._shape, + "seed": self._seed, + "fisher_type": self._fisher_type, + "mc_samples": self._mc_samples, + "kfac_approx": self._kfac_approx, + "loss_average": self._loss_average, + "num_per_example_loss_terms": self._num_per_example_loss_terms, + "separate_weight_and_bias": self._separate_weight_and_bias, + "num_data": self._N_data, + # Kronecker factors (if computed) + "input_covariances": self._input_covariances, + "gradient_covariances": self._gradient_covariances, + # Properties (not necessarily computed) + "trace": self._trace, + "det": self._det, + "logdet": self._logdet, + "frobenius_norm": self._frobenius_norm, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the state of the KFAC linear operator. + + Args: + state_dict: State dictionary. + + Raises: + ValueError: If the loss function does not match the state dict. + ValueError: If the loss function reduction does not match the state dict. + """ + self._model_func.load_state_dict(state_dict["model_func_state_dict"]) + # Verify that the loss function and its reduction match the state dict + loss_func_type = { + "MSELoss": MSELoss, + "CrossEntropyLoss": CrossEntropyLoss, + "BCEWithLogitsLoss": BCEWithLogitsLoss, + }[state_dict["loss_type"]] + if not isinstance(self._loss_func, loss_func_type): + raise ValueError( + f"Loss function mismatch: {loss_func_type} != {type(self._loss_func)}." + ) + if state_dict["loss_reduction"] != self._loss_func.reduction: + raise ValueError( + "Loss function reduction mismatch: " + f"{state_dict['loss_reduction']} != {self._loss_func.reduction}." + ) + + # Set attributes + self._progressbar = state_dict["progressbar"] + self._shape = state_dict["shape"] + self._seed = state_dict["seed"] + self._fisher_type = state_dict["fisher_type"] + self._mc_samples = state_dict["mc_samples"] + self._kfac_approx = state_dict["kfac_approx"] + self._loss_average = state_dict["loss_average"] + self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"] + self._separate_weight_and_bias = state_dict["separate_weight_and_bias"] + self._N_data = state_dict["num_data"] + + # Set Kronecker factors (if computed) + self._input_covariances = state_dict["input_covariances"] + self._gradient_covariances = state_dict["gradient_covariances"] + + # Set properties (not necessarily computed) + self._trace = state_dict["trace"] + self._det = state_dict["det"] + self._logdet = state_dict["logdet"] + self._frobenius_norm = state_dict["frobenius_norm"] + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, Any], + model_func: Module, + params: List[Parameter], + data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], + check_deterministic: bool = True, + batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, + ) -> KFACLinearOperator: + """Load a KFAC linear operator from a state dictionary. + + Args: + state_dict: State dictionary. + model_func: The model function. + params: The model's parameters that KFAC is computed for. + data: A data loader containing the data of the Fisher/GGN. + check_deterministic: Whether to check that the linear operator is + deterministic. Defaults to ``True``. + batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this + needs to be specified. The intended behavior is to consume the first + entry of the iterates from ``data`` and return their batch size. + + Returns: + Linear operator of KFAC approximation. + """ + loss_func = { + "MSELoss": MSELoss, + "CrossEntropyLoss": CrossEntropyLoss, + "BCEWithLogitsLoss": BCEWithLogitsLoss, + }[state_dict["loss_type"]](reduction=state_dict["loss_reduction"]) + kfac = cls( + model_func, + loss_func, + params, + data, + batch_size_fn=batch_size_fn, + check_deterministic=False, + progressbar=state_dict["progressbar"], + shape=state_dict["shape"], + seed=state_dict["seed"], + fisher_type=state_dict["fisher_type"], + mc_samples=state_dict["mc_samples"], + kfac_approx=state_dict["kfac_approx"], + loss_average=state_dict["loss_average"], + num_per_example_loss_terms=state_dict["num_per_example_loss_terms"], + separate_weight_and_bias=state_dict["separate_weight_and_bias"], + num_data=state_dict["num_data"], + ) + kfac.load_state_dict(state_dict) + + # Potentially call `check_deterministic` after the state dict is loaded + if check_deterministic: + old_device = kfac._device + kfac.to_device(device("cpu")) + try: + kfac._check_deterministic() + except RuntimeError as e: + raise e + finally: + kfac.to_device(old_device) + + return kfac From dfeb2573d42d20c107dc4bfe746b276062800f04 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 May 2024 23:20:44 +0100 Subject: [PATCH 3/7] Fix tests --- curvlinops/kfac.py | 7 +++++-- test/test_inverse.py | 6 +++--- test/test_kfac.py | 48 ++++++++++++++++++++++---------------------- 3 files changed, 32 insertions(+), 29 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index a4185cb..f5069a9 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -1089,7 +1089,7 @@ def state_dict(self) -> Dict[str, Any]: "loss_reduction": self._loss_func.reduction, # Attributes "progressbar": self._progressbar, - "shape": self._shape, + "shape": self.shape, "seed": self._seed, "fisher_type": self._fisher_type, "mc_samples": self._mc_samples, @@ -1137,7 +1137,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): # Set attributes self._progressbar = state_dict["progressbar"] - self._shape = state_dict["shape"] + self.shape = state_dict["shape"] self._seed = state_dict["seed"] self._fisher_type = state_dict["fisher_type"] self._mc_samples = state_dict["mc_samples"] @@ -1182,6 +1182,9 @@ def from_state_dict( Returns: Linear operator of KFAC approximation. + + Raises: + RuntimeError: If the check for deterministic behavior fails. """ loss_func = { "MSELoss": MSELoss, diff --git a/test/test_inverse.py b/test/test_inverse.py index 8762bc8..8ee26ac 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -683,10 +683,10 @@ def test_KFAC_inverse_save_and_load_state_dict(): # save state dict state_dict = inv_kfac.state_dict() + # create new inverse KFAC with different linop input and try to load state dict + wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)]) + inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac) with raises(ValueError, match="mismatch"): - # create new inverse KFAC with different linop input and try to load state dict - wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)]) - inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac) inv_kfac_wrong.load_state_dict(state_dict) # create new inverse KFAC and load state dict diff --git a/test/test_kfac.py b/test/test_kfac.py index 3d08ba4..997c83a 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1243,37 +1243,37 @@ def test_save_and_load_state_dict(): # save state dict state_dict = kfac.state_dict() + # create new KFAC with different loss function and try to load state dict + kfac_new = KFACLinearOperator( + model, + CrossEntropyLoss(), + params, + [(X, y)], + ) with raises(ValueError, match="loss"): - # create new KFAC with different loss function and try to load state dict - kfac_new = KFACLinearOperator( - model, - CrossEntropyLoss(), - params, - [(X, y)], - ) kfac_new.load_state_dict(state_dict) + # create new KFAC with different loss reduction and try to load state dict + kfac_new = KFACLinearOperator( + model, + MSELoss(), + params, + [(X, y)], + ) with raises(ValueError, match="reduction"): - # create new KFAC with different loss reduction and try to load state dict - kfac_new = KFACLinearOperator( - model, - MSELoss(), - params, - [(X, y)], - ) kfac_new.load_state_dict(state_dict) + # create new KFAC with different model and try to load state dict + wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out)) + wrong_params = list(wrong_model.parameters()) + kfac_new = KFACLinearOperator( + wrong_model, + MSELoss(reduction="sum"), + wrong_params, + [(X, y)], + loss_average=None, + ) with raises(RuntimeError, match="loading state_dict"): - # create new KFAC with different model and try to load state dict - wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out)) - wrong_params = list(wrong_model.parameters()) - kfac_new = KFACLinearOperator( - wrong_model, - MSELoss(reduction="sum"), - wrong_params, - [(X, y)], - loss_average=None, - ) kfac_new.load_state_dict(state_dict) # create new KFAC and load state dict From 04a32594228201fc8cccbd16799fbda4e9369855 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 23 May 2024 15:00:03 +0100 Subject: [PATCH 4/7] Address review comments on tests --- curvlinops/inverse.py | 4 +-- test/test_inverse.py | 59 +++++++++++++++++-------------------------- test/test_kfac.py | 12 ++++----- 3 files changed, 30 insertions(+), 45 deletions(-) diff --git a/curvlinops/inverse.py b/curvlinops/inverse.py index 3a357bf..8b49b24 100644 --- a/curvlinops/inverse.py +++ b/curvlinops/inverse.py @@ -738,9 +738,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): @classmethod def from_state_dict( - cls, - state_dict: Dict[str, Any], - A: KFACLinearOperator, + cls, state_dict: Dict[str, Any], A: KFACLinearOperator ) -> "KFACInverseLinearOperator": """Load an inverse KFAC linear operator from a state dictionary. diff --git a/test/test_inverse.py b/test/test_inverse.py index 8ee26ac..a9aeee6 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -656,6 +656,25 @@ def test_KFAC_inverse_damped_torch_matvec( report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy()) +def compare_state_dicts(state_dict: dict, state_dict_new: dict): + """Compare two state dicts recursively.""" + assert len(state_dict) == len(state_dict_new) + for value, value_new in zip(state_dict.values(), state_dict_new.values()): + if isinstance(value, torch.Tensor): + assert torch.allclose(value, value_new) + elif isinstance(value, dict): + compare_state_dicts(value, value_new) + elif isinstance(value, tuple): + assert len(value) == len(value_new) + assert all([isinstance(v, type(v2)) for v, v2 in zip(value, value_new)]) + assert all( + torch.allclose(torch.as_tensor(v), torch.as_tensor(v2)) + for v, v2 in zip(value, value_new) + ) + else: + assert value == value_new + + def test_KFAC_inverse_save_and_load_state_dict(): """Test that KFACInverseLinearOperator can be saved and loaded from state dict.""" torch.manual_seed(0) @@ -694,26 +713,10 @@ def test_KFAC_inverse_save_and_load_state_dict(): inv_kfac_new.load_state_dict(state_dict) # check that the two inverse KFACs are equal - def compare_state_dicts(state_dict: dict, state_dict_new: dict): - assert len(state_dict) == len(state_dict_new) - for value, value_new in zip(state_dict.values(), state_dict_new.values()): - if isinstance(value, torch.Tensor): - assert torch.allclose(value, value_new) - elif isinstance(value, dict): - compare_state_dicts(value, value_new) - elif isinstance(value, tuple): - assert all( - torch.allclose(torch.as_tensor(v), torch.as_tensor(v2)) - for v, v2 in zip(value, value_new) - ) - else: - assert value == value_new - + test_vec = torch.rand(inv_kfac.shape[1]) + report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) - test_mat = torch.rand(inv_kfac.shape[1]) - report_nonclose(inv_kfac @ test_mat, inv_kfac_new @ test_mat) - def test_KFAC_inverse_from_state_dict(): """Test that KFACInverseLinearOperator can be created from state dict.""" @@ -743,22 +746,6 @@ def test_KFAC_inverse_from_state_dict(): inv_kfac_new = KFACInverseLinearOperator.from_state_dict(state_dict, kfac) # check that the two inverse KFACs are equal - def compare_state_dicts(state_dict: dict, state_dict_new: dict): - assert len(state_dict) == len(state_dict_new) - for value, value_new in zip(state_dict.values(), state_dict_new.values()): - if isinstance(value, torch.Tensor): - assert torch.allclose(value, value_new) - elif isinstance(value, dict): - compare_state_dicts(value, value_new) - elif isinstance(value, tuple): - assert all( - torch.allclose(torch.as_tensor(v), torch.as_tensor(v2)) - for v, v2 in zip(value, value_new) - ) - else: - assert value == value_new - + test_vec = torch.rand(kfac.shape[1]) + report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) - - test_mat = torch.rand(kfac.shape[1]) - report_nonclose(inv_kfac @ test_mat, inv_kfac_new @ test_mat) diff --git a/test/test_kfac.py b/test/test_kfac.py index 997c83a..4e0500c 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1286,6 +1286,9 @@ def test_save_and_load_state_dict(): ) # check that the two KFACs are equal + test_vec = rand(kfac.shape[1]) + report_nonclose(kfac @ test_vec, kfac_new @ test_vec) + assert len(kfac.state_dict()) == len(kfac_new.state_dict()) for value, value_new in zip( kfac.state_dict().values(), kfac_new.state_dict().values() @@ -1298,9 +1301,6 @@ def test_save_and_load_state_dict(): else: assert value == value_new - test_mat = rand(kfac.shape[1]) - report_nonclose(kfac @ test_mat, kfac_new @ test_mat) - def test_from_state_dict(): """Test that KFACLinearOperator can be created from state dict.""" @@ -1327,6 +1327,9 @@ def test_from_state_dict(): kfac_new = KFACLinearOperator.from_state_dict(state_dict, model, params, [(X, y)]) # check that the two KFACs are equal + test_vec = rand(kfac.shape[1]) + report_nonclose(kfac @ test_vec, kfac_new @ test_vec) + assert len(kfac.state_dict()) == len(kfac_new.state_dict()) for value, value_new in zip( kfac.state_dict().values(), kfac_new.state_dict().values() @@ -1338,6 +1341,3 @@ def test_from_state_dict(): assert allclose(val, value_new[key]) else: assert value == value_new - - test_mat = rand(kfac.shape[1]) - report_nonclose(kfac @ test_mat, kfac_new @ test_mat) From d5cecfc18949555b3eb5e7a2721898f829589035 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 23 May 2024 15:18:22 +0100 Subject: [PATCH 5/7] Test torch.save/load as well and fix order of equivalence checks --- test/test_inverse.py | 15 ++++++++++----- test/test_kfac.py | 27 +++++++++++++++++---------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/test/test_inverse.py b/test/test_inverse.py index a9aeee6..205e643 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -1,5 +1,6 @@ """Contains tests for ``curvlinops/inverse``.""" +import os from math import sqrt from test.utils import cast_input from typing import Iterable, List, Tuple, Union @@ -666,7 +667,7 @@ def compare_state_dicts(state_dict: dict, state_dict_new: dict): compare_state_dicts(value, value_new) elif isinstance(value, tuple): assert len(value) == len(value_new) - assert all([isinstance(v, type(v2)) for v, v2 in zip(value, value_new)]) + assert all(isinstance(v, type(v2)) for v, v2 in zip(value, value_new)) assert all( torch.allclose(torch.as_tensor(v), torch.as_tensor(v2)) for v, v2 in zip(value, value_new) @@ -701,21 +702,25 @@ def test_KFAC_inverse_save_and_load_state_dict(): # save state dict state_dict = inv_kfac.state_dict() + torch.save(state_dict, "inv_kfac_state_dict.pt") # create new inverse KFAC with different linop input and try to load state dict wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)]) inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac) with raises(ValueError, match="mismatch"): - inv_kfac_wrong.load_state_dict(state_dict) + inv_kfac_wrong.load_state_dict(torch.load(state_dict)) # create new inverse KFAC and load state dict inv_kfac_new = KFACInverseLinearOperator(kfac) - inv_kfac_new.load_state_dict(state_dict) + inv_kfac_new.load_state_dict(torch.load(state_dict)) # check that the two inverse KFACs are equal + compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) test_vec = torch.rand(inv_kfac.shape[1]) report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) - compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) + + # clean up + os.remove("inv_kfac_state_dict.pt") def test_KFAC_inverse_from_state_dict(): @@ -746,6 +751,6 @@ def test_KFAC_inverse_from_state_dict(): inv_kfac_new = KFACInverseLinearOperator.from_state_dict(state_dict, kfac) # check that the two inverse KFACs are equal + compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) test_vec = torch.rand(kfac.shape[1]) report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) - compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) diff --git a/test/test_kfac.py b/test/test_kfac.py index 4e0500c..de52a4b 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,5 +1,6 @@ """Contains tests for ``curvlinops.kfac``.""" +import os from test.cases import DEVICES, DEVICES_IDS from test.utils import ( Conv2dModel, @@ -20,7 +21,7 @@ from scipy.linalg import block_diag from torch import Tensor, allclose, cat, cuda, device from torch import eye as torch_eye -from torch import isinf, isnan, manual_seed, rand, rand_like, randperm +from torch import isinf, isnan, load, manual_seed, rand, rand_like, randperm, save from torch.nn import ( BCEWithLogitsLoss, CrossEntropyLoss, @@ -1242,6 +1243,7 @@ def test_save_and_load_state_dict(): # save state dict state_dict = kfac.state_dict() + save(state_dict, "kfac_state_dict.pt") # create new KFAC with different loss function and try to load state dict kfac_new = KFACLinearOperator( @@ -1251,7 +1253,7 @@ def test_save_and_load_state_dict(): [(X, y)], ) with raises(ValueError, match="loss"): - kfac_new.load_state_dict(state_dict) + kfac_new.load_state_dict(load(state_dict)) # create new KFAC with different loss reduction and try to load state dict kfac_new = KFACLinearOperator( @@ -1261,7 +1263,7 @@ def test_save_and_load_state_dict(): [(X, y)], ) with raises(ValueError, match="reduction"): - kfac_new.load_state_dict(state_dict) + kfac_new.load_state_dict(load(state_dict)) # create new KFAC with different model and try to load state dict wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out)) @@ -1274,7 +1276,7 @@ def test_save_and_load_state_dict(): loss_average=None, ) with raises(RuntimeError, match="loading state_dict"): - kfac_new.load_state_dict(state_dict) + kfac_new.load_state_dict(load(state_dict)) # create new KFAC and load state dict kfac_new = KFACLinearOperator( @@ -1283,12 +1285,11 @@ def test_save_and_load_state_dict(): params, [(X, y)], loss_average=None, + check_deterministic=False, # turn off to avoid computing KFAC again ) + kfac_new.load_state_dict(load(state_dict)) # check that the two KFACs are equal - test_vec = rand(kfac.shape[1]) - report_nonclose(kfac @ test_vec, kfac_new @ test_vec) - assert len(kfac.state_dict()) == len(kfac_new.state_dict()) for value, value_new in zip( kfac.state_dict().values(), kfac_new.state_dict().values() @@ -1301,6 +1302,12 @@ def test_save_and_load_state_dict(): else: assert value == value_new + test_vec = rand(kfac.shape[1]) + report_nonclose(kfac @ test_vec, kfac_new @ test_vec) + + # clean up + os.remove("kfac_state_dict.pt") + def test_from_state_dict(): """Test that KFACLinearOperator can be created from state dict.""" @@ -1327,9 +1334,6 @@ def test_from_state_dict(): kfac_new = KFACLinearOperator.from_state_dict(state_dict, model, params, [(X, y)]) # check that the two KFACs are equal - test_vec = rand(kfac.shape[1]) - report_nonclose(kfac @ test_vec, kfac_new @ test_vec) - assert len(kfac.state_dict()) == len(kfac_new.state_dict()) for value, value_new in zip( kfac.state_dict().values(), kfac_new.state_dict().values() @@ -1341,3 +1345,6 @@ def test_from_state_dict(): assert allclose(val, value_new[key]) else: assert value == value_new + + test_vec = rand(kfac.shape[1]) + report_nonclose(kfac @ test_vec, kfac_new @ test_vec) From fb6ac4bbc3018ebd41d3d4a27e6c709060f2c2db Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 23 May 2024 15:35:45 +0100 Subject: [PATCH 6/7] Check if covariance and mapping keys match when loading state dict --- curvlinops/kfac.py | 17 +++++++++++++++++ test/test_inverse.py | 4 ++-- test/test_kfac.py | 8 ++++---- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index f5069a9..419ba79 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -1148,6 +1148,23 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self._N_data = state_dict["num_data"] # Set Kronecker factors (if computed) + if self._input_covariances or self._gradient_covariances: + # If computed, check if the keys match the mapping keys + input_covariances_keys = set(self._input_covariances.keys()) + gradient_covariances_keys = set(self._gradient_covariances.keys()) + mapping_keys = set(self._mapping.keys()) + if ( + input_covariances_keys != mapping_keys + or gradient_covariances_keys != mapping_keys + ): + raise ValueError( + "Input or gradient covariance keys in state dict do not match " + "mapping keys of linear operator. " + "Difference between input covariance and mapping keys: " + f"{input_covariances_keys - mapping_keys}. " + "Difference between gradient covariance and mapping keys: " + f"{gradient_covariances_keys - mapping_keys}." + ) self._input_covariances = state_dict["input_covariances"] self._gradient_covariances = state_dict["gradient_covariances"] diff --git a/test/test_inverse.py b/test/test_inverse.py index 205e643..1e6fb8c 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -708,11 +708,11 @@ def test_KFAC_inverse_save_and_load_state_dict(): wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)]) inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac) with raises(ValueError, match="mismatch"): - inv_kfac_wrong.load_state_dict(torch.load(state_dict)) + inv_kfac_wrong.load_state_dict(torch.load("inv_kfac_state_dict.pt")) # create new inverse KFAC and load state dict inv_kfac_new = KFACInverseLinearOperator(kfac) - inv_kfac_new.load_state_dict(torch.load(state_dict)) + inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt")) # check that the two inverse KFACs are equal compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) diff --git a/test/test_kfac.py b/test/test_kfac.py index de52a4b..6cbc4d9 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1253,7 +1253,7 @@ def test_save_and_load_state_dict(): [(X, y)], ) with raises(ValueError, match="loss"): - kfac_new.load_state_dict(load(state_dict)) + kfac_new.load_state_dict(load("kfac_state_dict.pt")) # create new KFAC with different loss reduction and try to load state dict kfac_new = KFACLinearOperator( @@ -1263,7 +1263,7 @@ def test_save_and_load_state_dict(): [(X, y)], ) with raises(ValueError, match="reduction"): - kfac_new.load_state_dict(load(state_dict)) + kfac_new.load_state_dict(load("kfac_state_dict.pt")) # create new KFAC with different model and try to load state dict wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out)) @@ -1276,7 +1276,7 @@ def test_save_and_load_state_dict(): loss_average=None, ) with raises(RuntimeError, match="loading state_dict"): - kfac_new.load_state_dict(load(state_dict)) + kfac_new.load_state_dict(load("kfac_state_dict.pt")) # create new KFAC and load state dict kfac_new = KFACLinearOperator( @@ -1287,7 +1287,7 @@ def test_save_and_load_state_dict(): loss_average=None, check_deterministic=False, # turn off to avoid computing KFAC again ) - kfac_new.load_state_dict(load(state_dict)) + kfac_new.load_state_dict(load("kfac_state_dict.pt")) # check that the two KFACs are equal assert len(kfac.state_dict()) == len(kfac_new.state_dict()) From 1cf85bdb53a394b2bad8f805d120eaafef0a081b Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 23 May 2024 21:28:40 +0100 Subject: [PATCH 7/7] Use compare_state_dicts everywhere --- test/test_inverse.py | 26 +++----------------------- test/test_kfac.py | 32 +++++--------------------------- test/utils.py | 39 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 51 deletions(-) diff --git a/test/test_inverse.py b/test/test_inverse.py index 1e6fb8c..0a5e68c 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -2,7 +2,7 @@ import os from math import sqrt -from test.utils import cast_input +from test.utils import cast_input, compare_state_dicts from typing import Iterable, List, Tuple, Union import torch @@ -657,25 +657,6 @@ def test_KFAC_inverse_damped_torch_matvec( report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy()) -def compare_state_dicts(state_dict: dict, state_dict_new: dict): - """Compare two state dicts recursively.""" - assert len(state_dict) == len(state_dict_new) - for value, value_new in zip(state_dict.values(), state_dict_new.values()): - if isinstance(value, torch.Tensor): - assert torch.allclose(value, value_new) - elif isinstance(value, dict): - compare_state_dicts(value, value_new) - elif isinstance(value, tuple): - assert len(value) == len(value_new) - assert all(isinstance(v, type(v2)) for v, v2 in zip(value, value_new)) - assert all( - torch.allclose(torch.as_tensor(v), torch.as_tensor(v2)) - for v, v2 in zip(value, value_new) - ) - else: - assert value == value_new - - def test_KFAC_inverse_save_and_load_state_dict(): """Test that KFACInverseLinearOperator can be saved and loaded from state dict.""" torch.manual_seed(0) @@ -713,15 +694,14 @@ def test_KFAC_inverse_save_and_load_state_dict(): # create new inverse KFAC and load state dict inv_kfac_new = KFACInverseLinearOperator(kfac) inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt")) + # clean up + os.remove("inv_kfac_state_dict.pt") # check that the two inverse KFACs are equal compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict()) test_vec = torch.rand(inv_kfac.shape[1]) report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec) - # clean up - os.remove("inv_kfac_state_dict.pt") - def test_KFAC_inverse_from_state_dict(): """Test that KFACInverseLinearOperator can be created from state dict.""" diff --git a/test/test_kfac.py b/test/test_kfac.py index 6cbc4d9..a2b7a2f 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -8,6 +8,7 @@ WeightShareModel, binary_classification_targets, classification_targets, + compare_state_dicts, ggn_block_diagonal, regression_targets, ) @@ -1288,26 +1289,14 @@ def test_save_and_load_state_dict(): check_deterministic=False, # turn off to avoid computing KFAC again ) kfac_new.load_state_dict(load("kfac_state_dict.pt")) + # clean up + os.remove("kfac_state_dict.pt") # check that the two KFACs are equal - assert len(kfac.state_dict()) == len(kfac_new.state_dict()) - for value, value_new in zip( - kfac.state_dict().values(), kfac_new.state_dict().values() - ): - if isinstance(value, Tensor): - assert allclose(value, value_new) - elif isinstance(value, dict): - for key, val in value.items(): - assert allclose(val, value_new[key]) - else: - assert value == value_new - + compare_state_dicts(kfac.state_dict(), kfac_new.state_dict()) test_vec = rand(kfac.shape[1]) report_nonclose(kfac @ test_vec, kfac_new @ test_vec) - # clean up - os.remove("kfac_state_dict.pt") - def test_from_state_dict(): """Test that KFACLinearOperator can be created from state dict.""" @@ -1334,17 +1323,6 @@ def test_from_state_dict(): kfac_new = KFACLinearOperator.from_state_dict(state_dict, model, params, [(X, y)]) # check that the two KFACs are equal - assert len(kfac.state_dict()) == len(kfac_new.state_dict()) - for value, value_new in zip( - kfac.state_dict().values(), kfac_new.state_dict().values() - ): - if isinstance(value, Tensor): - assert allclose(value, value_new) - elif isinstance(value, dict): - for key, val in value.items(): - assert allclose(val, value_new[key]) - else: - assert value == value_new - + compare_state_dicts(kfac.state_dict(), kfac_new.state_dict()) test_vec = rand(kfac.shape[1]) report_nonclose(kfac @ test_vec, kfac_new @ test_vec) diff --git a/test/utils.py b/test/utils.py index 07e6b82..dcc656b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -7,7 +7,18 @@ from einops import rearrange, reduce from einops.layers.torch import Rearrange from numpy import eye, ndarray -from torch import Tensor, cat, cuda, device, dtype, from_numpy, rand, randint +from torch import ( + Tensor, + allclose, + as_tensor, + cat, + cuda, + device, + dtype, + from_numpy, + rand, + randint, +) from torch.nn import ( AdaptiveAvgPool2d, BCEWithLogitsLoss, @@ -367,3 +378,29 @@ def batch_size_fn(X: MutableMapping) -> int: batch_size: The first dimension size of the tensor. """ return X["x"].shape[0] + + +def compare_state_dicts(state_dict: dict, state_dict_new: dict): + """Compare two state dicts recursively. + + Args: + state_dict (dict): The first state dict to compare. + state_dict_new (dict): The second state dict to compare. + + Raises: + AssertionError: If the state dicts are not equal. + """ + assert len(state_dict) == len(state_dict_new) + for value, value_new in zip(state_dict.values(), state_dict_new.values()): + if isinstance(value, Tensor): + assert allclose(value, value_new) + elif isinstance(value, dict): + compare_state_dicts(value, value_new) + elif isinstance(value, tuple): + assert len(value) == len(value_new) + assert all(isinstance(v, type(v2)) for v, v2 in zip(value, value_new)) + assert all( + allclose(as_tensor(v), as_tensor(v2)) for v, v2 in zip(value, value_new) + ) + else: + assert value == value_new