diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 6ca86c0..e31d139 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -19,11 +19,12 @@ from einops import rearrange from numpy import ndarray -from torch import Generator, Tensor, cat, einsum, randn +from torch import Generator, Tensor, cat, einsum, randn, stack from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter from torch.utils.hooks import RemovableHandle from curvlinops._base import _LinearOperator +from curvlinops.kfac_utils import loss_hessian_matrix_sqrt class KFACLinearOperator(_LinearOperator): @@ -125,7 +126,7 @@ def __init__( used which corresponds to the uncentered gradient covariance, or the empirical Fisher. Defaults to ``'mc'``. mc_samples: The number of Monte-Carlo samples to use per data point. - Will be ignored when ``fisher_type`` is not ``'mc'``. + Has to be set to ``1`` when ``fisher_type != 'mc'``. Defaults to ``1``. separate_weight_and_bias: Whether to treat weights and biases separately. Defaults to ``True``. @@ -138,6 +139,11 @@ def __init__( raise ValueError( f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}." ) + if fisher_type != "mc" and mc_samples != 1: + raise ValueError( + f"Invalid mc_samples: {mc_samples}. " + "Only mc_samples=1 is supported for fisher_type != 'mc'." + ) self.param_ids = [p.data_ptr() for p in params] # mapping from tuples of parameter data pointers in a module to its name @@ -231,13 +237,7 @@ def _adjoint(self) -> KFACLinearOperator: return self def _compute_kfac(self): - """Compute and cache KFAC's Kronecker factors for future ``matvec``s. - - Raises: - NotImplementedError: If ``fisher_type == 'type-2'``. - ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or - ``'empirical'``. - """ + """Compute and cache KFAC's Kronecker factors for future ``matvec``s.""" # install forward and backward hooks hook_handles: List[RemovableHandle] = [] @@ -266,31 +266,70 @@ def _compute_kfac(self): for X, y in self._loop_over_data(desc="KFAC matrices"): output = self._model_func(X) - - if self._fisher_type == "type-2": - raise NotImplementedError( - "Using the exact expectation for computing the KFAC " - "approximation of the Fisher is not yet supported." - ) - elif self._fisher_type == "mc": - for mc in range(self._mc_samples): - y_sampled = self.draw_label(output) - loss = self._loss_func(output, y_sampled) - loss.backward(retain_graph=mc != self._mc_samples - 1) - elif self._fisher_type == "empirical": - loss = self._loss_func(output, y) - loss.backward() - else: - raise ValueError( - f"Invalid fisher_type: {self._fisher_type}. " - + "Supported: 'type-2', 'mc', 'empirical'." - ) + self._compute_loss_and_backward(output, y) # clean up self._model_func.zero_grad() for handle in hook_handles: handle.remove() + def _compute_loss_and_backward(self, output: Tensor, y: Tensor): + r"""Compute the loss and the backward pass(es) required for KFAC. + + Args: + output: The model's prediction + :math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`. + y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`. + + Raises: + ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or + ``'empirical'``. + NotImplementedError: If ``fisher_type`` is ``'type-1'`` and the + output is not 2d. + """ + if self._fisher_type == "type-2": + if output.ndim != 2: + raise NotImplementedError( + "Type-2 Fisher not implemented for non-2d output." + ) + # Compute per-sample Hessian square root, then concatenate over samples. + # Result has shape `(batch_size, num_classes, num_classes)` + hessian_sqrts = stack( + [ + loss_hessian_matrix_sqrt(out.detach(), self._loss_func) + for out in output.split(1) + ] + ) + + # Fix scaling caused by the batch dimension + batch_size = output.shape[0] + reduction = self._loss_func.reduction + scale = {"sum": 1.0, "mean": 1.0 / batch_size}[reduction] + hessian_sqrts.mul_(scale) + + # For each column `c` of the matrix square root we need to backpropagate, + # but we can do this for all samples in parallel + num_cols = hessian_sqrts.shape[-1] + for c in range(num_cols): + batched_column = hessian_sqrts[:, :, c] + (output * batched_column).sum().backward(retain_graph=c < num_cols - 1) + + elif self._fisher_type == "mc": + for mc in range(self._mc_samples): + y_sampled = self.draw_label(output) + loss = self._loss_func(output, y_sampled) + loss.backward(retain_graph=mc != self._mc_samples - 1) + + elif self._fisher_type == "empirical": + loss = self._loss_func(output, y) + loss.backward() + + else: + raise ValueError( + f"Invalid fisher_type: {self._fisher_type}. " + + "Supported: 'type-2', 'mc', 'empirical'." + ) + def draw_label(self, output: Tensor) -> Tensor: r"""Draw a sample from the model's predictive distribution. @@ -393,6 +432,7 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor): ) batch_size = g.shape[0] + # self._mc_samples will be 1 if fisher_type != "mc" correction = { "sum": 1.0 / self._mc_samples, "mean": batch_size**2 / (self._N_data * self._mc_samples), diff --git a/curvlinops/kfac_utils.py b/curvlinops/kfac_utils.py new file mode 100644 index 0000000..5263246 --- /dev/null +++ b/curvlinops/kfac_utils.py @@ -0,0 +1,92 @@ +"""Utility functions related to KFAC.""" + +from math import sqrt +from typing import Union + +from torch import Tensor, diag, einsum, eye +from torch.nn import CrossEntropyLoss, MSELoss + + +def loss_hessian_matrix_sqrt( + output_one_datum: Tensor, loss_func: Union[MSELoss, CrossEntropyLoss] +) -> Tensor: + r"""Compute the loss function's matrix square root for a sample's output. + + Args: + output_one_datum: The model's prediction on a single datum. Has shape + ``[1, C]`` where ``C`` is the number of classes (outputs of the neural + network). + loss_func: The loss function. + + Returns: + The matrix square root + :math:`\mathbf{S}` of the Hessian. Has shape + ``[C, C]`` and satisfies the relation + + .. math:: + \mathbf{S} \mathbf{S}^\top + = + \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) + \in \mathbb{R}^{C \times C} + + where :math:`\mathbf{f} := f(\mathbf{x}) \in \mathbb{R}^C` is the model's + prediction on a single datum :math:`\mathbf{x}` and :math:`\mathbf{y}` is + the label. + + Note: + For :class:`torch.nn.MSELoss` (with :math:`c = 1` for ``reduction='sum'`` + and :math:`c = 1/C` for ``reduction='mean'``), we have: + + .. math:: + \ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2 + \\ + \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C + \\ + \mathbf{S} &= \sqrt{2 c} \mathbf{I}_C + + Note: + For :class:`torch.nn.CrossEntropyLoss` (with :math:`c = 1` irrespective of the + reduction, :math:`\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C`, + and the element-wise natural logarithm :math:`\log`) we have: + + .. math:: + \ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y) + \\ + \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y) + = + c \left( + \mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top + \right) + \\ + \mathbf{S} = \sqrt{c} \left( + \mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top + \right)\,, + + where the square root is applied element-wise. See for instance Example 5.1 of + `this thesis `_ or equations (5) and (6) of + `this paper `_. + + Raises: + ValueError: If the batch size is not one, or the output is not 2d. + NotImplementedError: If the loss function is not supported. + """ + if output_one_datum.ndim != 2 or output_one_datum.shape[0] != 1: + raise ValueError( + f"Expected 'output_one_datum' to be 2d with shape [1, C], got " + f"{output_one_datum.shape}" + ) + output = output_one_datum.squeeze(0) + output_dim = output.numel() + + if isinstance(loss_func, MSELoss): + c = {"sum": 1.0, "mean": 1.0 / output_dim}[loss_func.reduction] + return eye(output_dim, device=output.device, dtype=output.dtype).mul_( + sqrt(2 * c) + ) + elif isinstance(loss_func, CrossEntropyLoss): + c = 1.0 + p = output_one_datum.softmax(dim=1).squeeze() + p_sqrt = p.sqrt() + return (diag(p_sqrt) - einsum("i,j->ij", p, p_sqrt)).mul_(sqrt(c)) + else: + raise NotImplementedError(f"Loss function {loss_func} not supported.") diff --git a/docs/rtd/index.rst b/docs/rtd/index.rst index 4d483e6..87569da 100644 --- a/docs/rtd/index.rst +++ b/docs/rtd/index.rst @@ -40,3 +40,8 @@ Installation linops basic_usage/index + +.. toctree:: + :caption: Internals + + internals diff --git a/docs/rtd/internals.rst b/docs/rtd/internals.rst new file mode 100644 index 0000000..f6d9b4f --- /dev/null +++ b/docs/rtd/internals.rst @@ -0,0 +1,11 @@ +Internals +============ + +This section is for internal purposes only and serves to inform developers about +details; because rendered LaTeX is easier to read than source code. + + +KFAC-related +------------- + +.. autofunction:: curvlinops.kfac_utils.loss_hessian_matrix_sqrt diff --git a/test/conftest.py b/test/conftest.py index 9097d4c..10f0617 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -86,17 +86,3 @@ def kfac_expand_exact_one_datum_case( """ case = request.param yield initialize_case(case) - - -@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES) -def kfac_ef_exact_one_datum_case( - request, -) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]: - """Prepare a test case with one datum for which KFAC with empirical gradients equals the EF. - - Yields: - A neural network, the mean-squared error function, a list of parameters, and - a data set. - """ - case = request.param - yield initialize_case(case) diff --git a/test/test_kfac.py b/test/test_kfac.py index 0769645..e2d4b18 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -8,7 +8,15 @@ from pytest import mark from scipy.linalg import block_diag from torch import Tensor, device, manual_seed, rand, randperm -from torch.nn import Linear, Module, MSELoss, Parameter, ReLU, Sequential +from torch.nn import ( + CrossEntropyLoss, + Linear, + Module, + MSELoss, + Parameter, + ReLU, + Sequential, +) from curvlinops.examples.utils import report_nonclose from curvlinops.gradient_moments import EFLinearOperator @@ -22,7 +30,7 @@ "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) -def test_kfac( +def test_kfac_type2( kfac_expand_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ], @@ -59,30 +67,71 @@ def test_kfac( data, separate_weight_and_bias=separate_weight_and_bias, ) - kfac = KFACLinearOperator( model, loss_func, params, data, - mc_samples=2_000, + fisher_type="type-2", separate_weight_and_bias=separate_weight_and_bias, ) kfac_mat = kfac @ eye(kfac.shape[1]) - atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction] - rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction] - - report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) + report_nonclose(ggn, kfac_mat) # Check that input covariances were not computed if exclude == "weight": assert len(kfac._input_covariances) == 0 +@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +def test_kfac_mc( + kfac_expand_exact_case: Tuple[ + Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] + ], + shuffle: bool, +): + """Test the KFAC implementation using MC samples against the exact GGN. + + Args: + kfac_expand_exact_case: A fixture that returns a model, loss function, list of + parameters, and data. + shuffle: Whether to shuffle the parameters before computing the KFAC matrix. + """ + model, loss_func, params, data = kfac_expand_exact_case + + if shuffle: + permutation = randperm(len(params)) + params = [params[i] for i in permutation] + + ggn = ggn_block_diagonal(model, loss_func, params, data) + kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) + + kfac_mat = kfac @ eye(kfac.shape[1]) + + atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction] + rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction] + + report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) + + def test_kfac_one_datum( kfac_expand_exact_one_datum_case: Tuple[ - Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] + Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] + ] +): + model, loss_func, params, data = kfac_expand_exact_one_datum_case + + ggn = ggn_block_diagonal(model, loss_func, params, data) + kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2") + kfac_mat = kfac @ eye(kfac.shape[1]) + + report_nonclose(ggn, kfac_mat) + + +def test_kfac_mc_one_datum( + kfac_expand_exact_one_datum_case: Tuple[ + Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ] ): model, loss_func, params, data = kfac_expand_exact_one_datum_case @@ -98,11 +147,11 @@ def test_kfac_one_datum( def test_kfac_ef_one_datum( - kfac_ef_exact_one_datum_case: Tuple[ - Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] + kfac_expand_exact_one_datum_case: Tuple[ + Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ] ): - model, loss_func, params, data = kfac_ef_exact_one_datum_case + model, loss_func, params, data = kfac_expand_exact_one_datum_case ef_blocks = [] # list of per-parameter EFs for param in params: