Skip to content

Commit

Permalink
[ADD] Introduce pure PyTorch base class and implement Hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 18, 2024
1 parent 9a6bb6d commit 4c2e3f6
Show file tree
Hide file tree
Showing 12 changed files with 665 additions and 64 deletions.
541 changes: 529 additions & 12 deletions curvlinops/_base.py

Large diffs are not rendered by default.

23 changes: 20 additions & 3 deletions curvlinops/examples/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
"""Utility functions for the examples in the documentation."""

from numpy import allclose, isclose, ndarray
from typing import Union

from numpy import allclose as numpy_allclose
from numpy import isclose as numpy_isclose
from numpy import ndarray
from torch import Tensor
from torch import allclose as torch_allclose
from torch import isclose as torch_isclose


def report_nonclose(
array1: ndarray,
array2: ndarray,
array1: Union[ndarray, Tensor],
array2: Union[ndarray, Tensor],
rtol: float = 1e-5,
atol: float = 1e-8,
equal_nan: bool = False,
Expand All @@ -28,6 +35,16 @@ def report_nonclose(
f"Arrays shapes don't match: {array1.shape} vs. {array2.shape}."
)

if isinstance(array1, Tensor) and isinstance(array2, Tensor):
allclose, isclose = torch_allclose, torch_isclose
elif isinstance(array1, ndarray) and isinstance(array2, ndarray):
allclose, isclose = numpy_allclose, numpy_isclose
else:
raise ValueError(
"Both arrays should be either tensors or ndarrays."
f" Got {type(array1)} and {type(array2)}."
)

if allclose(array1, array2, rtol=rtol, atol=atol, equal_nan=equal_nan):
print("Compared arrays match.")
else:
Expand Down
2 changes: 1 addition & 1 deletion curvlinops/experimental/activation_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _preprocess(self, M: ndarray) -> List[Tensor]:
num_vectors = M.shape[1]
return [
from_numpy(M.T)
.to(self._device)
.to(self.device)
.reshape(num_vectors, *self._activation_shape)
]

Expand Down
69 changes: 52 additions & 17 deletions curvlinops/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from __future__ import annotations

from collections.abc import MutableMapping
from typing import List, Tuple, Union
from typing import Callable, Iterable, List, Tuple, Union

from backpack.hessianfree.hvp import hessian_vector_product
from torch import Tensor, zeros_like
from torch.autograd import grad
from torch.nn import Parameter

from curvlinops._base import _LinearOperator
from curvlinops._base import CurvatureLinearOperator
from curvlinops.utils import split_list


class HessianLinearOperator(_LinearOperator):
r"""Hessian as SciPy linear operator.
class HessianLinearOperator(CurvatureLinearOperator):
r"""Linear operator for the Hessian of an empirical risk in PyTorch.
Consider the empirical risk
Expand All @@ -41,15 +42,46 @@ class HessianLinearOperator(_LinearOperator):

SUPPORTS_BLOCKS: bool = True

def __init__(
self,
model_func: Callable[[Tensor | MutableMapping], Tensor],
loss_func: Callable[[Tensor, Tensor], Tensor] | None,
params: List[Tensor | Parameter],
data: Iterable[Tuple[Tensor | MutableMapping, Tensor]],
progressbar: bool = False,
num_data: int | None = None,
in_blocks: List[int] | None = None,
out_blocks: List[int] | None = None,
batch_size_fn: Callable[[Tensor | MutableMapping], int] | None = None,
):
in_shape = out_shape = [tuple(p.shape) for p in params]
(dt,) = {p.dtype for p in params}
(dev,) = {p.device for p in params}
super().__init__(
model_func,
loss_func,
params,
data,
in_shape,
out_shape,
dt,
dev,
progressbar=progressbar,
num_data=num_data,
in_blocks=in_blocks,
out_blocks=out_blocks,
batch_size_fn=batch_size_fn,
)

def _matmat_batch(
self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor]
) -> Tuple[Tensor, ...]:
self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor]
) -> List[Tensor]:
"""Apply the mini-batch Hessian to a matrix.
Args:
X: Input to the DNN.
y: Ground truth.
M_list: Matrix to be multiplied with in list format.
M: Matrix to be multiplied with in list format.
Tensors have same shape as trainable model parameters, and an
additional leading axis for the matrix columns.
Expand All @@ -58,29 +90,32 @@ def _matmat_batch(
``M_list``, i.e. each tensor in the list has the shape of a parameter and a
leading dimension of matrix columns.
"""
assert self._loss_func is not None
loss = self._loss_func(self._model_func(X), y)

