From 9722fa80d0741c0561bc1aa073928dfc1c47b3e4 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 31 Oct 2023 21:20:13 +0100 Subject: [PATCH] Add fisher_type argument and support empirical Fisher in KFAC --- curvlinops/kfac.py | 43 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index aeb7550..3ae02f2 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -80,6 +80,7 @@ def __init__( check_deterministic: bool = True, shape: Union[Tuple[int, int], None] = None, seed: int = 2147483647, + fisher_type: str = "mc", mc_samples: int = 1, ): """Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN. @@ -114,6 +115,14 @@ def __init__( from the parameters. Defaults to ``None``. seed: The seed for the random number generator used to draw labels from the model's predictive distribution. Defaults to ``2147483647``. + fisher_type: The type of Fisher/GGN to approximate. If ``'type-2'``, the + expectation over the model outputs is computed exactly by + computing the backwward pass for each output dimension. This is + sometimes also called type-2 Fisher. If ``'mc'``, the expectation + is approximated by sampling ``mc_samples`` labels from the model's + predictive distribution. If ``'empirical'``, the empirical gradients + are used which corresponds to the uncentered gradient covariance, or + the empirical Fisher. Defaults to ``'mc'``. mc_samples: The number of Monte-Carlo samples to use per data point. Defaults to ``1``. @@ -152,6 +161,7 @@ def __init__( self._seed = seed self._generator: Union[None, Generator] = None + self._fisher_type = fisher_type self._mc_samples = mc_samples self._input_covariances: Dict[Tuple[int, ...], Tensor] = {} self._gradient_covariances: Dict[Tuple[int, ...], Tensor] = {} @@ -210,7 +220,14 @@ def _adjoint(self) -> KFACLinearOperator: return self def _compute_kfac(self): - """Compute and cache KFAC's Kronecker factors for future ``matvec``s.""" + """Compute and cache KFAC's Kronecker factors for future ``matvec``s. + + Raises: + NotImplementedError: If ``fisher_type == 'type-2'``. + ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or + ``'empirical'``. + + """ # install forward and backward hooks hook_handles: List[RemovableHandle] = [] hook_handles.extend( @@ -231,13 +248,27 @@ def _compute_kfac(self): self._generator = Generator(device=self._device) self._generator.manual_seed(self._seed) - for X, _ in self._loop_over_data(desc="KFAC matrices"): + for X, y in self._loop_over_data(desc="KFAC matrices"): output = self._model_func(X) - for mc in range(self._mc_samples): - y_sampled = self.draw_label(output) - loss = self._loss_func(output, y_sampled) - loss.backward(retain_graph=mc != self._mc_samples - 1) + if self._fisher_type == "type-2": + raise NotImplementedError( + "Using the exact expectation for computing the KFAC " + "approximation of the Fisher is not yet supported." + ) + elif self._fisher_type == "mc": + for mc in range(self._mc_samples): + y_sampled = self.draw_label(output) + loss = self._loss_func(output, y_sampled) + loss.backward(retain_graph=mc != self._mc_samples - 1) + elif self._fisher_type == "empirical": + loss = self._loss_func(output, y) + loss.backward() + else: + raise ValueError( + f"Invalid fisher_type: {self._fisher_type}. " + + "Supported: 'type-2', 'mc', 'empirical'." + ) # clean up self._model_func.zero_grad()