Skip to content

Commit

Permalink
[ADD] Implement GGN as CurvatureLinearOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 22, 2024
1 parent 9c538ea commit 3a8ac95
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 79 deletions.
47 changes: 21 additions & 26 deletions curvlinops/ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from torch import Tensor, zeros_like

from curvlinops._base import _LinearOperator
from curvlinops._torch_base import CurvatureLinearOperator


class GGNLinearOperator(_LinearOperator):
r"""GGN as SciPy linear operator.
class GGNLinearOperator(CurvatureLinearOperator):
r"""Linear operator for the generalized Gauss-Newton matrix of an empirical risk.
Consider the empirical risk
Expand All @@ -39,47 +39,42 @@ class GGNLinearOperator(_LinearOperator):
\mathbf{J}_{\mathbf{\theta}}
f_{\mathbf{\theta}}(\mathbf{x}_n)
\right)\,.
Attributes:
SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for GGNs.
"""

SELF_ADJOINT: bool = True

def _matmat_batch(
self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor]
) -> Tuple[Tensor, ...]:
self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor]
) -> List[Tensor]:
"""Apply the mini-batch GGN to a matrix.
Args:
X: Input to the DNN.
y: Ground truth.
M_list: Matrix to be multiplied with in list format.
M: Matrix to be multiplied with in tensor list format.
Tensors have same shape as trainable model parameters, and an
additional leading axis for the matrix columns.
additional trailing axis for the matrix columns.
Returns:
Result of GGN multiplication in list format. Has the same shape as
``M_list``, i.e. each tensor in the list has the shape of a parameter and a
leading dimension of matrix columns.
``M_``, i.e. each tensor in the list has the shape of a parameter and a
trailing dimension of matrix columns.
"""
output = self._model_func(X)
loss = self._loss_func(output, y)

# collect matrix-matrix products per parameter
result_list = [zeros_like(M) for M in M_list]
(num_vecs,) = {m.shape[-1] for m in M}
GM = [zeros_like(m) for m in M]

num_vecs = M_list[0].shape[0]
for n in range(num_vecs):
col_n_list = ggn_vector_product_from_plist(
loss, output, self._params, [M[n] for M in M_list]
col_n = ggn_vector_product_from_plist(
loss, output, self._params, [m[..., n] for m in M]
)
for result, col_n in zip(result_list, col_n_list):
result[n].add_(col_n)

return tuple(result_list)

def _adjoint(self) -> GGNLinearOperator:
"""Return the linear operator representing the adjoint.
for GM_p, col_n_p in zip(GM, col_n):
GM_p[..., n].add_(col_n_p)

The GGN is real symmetric, and hence self-adjoint.
Returns:
Self.
"""
return self
return GM
10 changes: 5 additions & 5 deletions curvlinops/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,20 @@ def _matmat_batch(
grad_params = grad(loss, self._params, create_graph=True)

(num_vecs,) = {m.shape[-1] for m in M}
AM = [zeros_like(m) for m in M]
HM = [zeros_like(m) for m in M]

# per-block HMP
for M_block, p_block, g_block, AM_block in zip(
for M_block, p_block, g_block, HM_block in zip(
split_list(M, self._block_sizes),
split_list(self._params, self._block_sizes),
split_list(grad_params, self._block_sizes),
split_list(AM, self._block_sizes),
split_list(HM, self._block_sizes),
):
for n in range(num_vecs):
col_n = hessian_vector_product(
loss, p_block, [M[..., n] for M in M_block], grad_params=g_block
)
for p, col in enumerate(col_n):
AM_block[p][..., n].add_(col)
HM_block[p][..., n].add_(col)

return AM
return HM
2 changes: 1 addition & 1 deletion docs/examples/basic_usage/example_fisher_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
# Fisher and compute their matrix representations by multiplying them onto the
# identity matrix:

GGN = GGNLinearOperator(model, loss_function, params, data)
GGN = GGNLinearOperator(model, loss_function, params, data).to_scipy()
F = FisherMCLinearOperator(model, loss_function, params, data)

D = GGN.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/basic_usage/example_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def batch_size_fn(x: MutableMapping):
[(data, data["labels"])], # We still need to input a list of "(X, y)" pairs!
check_deterministic=False,
batch_size_fn=batch_size_fn, # Remember to specify this!
)
).to_scipy()

G = ggn @ np.eye(ggn.shape[0])

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/basic_usage/example_inverses.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# First, we set up a linear operator for the damped GGN/Fisher

data = [(X1, y1), (X2, y2)]
GGN = GGNLinearOperator(model, loss_function, params, data)
GGN = GGNLinearOperator(model, loss_function, params, data).to_scipy()

delta = 1e-2
damping = aslinearoperator(delta * sparse.eye(GGN.shape[0]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
#
# Setting up a linear operator for the Fisher/GGN is identical to the Hessian.

GGN = GGNLinearOperator(model, loss_function, params, data)
GGN = GGNLinearOperator(model, loss_function, params, data).to_scipy()

# %%
#
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/basic_usage/example_model_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def make_dataset() -> TensorDataset:
loss_function,
[p for p in model.parameters() if p.requires_grad],
data_loader,
)
).to_scipy()
for model, loss_function, data_loader in zip(models, loss_functions, data_loaders)
]

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/basic_usage/example_visual_tour.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
Hessian_linop = HessianLinearOperator(
model, loss_function, params, dataloader
).to_scipy()
GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader)
GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader).to_scipy()
EF_linop = EFLinearOperator(model, loss_function, params, dataloader)

# %%
Expand Down
47 changes: 13 additions & 34 deletions test/test_ggn.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,32 @@
"""Contains tests for ``curvlinops/ggn``."""

from collections.abc import MutableMapping
from test.utils import compare_matmat

from numpy import random
from pytest import raises

from curvlinops import GGNLinearOperator
from curvlinops.examples.functorch import functorch_ggn
from curvlinops.examples.utils import report_nonclose


def test_GGNLinearOperator_matvec(case, adjoint: bool):
def test_GGNLinearOperator_matvec(case, adjoint: bool, is_vec: bool):
"""Test matrix-matrix multiplication with the GGN.
Args:
case: Tuple of model, loss function, parameters, data, and batch size getter.
adjoint: Whether to test the adjoint operator.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
"""
model_func, loss_func, params, data, batch_size_fn = case

# Test when X is dict-like but batch_size_fn = None (default)
if isinstance(data[0][0], MutableMapping):
with raises(ValueError):
op = GGNLinearOperator(model_func, loss_func, params, data)

op = GGNLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
op_functorch = (
functorch_ggn(model_func, loss_func, params, data, input_key="x")
.detach()
.cpu()
.numpy()
)
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T
_ = GGNLinearOperator(model_func, loss_func, params, data)

x = random.rand(op.shape[1])
report_nonclose(op @ x, op_functorch @ x)


def test_GGNLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3):
model_func, loss_func, params, data, batch_size_fn = case

op = GGNLinearOperator(
G = GGNLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
op_functorch = (
functorch_ggn(model_func, loss_func, params, data, input_key="x")
.detach()
.cpu()
.numpy()
)
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T
G_mat = functorch_ggn(model_func, loss_func, params, data, input_key="x")

X = random.rand(op.shape[1], num_vecs)
report_nonclose(op @ X, op_functorch @ X)
compare_matmat(G, G_mat, adjoint, is_vec, atol=1e-7)
2 changes: 1 addition & 1 deletion test/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def test_HessianLinearOperator(
Args:
case: Tuple of model, loss function, parameters, data, and batch size getter.
adjoint: Whether to test the adjoint operator.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
block_sizes_fn: The function that generates the block sizes used to define
block diagonal approximations from the parameters.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
"""
model_func, loss_func, params, data, batch_size_fn = case
block_sizes = block_sizes_fn(params)
Expand Down
10 changes: 5 additions & 5 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_CG_inverse_damped_GGN_matvec(inv_case, delta: float = 2e-2):

