From c6fe6845104d9926209e6a5f15513334230b0dc5 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 23 Sep 2024 15:50:56 -0400 Subject: [PATCH] [ADD] Implement empirical Fisher as `CurvatureLinearOperator` --- curvlinops/gradient_moments.py | 59 +++++++---------- .../basic_usage/example_visual_tour.py | 2 +- test/test_gradient_moments.py | 63 ++++--------------- test/test_kfac.py | 2 +- test/test_submatrix_on_curvatures.py | 6 +- 5 files changed, 41 insertions(+), 91 deletions(-) diff --git a/curvlinops/gradient_moments.py b/curvlinops/gradient_moments.py index af746cb..a420e36 100644 --- a/curvlinops/gradient_moments.py +++ b/curvlinops/gradient_moments.py @@ -1,6 +1,4 @@ -"""Contains LinearOperator implementation of gradient moment matrices.""" - -from __future__ import annotations +"""Contains linear operator implementation of gradient moment matrices.""" from collections.abc import MutableMapping from typing import Callable, Iterable, List, Optional, Tuple, Union @@ -11,11 +9,11 @@ from torch.autograd import grad from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, Parameter -from curvlinops._base import _LinearOperator +from curvlinops._torch_base import CurvatureLinearOperator -class EFLinearOperator(_LinearOperator): - r"""Uncentered gradient covariance as SciPy linear operator. +class EFLinearOperator(CurvatureLinearOperator): + r"""Uncentered gradient covariance as PyTorch linear operator. The uncentered gradient covariance is often called 'empirical Fisher' (EF). @@ -41,23 +39,24 @@ class EFLinearOperator(_LinearOperator): \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right)^\top\,. - .. note:: - Multiplication with the empirical Fisher is currently implemented with an - inefficient for-loop. + Attributes: + SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for + empirical Fisher. """ supported_losses = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss) + SELF_ADJOINT: bool = True def __init__( self, - model_func: Callable[[Tensor], Tensor], + model_func: Callable[[Union[MutableMapping, Tensor]], Tensor], loss_func: Union[Callable[[Tensor, Tensor], Tensor], None], params: List[Parameter], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], progressbar: bool = False, check_deterministic: bool = True, num_data: Optional[int] = None, - batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, + batch_size_fn: Optional[Callable[[Union[Tensor, MutableMapping]], int]] = None, ): """Linear operator for the uncentered gradient covariance/empirical Fisher (EF). @@ -95,7 +94,7 @@ def __init__( Raises: NotImplementedError: If the loss function differs from ``MSELoss``, - BCEWithLogitsLoss, or ``CrossEntropyLoss``. + ``BCEWithLogitsLoss``, or ``CrossEntropyLoss``. """ if not isinstance(loss_func, self.supported_losses): raise NotImplementedError( @@ -113,21 +112,21 @@ def __init__( ) def _matmat_batch( - self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor] - ) -> Tuple[Tensor, ...]: - """Apply the mini-batch empirical Fisher to a matrix. + self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] + ) -> List[Tensor]: + """Apply the mini-batch empirical Fisher to a matrix in tensor list format. Args: X: Input to the DNN. y: Ground truth. - M_list: Matrix to be multiplied with in list format. + M_list: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an - additional leading axis for the matrix columns. + additional trailing axis for the matrix columns. Returns: - Result of EF multiplication in list format. Has the same shape as - ``M_list``, i.e. each tensor in the list has the shape of a parameter and a - leading dimension of matrix columns. + Result of EF multiplication in tensor list format. Has the same shape as + ``M``, i.e. each tensor in the list has the shape of a parameter and a + trailing dimension of matrix columns. """ output = self._model_func(X) # If >2d output we convert to an equivalent 2d output @@ -163,24 +162,14 @@ def _matmat_batch( ) # Multiply the EF onto each vector in the input matrix - result_list = [zeros_like(M) for M in M_list] - num_vectors = M_list[0].shape[0] + EM = [zeros_like(m) for m in M] + (num_vectors,) = {m.shape[-1] for m in M} for v in range(num_vectors): for idx, ggnvp in enumerate( ggn_vector_product_from_plist( - loss, output, self._params, [M[v] for M in M_list] + loss, output, self._params, [m[..., v] for m in M] ) ): - result_list[idx][v].add_(ggnvp.detach()) - - return tuple(result_list) - - def _adjoint(self) -> EFLinearOperator: - """Return the linear operator representing the adjoint. + EM[idx][..., v].add_(ggnvp.detach()) - The empirical Fisher is real symmetric, and hence self-adjoint. - - Returns: - Self. - """ - return self + return EM diff --git a/docs/examples/basic_usage/example_visual_tour.py b/docs/examples/basic_usage/example_visual_tour.py index 0461d27..b79e03a 100644 --- a/docs/examples/basic_usage/example_visual_tour.py +++ b/docs/examples/basic_usage/example_visual_tour.py @@ -81,7 +81,7 @@ model, loss_function, params, dataloader ).to_scipy() GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader).to_scipy() -EF_linop = EFLinearOperator(model, loss_function, params, dataloader) +EF_linop = EFLinearOperator(model, loss_function, params, dataloader).to_scipy() # %% # diff --git a/test/test_gradient_moments.py b/test/test_gradient_moments.py index 3ea5d5d..da3b6bf 100644 --- a/test/test_gradient_moments.py +++ b/test/test_gradient_moments.py @@ -1,65 +1,26 @@ """Contains tests for ``curvlinops/gradient_moments.py``.""" -from collections.abc import MutableMapping - -from numpy import random -from pytest import raises +from test.utils import compare_matmat from curvlinops import EFLinearOperator from curvlinops.examples.functorch import functorch_empirical_fisher -from curvlinops.examples.utils import report_nonclose - - -def test_EFLinearOperator_matvec(case, adjoint: bool): - model_func, loss_func, params, data, batch_size_fn = case - - # Test when X is dict-like but batch_size_fn = None (default) - if isinstance(data[0][0], MutableMapping): - with raises(ValueError): - op = EFLinearOperator(model_func, loss_func, params, data) - - op = EFLinearOperator( - model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) - op_functorch = ( - functorch_empirical_fisher( - model_func, - loss_func, - params, - data, - input_key="x", - ) - .detach() - .cpu() - .numpy() - ) - if adjoint: - op, op_functorch = op.adjoint(), op_functorch.conj().T - x = random.rand(op.shape[1]).astype(op.dtype) - report_nonclose(op @ x, op_functorch @ x, atol=1e-5) +def test_EFLinearOperator(case, adjoint: bool, is_vec: bool): + """Test matrix-matrix multiplication with the empirical Fisher. -def test_EFLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): + Args: + case: Tuple of model, loss function, parameters, data, and batch size getter. + adjoint: Whether to test the adjoint operator. + is_vec: Whether to test matrix-vector or matrix-matrix multiplication. + """ model_func, loss_func, params, data, batch_size_fn = case - op = EFLinearOperator( + E = EFLinearOperator( model_func, loss_func, params, data, batch_size_fn=batch_size_fn ) - op_functorch = ( - functorch_empirical_fisher( - model_func, - loss_func, - params, - data, - input_key="x", - ) - .detach() - .cpu() - .numpy() + E_mat = functorch_empirical_fisher( + model_func, loss_func, params, data, input_key="x" ) - if adjoint: - op, op_functorch = op.adjoint(), op_functorch.conj().T - X = random.rand(op.shape[1], num_vecs).astype(op.dtype) - report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=1e-4) + compare_matmat(E, E_mat, adjoint, is_vec, rtol=1e-4, atol=1e-7) diff --git a/test/test_kfac.py b/test/test_kfac.py index e7f8d61..e6f4311 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -340,7 +340,7 @@ def test_kfac_ef_one_datum( for param in params: ef = EFLinearOperator( model, loss_func, [param], data, batch_size_fn=batch_size_fn - ) + ).to_scipy() ef_blocks.append(ef @ eye(ef.shape[1])) ef = block_diag(*ef_blocks) diff --git a/test/test_submatrix_on_curvatures.py b/test/test_submatrix_on_curvatures.py index a8605eb..74c91b6 100644 --- a/test/test_submatrix_on_curvatures.py +++ b/test/test_submatrix_on_curvatures.py @@ -63,9 +63,9 @@ def setup_submatrix_linear_operator(case, operator_case, submatrix_case): row_idxs = submatrix_case["row_idx_fn"](dim) col_idxs = submatrix_case["col_idx_fn"](dim) - A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn) - if isinstance(A, (HessianLinearOperator, GGNLinearOperator)): - A = A.to_scipy() + A = operator_case( + model_func, loss_func, params, data, batch_size_fn=batch_size_fn + ).to_scipy() A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs) A_functorch = CURVATURE_IN_FUNCTORCH[operator_case](