Skip to content

Commit

Permalink
[REF] Combine full and block-diagonal matrix multiply tests
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 22, 2024
1 parent 59c6369 commit 07144a3
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 40 deletions.
13 changes: 13 additions & 0 deletions test/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,16 @@ def data():
NON_DETERMINISTIC_CASES.append(case_with_device)

ADJOINT_CASES = [False, True]
ADJOINT_IDS = ["", "adjoint"]

IS_VECS = [False, True]
IS_VEC_IDS = ["matvec", "matmat"]


BLOCK_SIZES_FNS = {
"full": lambda params: None,
"per-parameter-blocks": lambda params: [1 for _ in range(len(params))],
"two-blocks": lambda params: (
[1] if len(params) == 1 else [len(params) // 2, len(params) - len(params) // 2]
),
}
35 changes: 33 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from collections.abc import MutableMapping
from test.cases import (
ADJOINT_CASES,
ADJOINT_IDS,
BLOCK_SIZES_FNS,
CASES,
CNN_CASES,
INV_CASES,
IS_VEC_IDS,
IS_VECS,
NON_DETERMINISTIC_CASES,
)
from test.kfac_cases import (
Expand All @@ -21,7 +25,7 @@
from numpy import random
from pytest import fixture
from torch import Tensor, manual_seed
from torch.nn import Module, MSELoss
from torch.nn import Module, MSELoss, Parameter


def initialize_case(
Expand Down Expand Up @@ -110,11 +114,38 @@ def non_deterministic_case(
yield initialize_case(case)


@fixture(params=ADJOINT_CASES)
@fixture(params=ADJOINT_CASES, ids=ADJOINT_IDS)
def adjoint(request) -> bool:
return request.param


@fixture(params=IS_VECS, ids=IS_VEC_IDS)
def is_vec(request) -> bool:
"""Whether to test matrix-vector or matrix-matrix multiplication.
Args:
request: Pytest request object.
Returns:
``True`` if the test is for matrix-vector multiplication, ``False`` otherwise.
"""
return request.param


@fixture(params=BLOCK_SIZES_FNS.values(), ids=BLOCK_SIZES_FNS.keys())
def block_sizes_fn(request) -> Callable[[List[Parameter]], Optional[List[int]]]:
"""Function to generate the ``block_sizes`` argument for a linear operator.
Args:
request: Pytest request object.
Returns:
A function that generates the block sizes for a linear operator from the
parameters.
"""
return request.param


@fixture(params=KFAC_EXACT_CASES)
def kfac_exact_case(
request,
Expand Down
55 changes: 17 additions & 38 deletions test/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,40 @@

from collections.abc import MutableMapping
from test.utils import compare_matmat
from typing import Callable, List, Optional

from pytest import mark, raises
from pytest import raises
from torch import block_diag
from torch.nn import Parameter

from curvlinops import HessianLinearOperator
from curvlinops.examples.functorch import functorch_hessian
from curvlinops.utils import split_list


@mark.parametrize("is_vec", [True, False], ids=["matvec", "matmat"])
def test_HessianLinearOperator(case, adjoint: bool, is_vec: bool):
def test_HessianLinearOperator(
case,
adjoint: bool,
is_vec: bool,
block_sizes_fn: Callable[[List[Parameter]], Optional[List[int]]],
):
"""Test matrix-matrix multiplication with the Hessian.
Args:
case: Tuple of model, loss function, parameters, data, and batch size getter.
adjoint: Whether to test the adjoint operator.
block_sizes_fn: The function that generates the block sizes used to define
block diagonal approximations from the parameters.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
"""
model_func, loss_func, params, data, batch_size_fn = case
block_sizes = block_sizes_fn(params)

# Test when X is dict-like but batch_size_fn = None (default)
if isinstance(data[0][0], MutableMapping):
with raises(ValueError):
_ = HessianLinearOperator(model_func, loss_func, params, data)

H = HessianLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
H_mat = functorch_hessian(model_func, loss_func, params, data, input_key="x")

compare_matmat(H, H_mat, adjoint, is_vec, rtol=1e-4, atol=1e-6)


BLOCKING_FNS = {
"per-parameter": lambda params: [1 for _ in range(len(params))],
"two-blocks": lambda params: (
[1] if len(params) == 1 else [len(params) // 2, len(params) - len(params) // 2]
),
}


@mark.parametrize("blocking", BLOCKING_FNS.keys(), ids=BLOCKING_FNS.keys())
@mark.parametrize("is_vec", [True, False], ids=["matvec", "matmat"])
def test_blocked_HessianLinearOperator(
case, adjoint: bool, blocking: str, is_vec: bool
):
"""Test matrix-matrix multiplication with the block-diagonal Hessian.
Args:
case: Tuple of model, loss function, parameters, data and batch size getter.
adjoint: Whether to test the adjoint operator.
blocking: Blocking scheme.
is_vec: Whether to test matrix-vector or matrix-matrix multiplication.
"""
model_func, loss_func, params, data, batch_size_fn = case
block_sizes = BLOCKING_FNS[blocking](params)

H = HessianLinearOperator(
model_func,
loss_func,
Expand All @@ -69,10 +46,12 @@ def test_blocked_HessianLinearOperator(
)

# compute the blocks with functorch and build the block diagonal matrix
H_mat = [
H_blocks = [
functorch_hessian(model_func, loss_func, params_block, data, input_key="x")
for params_block in split_list(params, block_sizes)
for params_block in split_list(
params, [len(params)] if block_sizes is None else block_sizes
)
]
H_mat = block_diag(*H_mat)
H_mat = block_diag(*H_blocks)

compare_matmat(H, H_mat, adjoint, is_vec, rtol=1e-4, atol=1e-6)

0 comments on commit 07144a3

Please sign in to comment.