Skip to content

Commit

Permalink
Merge branch 'main' into hessian-linop
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 4, 2024
2 parents 58ab2c2 + 2ae34be commit d63a896
Show file tree
Hide file tree
Showing 16 changed files with 501 additions and 143 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ jobs:
python -m pip install --upgrade pip
make install-test
- name: Run test
if: contains('refs/heads/master refs/heads/development', github.ref)
if: contains('refs/heads/main', github.ref)
run: |
make test
- name: Run test-light
if: contains('refs/heads/master refs/heads/development', github.ref) != 1
if: contains('refs/heads/main', github.ref) != 1
run: |
make test-light
Expand Down
33 changes: 32 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,46 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added/New

### Fixed/Removed

### Internal

## [2.0.1] - 2024-10-25

Minor bug fixes and documentation polishing.

### Added/New

- Comparison of `eigsh` with power iteration in [eigenvalue
example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_eigenvalues.html#sphx-glr-basic-usage-example-eigenvalues-py)
([PR](https://github.com/f-dangel/curvlinops/pull/140))

### Fixed/Removed

- Deprecate Python 3.8 as it will reach its end of life in October 2024
([PR](https://github.com/f-dangel/curvlinops/pull/128))

- Improve `intersphinx` mapping to `curvlinops` objects
([issue](https://github.com/f-dangel/curvlinops/issues/138),
[PR](https://github.com/f-dangel/curvlinops/pull/141))

### Internal

- Update Github action versions and cache `pip`
([PR](https://github.com/f-dangel/curvlinops/pull/129))

- Re-activate Monte-Carlo tests, refactor, and reduce their run time
([PR](https://github.com/f-dangel/curvlinops/pull/131))

- Add more matrices in visual tour code example and prettify plots
([PR](https://github.com/f-dangel/curvlinops/pull/134))

- Prettify visualizations in [spectral density
example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_verification_spectral_density.html)
([PR](https://github.com/f-dangel/curvlinops/pull/139))

## [2.0.0] - 2024-08-15

This major release is almost fully backward compatible with the `1.x.y` release
Expand Down Expand Up @@ -295,7 +325,8 @@ Adds various new features:

Initial release

[Unreleased]: https://github.com/f-dangel/curvlinops/compare/2.0.0...HEAD
[Unreleased]: https://github.com/f-dangel/curvlinops/compare/2.0.1...HEAD
[2.0.1]: https://github.com/f-dangel/curvlinops/releases/tag/2.0.1
[2.0.0]: https://github.com/f-dangel/curvlinops/releases/tag/2.0.0
[1.2.0]: https://github.com/f-dangel/curvlinops/releases/tag/1.2.0
[1.1.0]: https://github.com/f-dangel/curvlinops/releases/tag/1.1.0
Expand Down
15 changes: 10 additions & 5 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from numpy import allclose, argwhere, float32, isclose, logical_not, ndarray
from numpy.random import rand
from scipy.sparse.linalg import LinearOperator
from torch import Tensor, cat
from torch import Tensor, as_tensor, bfloat16, cat
from torch import device as torch_device
from torch import from_numpy, tensor, zeros_like
from torch import tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module, Parameter
from tqdm import tqdm
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
self._loss_func = loss_func
self._data = data
self._device = self._infer_device(self._params)
(self._torch_dtype,) = {p.dtype for p in self._params}
self._progressbar = progressbar
self._batch_size_fn = (
(lambda X: X.shape[0]) if batch_size_fn is None else batch_size_fn
Expand Down Expand Up @@ -302,7 +303,7 @@ def _preprocess(self, M: ndarray) -> List[Tensor]:
M = M.astype(self.dtype)
num_vectors = M.shape[1]

result = from_numpy(M).to(self._device)
result = as_tensor(M, dtype=self._torch_dtype, device=self._device)
# split parameter blocks
dims = [p.numel() for p in self._params]
result = result.split(dims)
Expand All @@ -324,7 +325,11 @@ def _postprocess(self, M_list: List[Tensor]) -> ndarray:
concatenated dimensions over all list entries.
"""
result = [rearrange(M, "k ... -> (...) k") for M in M_list]
return cat(result, dim=0).cpu().numpy()
result = cat(result)
# calling .numpy() on a BF-16 tensor is not supported, see
# (https://github.com/pytorch/pytorch/issues/90574)
result = result.float() if result.dtype == bfloat16 else result
return result.cpu().numpy().astype(self.dtype)

def _loop_over_data(
self, desc: Optional[str] = None, add_device_to_desc: bool = True
Expand All @@ -340,7 +345,7 @@ def _loop_over_data(
Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)
data_iter = self._data

if self._progressbar:
desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}"
Expand Down
10 changes: 5 additions & 5 deletions curvlinops/_torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy
from scipy.sparse.linalg import LinearOperator
from torch import Size, Tensor, cat, device, dtype, from_numpy, rand, tensor, zeros_like
from torch import Size, Tensor, as_tensor, cat, device, dtype, rand, tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module, Parameter
from tqdm import tqdm
Expand All @@ -24,7 +24,7 @@ class PyTorchLinearOperator:
One main difference is that the linear operators cannot only multiply
vectors/matrices specified as single PyTorch tensors, but also
vectors/matrices specified in tensor list format. This is common in
PyTorch, where the space a linear operator acts on is a tensor product
PyTorch, where the space a linear operator acts on is a tensor product.
Functions that need to be implemented are ``_matmat`` and ``_adjoint``.
Expand Down Expand Up @@ -347,7 +347,7 @@ def f_scipy(X: numpy.ndarray) -> numpy.ndarray:
The output matrix in NumPy format.
"""
X_dtype = X.dtype
X_torch = from_numpy(X).to(device, dtype)
X_torch = as_tensor(X, dtype=dtype, device=device)
AX_torch = f(X_torch)
return AX_torch.detach().cpu().numpy().astype(X_dtype)

Expand Down Expand Up @@ -439,7 +439,7 @@ def __init__(
)

in_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape
out_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape
out_shape = [tuple(p.shape) for p in params] if out_shape is None else out_shape
super().__init__(in_shape, out_shape)

self._params = params
Expand Down Expand Up @@ -538,7 +538,7 @@ def _loop_over_data(
Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)
data_iter = self._data

if self._progressbar:
desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}"
Expand Down
11 changes: 6 additions & 5 deletions curvlinops/examples/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def linearized_loss(
return blocks_to_matrix(ggn_fn(X, y, anchor_dict, params_dict))


def functorch_gradient(
def functorch_gradient_and_loss(
model_func: Module,
loss_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
input_key: Optional[str] = None,
) -> Tuple[Tensor]:
"""Compute the gradient with functorch.
) -> Tuple[List[Tensor], Tensor]:
"""Compute the gradient and loss with functorch.
Args:
model_func: A function that maps the mini-batch input X to predictions.
Expand All @@ -171,7 +171,7 @@ def functorch_gradient(
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.
Returns:
Gradient in same format as the parameters.
Loss, and gradient in same format as the parameters.
"""
(dev,) = {p.device for p in params}
X, y = _concatenate_batches(data, input_key, device=dev)
Expand All @@ -190,8 +190,9 @@ def loss(

params_argnum = 2
grad_fn = grad(loss, argnums=params_argnum)
loss_value = loss(X, y, params_dict)

return tuple(grad_fn(X, y, params_dict).values())
return list(grad_fn(X, y, params_dict).values()), loss_value


def functorch_empirical_fisher(
Expand Down
4 changes: 4 additions & 0 deletions curvlinops/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def report_nonclose(
if allclose(array1, array2, rtol=rtol, atol=atol, equal_nan=equal_nan):
print("Compared arrays match.")
else:
nonclose_entries = 0
for a1, a2 in zip(array1.flatten(), array2.flatten()):
if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan):
print(f"{a1}{a2} (ratio {a1 / a2:.5f})")
nonclose_entries += 1
print(f"Max: {array1.max():.5f}, {array2.max():.5f}")
print(f"Min: {array1.min():.5f}, {array2.min():.5f}")
print(f"Nonclose entries: {nonclose_entries} / {array1.size}")
print(f"rtol = {rtol}, atol= {atol}")
raise ValueError("Compared arrays don't match.")
5 changes: 2 additions & 3 deletions curvlinops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ def allclose_report(
Args:
tensor1: First tensor for comparison.
tensor2: Second tensor for comparison.
rtol: Relative tolerance. Default: ``1e-5``.
atol: Absolute tolerance. Default: ``1e-8``.
rtol: Relative tolerance. Default is ``1e-5``.
atol: Absolute tolerance. Default is ``1e-8``.
Returns:
``True`` if the tensors are close, ``False`` otherwise.
"""
close = tensor1.allclose(tensor2, rtol=rtol, atol=atol)
if not close:
# print non-close values
nonclose_idx = tensor1.isclose(tensor2, rtol=rtol, atol=atol).logical_not_()
for idx, t1, t2 in zip(
nonclose_idx.argwhere(),
Expand Down
133 changes: 132 additions & 1 deletion docs/examples/basic_usage/example_eigenvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
As always, imports go first.
"""

from contextlib import redirect_stderr
from io import StringIO
from typing import List, Tuple

import numpy
import scipy
import torch
Expand Down Expand Up @@ -70,7 +74,7 @@
# :math:`k=3` eigenvalues.

k = 3
which = "LA" # largest algebraic
which = "LM" # largest magnitude
top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which)

print(f"Leading {k} Hessian eigenvalues: {top_k_evals}")
Expand Down Expand Up @@ -104,3 +108,130 @@
# :func:`scipy.sparse.linalg.eigsh` can also compute other subsets of
# eigenvalues, and also their associated eigenvectors. Check out its
# documentation for more!


# %%
#
# Power iteration versus ``eigsh``
# --------------------------------
#
# Here, we compare the query efficiency of :func:`scipy.sparse.linalg.eigsh` with the
# `power iteration <https://en.wikipedia.org/wiki/Power_iteration>`_ method, a simple
# method to compute the leading eigenvalues (in terms of magnitude). We re-use the im-
# plementation from the `PyHessian library <https://github.com/amirgholami/PyHessian>`_
# and adapt it to work with SciPy arrays rather than PyTorch tensors:


def power_method(
A: scipy.sparse.linalg.LinearOperator,
max_iterations: int = 100,
tol: float = 1e-3,
k: int = 1,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""Compute the top-k eigenpairs of a linear operator using power iteration.
Code modified from PyHessian, see
https://github.com/amirgholami/PyHessian/blob/72e5f0a0d06142387fccdab2226b4c6bae088202/pyhessian/hessian.py#L111-L156
Args:
A: Linear operator of dimension ``D`` whose top eigenpairs will be computed.
max_iterations: Maximum number of iterations. Defaults to ``100``.
tol: Relative tolerance between two consecutive iterations that has to be
reached for convergence. Defaults to ``1e-3``.
k: Number of eigenpairs to compute. Defaults to ``1``.
Returns:
The eigenvalues as array of shape ``[k]`` in descending order, and their
corresponding eigenvectors as array of shape ``[D, k]``.
"""
eigenvalues = []
eigenvectors = []

def normalize(v: numpy.ndarray) -> numpy.ndarray:
return v / numpy.linalg.norm(v)

def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarray:
for basis_vector in basis:
v -= numpy.dot(v, basis_vector) * basis_vector
return normalize(v)

computed_dim = 0
while computed_dim < k:
eigenvalue = None
v = normalize(numpy.random.randn(A.shape[0]))

for _ in range(max_iterations):
v = orthonormalize(v, eigenvectors)
Av = A @ v

tmp_eigenvalue = v.dot(Av)
v = normalize(Av)

if eigenvalue is None:
eigenvalue = tmp_eigenvalue
else:
if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol:
break
else:
eigenvalue = tmp_eigenvalue

eigenvalues.append(eigenvalue)
eigenvectors.append(v)
computed_dim += 1

# sort in ascending order and convert into arrays
eigenvalues = numpy.array(eigenvalues[::-1])
eigenvectors = numpy.array(eigenvectors[::-1])

return eigenvalues, eigenvectors


# %%
#
# Let's compute the top-3 eigenvalues via power iteration and verify they roughly match.
# Note that we are using a smaller :code:`tol` value than the PyHessian default value
# here to get better convergence, and we have to use relatively large tolerances for the
# comparison (which we didn't do when comparing :code:`eigsh` with :code:`eigh`).

top_k_evals_power, _ = power_method(H, tol=1e-4, k=k)
print(f"Comparing leading {k} Hessian eigenvalues (eigsh vs. power).")
report_nonclose(top_k_evals_functorch, top_k_evals_power, rtol=2e-2, atol=1e-6)

# %%
#
# This indicates that the power method achieves poorer accuracy than :code:`eigsh`. But
# does it therefore require fewer matrix-vector products? To answer this, let's turn on
# the linear operator's progress bar, which allows us to count the number of
# matrix-vector products invoked by both eigen-solvers:

H = HessianLinearOperator(model, loss_function, params, data, progressbar=True)

# determine number of matrix-vector products used by `eigsh`
with StringIO() as buf, redirect_stderr(buf):
top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which)
# The tqdm progressbar will print "matmat" for each batch in a matrix-vector
# product. Therefore, we need to divide by the number of batches
queries_eigsh = buf.getvalue().count("matmat") // len(data)
print(f"eigsh used {queries_eigsh} matrix-vector products.")

# determine number of matrix-vector products used by power iteration
with StringIO() as buf, redirect_stderr(buf):
top_k_evals_power, _ = power_method(H, k=k)
# The tqdm progressbar will print "matmat" for each batch in a matrix-vector
# product. Therefore, we need to divide by the number of batches
queries_power = buf.getvalue().count("matmat") // len(data)
print(f"Power iteration used {queries_power} matrix-vector products.")

assert queries_power > queries_eigsh

# %%
#
# Sadly, the power iteration also does not offer computational benefits, consuming
# more matrix-vector products than :code:`eigsh`. While it is elegant and simple,
# it cannot compete with :code:`eigsh`, at least in the comparison provided here
# (note that we used a relative small tolerance for the power iteration, and it will
# likely deteriorate further if we decrease the tolerance).
#
# Therefore, we recommend using :code:`eigsh` for computing eigenvalues. This method
# becomes accessible because :code:`curvlinops` interfaces with SciPy's linear
# operators.
Loading

0 comments on commit d63a896

Please sign in to comment.