From 07144a38a00bcc519dcda74a54300a8b14627d4e Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Sun, 22 Sep 2024 12:12:36 -0400 Subject: [PATCH] [REF] Combine full and block-diagonal matrix multiply tests --- test/cases.py | 13 +++++++++++ test/conftest.py | 35 ++++++++++++++++++++++++++-- test/test_hessian.py | 55 ++++++++++++++------------------------------ 3 files changed, 63 insertions(+), 40 deletions(-) diff --git a/test/cases.py b/test/cases.py index 5f6708b..fca664f 100644 --- a/test/cases.py +++ b/test/cases.py @@ -371,3 +371,16 @@ def data(): NON_DETERMINISTIC_CASES.append(case_with_device) ADJOINT_CASES = [False, True] +ADJOINT_IDS = ["", "adjoint"] + +IS_VECS = [False, True] +IS_VEC_IDS = ["matvec", "matmat"] + + +BLOCK_SIZES_FNS = { + "full": lambda params: None, + "per-parameter-blocks": lambda params: [1 for _ in range(len(params))], + "two-blocks": lambda params: ( + [1] if len(params) == 1 else [len(params) // 2, len(params) - len(params) // 2] + ), +} diff --git a/test/conftest.py b/test/conftest.py index f14b9b6..2512213 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,9 +4,13 @@ from collections.abc import MutableMapping from test.cases import ( ADJOINT_CASES, + ADJOINT_IDS, + BLOCK_SIZES_FNS, CASES, CNN_CASES, INV_CASES, + IS_VEC_IDS, + IS_VECS, NON_DETERMINISTIC_CASES, ) from test.kfac_cases import ( @@ -21,7 +25,7 @@ from numpy import random from pytest import fixture from torch import Tensor, manual_seed -from torch.nn import Module, MSELoss +from torch.nn import Module, MSELoss, Parameter def initialize_case( @@ -110,11 +114,38 @@ def non_deterministic_case( yield initialize_case(case) -@fixture(params=ADJOINT_CASES) +@fixture(params=ADJOINT_CASES, ids=ADJOINT_IDS) def adjoint(request) -> bool: return request.param +@fixture(params=IS_VECS, ids=IS_VEC_IDS) +def is_vec(request) -> bool: + """Whether to test matrix-vector or matrix-matrix multiplication. + + Args: + request: Pytest request object. + + Returns: + ``True`` if the test is for matrix-vector multiplication, ``False`` otherwise. + """ + return request.param + + +@fixture(params=BLOCK_SIZES_FNS.values(), ids=BLOCK_SIZES_FNS.keys()) +def block_sizes_fn(request) -> Callable[[List[Parameter]], Optional[List[int]]]: + """Function to generate the ``block_sizes`` argument for a linear operator. + + Args: + request: Pytest request object. + + Returns: + A function that generates the block sizes for a linear operator from the + parameters. + """ + return request.param + + @fixture(params=KFAC_EXACT_CASES) def kfac_exact_case( request, diff --git a/test/test_hessian.py b/test/test_hessian.py index 5b6aecd..0727a4d 100644 --- a/test/test_hessian.py +++ b/test/test_hessian.py @@ -2,63 +2,40 @@ from collections.abc import MutableMapping from test.utils import compare_matmat +from typing import Callable, List, Optional -from pytest import mark, raises +from pytest import raises from torch import block_diag +from torch.nn import Parameter from curvlinops import HessianLinearOperator from curvlinops.examples.functorch import functorch_hessian from curvlinops.utils import split_list -@mark.parametrize("is_vec", [True, False], ids=["matvec", "matmat"]) -def test_HessianLinearOperator(case, adjoint: bool, is_vec: bool): +def test_HessianLinearOperator( + case, + adjoint: bool, + is_vec: bool, + block_sizes_fn: Callable[[List[Parameter]], Optional[List[int]]], +): """Test matrix-matrix multiplication with the Hessian. Args: case: Tuple of model, loss function, parameters, data, and batch size getter. adjoint: Whether to test the adjoint operator. + block_sizes_fn: The function that generates the block sizes used to define + block diagonal approximations from the parameters. is_vec: Whether to test matrix-vector or matrix-matrix multiplication. """ model_func, loss_func, params, data, batch_size_fn = case + block_sizes = block_sizes_fn(params) # Test when X is dict-like but batch_size_fn = None (default) if isinstance(data[0][0], MutableMapping): with raises(ValueError): _ = HessianLinearOperator(model_func, loss_func, params, data) - H = HessianLinearOperator( - model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) - H_mat = functorch_hessian(model_func, loss_func, params, data, input_key="x") - - compare_matmat(H, H_mat, adjoint, is_vec, rtol=1e-4, atol=1e-6) - - -BLOCKING_FNS = { - "per-parameter": lambda params: [1 for _ in range(len(params))], - "two-blocks": lambda params: ( - [1] if len(params) == 1 else [len(params) // 2, len(params) - len(params) // 2] - ), -} - - -@mark.parametrize("blocking", BLOCKING_FNS.keys(), ids=BLOCKING_FNS.keys()) -@mark.parametrize("is_vec", [True, False], ids=["matvec", "matmat"]) -def test_blocked_HessianLinearOperator( - case, adjoint: bool, blocking: str, is_vec: bool -): - """Test matrix-matrix multiplication with the block-diagonal Hessian. - - Args: - case: Tuple of model, loss function, parameters, data and batch size getter. - adjoint: Whether to test the adjoint operator. - blocking: Blocking scheme. - is_vec: Whether to test matrix-vector or matrix-matrix multiplication. - """ - model_func, loss_func, params, data, batch_size_fn = case - block_sizes = BLOCKING_FNS[blocking](params) - H = HessianLinearOperator( model_func, loss_func, @@ -69,10 +46,12 @@ def test_blocked_HessianLinearOperator( ) # compute the blocks with functorch and build the block diagonal matrix - H_mat = [ + H_blocks = [ functorch_hessian(model_func, loss_func, params_block, data, input_key="x") - for params_block in split_list(params, block_sizes) + for params_block in split_list( + params, [len(params)] if block_sizes is None else block_sizes + ) ] - H_mat = block_diag(*H_mat) + H_mat = block_diag(*H_blocks) compare_matmat(H, H_mat, adjoint, is_vec, rtol=1e-4, atol=1e-6)