Skip to content

Commit

Permalink
[ADD] Implement empirical Fisher as CurvatureLinearOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 23, 2024
1 parent 87c6eba commit c6fe684
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 91 deletions.
59 changes: 24 additions & 35 deletions curvlinops/gradient_moments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Contains LinearOperator implementation of gradient moment matrices."""

from __future__ import annotations
"""Contains linear operator implementation of gradient moment matrices."""

from collections.abc import MutableMapping
from typing import Callable, Iterable, List, Optional, Tuple, Union
Expand All @@ -11,11 +9,11 @@
from torch.autograd import grad
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, Parameter

from curvlinops._base import _LinearOperator
from curvlinops._torch_base import CurvatureLinearOperator


class EFLinearOperator(_LinearOperator):
r"""Uncentered gradient covariance as SciPy linear operator.
class EFLinearOperator(CurvatureLinearOperator):
r"""Uncentered gradient covariance as PyTorch linear operator.
The uncentered gradient covariance is often called 'empirical Fisher' (EF).
Expand All @@ -41,23 +39,24 @@ class EFLinearOperator(_LinearOperator):
\ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)
\right)^\top\,.
.. note::
Multiplication with the empirical Fisher is currently implemented with an
inefficient for-loop.
Attributes:
SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for
empirical Fisher.
"""

supported_losses = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)
SELF_ADJOINT: bool = True

def __init__(
self,
model_func: Callable[[Tensor], Tensor],
model_func: Callable[[Union[MutableMapping, Tensor]], Tensor],
loss_func: Union[Callable[[Tensor, Tensor], Tensor], None],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
num_data: Optional[int] = None,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
batch_size_fn: Optional[Callable[[Union[Tensor, MutableMapping]], int]] = None,
):
"""Linear operator for the uncentered gradient covariance/empirical Fisher (EF).
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(
Raises:
NotImplementedError: If the loss function differs from ``MSELoss``,
BCEWithLogitsLoss, or ``CrossEntropyLoss``.
``BCEWithLogitsLoss``, or ``CrossEntropyLoss``.
"""
if not isinstance(loss_func, self.supported_losses):
raise NotImplementedError(
Expand All @@ -113,21 +112,21 @@ def __init__(
)

def _matmat_batch(
self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor]
) -> Tuple[Tensor, ...]:
"""Apply the mini-batch empirical Fisher to a matrix.
self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor]
) -> List[Tensor]:
"""Apply the mini-batch empirical Fisher to a matrix in tensor list format.
Args:
X: Input to the DNN.
y: Ground truth.
M_list: Matrix to be multiplied with in list format.
M_list: Matrix to be multiplied with in tensor list format.
Tensors have same shape as trainable model parameters, and an
additional leading axis for the matrix columns.
additional trailing axis for the matrix columns.
Returns:
Result of EF multiplication in list format. Has the same shape as
``M_list``, i.e. each tensor in the list has the shape of a parameter and a
leading dimension of matrix columns.
Result of EF multiplication in tensor list format. Has the same shape as
``M``, i.e. each tensor in the list has the shape of a parameter and a
trailing dimension of matrix columns.
"""
output = self._model_func(X)
# If >2d output we convert to an equivalent 2d output
Expand Down Expand Up @@ -163,24 +162,14 @@ def _matmat_batch(
)

# Multiply the EF onto each vector in the input matrix
result_list = [zeros_like(M) for M in M_list]
num_vectors = M_list[0].shape[0]
EM = [zeros_like(m) for m in M]
(num_vectors,) = {m.shape[-1] for m in M}
for v in range(num_vectors):
for idx, ggnvp in enumerate(
ggn_vector_product_from_plist(
loss, output, self._params, [M[v] for M in M_list]
loss, output, self._params, [m[..., v] for m in M]
)
):
result_list[idx][v].add_(ggnvp.detach())

return tuple(result_list)

def _adjoint(self) -> EFLinearOperator:
"""Return the linear operator representing the adjoint.
EM[idx][..., v].add_(ggnvp.detach())

The empirical Fisher is real symmetric, and hence self-adjoint.
Returns:
Self.
"""
return self
return EM
2 changes: 1 addition & 1 deletion docs/examples/basic_usage/example_visual_tour.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
model, loss_function, params, dataloader
).to_scipy()
GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader).to_scipy()
EF_linop = EFLinearOperator(model, loss_function, params, dataloader)
EF_linop = EFLinearOperator(model, loss_function, params, dataloader).to_scipy()

# %%
#
Expand Down
63 changes: 12 additions & 51 deletions test/test_gradient_moments.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,26 @@
"""Contains tests for ``curvlinops/gradient_moments.py``."""

from collections.abc import MutableMapping

from numpy import random
from pytest import raises
from test.utils import compare_matmat

from curvlinops import EFLinearOperator
from curvlinops.examples.functorch import functorch_empirical_fisher
from curvlinops.examples.utils import report_nonclose


def test_EFLinearOperator_matvec(case, adjoint: bool):
model_func, loss_func, params, data, batch_size_fn = case

# Test when X is dict-like but batch_size_fn = None (default)
if isinstance(data[0][0], MutableMapping):
with raises(ValueError):
op = EFLinearOperator(model_func, loss_func, params, data)

op = EFLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
op_functorch = (
functorch_empirical_fisher(
model_func,
loss_func,
params,
data,
input_key="x",
)
.detach()
.cpu()
.numpy()
)
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

x = random.rand(op.shape[1]).astype(op.dtype)
report_nonclose(op @ x, op_functorch @ x, atol=1e-5)

def test_EFLinearOperator(case, adjoint: bool, is_vec: bool):
"""Test matrix-matrix multiplication with the empirical Fisher.
def test_EFLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3):
Args:
case: Tuple of model, loss function, parameters, data, and batch size getter.
adjoint: Whether to test the adjoint operator.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
"""
model_func, loss_func, params, data, batch_size_fn = case

op = EFLinearOperator(
E = EFLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
op_functorch = (
functorch_empirical_fisher(
model_func,
loss_func,
params,
data,
input_key="x",
)
.detach()
.cpu()
.numpy()
E_mat = functorch_empirical_fisher(
model_func, loss_func, params, data, input_key="x"
)
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

X = random.rand(op.shape[1], num_vecs).astype(op.dtype)
report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=1e-4)
compare_matmat(E, E_mat, adjoint, is_vec, rtol=1e-4, atol=1e-7)
2 changes: 1 addition & 1 deletion test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_kfac_ef_one_datum(
for param in params:
ef = EFLinearOperator(
model, loss_func, [param], data, batch_size_fn=batch_size_fn
)
).to_scipy()
ef_blocks.append(ef @ eye(ef.shape[1]))
ef = block_diag(*ef_blocks)

Expand Down
6 changes: 3 additions & 3 deletions test/test_submatrix_on_curvatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def setup_submatrix_linear_operator(case, operator_case, submatrix_case):
row_idxs = submatrix_case["row_idx_fn"](dim)
col_idxs = submatrix_case["col_idx_fn"](dim)

A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn)
if isinstance(A, (HessianLinearOperator, GGNLinearOperator)):
A = A.to_scipy()
A = operator_case(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
).to_scipy()
A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs)

A_functorch = CURVATURE_IN_FUNCTORCH[operator_case](
Expand Down

0 comments on commit c6fe684

Please sign in to comment.