Skip to content

Commit

Permalink
[ADD] Implement and test _adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 17, 2023
1 parent 00733c2 commit f3d4bd9
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 37 deletions.
12 changes: 12 additions & 0 deletions curvlinops/fisher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Contains LinearOperator implementation of the (approximate) Fisher."""

from __future__ import annotations

from math import sqrt
from typing import Callable, Iterable, List, Tuple, Union

Expand Down Expand Up @@ -270,3 +272,13 @@ def sample_grad_output(self, output: Tensor) -> Tensor:

else:
raise NotImplementedError(f"Supported losses: {self.supported_losses}")

def _adjoint(self) -> FisherMCLinearOperator:
"""Return the linear operator representing the adjoint.
The Fisher MC-approximation is real symmetric, and hence self-adjoint.
Returns:
Self.
"""
return self
12 changes: 12 additions & 0 deletions curvlinops/ggn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Contains LinearOperator implementation of the GGN."""

from __future__ import annotations

from typing import List, Tuple

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
Expand Down Expand Up @@ -54,3 +56,13 @@ def _matvec_batch(
output = self._model_func(X)
loss = self._loss_func(output, y)
return ggn_vector_product_from_plist(loss, output, self._params, x_list)

def _adjoint(self) -> GGNLinearOperator:
"""Return the linear operator representing the adjoint.
The GGN is real symmetric, and hence self-adjoint.
Returns:
Self.
"""
return self
12 changes: 12 additions & 0 deletions curvlinops/gradient_moments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Contains LinearOperator implementation of gradient moment matrices."""

from __future__ import annotations

from typing import List, Tuple

from torch import Tensor, autograd, einsum, zeros_like
Expand Down Expand Up @@ -76,3 +78,13 @@ def _matvec_batch(
raise ValueError("Loss must have reduction 'mean' or 'sum'.")

return tuple(r / normalization for r in result_list)

def _adjoint(self) -> EFLinearOperator:
"""Return the linear operator representing the adjoint.
The empirical Fisher is real symmetric, and hence self-adjoint.
Returns:
Self.
"""
return self
12 changes: 12 additions & 0 deletions curvlinops/hessian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Contains LinearOperator implementation of the Hessian."""

from __future__ import annotations

from typing import List, Tuple

from backpack.hessianfree.hvp import hessian_vector_product
Expand Down Expand Up @@ -45,3 +47,13 @@ def _matvec_batch(
"""
loss = self._loss_func(self._model_func(X), y)
return hessian_vector_product(loss, self._params, x_list)

def _adjoint(self) -> HessianLinearOperator:
"""Return the linear operator representing the adjoint.
The Hessian is real symmetric, and hence self-adjoint.
Returns:
Self.
"""
return self
12 changes: 12 additions & 0 deletions curvlinops/outer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility linear operators."""

from __future__ import annotations

from numpy import einsum, einsum_path, ndarray, ones
from scipy.sparse.linalg import LinearOperator

Expand Down Expand Up @@ -42,6 +44,16 @@ def _matvec(self, x: ndarray) -> ndarray:
"""
return einsum(self._equation, *self._operands, x, optimize=self._path)

def _adjoint(self) -> OuterProductLinearOperator:
"""Return the linear operator representing the adjoint.
An outer product is self-adjoint.
Returns:
Self.
"""
return self


class Projector(OuterProductLinearOperator):
"""Linear operator for the projector onto the orthonormal basis ``{ aᵢ }``."""
Expand Down
2 changes: 2 additions & 0 deletions test/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,5 @@ def data():
for device in DEVICES:
case_with_device = {**case, "device": device}
NON_DETERMINISTIC_CASES.append(case_with_device)

ADJOINT_CASES = [False, True]
7 changes: 6 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains pytest fixtures that are visible by other files."""

from test.cases import CASES, NON_DETERMINISTIC_CASES
from test.cases import ADJOINT_CASES, CASES, NON_DETERMINISTIC_CASES
from typing import Callable, Dict, Iterable, List, Tuple

from numpy import random
Expand Down Expand Up @@ -51,3 +51,8 @@ def non_deterministic_case(
]:
case = request.param
yield initialize_case(case)


@fixture(params=ADJOINT_CASES)
def adjoint(request) -> bool:
return request.param
18 changes: 12 additions & 6 deletions test/test_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
@mark.parametrize(
"max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS
)
def test_LinearOperator_matvec_expectation(case, max_repeats: int, mc_samples: int):
def test_LinearOperator_matvec_expectation(
case, adjoint: bool, max_repeats: int, mc_samples: int
):
F = FisherMCLinearOperator(*case, mc_samples=mc_samples)
x = random.rand(F.shape[1]).astype(F.dtype)

G_functorch = functorch_ggn(*case).detach().cpu().numpy()
if adjoint:
F, G_functorch = F.adjoint(), G_functorch.conj().T

x = random.rand(F.shape[1]).astype(F.dtype)
Gx = G_functorch @ x

Fx = zeros_like(x)
Expand All @@ -49,12 +53,14 @@ def test_LinearOperator_matvec_expectation(case, max_repeats: int, mc_samples: i
"max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS
)
def test_LinearOperator_matmat_expectation(
case, max_repeats: int, mc_samples: int, num_vecs: int = 2
case, adjoint: bool, max_repeats: int, mc_samples: int, num_vecs: int = 2
):
F = FisherMCLinearOperator(*case, mc_samples=mc_samples)
X = random.rand(F.shape[1], num_vecs).astype(F.dtype)

G_functorch = functorch_ggn(*case).detach().cpu().numpy()
if adjoint:
F, G_functorch = F.adjoint(), G_functorch.conj().T

X = random.rand(F.shape[1], num_vecs).astype(F.dtype)
GX = G_functorch @ X

FX = zeros_like(X)
Expand Down
24 changes: 14 additions & 10 deletions test/test_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,29 @@
from curvlinops.examples.utils import report_nonclose


def test_GGNLinearOperator_matvec(case):
def test_GGNLinearOperator_matvec(case, adjoint: bool):
model_func, loss_func, params, data = case

GGN = GGNLinearOperator(model_func, loss_func, params, data)
GGN_functorch = (
op = GGNLinearOperator(model_func, loss_func, params, data)
op_functorch = (
functorch_ggn(model_func, loss_func, params, data).detach().cpu().numpy()
)
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

x = random.rand(GGN.shape[1])
report_nonclose(GGN @ x, GGN_functorch @ x)
x = random.rand(op.shape[1])
report_nonclose(op @ x, op_functorch @ x)


def test_GGNLinearOperator_matmat(case, num_vecs: int = 3):
def test_GGNLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3):
model_func, loss_func, params, data = case

GGN = GGNLinearOperator(model_func, loss_func, params, data)
GGN_functorch = (
op = GGNLinearOperator(model_func, loss_func, params, data)
op_functorch = (
functorch_ggn(model_func, loss_func, params, data).detach().cpu().numpy()
)
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

X = random.rand(GGN.shape[1], num_vecs)
report_nonclose(GGN @ X, GGN_functorch @ X)
X = random.rand(op.shape[1], num_vecs)
report_nonclose(op @ X, op_functorch @ X)
24 changes: 14 additions & 10 deletions test/test_gradient_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
from curvlinops.examples.utils import report_nonclose


def test_EFLinearOperator_matvec(case):
EF = EFLinearOperator(*case)
EF_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy()
def test_EFLinearOperator_matvec(case, adjoint: bool):
op = EFLinearOperator(*case)
op_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy()
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

x = random.rand(EF.shape[1]).astype(EF.dtype)
report_nonclose(EF @ x, EF_functorch @ x)
x = random.rand(op.shape[1]).astype(op.dtype)
report_nonclose(op @ x, op_functorch @ x)


def test_EFLinearOperator_matmat(case, num_vecs: int = 3):
EF = EFLinearOperator(*case)
EF_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy()
def test_EFLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3):
op = EFLinearOperator(*case)
op_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy()
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

X = random.rand(EF.shape[1], num_vecs).astype(EF.dtype)
report_nonclose(EF @ X, EF_functorch @ X, atol=1e-7, rtol=1e-4)
X = random.rand(op.shape[1], num_vecs).astype(op.dtype)
report_nonclose(op @ X, op_functorch @ X, atol=1e-7, rtol=1e-4)
24 changes: 14 additions & 10 deletions test/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
from curvlinops.examples.utils import report_nonclose


def test_HessianLinearOperator_matvec(case):
H = HessianLinearOperator(*case)
H_functorch = functorch_hessian(*case).detach().cpu().numpy()
def test_HessianLinearOperator_matvec(case, adjoint: bool):
op = HessianLinearOperator(*case)
op_functorch = functorch_hessian(*case).detach().cpu().numpy()
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

x = random.rand(H.shape[1])
report_nonclose(H @ x, H_functorch @ x)
x = random.rand(op.shape[1])
report_nonclose(op @ x, op_functorch @ x)


def test_HessianLinearOperator_matmat(case, num_vecs: int = 3):
H = HessianLinearOperator(*case)
H_functorch = functorch_hessian(*case).detach().cpu().numpy()
def test_HessianLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3):
op = HessianLinearOperator(*case)
op_functorch = functorch_hessian(*case).detach().cpu().numpy()
if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

X = random.rand(H.shape[1], num_vecs)
report_nonclose(H @ X, H_functorch @ X, atol=1e-6, rtol=5e-4)
X = random.rand(op.shape[1], num_vecs)
report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=5e-4)

0 comments on commit f3d4bd9

Please sign in to comment.