Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] state_dict functionality to KFACLinearOperator and KFACInverseLinearOperator #114

Merged
merged 8 commits into from
May 23, 2024
69 changes: 68 additions & 1 deletion curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implements linear operator inverses."""

from math import sqrt
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

from einops import einsum, rearrange
Expand Down Expand Up @@ -695,3 +695,70 @@ def _matmat(self, M: ndarray) -> ndarray:
M_torch = self._A._preprocess(M)
M_torch = self.torch_matmat(M_torch)
return self._A._postprocess(M_torch)

def state_dict(self) -> Dict[str, Any]:
"""Return the state of the inverse KFAC linear operator.

Returns:
State dictionary.
"""
return {
"A": self._A.state_dict(),
# Attributes
"damping": self._damping,
"use_heuristic_damping": self._use_heuristic_damping,
"min_damping": self._min_damping,
"use_exact_damping": self._use_exact_damping,
"cache": self._cache,
"retry_double_precision": self._retry_double_precision,
# Inverse Kronecker factors (if computed and cached)
"inverse_input_covariances": self._inverse_input_covariances,
"inverse_gradient_covariances": self._inverse_gradient_covariances,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the state of the inverse KFAC linear operator.

Args:
state_dict: State dictionary.
"""
self._A.load_state_dict(state_dict["A"])

# Set attributes
self._damping = state_dict["damping"]
self._use_heuristic_damping = state_dict["use_heuristic_damping"]
self._min_damping = state_dict["min_damping"]
self._use_exact_damping = state_dict["use_exact_damping"]
self._cache = state_dict["cache"]
self._retry_double_precision = state_dict["retry_double_precision"]

# Set inverse Kronecker factors (if computed and cached)
self._inverse_input_covariances = state_dict["inverse_input_covariances"]
self._inverse_gradient_covariances = state_dict["inverse_gradient_covariances"]

@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, Any],
A: KFACLinearOperator,
runame marked this conversation as resolved.
Show resolved Hide resolved
) -> "KFACInverseLinearOperator":
"""Load an inverse KFAC linear operator from a state dictionary.

Args:
state_dict: State dictionary.
A: ``KFACLinearOperator`` whose inverse is formed.

Returns:
Linear operator of inverse KFAC approximation.
"""
inv_kfac = cls(
A,
damping=state_dict["damping"],
use_heuristic_damping=state_dict["use_heuristic_damping"],
min_damping=state_dict["min_damping"],
use_exact_damping=state_dict["use_exact_damping"],
cache=state_dict["cache"],
retry_double_precision=state_dict["retry_double_precision"],
)
inv_kfac.load_state_dict(state_dict)
return inv_kfac
157 changes: 155 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from collections.abc import MutableMapping
from functools import partial
from math import sqrt
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from einops import einsum, rearrange, reduce
from numpy import ndarray
Expand Down Expand Up @@ -111,7 +111,7 @@ class KFACLinearOperator(_LinearOperator):
def __init__( # noqa: C901
self,
model_func: Module,
loss_func: MSELoss,
loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
Expand Down Expand Up @@ -1070,3 +1070,156 @@ def frobenius_norm(self) -> Tensor:
)
self._frobenius_norm.sqrt_()
return self._frobenius_norm

