Skip to content

Commit

Permalink
[ADD] Support for empirical Fisher in KFAC (#54)
Browse files Browse the repository at this point in the history
* Add test for one datum KFAC EF exactness

* Add fisher_type argument and support empirical Fisher in KFAC

* Minor docstring fixes

* Clarify fisher_type docstring
  • Loading branch information
runame authored Nov 7, 2023
1 parent b3d7463 commit 238a7e4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 6 deletions.
44 changes: 38 additions & 6 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
check_deterministic: bool = True,
shape: Union[Tuple[int, int], None] = None,
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand Down Expand Up @@ -114,7 +115,17 @@ def __init__(
from the parameters. Defaults to ``None``.
seed: The seed for the random number generator used to draw labels
from the model's predictive distribution. Defaults to ``2147483647``.
fisher_type: The type of Fisher/GGN to approximate. If 'type-2', the
exact Hessian of the loss w.r.t. the model outputs is used. This
requires as many backward passes as the output dimension, i.e.
the number of classes for classification. This is sometimes also
called type-2 Fisher. If ``'mc'``, the expectation is approximated
by sampling ``mc_samples`` labels from the model's predictive
distribution. If ``'empirical'``, the empirical gradients are
used which corresponds to the uncentered gradient covariance, or
the empirical Fisher. Defaults to ``'mc'``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Will be ignored when ``fisher_type`` is not ``'mc'``.
Defaults to ``1``.
Raises:
Expand Down Expand Up @@ -152,6 +163,7 @@ def __init__(

self._seed = seed
self._generator: Union[None, Generator] = None
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._input_covariances: Dict[Tuple[int, ...], Tensor] = {}
self._gradient_covariances: Dict[Tuple[int, ...], Tensor] = {}
Expand Down Expand Up @@ -210,7 +222,13 @@ def _adjoint(self) -> KFACLinearOperator:
return self

def _compute_kfac(self):
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s.
Raises:
NotImplementedError: If ``fisher_type == 'type-2'``.
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
"""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []
hook_handles.extend(
Expand All @@ -231,13 +249,27 @@ def _compute_kfac(self):
self._generator = Generator(device=self._device)
self._generator.manual_seed(self._seed)

for X, _ in self._loop_over_data(desc="KFAC matrices"):
for X, y in self._loop_over_data(desc="KFAC matrices"):
output = self._model_func(X)

for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
if self._fisher_type == "type-2":
raise NotImplementedError(
"Using the exact expectation for computing the KFAC "
"approximation of the Fisher is not yet supported."
)
elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()
else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)

# clean up
self._model_func.zero_grad()
Expand Down
14 changes: 14 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,17 @@ def kfac_expand_exact_one_datum_case(
"""
case = request.param
yield initialize_case(case)


@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES)
def kfac_ef_exact_one_datum_case(
request,
) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]:
"""Prepare a test case with one datum for which KFAC with empirical gradients equals the EF.
Yields:
A neural network, the mean-squared error function, a list of parameters, and
a data set.
"""
case = request.param
yield initialize_case(case)
20 changes: 20 additions & 0 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from curvlinops.examples.utils import report_nonclose
from curvlinops.ggn import GGNLinearOperator
from curvlinops.gradient_moments import EFLinearOperator
from curvlinops.kfac import KFACLinearOperator


Expand Down Expand Up @@ -68,3 +69,22 @@ def test_kfac_one_datum(
rtol = {"sum": 3e-2, "mean": 3e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)


def test_kfac_ef_one_datum(
kfac_ef_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_ef_exact_one_datum_case

ef_blocks = [] # list of per-parameter EFs
for param in params:
ef = EFLinearOperator(model, loss_func, [param], data)
ef_blocks.append(ef @ eye(ef.shape[1]))
ef = block_diag(*ef_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="empirical")
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ef, kfac_mat)

0 comments on commit 238a7e4

Please sign in to comment.