Skip to content

Commit

Permalink
Add fisher_type argument and support empirical Fisher in KFAC
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Oct 31, 2023
1 parent 8d3e44c commit 9722fa8
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
check_deterministic: bool = True,
shape: Union[Tuple[int, int], None] = None,
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand Down Expand Up @@ -114,6 +115,14 @@ def __init__(
from the parameters. Defaults to ``None``.
seed: The seed for the random number generator used to draw labels
from the model's predictive distribution. Defaults to ``2147483647``.
fisher_type: The type of Fisher/GGN to approximate. If ``'type-2'``, the
expectation over the model outputs is computed exactly by
computing the backwward pass for each output dimension. This is
sometimes also called type-2 Fisher. If ``'mc'``, the expectation
is approximated by sampling ``mc_samples`` labels from the model's
predictive distribution. If ``'empirical'``, the empirical gradients
are used which corresponds to the uncentered gradient covariance, or
the empirical Fisher. Defaults to ``'mc'``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Defaults to ``1``.
Expand Down Expand Up @@ -152,6 +161,7 @@ def __init__(

self._seed = seed
self._generator: Union[None, Generator] = None
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._input_covariances: Dict[Tuple[int, ...], Tensor] = {}
self._gradient_covariances: Dict[Tuple[int, ...], Tensor] = {}
Expand Down Expand Up @@ -210,7 +220,14 @@ def _adjoint(self) -> KFACLinearOperator:
return self

def _compute_kfac(self):
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s.
Raises:
NotImplementedError: If ``fisher_type == 'type-2'``.
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
"""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []
hook_handles.extend(
Expand All @@ -231,13 +248,27 @@ def _compute_kfac(self):
self._generator = Generator(device=self._device)
self._generator.manual_seed(self._seed)

for X, _ in self._loop_over_data(desc="KFAC matrices"):
for X, y in self._loop_over_data(desc="KFAC matrices"):
output = self._model_func(X)

for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
if self._fisher_type == "type-2":
raise NotImplementedError(
"Using the exact expectation for computing the KFAC "
"approximation of the Fisher is not yet supported."
)
elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()
else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)

# clean up
self._model_func.zero_grad()
Expand Down

0 comments on commit 9722fa8

Please sign in to comment.