def state_dict(self) -> Dict[str, Any]:
"""Return the state of the KFAC linear operator.

Returns:
State dictionary.
"""
runame marked this conversation as resolved.
Show resolved Hide resolved
loss_type = {
MSELoss: "MSELoss",
CrossEntropyLoss: "CrossEntropyLoss",
BCEWithLogitsLoss: "BCEWithLogitsLoss",
}[type(self._loss_func)]
return {
# Model and loss function
"model_func_state_dict": self._model_func.state_dict(),
"loss_type": loss_type,
"loss_reduction": self._loss_func.reduction,
# Attributes
"progressbar": self._progressbar,
"shape": self.shape,
"seed": self._seed,
"fisher_type": self._fisher_type,
"mc_samples": self._mc_samples,
"kfac_approx": self._kfac_approx,
"loss_average": self._loss_average,
"num_per_example_loss_terms": self._num_per_example_loss_terms,
"separate_weight_and_bias": self._separate_weight_and_bias,
"num_data": self._N_data,
# Kronecker factors (if computed)
"input_covariances": self._input_covariances,
"gradient_covariances": self._gradient_covariances,
# Properties (not necessarily computed)
"trace": self._trace,
"det": self._det,
"logdet": self._logdet,
"frobenius_norm": self._frobenius_norm,
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the state of the KFAC linear operator.

Args:
state_dict: State dictionary.

Raises:
ValueError: If the loss function does not match the state dict.
ValueError: If the loss function reduction does not match the state dict.
"""
self._model_func.load_state_dict(state_dict["model_func_state_dict"])
# Verify that the loss function and its reduction match the state dict
loss_func_type = {
"MSELoss": MSELoss,
"CrossEntropyLoss": CrossEntropyLoss,
"BCEWithLogitsLoss": BCEWithLogitsLoss,
}[state_dict["loss_type"]]
if not isinstance(self._loss_func, loss_func_type):
raise ValueError(
f"Loss function mismatch: {loss_func_type} != {type(self._loss_func)}."
)
if state_dict["loss_reduction"] != self._loss_func.reduction:
raise ValueError(
"Loss function reduction mismatch: "
f"{state_dict['loss_reduction']} != {self._loss_func.reduction}."
)

# Set attributes
self._progressbar = state_dict["progressbar"]
self.shape = state_dict["shape"]
self._seed = state_dict["seed"]
self._fisher_type = state_dict["fisher_type"]
self._mc_samples = state_dict["mc_samples"]
self._kfac_approx = state_dict["kfac_approx"]
self._loss_average = state_dict["loss_average"]
self._num_per_example_loss_terms = state_dict["num_per_example_loss_terms"]
self._separate_weight_and_bias = state_dict["separate_weight_and_bias"]
self._N_data = state_dict["num_data"]

# Set Kronecker factors (if computed)
runame marked this conversation as resolved.
Show resolved Hide resolved
self._input_covariances = state_dict["input_covariances"]
self._gradient_covariances = state_dict["gradient_covariances"]

# Set properties (not necessarily computed)
self._trace = state_dict["trace"]
self._det = state_dict["det"]
self._logdet = state_dict["logdet"]
self._frobenius_norm = state_dict["frobenius_norm"]

@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, Any],
model_func: Module,
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
check_deterministic: bool = True,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
) -> KFACLinearOperator:
"""Load a KFAC linear operator from a state dictionary.

Args:
state_dict: State dictionary.
model_func: The model function.
params: The model's parameters that KFAC is computed for.
data: A data loader containing the data of the Fisher/GGN.
check_deterministic: Whether to check that the linear operator is
deterministic. Defaults to ``True``.
batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this
needs to be specified. The intended behavior is to consume the first
entry of the iterates from ``data`` and return their batch size.

Returns:
Linear operator of KFAC approximation.

Raises:
RuntimeError: If the check for deterministic behavior fails.
"""
loss_func = {
"MSELoss": MSELoss,
"CrossEntropyLoss": CrossEntropyLoss,
"BCEWithLogitsLoss": BCEWithLogitsLoss,
}[state_dict["loss_type"]](reduction=state_dict["loss_reduction"])
kfac = cls(
model_func,
loss_func,
params,
data,
batch_size_fn=batch_size_fn,
check_deterministic=False,
progressbar=state_dict["progressbar"],
shape=state_dict["shape"],
seed=state_dict["seed"],
fisher_type=state_dict["fisher_type"],
mc_samples=state_dict["mc_samples"],
kfac_approx=state_dict["kfac_approx"],
loss_average=state_dict["loss_average"],
num_per_example_loss_terms=state_dict["num_per_example_loss_terms"],
separate_weight_and_bias=state_dict["separate_weight_and_bias"],
num_data=state_dict["num_data"],
)
kfac.load_state_dict(state_dict)

# Potentially call `check_deterministic` after the state dict is loaded
if check_deterministic:
old_device = kfac._device
kfac.to_device(device("cpu"))
try:
kfac._check_deterministic()
except RuntimeError as e:
raise e
finally:
kfac.to_device(old_device)

return kfac
108 changes: 108 additions & 0 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,111 @@ def test_KFAC_inverse_damped_torch_matvec(

# Test against _matmat
report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy())


def test_KFAC_inverse_save_and_load_state_dict():
"""Test that KFACInverseLinearOperator can be saved and loaded from state dict."""
torch.manual_seed(0)
batch_size, D_in, D_out = 4, 3, 2
X = torch.rand(batch_size, D_in)
y = torch.rand(batch_size, D_out)
model = torch.nn.Linear(D_in, D_out)
runame marked this conversation as resolved.
Show resolved Hide resolved

params = list(model.parameters())
# create and compute KFAC
kfac = KFACLinearOperator(
model,
MSELoss(reduction="sum"),
params,
[(X, y)],
loss_average=None,
)

# create inverse KFAC
inv_kfac = KFACInverseLinearOperator(
kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False
)
_ = inv_kfac @ eye(kfac.shape[1]) # to trigger inverse computation

# save state dict
state_dict = inv_kfac.state_dict()
runame marked this conversation as resolved.
Show resolved Hide resolved

# create new inverse KFAC with different linop input and try to load state dict
wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)])
inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac)
with raises(ValueError, match="mismatch"):
inv_kfac_wrong.load_state_dict(state_dict)

# create new inverse KFAC and load state dict
inv_kfac_new = KFACInverseLinearOperator(kfac)
inv_kfac_new.load_state_dict(state_dict)

# check that the two inverse KFACs are equal
def compare_state_dicts(state_dict: dict, state_dict_new: dict):
assert len(state_dict) == len(state_dict_new)
for value, value_new in zip(state_dict.values(), state_dict_new.values()):
if isinstance(value, torch.Tensor):
assert torch.allclose(value, value_new)
elif isinstance(value, dict):
compare_state_dicts(value, value_new)
elif isinstance(value, tuple):
assert all(
runame marked this conversation as resolved.
Show resolved Hide resolved
torch.allclose(torch.as_tensor(v), torch.as_tensor(v2))
for v, v2 in zip(value, value_new)
)
else:
assert value == value_new

compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())

test_mat = torch.rand(inv_kfac.shape[1])
runame marked this conversation as resolved.
Show resolved Hide resolved
report_nonclose(inv_kfac @ test_mat, inv_kfac_new @ test_mat)
runame marked this conversation as resolved.
Show resolved Hide resolved


def test_KFAC_inverse_from_state_dict():
"""Test that KFACInverseLinearOperator can be created from state dict."""
torch.manual_seed(0)
batch_size, D_in, D_out = 4, 3, 2
X = torch.rand(batch_size, D_in)
y = torch.rand(batch_size, D_out)
model = torch.nn.Linear(D_in, D_out)

params = list(model.parameters())
# create and compute KFAC
kfac = KFACLinearOperator(
model,
MSELoss(reduction="sum"),
params,
[(X, y)],
loss_average=None,
)

# create inverse KFAC and save state dict
inv_kfac = KFACInverseLinearOperator(
kfac, damping=1e-2, use_heuristic_damping=True, retry_double_precision=False
)
state_dict = inv_kfac.state_dict()

# create new KFAC from state dict
inv_kfac_new = KFACInverseLinearOperator.from_state_dict(state_dict, kfac)

# check that the two inverse KFACs are equal
def compare_state_dicts(state_dict: dict, state_dict_new: dict):
assert len(state_dict) == len(state_dict_new)
for value, value_new in zip(state_dict.values(), state_dict_new.values()):
if isinstance(value, torch.Tensor):
assert torch.allclose(value, value_new)
elif isinstance(value, dict):
compare_state_dicts(value, value_new)
elif isinstance(value, tuple):
assert all(
torch.allclose(torch.as_tensor(v), torch.as_tensor(v2))
for v, v2 in zip(value, value_new)
)
else:
assert value == value_new

compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())

test_mat = torch.rand(kfac.shape[1])
report_nonclose(inv_kfac @ test_mat, inv_kfac_new @ test_mat)
runame marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading