Skip to content

Commit

Permalink
[ADD] Change Hessian to purely PyTorch (#145)
Browse files Browse the repository at this point in the history
* [ADD] Minimal linear operator interface for PyTorch

* [FIX] Linters

* [ADD] Replicate current base class but inherit from `PyTorchLinearOperator`

* [DEL] Unused import

* [ADD] Implement Hessian as `CurvatureLinearOperator`

* [REF] Combine full and block-diagonal matrix multiply tests

* [FIX] flake8

* [FIX] RTD examples

* [FIX] Decrease `tol` for power iteration
  • Loading branch information
f-dangel authored Nov 4, 2024
1 parent 2ae34be commit 5e118e9
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 111 deletions.
38 changes: 19 additions & 19 deletions curvlinops/hessian.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""Contains LinearOperator implementation of the Hessian."""
"""Contains a linear operator implementation of the Hessian."""

from __future__ import annotations

from collections.abc import MutableMapping
from typing import List, Tuple, Union
from typing import List, Union

from backpack.hessianfree.hvp import hessian_vector_product
from torch import Tensor, zeros_like
from torch.autograd import grad

from curvlinops._base import _LinearOperator
from curvlinops._torch_base import CurvatureLinearOperator
from curvlinops.utils import split_list


class HessianLinearOperator(_LinearOperator):
r"""Hessian as SciPy linear operator.
class HessianLinearOperator(CurvatureLinearOperator):
r"""Linear operator for the Hessian of an empirical risk.
Consider the empirical risk
Expand Down Expand Up @@ -42,45 +42,45 @@ class HessianLinearOperator(_LinearOperator):
SUPPORTS_BLOCKS: bool = True

def _matmat_batch(
self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor]
) -> Tuple[Tensor, ...]:
self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor]
) -> List[Tensor]:
"""Apply the mini-batch Hessian to a matrix.
Args:
X: Input to the DNN.
y: Ground truth.
M_list: Matrix to be multiplied with in list format.
M: Matrix to be multiplied with in tensor list format.
Tensors have same shape as trainable model parameters, and an
additional leading axis for the matrix columns.
additional trailing axis for the matrix columns.
Returns:
Result of Hessian multiplication in list format. Has the same shape as
``M_list``, i.e. each tensor in the list has the shape of a parameter and a
leading dimension of matrix columns.
``M``, i.e. each tensor in the list has the shape of a parameter and a
trailing dimension of matrix columns.
"""
loss = self._loss_func(self._model_func(X), y)

# Re-cycle first backward pass from the HVP's double-backward
grad_params = grad(loss, self._params, create_graph=True)

num_vecs = M_list[0].shape[0]
result = [zeros_like(M) for M in M_list]
(num_vecs,) = {m.shape[-1] for m in M}
AM = [zeros_like(m) for m in M]

# per-block HMP
for M_block, p_block, g_block, res_block in zip(
split_list(M_list, self._block_sizes),
for M_block, p_block, g_block, AM_block in zip(
split_list(M, self._block_sizes),
split_list(self._params, self._block_sizes),
split_list(grad_params, self._block_sizes),
split_list(result, self._block_sizes),
split_list(AM, self._block_sizes),
):
for n in range(num_vecs):
col_n = hessian_vector_product(
loss, p_block, [M[n] for M in M_block], grad_params=g_block
loss, p_block, [M[..., n] for M in M_block], grad_params=g_block
)
for p, col in enumerate(col_n):
res_block[p][n].add_(col)
AM_block[p][..., n].add_(col)

return tuple(result)
return AM

def _adjoint(self) -> HessianLinearOperator:
"""Return the linear operator representing the adjoint.
Expand Down
16 changes: 11 additions & 5 deletions curvlinops/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""General utility functions."""

from typing import List
from typing import List, Tuple, Union

from numpy import cumsum
from torch import Tensor


def split_list(x: List, sizes: List[int]) -> List[List]:
def split_list(x: Union[List, Tuple], sizes: List[int]) -> List[List]:
"""Split a list into multiple lists of specified size.
Args:
x: List to be split.
x: List or tuple to be split.
sizes: Sizes of the resulting lists.
Returns:
Expand All @@ -25,7 +25,7 @@ def split_list(x: List, sizes: List[int]) -> List[List]:
+ f" of {sum(sizes)} entries."
)
boundaries = cumsum([0] + sizes)
return [x[boundaries[i] : boundaries[i + 1]] for i in range(len(sizes))]
return [list(x[boundaries[i] : boundaries[i + 1]]) for i in range(len(sizes))]


