From f3d4bd92ecb9ab683e138f466c423aa5d6619e6c Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 17 Jul 2023 15:30:45 -0400 Subject: [PATCH] [ADD] Implement and test `_adjoint` --- curvlinops/fisher.py | 12 ++++++++++++ curvlinops/ggn.py | 12 ++++++++++++ curvlinops/gradient_moments.py | 12 ++++++++++++ curvlinops/hessian.py | 12 ++++++++++++ curvlinops/outer.py | 12 ++++++++++++ test/cases.py | 2 ++ test/conftest.py | 7 ++++++- test/test_fisher.py | 18 ++++++++++++------ test/test_ggn.py | 24 ++++++++++++++---------- test/test_gradient_moments.py | 24 ++++++++++++++---------- test/test_hessian.py | 24 ++++++++++++++---------- 11 files changed, 122 insertions(+), 37 deletions(-) diff --git a/curvlinops/fisher.py b/curvlinops/fisher.py index adfad2d..86264f9 100644 --- a/curvlinops/fisher.py +++ b/curvlinops/fisher.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of the (approximate) Fisher.""" +from __future__ import annotations + from math import sqrt from typing import Callable, Iterable, List, Tuple, Union @@ -270,3 +272,13 @@ def sample_grad_output(self, output: Tensor) -> Tensor: else: raise NotImplementedError(f"Supported losses: {self.supported_losses}") + + def _adjoint(self) -> FisherMCLinearOperator: + """Return the linear operator representing the adjoint. + + The Fisher MC-approximation is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/ggn.py b/curvlinops/ggn.py index 4b542fc..8e901bd 100644 --- a/curvlinops/ggn.py +++ b/curvlinops/ggn.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of the GGN.""" +from __future__ import annotations + from typing import List, Tuple from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist @@ -54,3 +56,13 @@ def _matvec_batch( output = self._model_func(X) loss = self._loss_func(output, y) return ggn_vector_product_from_plist(loss, output, self._params, x_list) + + def _adjoint(self) -> GGNLinearOperator: + """Return the linear operator representing the adjoint. + + The GGN is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/gradient_moments.py b/curvlinops/gradient_moments.py index 2b25d72..7dbd0ea 100644 --- a/curvlinops/gradient_moments.py +++ b/curvlinops/gradient_moments.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of gradient moment matrices.""" +from __future__ import annotations + from typing import List, Tuple from torch import Tensor, autograd, einsum, zeros_like @@ -76,3 +78,13 @@ def _matvec_batch( raise ValueError("Loss must have reduction 'mean' or 'sum'.") return tuple(r / normalization for r in result_list) + + def _adjoint(self) -> EFLinearOperator: + """Return the linear operator representing the adjoint. + + The empirical Fisher is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/hessian.py b/curvlinops/hessian.py index 6a932c6..de1b44e 100644 --- a/curvlinops/hessian.py +++ b/curvlinops/hessian.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of the Hessian.""" +from __future__ import annotations + from typing import List, Tuple from backpack.hessianfree.hvp import hessian_vector_product @@ -45,3 +47,13 @@ def _matvec_batch( """ loss = self._loss_func(self._model_func(X), y) return hessian_vector_product(loss, self._params, x_list) + + def _adjoint(self) -> HessianLinearOperator: + """Return the linear operator representing the adjoint. + + The Hessian is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/outer.py b/curvlinops/outer.py index eec6a3e..a87c676 100644 --- a/curvlinops/outer.py +++ b/curvlinops/outer.py @@ -1,5 +1,7 @@ """Utility linear operators.""" +from __future__ import annotations + from numpy import einsum, einsum_path, ndarray, ones from scipy.sparse.linalg import LinearOperator @@ -42,6 +44,16 @@ def _matvec(self, x: ndarray) -> ndarray: """ return einsum(self._equation, *self._operands, x, optimize=self._path) + def _adjoint(self) -> OuterProductLinearOperator: + """Return the linear operator representing the adjoint. + + An outer product is self-adjoint. + + Returns: + Self. + """ + return self + class Projector(OuterProductLinearOperator): """Linear operator for the projector onto the orthonormal basis ``{ aᵢ }``.""" diff --git a/test/cases.py b/test/cases.py index bc475fd..071b332 100644 --- a/test/cases.py +++ b/test/cases.py @@ -171,3 +171,5 @@ def data(): for device in DEVICES: case_with_device = {**case, "device": device} NON_DETERMINISTIC_CASES.append(case_with_device) + +ADJOINT_CASES = [False, True] diff --git a/test/conftest.py b/test/conftest.py index 5ab5b2b..408074e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,6 @@ """Contains pytest fixtures that are visible by other files.""" -from test.cases import CASES, NON_DETERMINISTIC_CASES +from test.cases import ADJOINT_CASES, CASES, NON_DETERMINISTIC_CASES from typing import Callable, Dict, Iterable, List, Tuple from numpy import random @@ -51,3 +51,8 @@ def non_deterministic_case( ]: case = request.param yield initialize_case(case) + + +@fixture(params=ADJOINT_CASES) +def adjoint(request) -> bool: + return request.param diff --git a/test/test_fisher.py b/test/test_fisher.py index 7fafe09..c0f4e97 100644 --- a/test/test_fisher.py +++ b/test/test_fisher.py @@ -20,11 +20,15 @@ @mark.parametrize( "max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS ) -def test_LinearOperator_matvec_expectation(case, max_repeats: int, mc_samples: int): +def test_LinearOperator_matvec_expectation( + case, adjoint: bool, max_repeats: int, mc_samples: int +): F = FisherMCLinearOperator(*case, mc_samples=mc_samples) - x = random.rand(F.shape[1]).astype(F.dtype) - G_functorch = functorch_ggn(*case).detach().cpu().numpy() + if adjoint: + F, G_functorch = F.adjoint(), G_functorch.conj().T + + x = random.rand(F.shape[1]).astype(F.dtype) Gx = G_functorch @ x Fx = zeros_like(x) @@ -49,12 +53,14 @@ def test_LinearOperator_matvec_expectation(case, max_repeats: int, mc_samples: i "max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS ) def test_LinearOperator_matmat_expectation( - case, max_repeats: int, mc_samples: int, num_vecs: int = 2 + case, adjoint: bool, max_repeats: int, mc_samples: int, num_vecs: int = 2 ): F = FisherMCLinearOperator(*case, mc_samples=mc_samples) - X = random.rand(F.shape[1], num_vecs).astype(F.dtype) - G_functorch = functorch_ggn(*case).detach().cpu().numpy() + if adjoint: + F, G_functorch = F.adjoint(), G_functorch.conj().T + + X = random.rand(F.shape[1], num_vecs).astype(F.dtype) GX = G_functorch @ X FX = zeros_like(X) diff --git a/test/test_ggn.py b/test/test_ggn.py index 72721f2..3d58f3d 100644 --- a/test/test_ggn.py +++ b/test/test_ggn.py @@ -7,25 +7,29 @@ from curvlinops.examples.utils import report_nonclose -def test_GGNLinearOperator_matvec(case): +def test_GGNLinearOperator_matvec(case, adjoint: bool): model_func, loss_func, params, data = case - GGN = GGNLinearOperator(model_func, loss_func, params, data) - GGN_functorch = ( + op = GGNLinearOperator(model_func, loss_func, params, data) + op_functorch = ( functorch_ggn(model_func, loss_func, params, data).detach().cpu().numpy() ) + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - x = random.rand(GGN.shape[1]) - report_nonclose(GGN @ x, GGN_functorch @ x) + x = random.rand(op.shape[1]) + report_nonclose(op @ x, op_functorch @ x) -def test_GGNLinearOperator_matmat(case, num_vecs: int = 3): +def test_GGNLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): model_func, loss_func, params, data = case - GGN = GGNLinearOperator(model_func, loss_func, params, data) - GGN_functorch = ( + op = GGNLinearOperator(model_func, loss_func, params, data) + op_functorch = ( functorch_ggn(model_func, loss_func, params, data).detach().cpu().numpy() ) + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - X = random.rand(GGN.shape[1], num_vecs) - report_nonclose(GGN @ X, GGN_functorch @ X) + X = random.rand(op.shape[1], num_vecs) + report_nonclose(op @ X, op_functorch @ X) diff --git a/test/test_gradient_moments.py b/test/test_gradient_moments.py index f123372..0c97da0 100644 --- a/test/test_gradient_moments.py +++ b/test/test_gradient_moments.py @@ -7,17 +7,21 @@ from curvlinops.examples.utils import report_nonclose -def test_EFLinearOperator_matvec(case): - EF = EFLinearOperator(*case) - EF_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() +def test_EFLinearOperator_matvec(case, adjoint: bool): + op = EFLinearOperator(*case) + op_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - x = random.rand(EF.shape[1]).astype(EF.dtype) - report_nonclose(EF @ x, EF_functorch @ x) + x = random.rand(op.shape[1]).astype(op.dtype) + report_nonclose(op @ x, op_functorch @ x) -def test_EFLinearOperator_matmat(case, num_vecs: int = 3): - EF = EFLinearOperator(*case) - EF_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() +def test_EFLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): + op = EFLinearOperator(*case) + op_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - X = random.rand(EF.shape[1], num_vecs).astype(EF.dtype) - report_nonclose(EF @ X, EF_functorch @ X, atol=1e-7, rtol=1e-4) + X = random.rand(op.shape[1], num_vecs).astype(op.dtype) + report_nonclose(op @ X, op_functorch @ X, atol=1e-7, rtol=1e-4) diff --git a/test/test_hessian.py b/test/test_hessian.py index 91bb00e..5012b69 100644 --- a/test/test_hessian.py +++ b/test/test_hessian.py @@ -7,17 +7,21 @@ from curvlinops.examples.utils import report_nonclose -def test_HessianLinearOperator_matvec(case): - H = HessianLinearOperator(*case) - H_functorch = functorch_hessian(*case).detach().cpu().numpy() +def test_HessianLinearOperator_matvec(case, adjoint: bool): + op = HessianLinearOperator(*case) + op_functorch = functorch_hessian(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - x = random.rand(H.shape[1]) - report_nonclose(H @ x, H_functorch @ x) + x = random.rand(op.shape[1]) + report_nonclose(op @ x, op_functorch @ x) -def test_HessianLinearOperator_matmat(case, num_vecs: int = 3): - H = HessianLinearOperator(*case) - H_functorch = functorch_hessian(*case).detach().cpu().numpy() +def test_HessianLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): + op = HessianLinearOperator(*case) + op_functorch = functorch_hessian(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - X = random.rand(H.shape[1], num_vecs) - report_nonclose(H @ X, H_functorch @ X, atol=1e-6, rtol=5e-4) + X = random.rand(op.shape[1], num_vecs) + report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=5e-4)