diff --git a/curvlinops/hessian.py b/curvlinops/hessian.py index 4f45387..e64abdd 100644 --- a/curvlinops/hessian.py +++ b/curvlinops/hessian.py @@ -1,20 +1,20 @@ -"""Contains LinearOperator implementation of the Hessian.""" +"""Contains a linear operator implementation of the Hessian.""" from __future__ import annotations from collections.abc import MutableMapping -from typing import List, Tuple, Union +from typing import List, Union from backpack.hessianfree.hvp import hessian_vector_product from torch import Tensor, zeros_like from torch.autograd import grad -from curvlinops._base import _LinearOperator +from curvlinops._torch_base import CurvatureLinearOperator from curvlinops.utils import split_list -class HessianLinearOperator(_LinearOperator): - r"""Hessian as SciPy linear operator. +class HessianLinearOperator(CurvatureLinearOperator): + r"""Linear operator for the Hessian of an empirical risk. Consider the empirical risk @@ -42,45 +42,45 @@ class HessianLinearOperator(_LinearOperator): SUPPORTS_BLOCKS: bool = True def _matmat_batch( - self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor] - ) -> Tuple[Tensor, ...]: + self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] + ) -> List[Tensor]: """Apply the mini-batch Hessian to a matrix. Args: X: Input to the DNN. y: Ground truth. - M_list: Matrix to be multiplied with in list format. + M: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an - additional leading axis for the matrix columns. + additional trailing axis for the matrix columns. Returns: Result of Hessian multiplication in list format. Has the same shape as - ``M_list``, i.e. each tensor in the list has the shape of a parameter and a - leading dimension of matrix columns. + ``M``, i.e. each tensor in the list has the shape of a parameter and a + trailing dimension of matrix columns. """ loss = self._loss_func(self._model_func(X), y) # Re-cycle first backward pass from the HVP's double-backward grad_params = grad(loss, self._params, create_graph=True) - num_vecs = M_list[0].shape[0] - result = [zeros_like(M) for M in M_list] + (num_vecs,) = {m.shape[-1] for m in M} + AM = [zeros_like(m) for m in M] # per-block HMP - for M_block, p_block, g_block, res_block in zip( - split_list(M_list, self._block_sizes), + for M_block, p_block, g_block, AM_block in zip( + split_list(M, self._block_sizes), split_list(self._params, self._block_sizes), split_list(grad_params, self._block_sizes), - split_list(result, self._block_sizes), + split_list(AM, self._block_sizes), ): for n in range(num_vecs): col_n = hessian_vector_product( - loss, p_block, [M[n] for M in M_block], grad_params=g_block + loss, p_block, [M[..., n] for M in M_block], grad_params=g_block ) for p, col in enumerate(col_n): - res_block[p][n].add_(col) + AM_block[p][..., n].add_(col) - return tuple(result) + return AM def _adjoint(self) -> HessianLinearOperator: """Return the linear operator representing the adjoint. diff --git a/curvlinops/utils.py b/curvlinops/utils.py index ec414cc..698f236 100644 --- a/curvlinops/utils.py +++ b/curvlinops/utils.py @@ -1,16 +1,16 @@ """General utility functions.""" -from typing import List +from typing import List, Tuple, Union from numpy import cumsum from torch import Tensor -def split_list(x: List, sizes: List[int]) -> List[List]: +def split_list(x: Union[List, Tuple], sizes: List[int]) -> List[List]: """Split a list into multiple lists of specified size. Args: - x: List to be split. + x: List or tuple to be split. sizes: Sizes of the resulting lists. Returns: @@ -25,7 +25,7 @@ def split_list(x: List, sizes: List[int]) -> List[List]: + f" of {sum(sizes)} entries." ) boundaries = cumsum([0] + sizes) - return [x[boundaries[i] : boundaries[i + 1]] for i in range(len(sizes))] + return [list(x[boundaries[i] : boundaries[i + 1]]) for i in range(len(sizes))] def allclose_report( @@ -50,6 +50,12 @@ def allclose_report( tensor1[nonclose_idx].flatten(), tensor2[nonclose_idx].flatten(), ): - print(f"at index {idx}: {t1:.5e} ≠ {t2:.5e}, ratio: {t1 / t2:.5e}") + print(f"at index {idx.tolist()}: {t1:.5e} ≠ {t2:.5e}, ratio: {t1 / t2:.5e}") + + # print largest and smallest absolute entries + amax1, amax2 = tensor1.abs().max().item(), tensor2.abs().max().item() + print(f"Abs max: {amax1:.5e} vs. {amax2:.5e}.") + amin1, amin2 = tensor1.abs().min().item(), tensor2.abs().min().item() + print(f"Abs min: {amin1:.5e} vs. {amin2:.5e}.") return close diff --git a/docs/examples/basic_usage/example_eigenvalues.py b/docs/examples/basic_usage/example_eigenvalues.py index aa8cf60..5fdbccd 100644 --- a/docs/examples/basic_usage/example_eigenvalues.py +++ b/docs/examples/basic_usage/example_eigenvalues.py @@ -63,7 +63,7 @@ # We are ready to setup the linear operator. In this example, we will use the Hessian. data = [(X1, y1), (X2, y2)] -H = HessianLinearOperator(model, loss_function, params, data) +H = HessianLinearOperator(model, loss_function, params, data).to_scipy() # %% # @@ -204,7 +204,9 @@ def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarra # the linear operator's progress bar, which allows us to count the number of # matrix-vector products invoked by both eigen-solvers: -H = HessianLinearOperator(model, loss_function, params, data, progressbar=True) +H = HessianLinearOperator( + model, loss_function, params, data, progressbar=True +).to_scipy() # determine number of matrix-vector products used by `eigsh` with StringIO() as buf, redirect_stderr(buf): @@ -216,7 +218,7 @@ def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarra # determine number of matrix-vector products used by power iteration with StringIO() as buf, redirect_stderr(buf): - top_k_evals_power, _ = power_method(H, k=k) + top_k_evals_power, _ = power_method(H, k=k, tol=1e-4) # The tqdm progressbar will print "matmat" for each batch in a matrix-vector # product. Therefore, we need to divide by the number of batches queries_power = buf.getvalue().count("matmat") // len(data) @@ -228,9 +230,7 @@ def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarra # # Sadly, the power iteration also does not offer computational benefits, consuming # more matrix-vector products than :code:`eigsh`. While it is elegant and simple, -# it cannot compete with :code:`eigsh`, at least in the comparison provided here -# (note that we used a relative small tolerance for the power iteration, and it will -# likely deteriorate further if we decrease the tolerance). +# it cannot compete with :code:`eigsh`, at least in the comparison provided here. # # Therefore, we recommend using :code:`eigsh` for computing eigenvalues. This method # becomes accessible because :code:`curvlinops` interfaces with SciPy's linear diff --git a/docs/examples/basic_usage/example_matrix_vector_products.py b/docs/examples/basic_usage/example_matrix_vector_products.py index 5dba016..1b3a96e 100644 --- a/docs/examples/basic_usage/example_matrix_vector_products.py +++ b/docs/examples/basic_usage/example_matrix_vector_products.py @@ -55,7 +55,7 @@ # Setting up a linear operator for the Hessian is straightforward. data = [(X, y)] -H = HessianLinearOperator(model, loss_function, params, data) +H = HessianLinearOperator(model, loss_function, params, data).to_scipy() # %% # diff --git a/docs/examples/basic_usage/example_submatrices.py b/docs/examples/basic_usage/example_submatrices.py index da01a30..d7380be 100644 --- a/docs/examples/basic_usage/example_submatrices.py +++ b/docs/examples/basic_usage/example_submatrices.py @@ -73,7 +73,7 @@ # its matrix representation through multiplication with the identity matrix, # followed by comparison to the Hessian matrix computed via :mod:`functorch`. -H = HessianLinearOperator(model, loss_function, params, data) +H = HessianLinearOperator(model, loss_function, params, data).to_scipy() report_nonclose(H_functorch, H @ numpy.eye(H.shape[1])) @@ -122,7 +122,7 @@ def extract_block( # We can build a linear operator for this sub-Hessian by only providing the # first layer's weight as parameter: -H_param0 = HessianLinearOperator(model, loss_function, [params[i]], data) +H_param0 = HessianLinearOperator(model, loss_function, [params[i]], data).to_scipy() # %% # diff --git a/docs/examples/basic_usage/example_visual_tour.py b/docs/examples/basic_usage/example_visual_tour.py index 2fe8576..af1ec76 100644 --- a/docs/examples/basic_usage/example_visual_tour.py +++ b/docs/examples/basic_usage/example_visual_tour.py @@ -85,7 +85,9 @@ # # First, create the linear operators: -Hessian_linop = HessianLinearOperator(model, loss_function, params, dataloader) +Hessian_linop = HessianLinearOperator( + model, loss_function, params, dataloader +).to_scipy() GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader) EF_linop = EFLinearOperator(model, loss_function, params, dataloader) Hessian_blocked_linop = HessianLinearOperator( @@ -94,7 +96,7 @@ params, dataloader, block_sizes=[s for s in num_tensors_layer if s != 0], -) +).to_scipy() F_linop = FisherMCLinearOperator(model, loss_function, params, dataloader) KFAC_linop = KFACLinearOperator( model, loss_function, params, dataloader, separate_weight_and_bias=False diff --git a/test/cases.py b/test/cases.py index 5f6708b..fca664f 100644 --- a/test/cases.py +++ b/test/cases.py @@ -371,3 +371,16 @@ def data(): NON_DETERMINISTIC_CASES.append(case_with_device) ADJOINT_CASES = [False, True] +ADJOINT_IDS = ["", "adjoint"] + +IS_VECS = [False, True] +IS_VEC_IDS = ["matvec", "matmat"] + + +BLOCK_SIZES_FNS = { + "full": lambda params: None, + "per-parameter-blocks": lambda params: [1 for _ in range(len(params))], + "two-blocks": lambda params: ( + [1] if len(params) == 1 else [len(params) // 2, len(params) - len(params) // 2] + ), +} diff --git a/test/conftest.py b/test/conftest.py index f14b9b6..2512213 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,9 +4,13 @@ from collections.abc import MutableMapping from test.cases import ( ADJOINT_CASES, + ADJOINT_IDS, + BLOCK_SIZES_FNS, CASES, CNN_CASES, INV_CASES, + IS_VEC_IDS, + IS_VECS, NON_DETERMINISTIC_CASES, ) from test.kfac_cases import ( @@ -21,7 +25,7 @@ from numpy import random from pytest import fixture from torch import Tensor, manual_seed -from torch.nn import Module, MSELoss +from torch.nn import Module, MSELoss, Parameter def initialize_case( @@ -110,11 +114,38 @@ def non_deterministic_case( yield initialize_case(case) -@fixture(params=ADJOINT_CASES) +@fixture(params=ADJOINT_CASES, ids=ADJOINT_IDS) def adjoint(request) -> bool: return request.param +@fixture(params=IS_VECS, ids=IS_VEC_IDS) +def is_vec(request) -> bool: + """Whether to test matrix-vector or matrix-matrix multiplication. + + Args: + request: Pytest request object. + + Returns: + ``True`` if the test is for matrix-vector multiplication, ``False`` otherwise. + """ + return request.param + + +@fixture(params=BLOCK_SIZES_FNS.values(), ids=BLOCK_SIZES_FNS.keys()) +def block_sizes_fn(request) -> Callable[[List[Parameter]], Optional[List[int]]]: + """Function to generate the ``block_sizes`` argument for a linear operator. + + Args: + request: Pytest request object. + + Returns: + A function that generates the block sizes for a linear operator from the + parameters. + """ + return request.param + + @fixture(params=KFAC_EXACT_CASES) def kfac_exact_case( request, diff --git a/test/test_hessian.py b/test/test_hessian.py index 9bea4f0..0727a4d 100644 --- a/test/test_hessian.py +++ b/test/test_hessian.py @@ -1,84 +1,42 @@ """Contains tests for ``curvlinops/hessian``.""" from collections.abc import MutableMapping +from test.utils import compare_matmat +from typing import Callable, List, Optional -from numpy import random -from pytest import mark, raises +from pytest import raises from torch import block_diag +from torch.nn import Parameter from curvlinops import HessianLinearOperator from curvlinops.examples.functorch import functorch_hessian -from curvlinops.examples.utils import report_nonclose from curvlinops.utils import split_list -def test_HessianLinearOperator_matvec(case, adjoint: bool): - model_func, loss_func, params, data, batch_size_fn = case - - # Test when X is dict-like but batch_size_fn = None (default) - if isinstance(data[0][0], MutableMapping): - with raises(ValueError): - op = HessianLinearOperator(model_func, loss_func, params, data) - - op = HessianLinearOperator( - model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) - op_functorch = ( - functorch_hessian(model_func, loss_func, params, data, input_key="x") - .detach() - .cpu() - .numpy() - ) - if adjoint: - op, op_functorch = op.adjoint(), op_functorch.conj().T - - x = random.rand(op.shape[1]) - report_nonclose(op @ x, op_functorch @ x, atol=1e-7) - - -def test_HessianLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): - model_func, loss_func, params, data, batch_size_fn = case - - op = HessianLinearOperator( - model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) - op_functorch = ( - functorch_hessian(model_func, loss_func, params, data, input_key="x") - .detach() - .cpu() - .numpy() - ) - if adjoint: - op, op_functorch = op.adjoint(), op_functorch.conj().T - - X = random.rand(op.shape[1], num_vecs) - report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=5e-4) - - -BLOCKING_FNS = { - "per-parameter": lambda params: [1 for _ in range(len(params))], - "two-blocks": lambda params: ( - [1] if len(params) == 1 else [len(params) // 2, len(params) - len(params) // 2] - ), -} - - -@mark.parametrize("blocking", BLOCKING_FNS.keys(), ids=BLOCKING_FNS.keys()) -def test_blocked_HessianLinearOperator_matmat( - case, adjoint: bool, blocking: str, num_vecs: int = 2 +def test_HessianLinearOperator( + case, + adjoint: bool, + is_vec: bool, + block_sizes_fn: Callable[[List[Parameter]], Optional[List[int]]], ): - """Test matrix-matrix multiplication with the block-diagonal Hessian. + """Test matrix-matrix multiplication with the Hessian. Args: - case: Tuple of model, loss function, parameters, and data. + case: Tuple of model, loss function, parameters, data, and batch size getter. adjoint: Whether to test the adjoint operator. - blocking: Blocking scheme. - num_vecs: Number of vectors to multiply with. Default is ``2``. + block_sizes_fn: The function that generates the block sizes used to define + block diagonal approximations from the parameters. + is_vec: Whether to test matrix-vector or matrix-matrix multiplication. """ model_func, loss_func, params, data, batch_size_fn = case - block_sizes = BLOCKING_FNS[blocking](params) + block_sizes = block_sizes_fn(params) + + # Test when X is dict-like but batch_size_fn = None (default) + if isinstance(data[0][0], MutableMapping): + with raises(ValueError): + _ = HessianLinearOperator(model_func, loss_func, params, data) - op = HessianLinearOperator( + H = HessianLinearOperator( model_func, loss_func, params, @@ -88,16 +46,12 @@ def test_blocked_HessianLinearOperator_matmat( ) # compute the blocks with functorch and build the block diagonal matrix - op_functorch = [ - functorch_hessian( - model_func, loss_func, params_block, data, input_key="x" - ).detach() - for params_block in split_list(params, block_sizes) + H_blocks = [ + functorch_hessian(model_func, loss_func, params_block, data, input_key="x") + for params_block in split_list( + params, [len(params)] if block_sizes is None else block_sizes + ) ] - op_functorch = block_diag(*op_functorch).cpu().numpy() - - if adjoint: - op, op_functorch = op.adjoint(), op_functorch.conj().T + H_mat = block_diag(*H_blocks) - X = random.rand(op.shape[1], num_vecs) - report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=5e-4) + compare_matmat(H, H_mat, adjoint, is_vec, rtol=1e-4, atol=1e-6) diff --git a/test/test_submatrix_on_curvatures.py b/test/test_submatrix_on_curvatures.py index 2fdc53e..f234a91 100644 --- a/test/test_submatrix_on_curvatures.py +++ b/test/test_submatrix_on_curvatures.py @@ -64,6 +64,8 @@ def setup_submatrix_linear_operator(case, operator_case, submatrix_case): col_idxs = submatrix_case["col_idx_fn"](dim) A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn) + if isinstance(A, HessianLinearOperator): + A = A.to_scipy() A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs) A_functorch = CURVATURE_IN_FUNCTORCH[operator_case]( diff --git a/test/utils.py b/test/utils.py index dcc656b..72d7a68 100644 --- a/test/utils.py +++ b/test/utils.py @@ -34,6 +34,8 @@ ) from curvlinops import GGNLinearOperator +from curvlinops._torch_base import PyTorchLinearOperator +from curvlinops.utils import allclose_report def get_available_devices() -> List[device]: @@ -404,3 +406,95 @@ def compare_state_dicts(state_dict: dict, state_dict_new: dict): ) else: assert value == value_new + + +def rand_accepted_formats( + shapes: List[Tuple[int, ...]], + is_vec: bool, + dtype: dtype, + device: device, + num_vecs: int = 1, +) -> Tuple[List[Tensor], Tensor, ndarray]: + """Generate a random vector/matrix in all accepted formats. + + Args: + shapes: Sizes of the tensor product space. + is_vec: Whether to generate representations of a vector or a matrix. + dtype: Data type of the generated tensors. + device: Device of the generated tensors. + num_vecs: Number of vectors to generate. Ignored if ``is_vec`` is ``False``. + Default: ``1``. + + Returns: + M_tensor_list: Random vector/matrix in tensor list format. + M_tensor: Random vector/matrix in tensor format. + M_ndarray: Random vector/matrix in numpy format. + """ + M_tensor_list = [ + rand(*shape, num_vecs, dtype=dtype, device=device) for shape in shapes + ] + M_tensor = cat([M.flatten(end_dim=-2) for M in M_tensor_list]) + + if is_vec: + M_tensor_list = [M.squeeze(-1) for M in M_tensor_list] + M_tensor.squeeze(-1) + + M_ndarray = M_tensor.cpu().numpy() + + return M_tensor_list, M_tensor, M_ndarray + + +def compare_matmat( + op: PyTorchLinearOperator, + mat: Tensor, + adjoint: bool, + is_vec: bool, + num_vecs: int = 2, + rtol: float = 1e-5, + atol: float = 1e-8, +): + """Test the matrix-vector product of a PyTorch linear operator. + + Try all accepted formats for the input, as well as the SciPy-exported operator. + + Args: + op: The operator to test. + mat: The matrix representation of the linear operator. + adjoint: Whether to test the adjoint operator. + is_vec: Whether to test matrix-vector or matrix-matrix multiplication. + num_vecs: Number of vectors to test (ignored if ``is_vec`` is ``True``). + Default: ``2``. + rtol: Relative tolerance for the comparison. Default: ``1e-5``. + atol: Absolute tolerance for the comparison. Default: ``1e-8``. + """ + if adjoint: + op, mat = op.adjoint(), mat.conj().T + + num_vecs = 1 if is_vec else num_vecs + dt = op._infer_dtype() + dev = op._infer_device() + x_list, x_tensor, x_numpy = rand_accepted_formats( + [tuple(s) for s in op._in_shape], is_vec, dt, dev, num_vecs=num_vecs + ) + + tol = {"atol": atol, "rtol": rtol} + + # input in tensor format + mat_x = mat @ x_tensor + assert allclose_report(op @ x_tensor, mat_x, **tol) + + # input in numpy format + op_scipy = op.to_scipy() + op_x = op_scipy @ x_numpy + assert type(op_x) is ndarray + assert allclose_report(from_numpy(op_x).to(dev), mat_x, **tol) + + # input in tensor list format + mat_x = [ + m_x.reshape(s if is_vec else (*s, num_vecs)) + for m_x, s in zip(mat_x.split(op._out_shape_flat), op._out_shape) + ] + op_x = op @ x_list + assert len(op_x) == len(mat_x) + for o_x, m_x in zip(op_x, mat_x): + assert allclose_report(o_x, m_x, **tol)