Skip to content

Commit

Permalink
Merge branch 'main' into ggn-linop
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 4, 2024
2 parents 8b6fed7 + 5e118e9 commit 6ef7099
Show file tree
Hide file tree
Showing 17 changed files with 502 additions and 156 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ jobs:
python -m pip install --upgrade pip
make install-test
- name: Run test
if: contains('refs/heads/master refs/heads/development', github.ref)
if: contains('refs/heads/main', github.ref)
run: |
make test
- name: Run test-light
if: contains('refs/heads/master refs/heads/development', github.ref) != 1
if: contains('refs/heads/main', github.ref) != 1
run: |
make test-light
Expand Down
33 changes: 32 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,46 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added/New

### Fixed/Removed

### Internal

## [2.0.1] - 2024-10-25

Minor bug fixes and documentation polishing.

### Added/New

- Comparison of `eigsh` with power iteration in [eigenvalue
example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_eigenvalues.html#sphx-glr-basic-usage-example-eigenvalues-py)
([PR](https://github.com/f-dangel/curvlinops/pull/140))

### Fixed/Removed

- Deprecate Python 3.8 as it will reach its end of life in October 2024
([PR](https://github.com/f-dangel/curvlinops/pull/128))

- Improve `intersphinx` mapping to `curvlinops` objects
([issue](https://github.com/f-dangel/curvlinops/issues/138),
[PR](https://github.com/f-dangel/curvlinops/pull/141))

### Internal

- Update Github action versions and cache `pip`
([PR](https://github.com/f-dangel/curvlinops/pull/129))

- Re-activate Monte-Carlo tests, refactor, and reduce their run time
([PR](https://github.com/f-dangel/curvlinops/pull/131))

- Add more matrices in visual tour code example and prettify plots
([PR](https://github.com/f-dangel/curvlinops/pull/134))

- Prettify visualizations in [spectral density
example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_verification_spectral_density.html)
([PR](https://github.com/f-dangel/curvlinops/pull/139))

## [2.0.0] - 2024-08-15

This major release is almost fully backward compatible with the `1.x.y` release
Expand Down Expand Up @@ -295,7 +325,8 @@ Adds various new features:

Initial release

[Unreleased]: https://github.com/f-dangel/curvlinops/compare/2.0.0...HEAD
[Unreleased]: https://github.com/f-dangel/curvlinops/compare/2.0.1...HEAD
[2.0.1]: https://github.com/f-dangel/curvlinops/releases/tag/2.0.1
[2.0.0]: https://github.com/f-dangel/curvlinops/releases/tag/2.0.0
[1.2.0]: https://github.com/f-dangel/curvlinops/releases/tag/1.2.0
[1.1.0]: https://github.com/f-dangel/curvlinops/releases/tag/1.1.0
Expand Down
15 changes: 10 additions & 5 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from numpy import allclose, argwhere, float32, isclose, logical_not, ndarray
from numpy.random import rand
from scipy.sparse.linalg import LinearOperator
from torch import Tensor, cat
from torch import Tensor, as_tensor, bfloat16, cat
from torch import device as torch_device
from torch import from_numpy, tensor, zeros_like
from torch import tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module, Parameter
from tqdm import tqdm
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
self._loss_func = loss_func
self._data = data
self._device = self._infer_device(self._params)
(self._torch_dtype,) = {p.dtype for p in self._params}
self._progressbar = progressbar
self._batch_size_fn = (
(lambda X: X.shape[0]) if batch_size_fn is None else batch_size_fn
Expand Down Expand Up @@ -302,7 +303,7 @@ def _preprocess(self, M: ndarray) -> List[Tensor]:
M = M.astype(self.dtype)
num_vectors = M.shape[1]

result = from_numpy(M).to(self._device)
result = as_tensor(M, dtype=self._torch_dtype, device=self._device)
# split parameter blocks
dims = [p.numel() for p in self._params]
result = result.split(dims)
Expand All @@ -324,7 +325,11 @@ def _postprocess(self, M_list: List[Tensor]) -> ndarray:
concatenated dimensions over all list entries.
"""
result = [rearrange(M, "k ... -> (...) k") for M in M_list]
return cat(result, dim=0).cpu().numpy()
result = cat(result)
# calling .numpy() on a BF-16 tensor is not supported, see
# (https://github.com/pytorch/pytorch/issues/90574)
result = result.float() if result.dtype == bfloat16 else result
return result.cpu().numpy().astype(self.dtype)

def _loop_over_data(
self, desc: Optional[str] = None, add_device_to_desc: bool = True
Expand All @@ -340,7 +345,7 @@ def _loop_over_data(
Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)
data_iter = self._data

if self._progressbar:
desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}"
Expand Down
22 changes: 5 additions & 17 deletions curvlinops/_torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy
from scipy.sparse.linalg import LinearOperator
from torch import Size, Tensor, cat, device, dtype, from_numpy, rand, tensor, zeros_like
from torch import Size, Tensor, as_tensor, cat, device, dtype, rand, tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module, Parameter
from tqdm import tqdm
Expand All @@ -24,7 +24,7 @@ class PyTorchLinearOperator:
One main difference is that the linear operators cannot only multiply
vectors/matrices specified as single PyTorch tensors, but also
vectors/matrices specified in tensor list format. This is common in
PyTorch, where the space a linear operator acts on is a tensor product
PyTorch, where the space a linear operator acts on is a tensor product.
Functions that need to be implemented are ``_matmat`` and ``_adjoint``.
Expand All @@ -35,7 +35,6 @@ class PyTorchLinearOperator:
Attributes:
SELF_ADJOINT: Whether the linear operator is self-adjoint. If ``True``,
``_adjoint`` does not need to be implemented. Default: ``False``.
"""

SELF_ADJOINT: bool = False
Expand Down Expand Up @@ -114,17 +113,6 @@ def adjoint(self) -> PyTorchLinearOperator:
"""
return self if self.SELF_ADJOINT else self._adjoint()

def _adjoint(self) -> PyTorchLinearOperator:
"""Adjoint of the linear operator.
Returns: # noqa: D402
The adjoint of the linear operator.
Raises:
NotImplementedError: Must be implemented by the subclass.
"""
raise NotImplementedError

def _check_input_and_preprocess(
self, X: Union[List[Tensor], Tensor]
) -> Tuple[List[Tensor], bool, bool, int]:
Expand Down Expand Up @@ -353,7 +341,7 @@ def f_scipy(X: numpy.ndarray) -> numpy.ndarray:
The output matrix in NumPy format.
"""
X_dtype = X.dtype
X_torch = from_numpy(X).to(device, dtype)
X_torch = as_tensor(X, dtype=dtype, device=device)
AX_torch = f(X_torch)
return AX_torch.detach().cpu().numpy().astype(X_dtype)

Expand Down Expand Up @@ -445,7 +433,7 @@ def __init__(
)

in_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape
out_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape
out_shape = [tuple(p.shape) for p in params] if out_shape is None else out_shape
super().__init__(in_shape, out_shape)

self._params = params
Expand Down Expand Up @@ -544,7 +532,7 @@ def _loop_over_data(
Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)
data_iter = self._data

if self._progressbar:
desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}"
Expand Down
11 changes: 6 additions & 5 deletions curvlinops/examples/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def linearized_loss(
return blocks_to_matrix(ggn_fn(X, y, anchor_dict, params_dict))


def functorch_gradient(
def functorch_gradient_and_loss(
model_func: Module,
loss_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
input_key: Optional[str] = None,
) -> Tuple[Tensor]:
"""Compute the gradient with functorch.
) -> Tuple[List[Tensor], Tensor]:
"""Compute the gradient and loss with functorch.
Args:
model_func: A function that maps the mini-batch input X to predictions.
Expand All @@ -171,7 +171,7 @@ def functorch_gradient(
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.
Returns:
Gradient in same format as the parameters.
Loss, and gradient in same format as the parameters.
"""
(dev,) = {p.device for p in params}
X, y = _concatenate_batches(data, input_key, device=dev)
Expand All @@ -190,8 +190,9 @@ def loss(

params_argnum = 2
grad_fn = grad(loss, argnums=params_argnum)
loss_value = loss(X, y, params_dict)

return tuple(grad_fn(X, y, params_dict).values())
return list(grad_fn(X, y, params_dict).values()), loss_value


def functorch_empirical_fisher(
Expand Down
4 changes: 4 additions & 0 deletions curvlinops/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def report_nonclose(
if allclose(array1, array2, rtol=rtol, atol=atol, equal_nan=equal_nan):
print("Compared arrays match.")
else:
nonclose_entries = 0
for a1, a2 in zip(array1.flatten(), array2.flatten()):
if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan):
print(f"{a1}{a2} (ratio {a1 / a2:.5f})")
nonclose_entries += 1
print(f"Max: {array1.max():.5f}, {array2.max():.5f}")
print(f"Min: {array1.min():.5f}, {array2.min():.5f}")
print(f"Nonclose entries: {nonclose_entries} / {array1.size}")
print(f"rtol = {rtol}, atol= {atol}")
raise ValueError("Compared arrays don't match.")
5 changes: 2 additions & 3 deletions curvlinops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ def allclose_report(
Args:
tensor1: First tensor for comparison.
tensor2: Second tensor for comparison.
rtol: Relative tolerance. Default: ``1e-5``.
atol: Absolute tolerance. Default: ``1e-8``.
rtol: Relative tolerance. Default is ``1e-5``.
atol: Absolute tolerance. Default is ``1e-8``.
Returns:
``True`` if the tensors are close, ``False`` otherwise.
"""
close = tensor1.allclose(tensor2, rtol=rtol, atol=atol)
if not close:
# print non-close values
nonclose_idx = tensor1.isclose(tensor2, rtol=rtol, atol=atol).logical_not_()
for idx, t1, t2 in zip(
nonclose_idx.argwhere(),
Expand Down
Loading

0 comments on commit 6ef7099

Please sign in to comment.