def allclose_report(
Expand All @@ -50,6 +50,12 @@ def allclose_report(
tensor1[nonclose_idx].flatten(),
tensor2[nonclose_idx].flatten(),
):
print(f"at index {idx}: {t1:.5e}{t2:.5e}, ratio: {t1 / t2:.5e}")
print(f"at index {idx.tolist()}: {t1:.5e}{t2:.5e}, ratio: {t1 / t2:.5e}")

# print largest and smallest absolute entries
amax1, amax2 = tensor1.abs().max().item(), tensor2.abs().max().item()
print(f"Abs max: {amax1:.5e} vs. {amax2:.5e}.")
amin1, amin2 = tensor1.abs().min().item(), tensor2.abs().min().item()
print(f"Abs min: {amin1:.5e} vs. {amin2:.5e}.")

return close
12 changes: 6 additions & 6 deletions docs/examples/basic_usage/example_eigenvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
# We are ready to setup the linear operator. In this example, we will use the Hessian.

data = [(X1, y1), (X2, y2)]
H = HessianLinearOperator(model, loss_function, params, data)
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()

# %%
#
Expand Down Expand Up @@ -204,7 +204,9 @@ def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarra
# the linear operator's progress bar, which allows us to count the number of
# matrix-vector products invoked by both eigen-solvers:

H = HessianLinearOperator(model, loss_function, params, data, progressbar=True)
H = HessianLinearOperator(
model, loss_function, params, data, progressbar=True
).to_scipy()

# determine number of matrix-vector products used by `eigsh`
with StringIO() as buf, redirect_stderr(buf):
Expand All @@ -216,7 +218,7 @@ def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarra

# determine number of matrix-vector products used by power iteration
with StringIO() as buf, redirect_stderr(buf):
top_k_evals_power, _ = power_method(H, k=k)
top_k_evals_power, _ = power_method(H, k=k, tol=1e-4)
# The tqdm progressbar will print "matmat" for each batch in a matrix-vector
# product. Therefore, we need to divide by the number of batches
queries_power = buf.getvalue().count("matmat") // len(data)
Expand All @@ -228,9 +230,7 @@ def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarra
#
# Sadly, the power iteration also does not offer computational benefits, consuming
# more matrix-vector products than :code:`eigsh`. While it is elegant and simple,
# it cannot compete with :code:`eigsh`, at least in the comparison provided here
# (note that we used a relative small tolerance for the power iteration, and it will
# likely deteriorate further if we decrease the tolerance).
# it cannot compete with :code:`eigsh`, at least in the comparison provided here.
#
# Therefore, we recommend using :code:`eigsh` for computing eigenvalues. This method
# becomes accessible because :code:`curvlinops` interfaces with SciPy's linear
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
# Setting up a linear operator for the Hessian is straightforward.

data = [(X, y)]
H = HessianLinearOperator(model, loss_function, params, data)
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()

# %%
#
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/basic_usage/example_submatrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
# its matrix representation through multiplication with the identity matrix,
# followed by comparison to the Hessian matrix computed via :mod:`functorch`.

H = HessianLinearOperator(model, loss_function, params, data)
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()

report_nonclose(H_functorch, H @ numpy.eye(H.shape[1]))

Expand Down Expand Up @@ -122,7 +122,7 @@ def extract_block(
# We can build a linear operator for this sub-Hessian by only providing the
# first layer's weight as parameter:

H_param0 = HessianLinearOperator(model, loss_function, [params[i]], data)
H_param0 = HessianLinearOperator(model, loss_function, [params[i]], data).to_scipy()

# %%
#
Expand Down
6 changes: 4 additions & 2 deletions docs/examples/basic_usage/example_visual_tour.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@
#
# First, create the linear operators:

Hessian_linop = HessianLinearOperator(model, loss_function, params, dataloader)
Hessian_linop = HessianLinearOperator(
model, loss_function, params, dataloader
).to_scipy()
GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader)
EF_linop = EFLinearOperator(model, loss_function, params, dataloader)
Hessian_blocked_linop = HessianLinearOperator(
Expand All @@ -94,7 +96,7 @@
params,
dataloader,
block_sizes=[s for s in num_tensors_layer if s != 0],
)
).to_scipy()
F_linop = FisherMCLinearOperator(model, loss_function, params, dataloader)
KFAC_linop = KFACLinearOperator(
model, loss_function, params, dataloader, separate_weight_and_bias=False
Expand Down
13 changes: 13 additions & 0 deletions test/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,16 @@ def data():
NON_DETERMINISTIC_CASES.append(case_with_device)

ADJOINT_CASES = [False, True]
ADJOINT_IDS = ["", "adjoint"]

IS_VECS = [False, True]
IS_VEC_IDS = ["matvec", "matmat"]


BLOCK_SIZES_FNS = {
"full": lambda params: None,
"per-parameter-blocks": lambda params: [1 for _ in range(len(params))],
"two-blocks": lambda params: (
[1] if len(params) == 1 else [len(params) // 2, len(params) - len(params) // 2]
),
}
35 changes: 33 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from collections.abc import MutableMapping
from test.cases import (
ADJOINT_CASES,
ADJOINT_IDS,
BLOCK_SIZES_FNS,
CASES,
CNN_CASES,
INV_CASES,
IS_VEC_IDS,
IS_VECS,
NON_DETERMINISTIC_CASES,
)
from test.kfac_cases import (
Expand All @@ -21,7 +25,7 @@
from numpy import random
from pytest import fixture
from torch import Tensor, manual_seed
from torch.nn import Module, MSELoss
from torch.nn import Module, MSELoss, Parameter


def initialize_case(
Expand Down Expand Up @@ -110,11 +114,38 @@ def non_deterministic_case(
yield initialize_case(case)


@fixture(params=ADJOINT_CASES)
@fixture(params=ADJOINT_CASES, ids=ADJOINT_IDS)
def adjoint(request) -> bool:
return request.param


@fixture(params=IS_VECS, ids=IS_VEC_IDS)
def is_vec(request) -> bool:
"""Whether to test matrix-vector or matrix-matrix multiplication.
Args:
request: Pytest request object.
Returns:
``True`` if the test is for matrix-vector multiplication, ``False`` otherwise.
"""
return request.param


@fixture(params=BLOCK_SIZES_FNS.values(), ids=BLOCK_SIZES_FNS.keys())
def block_sizes_fn(request) -> Callable[[List[Parameter]], Optional[List[int]]]:
"""Function to generate the ``block_sizes`` argument for a linear operator.
Args:
request: Pytest request object.
Returns:
A function that generates the block sizes for a linear operator from the
parameters.
"""
return request.param


@fixture(params=KFAC_EXACT_CASES)
def kfac_exact_case(
request,
Expand Down
Loading

0 comments on commit 5e118e9

Please sign in to comment.