GGN = GGNLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
).to_scipy()
damping = aslinearoperator(delta * sparse.eye(GGN.shape[0]))

inv_GGN = CGInverseLinearOperator(GGN + damping)
Expand All @@ -56,7 +56,7 @@ def test_CG_inverse_damped_GGN_matmat(inv_case, delta: float = 1e-2, num_vecs: i

GGN = GGNLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
).to_scipy()
damping = aslinearoperator(delta * sparse.eye(GGN.shape[0]))

inv_GGN = CGInverseLinearOperator(GGN + damping)
Expand All @@ -78,7 +78,7 @@ def test_LSMR_inverse_damped_GGN_matvec(inv_case, delta: float = 2e-2):

GGN = GGNLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
).to_scipy()
damping = aslinearoperator(delta * sparse.eye(GGN.shape[0]))

inv_GGN = LSMRInverseLinearOperator(GGN + damping)
Expand All @@ -104,7 +104,7 @@ def test_LSMR_inverse_damped_GGN_matmat(

GGN = GGNLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
).to_scipy()
damping = aslinearoperator(delta * sparse.eye(GGN.shape[0]))

inv_GGN = LSMRInverseLinearOperator(GGN + damping)
Expand All @@ -128,7 +128,7 @@ def test_Neumann_inverse_damped_GGN_matvec(inv_case, delta: float = 1e-2):

GGN = GGNLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
).to_scipy()
damping = aslinearoperator(delta * sparse.eye(GGN.shape[0]))

damped_GGN_functorch = functorch_ggn(
Expand Down
2 changes: 1 addition & 1 deletion test/test_submatrix_on_curvatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def setup_submatrix_linear_operator(case, operator_case, submatrix_case):
col_idxs = submatrix_case["col_idx_fn"](dim)

A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn)
if isinstance(A, HessianLinearOperator):
if isinstance(A, (HessianLinearOperator, GGNLinearOperator)):
A = A.to_scipy()
A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs)

Expand Down
4 changes: 3 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def ggn_block_diagonal(
The block-diagonal GGN.
"""
# compute the full GGN then zero out the off-diagonal blocks
ggn = GGNLinearOperator(model, loss_func, params, data, batch_size_fn=batch_size_fn)
ggn = GGNLinearOperator(
model, loss_func, params, data, batch_size_fn=batch_size_fn
).to_scipy()
ggn = from_numpy(ggn @ eye(ggn.shape[1]))
sizes = [p.numel() for p in params]
# ggn_blocks[i, j] corresponds to the block of (params[i], params[j])
Expand Down

0 comments on commit 3a8ac95

Please sign in to comment.