# Re-cycle first backward pass from the HVP's double-backward
grad_params = grad(loss, self._params, create_graph=True)
grad_params = list(grad(loss, self._params, create_graph=True))

(num_vecs,) = {m.shape[-1] for m in M}
result = [zeros_like(m) for m in M]

num_vecs = M_list[0].shape[0]
result = [zeros_like(M) for M in M_list]
assert self._in_blocks == self._out_blocks

# per-block HMP
for M_block, p_block, g_block, res_block in zip(
split_list(M_list, self._block_sizes),
split_list(self._params, self._block_sizes),
split_list(grad_params, self._block_sizes),
split_list(result, self._block_sizes),
split_list(M, self._in_blocks),
split_list(self._params, self._in_blocks),
split_list(grad_params, self._in_blocks),
split_list(result, self._in_blocks),
):
for n in range(num_vecs):
col_n = hessian_vector_product(
loss, p_block, [M[n] for M in M_block], grad_params=g_block
loss, p_block, [m[..., n] for m in M_block], grad_params=g_block
)
for p, col in enumerate(col_n):
res_block[p][n].add_(col)
res_block[p][..., n].add_(col)

return tuple(result)
return result

def _adjoint(self) -> HessianLinearOperator:
"""Return the linear operator representing the adjoint.
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/basic_usage/example_eigenvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
# We are ready to setup the linear operator. In this example, we will use the Hessian.

data = [(X1, y1), (X2, y2)]
H = HessianLinearOperator(model, loss_function, params, data)
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()

# %%
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
# Setting up a linear operator for the Hessian is straightforward.

data = [(X, y)]
H = HessianLinearOperator(model, loss_function, params, data)
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()

# %%
#
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/basic_usage/example_submatrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
# its matrix representation through multiplication with the identity matrix,
# followed by comparison to the Hessian matrix computed via :mod:`functorch`.

H = HessianLinearOperator(model, loss_function, params, data)
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()

report_nonclose(H_functorch, H @ numpy.eye(H.shape[1]))

Expand Down Expand Up @@ -122,7 +122,7 @@ def extract_block(
# We can build a linear operator for this sub-Hessian by only providing the
# first layer's weight as parameter:

H_param0 = HessianLinearOperator(model, loss_function, [params[i]], data)
H_param0 = HessianLinearOperator(model, loss_function, [params[i]], data).to_scipy()

# %%
#
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/basic_usage/example_visual_tour.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@
#
# First, create the linear operators:

Hessian_linop = HessianLinearOperator(model, loss_function, params, dataloader)
Hessian_linop = HessianLinearOperator(
model, loss_function, params, dataloader
).to_scipy()
GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader)
EF_linop = EFLinearOperator(model, loss_function, params, dataloader)

Expand Down
2 changes: 2 additions & 0 deletions test/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,5 @@ def data():
NON_DETERMINISTIC_CASES.append(case_with_device)

ADJOINT_CASES = [False, True]

SCIPY_FRONTEND_CASES = [False, True]
6 changes: 6 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CNN_CASES,
INV_CASES,
NON_DETERMINISTIC_CASES,
SCIPY_FRONTEND_CASES,
)
from test.kfac_cases import (
KFAC_EXACT_CASES,
Expand Down Expand Up @@ -115,6 +116,11 @@ def adjoint(request) -> bool:
return request.param


@fixture(params=SCIPY_FRONTEND_CASES)
def scipy_frontend(request) -> bool:
return request.param


@fixture(params=KFAC_EXACT_CASES)
def kfac_exact_case(
request,
Expand Down
70 changes: 45 additions & 25 deletions test/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

from numpy import random
from pytest import mark, raises
from torch import block_diag
from torch import block_diag, rand

from curvlinops import HessianLinearOperator
from curvlinops.examples.functorch import functorch_hessian
from curvlinops.examples.utils import report_nonclose
from curvlinops.hessian import HessianLinearOperator
from curvlinops.utils import split_list


def test_HessianLinearOperator_matvec(case, adjoint: bool):
def test_HessianLinearOperator_matvec(case, adjoint: bool, scipy_frontend: bool):
model_func, loss_func, params, data, batch_size_fn = case

# Test when X is dict-like but batch_size_fn = None (default)
Expand All @@ -23,35 +24,45 @@ def test_HessianLinearOperator_matvec(case, adjoint: bool):
op = HessianLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
op_functorch = (
functorch_hessian(model_func, loss_func, params, data, input_key="x")
.detach()
.cpu()
.numpy()
)
op_functorch = functorch_hessian(model_func, loss_func, params, data, input_key="x")

if scipy_frontend:
op = op.to_scipy()
op_functorch = op_functorch.detach().cpu().numpy()

if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

x = random.rand(op.shape[1])
report_nonclose(op @ x, op_functorch @ x, atol=1e-7)
x = (
random.rand(op.shape[1])
if scipy_frontend
else rand(op.shape[1], dtype=op.dtype, device=op.device)
)
report_nonclose(op @ x, op_functorch @ x, atol=1e-7, rtol=1e-4)


def test_HessianLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3):
def test_HessianLinearOperator_matmat(
case, adjoint: bool, scipy_frontend: bool, num_vecs: int = 3
):
model_func, loss_func, params, data, batch_size_fn = case

op = HessianLinearOperator(
model_func, loss_func, params, data, batch_size_fn=batch_size_fn
)
op_functorch = (
functorch_hessian(model_func, loss_func, params, data, input_key="x")
.detach()
.cpu()
.numpy()
)
op_functorch = functorch_hessian(model_func, loss_func, params, data, input_key="x")

if scipy_frontend:
op = op.to_scipy()
op_functorch = op_functorch.detach().cpu().numpy()

if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

X = random.rand(op.shape[1], num_vecs)
X = (
random.rand(op.shape[1], num_vecs)
if scipy_frontend
else rand(op.shape[1], num_vecs, dtype=op.dtype, device=op.device)
)
report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=5e-4)


Expand All @@ -65,14 +76,16 @@ def test_HessianLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3):

@mark.parametrize("blocking", BLOCKING_FNS.keys(), ids=BLOCKING_FNS.keys())
def test_blocked_HessianLinearOperator_matmat(
case, adjoint: bool, blocking: str, num_vecs: int = 2
case, adjoint: bool, blocking: str, scipy_frontend: bool, num_vecs: int = 2
):
"""Test matrix-matrix multiplication with the block-diagonal Hessian.
Args:
case: Tuple of model, loss function, parameters, and data.
adjoint: Whether to test the adjoint operator.
blocking: Blocking scheme.
scipy_frontend: Whether to feed SciPy vectors as inputs. If `False`,
PyTorch vectors are used.
num_vecs: Number of vectors to multiply with. Default is ``2``.
"""
model_func, loss_func, params, data, batch_size_fn = case
Expand All @@ -84,20 +97,27 @@ def test_blocked_HessianLinearOperator_matmat(
params,
data,
batch_size_fn=batch_size_fn,
block_sizes=block_sizes,
in_blocks=block_sizes,
out_blocks=block_sizes,
)

# compute the blocks with functorch and build the block diagonal matrix
op_functorch = [
functorch_hessian(
model_func, loss_func, params_block, data, input_key="x"
).detach()
functorch_hessian(model_func, loss_func, params_block, data, input_key="x")
for params_block in split_list(params, block_sizes)
]
op_functorch = block_diag(*op_functorch).cpu().numpy()
op_functorch = block_diag(*op_functorch)

if scipy_frontend:
op = op.to_scipy()
op_functorch = op_functorch.detach().cpu().numpy()

if adjoint:
op, op_functorch = op.adjoint(), op_functorch.conj().T

X = random.rand(op.shape[1], num_vecs)
X = (
random.rand(op.shape[1], num_vecs)
if scipy_frontend
else rand(op.shape[1], num_vecs, dtype=op.dtype, device=op.device)
)
report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=5e-4)
4 changes: 3 additions & 1 deletion test/test_submatrix_on_curvatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def setup_submatrix_linear_operator(case, operator_case, submatrix_case):
col_idxs = submatrix_case["col_idx_fn"](dim)

A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn)
if isinstance(A, HessianLinearOperator):
A = A.to_scipy()
A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs)

A_functorch = CURVATURE_IN_FUNCTORCH[operator_case](
Expand All @@ -86,7 +88,7 @@ def test_SubmatrixLinearOperator_on_curvatures_matvec(
A_sub_x = A_sub @ x

assert A_sub_x.shape == (len(row_idxs),)
report_nonclose(A_sub_x, A_sub_functorch @ x, atol=2e-7)
report_nonclose(A_sub_x, A_sub_functorch @ x, atol=2e-7, rtol=1e-4)


@mark.parametrize("operator_case", CURVATURE_CASES)
Expand Down

0 comments on commit 4c2e3f6

Please sign in to comment.