Skip to content

Commit

Permalink
[REF] Extract testing matrix multiplies in expectation
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 23, 2024
1 parent eba5976 commit e678ea2
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 92 deletions.
108 changes: 16 additions & 92 deletions test/test_fisher.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Contains tests for ``curvlinops/fisher.py``."""

from collections.abc import MutableMapping
from contextlib import redirect_stdout, suppress
from test.utils import compare_matmat_expectation

from numpy import random, zeros_like
from pytest import mark, raises
from pytest import mark

from curvlinops import FisherMCLinearOperator
from curvlinops.examples.functorch import functorch_ggn
from curvlinops.examples.utils import report_nonclose

MAX_REPEATS_MC_SAMPLES = [(10_000, 1), (100, 100)]
MAX_REPEATS_MC_SAMPLES_IDS = [
Expand All @@ -21,102 +18,29 @@
@mark.parametrize(
"max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS
)
def test_LinearOperator_matvec_expectation(
case, adjoint: bool, max_repeats: int, mc_samples: int
def test_FisherMCLinearOperator_expectation(
case, adjoint: bool, is_vec: bool, max_repeats: int, mc_samples: int
):
model_func, loss_func, params, data, batch_size_fn = case

# Test when X is dict-like but batch_size_fn = None (default)
if isinstance(data[0][0], MutableMapping):
with raises(ValueError):
F = FisherMCLinearOperator(
model_func,
loss_func,
params,
data,
mc_samples=mc_samples,
)

F_torch = FisherMCLinearOperator(
model_func,
loss_func,
params,
data,
batch_size_fn=batch_size_fn,
mc_samples=mc_samples,
)
F = F_torch.to_scipy()
G_functorch = (
functorch_ggn(model_func, loss_func, params, data, input_key="x")
.detach()
.cpu()
.numpy()
)
if adjoint:
F, G_functorch = F.adjoint(), G_functorch.conj().T

x = random.rand(F.shape[1]).astype(F.dtype)
Gx = G_functorch @ x

Fx = zeros_like(x)
atol = 5e-3 * max(abs(Gx))
rtol = 1e-1

for m in range(max_repeats):
Fx += F @ x
F_torch._seed += 1
"""Test matrix-matrix multiplication with the Monte-Carlo Fisher.
total_samples = (m + 1) * mc_samples
if total_samples % CHECK_EVERY == 0:
with redirect_stdout(None), suppress(ValueError):
report_nonclose(Fx / (m + 1), Gx, rtol=rtol, atol=atol)
return

report_nonclose(Fx / max_repeats, Gx, rtol=rtol, atol=atol)


@mark.montecarlo
@mark.parametrize(
"max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS
)
def test_LinearOperator_matmat_expectation(
case, adjoint: bool, max_repeats: int, mc_samples: int, num_vecs: int = 2
):
Args:
case: Tuple of model, loss function, parameters, data, and batch size getter.
adjoint: Whether to test the adjoint operator.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
"""
model_func, loss_func, params, data, batch_size_fn = case

F_torch = FisherMCLinearOperator(
F = FisherMCLinearOperator(
model_func,
loss_func,
params,
data,
batch_size_fn=batch_size_fn,
mc_samples=mc_samples,
)
F = F_torch.to_scipy()
G_functorch = (
functorch_ggn(model_func, loss_func, params, data, input_key="x")
.detach()
.cpu()
.numpy()
)
if adjoint:
F, G_functorch = F.adjoint(), G_functorch.conj().T

X = random.rand(F.shape[1], num_vecs).astype(F.dtype)
GX = G_functorch @ X

FX = zeros_like(X)
atol = 5e-3 * max(abs(GX.flatten()))
rtol = 1.5e-1
G_mat = functorch_ggn(model_func, loss_func, params, data, input_key="x")

for m in range(max_repeats):
FX += F @ X
F_torch._seed += 1

total_samples = (m + 1) * mc_samples
if total_samples % CHECK_EVERY == 0:
with redirect_stdout(None), suppress(ValueError):
report_nonclose(FX / (m + 1), GX, rtol=rtol, atol=atol)
return

report_nonclose(FX / max_repeats, GX, rtol=rtol, atol=atol)
rtol = 1e-1 if is_vec else 1.5e-1
compare_matmat_expectation(
F, G_mat, adjoint, is_vec, max_repeats, CHECK_EVERY, rtol=rtol, atol=5e-3
)
60 changes: 60 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions to test ``curvlinops``."""

from collections.abc import MutableMapping
from contextlib import redirect_stdout, suppress
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple, Union

Expand All @@ -18,6 +19,7 @@
from_numpy,
rand,
randint,
zeros_like,
)
from torch.nn import (
AdaptiveAvgPool2d,
Expand All @@ -35,6 +37,7 @@

from curvlinops import GGNLinearOperator
from curvlinops._torch_base import PyTorchLinearOperator
from curvlinops.fisher import FisherMCLinearOperator
from curvlinops.utils import allclose_report


Expand Down Expand Up @@ -500,3 +503,60 @@ def compare_matmat(
assert len(op_x) == len(mat_x)
for o_x, m_x in zip(op_x, mat_x):
assert allclose_report(o_x, m_x, **tol)


def compare_matmat_expectation(
op: FisherMCLinearOperator,
mat: Tensor,
adjoint: bool,
is_vec: bool,
max_repeats: int,
check_every: int,
num_vecs: int = 2,
rtol: float = 1e-5,
atol: float = 1e-8,
):
"""Test the matrix-vector product of a PyTorch linear operator in expectation.
Args:
op: The operator to test.
mat: The matrix representation of the linear operator.
adjoint: Whether to test the adjoint operator.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
max_repeats: Maximum number of matrix-vector product within which the
expectation must converge.
check_every: Check the expectation every ``check_every`` iterations for
convergence.
num_vecs: Number of vectors to test (ignored if ``is_vec`` is ``True``).
Default: ``2``.
rtol: Relative tolerance for the comparison. Default: ``1e-5``.
atol: Absolute tolerance for the comparison. Will be multiplied by the maximum
absolute value of the ground truth. Default: ``1e-8``.
"""
if adjoint:
op, mat = op.adjoint(), mat.conj().T

num_vecs = 1 if is_vec else num_vecs
dt = op._infer_dtype()
dev = op._infer_device()
_, x, _ = rand_accepted_formats(
[tuple(s) for s in op._in_shape], is_vec, dt, dev, num_vecs=num_vecs
)

op_x = zeros_like(x)
mat_x = mat @ x

atol *= mat_x.flatten().abs().max().item()
tol = {"atol": atol, "rtol": rtol}

for m in range(max_repeats):
op_x += op @ x
op._seed += 1

total_samples = (m + 1) * op._mc_samples
if total_samples % check_every == 0:
with redirect_stdout(None), suppress(ValueError), suppress(AssertionError):
assert allclose_report(op_x / (m + 1), mat_x, **tol)
return

assert allclose_report(op_x / max_repeats, mat_x, **tol)

0 comments on commit e678ea2

Please sign in to comment.