diff --git a/curvlinops/ggn.py b/curvlinops/ggn.py index 6c6231d..4c588ff 100644 --- a/curvlinops/ggn.py +++ b/curvlinops/ggn.py @@ -8,11 +8,11 @@ from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist from torch import Tensor, zeros_like -from curvlinops._base import _LinearOperator +from curvlinops._torch_base import CurvatureLinearOperator -class GGNLinearOperator(_LinearOperator): - r"""GGN as SciPy linear operator. +class GGNLinearOperator(CurvatureLinearOperator): + r"""Linear operator for the generalized Gauss-Newton matrix of an empirical risk. Consider the empirical risk @@ -39,47 +39,42 @@ class GGNLinearOperator(_LinearOperator): \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}_n) \right)\,. + + Attributes: + SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for GGNs. """ + SELF_ADJOINT: bool = True + def _matmat_batch( - self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor] - ) -> Tuple[Tensor, ...]: + self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] + ) -> List[Tensor]: """Apply the mini-batch GGN to a matrix. Args: X: Input to the DNN. y: Ground truth. - M_list: Matrix to be multiplied with in list format. + M: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an - additional leading axis for the matrix columns. + additional trailing axis for the matrix columns. Returns: Result of GGN multiplication in list format. Has the same shape as - ``M_list``, i.e. each tensor in the list has the shape of a parameter and a - leading dimension of matrix columns. + ``M_``, i.e. each tensor in the list has the shape of a parameter and a + trailing dimension of matrix columns. """ output = self._model_func(X) loss = self._loss_func(output, y) # collect matrix-matrix products per parameter - result_list = [zeros_like(M) for M in M_list] + (num_vecs,) = {m.shape[-1] for m in M} + GM = [zeros_like(m) for m in M] - num_vecs = M_list[0].shape[0] for n in range(num_vecs): - col_n_list = ggn_vector_product_from_plist( - loss, output, self._params, [M[n] for M in M_list] + col_n = ggn_vector_product_from_plist( + loss, output, self._params, [m[..., n] for m in M] ) - for result, col_n in zip(result_list, col_n_list): - result[n].add_(col_n) - - return tuple(result_list) - - def _adjoint(self) -> GGNLinearOperator: - """Return the linear operator representing the adjoint. + for GM_p, col_n_p in zip(GM, col_n): + GM_p[..., n].add_(col_n_p) - The GGN is real symmetric, and hence self-adjoint. - - Returns: - Self. - """ - return self + return GM diff --git a/curvlinops/hessian.py b/curvlinops/hessian.py index 2bae30b..0f93533 100644 --- a/curvlinops/hessian.py +++ b/curvlinops/hessian.py @@ -65,20 +65,20 @@ def _matmat_batch( grad_params = grad(loss, self._params, create_graph=True) (num_vecs,) = {m.shape[-1] for m in M} - AM = [zeros_like(m) for m in M] + HM = [zeros_like(m) for m in M] # per-block HMP - for M_block, p_block, g_block, AM_block in zip( + for M_block, p_block, g_block, HM_block in zip( split_list(M, self._block_sizes), split_list(self._params, self._block_sizes), split_list(grad_params, self._block_sizes), - split_list(AM, self._block_sizes), + split_list(HM, self._block_sizes), ): for n in range(num_vecs): col_n = hessian_vector_product( loss, p_block, [M[..., n] for M in M_block], grad_params=g_block ) for p, col in enumerate(col_n): - AM_block[p][..., n].add_(col) + HM_block[p][..., n].add_(col) - return AM + return HM diff --git a/docs/examples/basic_usage/example_fisher_monte_carlo.py b/docs/examples/basic_usage/example_fisher_monte_carlo.py index 30acfef..4948dc1 100644 --- a/docs/examples/basic_usage/example_fisher_monte_carlo.py +++ b/docs/examples/basic_usage/example_fisher_monte_carlo.py @@ -78,7 +78,7 @@ # Fisher and compute their matrix representations by multiplying them onto the # identity matrix: -GGN = GGNLinearOperator(model, loss_function, params, data) +GGN = GGNLinearOperator(model, loss_function, params, data).to_scipy() F = FisherMCLinearOperator(model, loss_function, params, data) D = GGN.shape[0] diff --git a/docs/examples/basic_usage/example_huggingface.py b/docs/examples/basic_usage/example_huggingface.py index 0674dd1..a673a7f 100644 --- a/docs/examples/basic_usage/example_huggingface.py +++ b/docs/examples/basic_usage/example_huggingface.py @@ -158,7 +158,7 @@ def batch_size_fn(x: MutableMapping): [(data, data["labels"])], # We still need to input a list of "(X, y)" pairs! check_deterministic=False, batch_size_fn=batch_size_fn, # Remember to specify this! -) +).to_scipy() G = ggn @ np.eye(ggn.shape[0]) diff --git a/docs/examples/basic_usage/example_inverses.py b/docs/examples/basic_usage/example_inverses.py index b657713..74df335 100644 --- a/docs/examples/basic_usage/example_inverses.py +++ b/docs/examples/basic_usage/example_inverses.py @@ -80,7 +80,7 @@ # First, we set up a linear operator for the damped GGN/Fisher data = [(X1, y1), (X2, y2)] -GGN = GGNLinearOperator(model, loss_function, params, data) +GGN = GGNLinearOperator(model, loss_function, params, data).to_scipy() delta = 1e-2 damping = aslinearoperator(delta * sparse.eye(GGN.shape[0])) diff --git a/docs/examples/basic_usage/example_matrix_vector_products.py b/docs/examples/basic_usage/example_matrix_vector_products.py index 1b3a96e..b056171 100644 --- a/docs/examples/basic_usage/example_matrix_vector_products.py +++ b/docs/examples/basic_usage/example_matrix_vector_products.py @@ -116,7 +116,7 @@ # # Setting up a linear operator for the Fisher/GGN is identical to the Hessian. -GGN = GGNLinearOperator(model, loss_function, params, data) +GGN = GGNLinearOperator(model, loss_function, params, data).to_scipy() # %% # diff --git a/docs/examples/basic_usage/example_model_merging.py b/docs/examples/basic_usage/example_model_merging.py index be373ae..ed91e98 100644 --- a/docs/examples/basic_usage/example_model_merging.py +++ b/docs/examples/basic_usage/example_model_merging.py @@ -136,7 +136,7 @@ def make_dataset() -> TensorDataset: loss_function, [p for p in model.parameters() if p.requires_grad], data_loader, - ) + ).to_scipy() for model, loss_function, data_loader in zip(models, loss_functions, data_loaders) ] diff --git a/docs/examples/basic_usage/example_visual_tour.py b/docs/examples/basic_usage/example_visual_tour.py index 8c2b580..0461d27 100644 --- a/docs/examples/basic_usage/example_visual_tour.py +++ b/docs/examples/basic_usage/example_visual_tour.py @@ -80,7 +80,7 @@ Hessian_linop = HessianLinearOperator( model, loss_function, params, dataloader ).to_scipy() -GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader) +GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader).to_scipy() EF_linop = EFLinearOperator(model, loss_function, params, dataloader) # %% diff --git a/test/test_ggn.py b/test/test_ggn.py index 213daef..4e72494 100644 --- a/test/test_ggn.py +++ b/test/test_ggn.py @@ -1,53 +1,32 @@ """Contains tests for ``curvlinops/ggn``.""" from collections.abc import MutableMapping +from test.utils import compare_matmat -from numpy import random from pytest import raises from curvlinops import GGNLinearOperator from curvlinops.examples.functorch import functorch_ggn -from curvlinops.examples.utils import report_nonclose -def test_GGNLinearOperator_matvec(case, adjoint: bool): +def test_GGNLinearOperator_matvec(case, adjoint: bool, is_vec: bool): + """Test matrix-matrix multiplication with the GGN. + + Args: + case: Tuple of model, loss function, parameters, data, and batch size getter. + adjoint: Whether to test the adjoint operator. + is_vec: Whether to test matrix-vector or matrix-matrix multiplication. + """ model_func, loss_func, params, data, batch_size_fn = case # Test when X is dict-like but batch_size_fn = None (default) if isinstance(data[0][0], MutableMapping): with raises(ValueError): - op = GGNLinearOperator(model_func, loss_func, params, data) - - op = GGNLinearOperator( - model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) - op_functorch = ( - functorch_ggn(model_func, loss_func, params, data, input_key="x") - .detach() - .cpu() - .numpy() - ) - if adjoint: - op, op_functorch = op.adjoint(), op_functorch.conj().T + _ = GGNLinearOperator(model_func, loss_func, params, data) - x = random.rand(op.shape[1]) - report_nonclose(op @ x, op_functorch @ x) - - -def test_GGNLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): - model_func, loss_func, params, data, batch_size_fn = case - - op = GGNLinearOperator( + G = GGNLinearOperator( model_func, loss_func, params, data, batch_size_fn=batch_size_fn ) - op_functorch = ( - functorch_ggn(model_func, loss_func, params, data, input_key="x") - .detach() - .cpu() - .numpy() - ) - if adjoint: - op, op_functorch = op.adjoint(), op_functorch.conj().T + G_mat = functorch_ggn(model_func, loss_func, params, data, input_key="x") - X = random.rand(op.shape[1], num_vecs) - report_nonclose(op @ X, op_functorch @ X) + compare_matmat(G, G_mat, adjoint, is_vec, atol=1e-7) diff --git a/test/test_hessian.py b/test/test_hessian.py index 0727a4d..c6efe03 100644 --- a/test/test_hessian.py +++ b/test/test_hessian.py @@ -24,9 +24,9 @@ def test_HessianLinearOperator( Args: case: Tuple of model, loss function, parameters, data, and batch size getter. adjoint: Whether to test the adjoint operator. + is_vec: Whether to test matrix-vector or matrix-matrix multiplication. block_sizes_fn: The function that generates the block sizes used to define block diagonal approximations from the parameters. - is_vec: Whether to test matrix-vector or matrix-matrix multiplication. """ model_func, loss_func, params, data, batch_size_fn = case block_sizes = block_sizes_fn(params) diff --git a/test/test_inverse.py b/test/test_inverse.py index f064ccc..cfa8309 100644 --- a/test/test_inverse.py +++ b/test/test_inverse.py @@ -34,7 +34,7 @@ def test_CG_inverse_damped_GGN_matvec(inv_case, delta: float = 2e-2): GGN = GGNLinearOperator( model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) + ).to_scipy() damping = aslinearoperator(delta * sparse.eye(GGN.shape[0])) inv_GGN = CGInverseLinearOperator(GGN + damping) @@ -56,7 +56,7 @@ def test_CG_inverse_damped_GGN_matmat(inv_case, delta: float = 1e-2, num_vecs: i GGN = GGNLinearOperator( model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) + ).to_scipy() damping = aslinearoperator(delta * sparse.eye(GGN.shape[0])) inv_GGN = CGInverseLinearOperator(GGN + damping) @@ -78,7 +78,7 @@ def test_LSMR_inverse_damped_GGN_matvec(inv_case, delta: float = 2e-2): GGN = GGNLinearOperator( model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) + ).to_scipy() damping = aslinearoperator(delta * sparse.eye(GGN.shape[0])) inv_GGN = LSMRInverseLinearOperator(GGN + damping) @@ -104,7 +104,7 @@ def test_LSMR_inverse_damped_GGN_matmat( GGN = GGNLinearOperator( model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) + ).to_scipy() damping = aslinearoperator(delta * sparse.eye(GGN.shape[0])) inv_GGN = LSMRInverseLinearOperator(GGN + damping) @@ -128,7 +128,7 @@ def test_Neumann_inverse_damped_GGN_matvec(inv_case, delta: float = 1e-2): GGN = GGNLinearOperator( model_func, loss_func, params, data, batch_size_fn=batch_size_fn - ) + ).to_scipy() damping = aslinearoperator(delta * sparse.eye(GGN.shape[0])) damped_GGN_functorch = functorch_ggn( diff --git a/test/test_submatrix_on_curvatures.py b/test/test_submatrix_on_curvatures.py index f234a91..a8605eb 100644 --- a/test/test_submatrix_on_curvatures.py +++ b/test/test_submatrix_on_curvatures.py @@ -64,7 +64,7 @@ def setup_submatrix_linear_operator(case, operator_case, submatrix_case): col_idxs = submatrix_case["col_idx_fn"](dim) A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn) - if isinstance(A, HessianLinearOperator): + if isinstance(A, (HessianLinearOperator, GGNLinearOperator)): A = A.to_scipy() A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs) diff --git a/test/utils.py b/test/utils.py index 83f4bec..02317cc 100644 --- a/test/utils.py +++ b/test/utils.py @@ -112,7 +112,9 @@ def ggn_block_diagonal( The block-diagonal GGN. """ # compute the full GGN then zero out the off-diagonal blocks - ggn = GGNLinearOperator(model, loss_func, params, data, batch_size_fn=batch_size_fn) + ggn = GGNLinearOperator( + model, loss_func, params, data, batch_size_fn=batch_size_fn + ).to_scipy() ggn = from_numpy(ggn @ eye(ggn.shape[1])) sizes = [p.numel() for p in params] # ggn_blocks[i, j] corresponds to the block of (params[i], params[j])