diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2707c83..a49757d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,11 +29,11 @@ jobs: python -m pip install --upgrade pip make install-test - name: Run test - if: contains('refs/heads/master refs/heads/development', github.ref) + if: contains('refs/heads/main', github.ref) run: | make test - name: Run test-light - if: contains('refs/heads/master refs/heads/development', github.ref) != 1 + if: contains('refs/heads/main', github.ref) != 1 run: | make test-light diff --git a/changelog.md b/changelog.md index d644aaa..92c0f79 100644 --- a/changelog.md +++ b/changelog.md @@ -6,16 +6,46 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added/New + +### Fixed/Removed + +### Internal + +## [2.0.1] - 2024-10-25 + +Minor bug fixes and documentation polishing. + +### Added/New + +- Comparison of `eigsh` with power iteration in [eigenvalue + example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_eigenvalues.html#sphx-glr-basic-usage-example-eigenvalues-py) + ([PR](https://github.com/f-dangel/curvlinops/pull/140)) + ### Fixed/Removed - Deprecate Python 3.8 as it will reach its end of life in October 2024 ([PR](https://github.com/f-dangel/curvlinops/pull/128)) +- Improve `intersphinx` mapping to `curvlinops` objects + ([issue](https://github.com/f-dangel/curvlinops/issues/138), + [PR](https://github.com/f-dangel/curvlinops/pull/141)) + ### Internal - Update Github action versions and cache `pip` ([PR](https://github.com/f-dangel/curvlinops/pull/129)) +- Re-activate Monte-Carlo tests, refactor, and reduce their run time + ([PR](https://github.com/f-dangel/curvlinops/pull/131)) + +- Add more matrices in visual tour code example and prettify plots + ([PR](https://github.com/f-dangel/curvlinops/pull/134)) + +- Prettify visualizations in [spectral density + example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_verification_spectral_density.html) + ([PR](https://github.com/f-dangel/curvlinops/pull/139)) + ## [2.0.0] - 2024-08-15 This major release is almost fully backward compatible with the `1.x.y` release @@ -295,7 +325,8 @@ Adds various new features: Initial release -[Unreleased]: https://github.com/f-dangel/curvlinops/compare/2.0.0...HEAD +[Unreleased]: https://github.com/f-dangel/curvlinops/compare/2.0.1...HEAD +[2.0.1]: https://github.com/f-dangel/curvlinops/releases/tag/2.0.1 [2.0.0]: https://github.com/f-dangel/curvlinops/releases/tag/2.0.0 [1.2.0]: https://github.com/f-dangel/curvlinops/releases/tag/1.2.0 [1.1.0]: https://github.com/f-dangel/curvlinops/releases/tag/1.1.0 diff --git a/curvlinops/_base.py b/curvlinops/_base.py index ca57916..f910b11 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -8,9 +8,9 @@ from numpy import allclose, argwhere, float32, isclose, logical_not, ndarray from numpy.random import rand from scipy.sparse.linalg import LinearOperator -from torch import Tensor, cat +from torch import Tensor, as_tensor, bfloat16, cat from torch import device as torch_device -from torch import from_numpy, tensor, zeros_like +from torch import tensor, zeros_like from torch.autograd import grad from torch.nn import Module, Parameter from tqdm import tqdm @@ -117,6 +117,7 @@ def __init__( self._loss_func = loss_func self._data = data self._device = self._infer_device(self._params) + (self._torch_dtype,) = {p.dtype for p in self._params} self._progressbar = progressbar self._batch_size_fn = ( (lambda X: X.shape[0]) if batch_size_fn is None else batch_size_fn @@ -302,7 +303,7 @@ def _preprocess(self, M: ndarray) -> List[Tensor]: M = M.astype(self.dtype) num_vectors = M.shape[1] - result = from_numpy(M).to(self._device) + result = as_tensor(M, dtype=self._torch_dtype, device=self._device) # split parameter blocks dims = [p.numel() for p in self._params] result = result.split(dims) @@ -324,7 +325,11 @@ def _postprocess(self, M_list: List[Tensor]) -> ndarray: concatenated dimensions over all list entries. """ result = [rearrange(M, "k ... -> (...) k") for M in M_list] - return cat(result, dim=0).cpu().numpy() + result = cat(result) + # calling .numpy() on a BF-16 tensor is not supported, see + # (https://github.com/pytorch/pytorch/issues/90574) + result = result.float() if result.dtype == bfloat16 else result + return result.cpu().numpy().astype(self.dtype) def _loop_over_data( self, desc: Optional[str] = None, add_device_to_desc: bool = True @@ -340,7 +345,7 @@ def _loop_over_data( Yields: Mini-batches ``(X, y)``. """ - data_iter = iter(self._data) + data_iter = self._data if self._progressbar: desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}" diff --git a/curvlinops/_torch_base.py b/curvlinops/_torch_base.py index 97c4aaf..4b52efd 100644 --- a/curvlinops/_torch_base.py +++ b/curvlinops/_torch_base.py @@ -6,7 +6,7 @@ import numpy from scipy.sparse.linalg import LinearOperator -from torch import Size, Tensor, cat, device, dtype, from_numpy, rand, tensor, zeros_like +from torch import Size, Tensor, as_tensor, cat, device, dtype, rand, tensor, zeros_like from torch.autograd import grad from torch.nn import Module, Parameter from tqdm import tqdm @@ -24,7 +24,7 @@ class PyTorchLinearOperator: One main difference is that the linear operators cannot only multiply vectors/matrices specified as single PyTorch tensors, but also vectors/matrices specified in tensor list format. This is common in - PyTorch, where the space a linear operator acts on is a tensor product + PyTorch, where the space a linear operator acts on is a tensor product. Functions that need to be implemented are ``_matmat`` and ``_adjoint``. @@ -35,7 +35,6 @@ class PyTorchLinearOperator: Attributes: SELF_ADJOINT: Whether the linear operator is self-adjoint. If ``True``, ``_adjoint`` does not need to be implemented. Default: ``False``. - """ SELF_ADJOINT: bool = False @@ -114,17 +113,6 @@ def adjoint(self) -> PyTorchLinearOperator: """ return self if self.SELF_ADJOINT else self._adjoint() - def _adjoint(self) -> PyTorchLinearOperator: - """Adjoint of the linear operator. - - Returns: # noqa: D402 - The adjoint of the linear operator. - - Raises: - NotImplementedError: Must be implemented by the subclass. - """ - raise NotImplementedError - def _check_input_and_preprocess( self, X: Union[List[Tensor], Tensor] ) -> Tuple[List[Tensor], bool, bool, int]: @@ -353,7 +341,7 @@ def f_scipy(X: numpy.ndarray) -> numpy.ndarray: The output matrix in NumPy format. """ X_dtype = X.dtype - X_torch = from_numpy(X).to(device, dtype) + X_torch = as_tensor(X, dtype=dtype, device=device) AX_torch = f(X_torch) return AX_torch.detach().cpu().numpy().astype(X_dtype) @@ -445,7 +433,7 @@ def __init__( ) in_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape - out_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape + out_shape = [tuple(p.shape) for p in params] if out_shape is None else out_shape super().__init__(in_shape, out_shape) self._params = params @@ -544,7 +532,7 @@ def _loop_over_data( Yields: Mini-batches ``(X, y)``. """ - data_iter = iter(self._data) + data_iter = self._data if self._progressbar: desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}" diff --git a/curvlinops/examples/functorch.py b/curvlinops/examples/functorch.py index 77f6ed6..58eefa3 100644 --- a/curvlinops/examples/functorch.py +++ b/curvlinops/examples/functorch.py @@ -151,14 +151,14 @@ def linearized_loss( return blocks_to_matrix(ggn_fn(X, y, anchor_dict, params_dict)) -def functorch_gradient( +def functorch_gradient_and_loss( model_func: Module, loss_func: Module, params: List[Tensor], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], input_key: Optional[str] = None, -) -> Tuple[Tensor]: - """Compute the gradient with functorch. +) -> Tuple[List[Tensor], Tensor]: + """Compute the gradient and loss with functorch. Args: model_func: A function that maps the mini-batch input X to predictions. @@ -171,7 +171,7 @@ def functorch_gradient( input_key: Key to obtain the input tensor when ``X`` is a dict-like object. Returns: - Gradient in same format as the parameters. + Loss, and gradient in same format as the parameters. """ (dev,) = {p.device for p in params} X, y = _concatenate_batches(data, input_key, device=dev) @@ -190,8 +190,9 @@ def loss( params_argnum = 2 grad_fn = grad(loss, argnums=params_argnum) + loss_value = loss(X, y, params_dict) - return tuple(grad_fn(X, y, params_dict).values()) + return list(grad_fn(X, y, params_dict).values()), loss_value def functorch_empirical_fisher( diff --git a/curvlinops/examples/utils.py b/curvlinops/examples/utils.py index e22fe4a..08241c2 100644 --- a/curvlinops/examples/utils.py +++ b/curvlinops/examples/utils.py @@ -31,9 +31,13 @@ def report_nonclose( if allclose(array1, array2, rtol=rtol, atol=atol, equal_nan=equal_nan): print("Compared arrays match.") else: + nonclose_entries = 0 for a1, a2 in zip(array1.flatten(), array2.flatten()): if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan): print(f"{a1} ≠ {a2} (ratio {a1 / a2:.5f})") + nonclose_entries += 1 print(f"Max: {array1.max():.5f}, {array2.max():.5f}") print(f"Min: {array1.min():.5f}, {array2.min():.5f}") + print(f"Nonclose entries: {nonclose_entries} / {array1.size}") + print(f"rtol = {rtol}, atol= {atol}") raise ValueError("Compared arrays don't match.") diff --git a/curvlinops/utils.py b/curvlinops/utils.py index 2bda099..698f236 100644 --- a/curvlinops/utils.py +++ b/curvlinops/utils.py @@ -36,15 +36,14 @@ def allclose_report( Args: tensor1: First tensor for comparison. tensor2: Second tensor for comparison. - rtol: Relative tolerance. Default: ``1e-5``. - atol: Absolute tolerance. Default: ``1e-8``. + rtol: Relative tolerance. Default is ``1e-5``. + atol: Absolute tolerance. Default is ``1e-8``. Returns: ``True`` if the tensors are close, ``False`` otherwise. """ close = tensor1.allclose(tensor2, rtol=rtol, atol=atol) if not close: - # print non-close values nonclose_idx = tensor1.isclose(tensor2, rtol=rtol, atol=atol).logical_not_() for idx, t1, t2 in zip( nonclose_idx.argwhere(), diff --git a/docs/examples/basic_usage/example_eigenvalues.py b/docs/examples/basic_usage/example_eigenvalues.py index daeeacf..5fdbccd 100644 --- a/docs/examples/basic_usage/example_eigenvalues.py +++ b/docs/examples/basic_usage/example_eigenvalues.py @@ -8,6 +8,10 @@ As always, imports go first. """ +from contextlib import redirect_stderr +from io import StringIO +from typing import List, Tuple + import numpy import scipy import torch @@ -70,7 +74,7 @@ # :math:`k=3` eigenvalues. k = 3 -which = "LA" # largest algebraic +which = "LM" # largest magnitude top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which) print(f"Leading {k} Hessian eigenvalues: {top_k_evals}") @@ -104,3 +108,130 @@ # :func:`scipy.sparse.linalg.eigsh` can also compute other subsets of # eigenvalues, and also their associated eigenvectors. Check out its # documentation for more! + + +# %% +# +# Power iteration versus ``eigsh`` +# -------------------------------- +# +# Here, we compare the query efficiency of :func:`scipy.sparse.linalg.eigsh` with the +# `power iteration `_ method, a simple +# method to compute the leading eigenvalues (in terms of magnitude). We re-use the im- +# plementation from the `PyHessian library `_ +# and adapt it to work with SciPy arrays rather than PyTorch tensors: + + +def power_method( + A: scipy.sparse.linalg.LinearOperator, + max_iterations: int = 100, + tol: float = 1e-3, + k: int = 1, +) -> Tuple[numpy.ndarray, numpy.ndarray]: + """Compute the top-k eigenpairs of a linear operator using power iteration. + + Code modified from PyHessian, see + https://github.com/amirgholami/PyHessian/blob/72e5f0a0d06142387fccdab2226b4c6bae088202/pyhessian/hessian.py#L111-L156 + + Args: + A: Linear operator of dimension ``D`` whose top eigenpairs will be computed. + max_iterations: Maximum number of iterations. Defaults to ``100``. + tol: Relative tolerance between two consecutive iterations that has to be + reached for convergence. Defaults to ``1e-3``. + k: Number of eigenpairs to compute. Defaults to ``1``. + + Returns: + The eigenvalues as array of shape ``[k]`` in descending order, and their + corresponding eigenvectors as array of shape ``[D, k]``. + """ + eigenvalues = [] + eigenvectors = [] + + def normalize(v: numpy.ndarray) -> numpy.ndarray: + return v / numpy.linalg.norm(v) + + def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarray: + for basis_vector in basis: + v -= numpy.dot(v, basis_vector) * basis_vector + return normalize(v) + + computed_dim = 0 + while computed_dim < k: + eigenvalue = None + v = normalize(numpy.random.randn(A.shape[0])) + + for _ in range(max_iterations): + v = orthonormalize(v, eigenvectors) + Av = A @ v + + tmp_eigenvalue = v.dot(Av) + v = normalize(Av) + + if eigenvalue is None: + eigenvalue = tmp_eigenvalue + else: + if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol: + break + else: + eigenvalue = tmp_eigenvalue + + eigenvalues.append(eigenvalue) + eigenvectors.append(v) + computed_dim += 1 + + # sort in ascending order and convert into arrays + eigenvalues = numpy.array(eigenvalues[::-1]) + eigenvectors = numpy.array(eigenvectors[::-1]) + + return eigenvalues, eigenvectors + + +# %% +# +# Let's compute the top-3 eigenvalues via power iteration and verify they roughly match. +# Note that we are using a smaller :code:`tol` value than the PyHessian default value +# here to get better convergence, and we have to use relatively large tolerances for the +# comparison (which we didn't do when comparing :code:`eigsh` with :code:`eigh`). + +top_k_evals_power, _ = power_method(H, tol=1e-4, k=k) +print(f"Comparing leading {k} Hessian eigenvalues (eigsh vs. power).") +report_nonclose(top_k_evals_functorch, top_k_evals_power, rtol=2e-2, atol=1e-6) + +# %% +# +# This indicates that the power method achieves poorer accuracy than :code:`eigsh`. But +# does it therefore require fewer matrix-vector products? To answer this, let's turn on +# the linear operator's progress bar, which allows us to count the number of +# matrix-vector products invoked by both eigen-solvers: + +H = HessianLinearOperator( + model, loss_function, params, data, progressbar=True +).to_scipy() + +# determine number of matrix-vector products used by `eigsh` +with StringIO() as buf, redirect_stderr(buf): + top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which) + # The tqdm progressbar will print "matmat" for each batch in a matrix-vector + # product. Therefore, we need to divide by the number of batches + queries_eigsh = buf.getvalue().count("matmat") // len(data) +print(f"eigsh used {queries_eigsh} matrix-vector products.") + +# determine number of matrix-vector products used by power iteration +with StringIO() as buf, redirect_stderr(buf): + top_k_evals_power, _ = power_method(H, k=k, tol=1e-4) + # The tqdm progressbar will print "matmat" for each batch in a matrix-vector + # product. Therefore, we need to divide by the number of batches + queries_power = buf.getvalue().count("matmat") // len(data) +print(f"Power iteration used {queries_power} matrix-vector products.") + +assert queries_power > queries_eigsh + +# %% +# +# Sadly, the power iteration also does not offer computational benefits, consuming +# more matrix-vector products than :code:`eigsh`. While it is elegant and simple, +# it cannot compete with :code:`eigsh`, at least in the comparison provided here. +# +# Therefore, we recommend using :code:`eigsh` for computing eigenvalues. This method +# becomes accessible because :code:`curvlinops` interfaces with SciPy's linear +# operators. diff --git a/docs/examples/basic_usage/example_inverses.py b/docs/examples/basic_usage/example_inverses.py index 74df335..30b6a3d 100644 --- a/docs/examples/basic_usage/example_inverses.py +++ b/docs/examples/basic_usage/example_inverses.py @@ -33,7 +33,7 @@ GGNLinearOperator, NeumannInverseLinearOperator, ) -from curvlinops.examples.functorch import functorch_ggn, functorch_gradient +from curvlinops.examples.functorch import functorch_ggn, functorch_gradient_and_loss from curvlinops.examples.utils import report_nonclose # make deterministic @@ -70,7 +70,7 @@ loss_function = nn.MSELoss(reduction="mean").to(DEVICE) -# % +# %% # # Next, let's compute the ingredients for the natural gradient. # @@ -151,7 +151,7 @@ # Next, let's compute the gradient with :code:`functorch`, using a utility # function from :code:`curvlinops.examples`: -gradient_functorch = functorch_gradient(model, loss_function, params, data) +gradient_functorch, _ = functorch_gradient_and_loss(model, loss_function, params, data) # convert to numpy (vector) format gradient_functorch = ( nn.utils.parameters_to_vector(gradient_functorch).detach().cpu().numpy() diff --git a/docs/examples/basic_usage/example_verification_spectral_density.py b/docs/examples/basic_usage/example_verification_spectral_density.py index 13e1bcd..8779481 100644 --- a/docs/examples/basic_usage/example_verification_spectral_density.py +++ b/docs/examples/basic_usage/example_verification_spectral_density.py @@ -9,11 +9,14 @@ Here are the imports: """ +from os import getenv + import matplotlib.pyplot as plt from numpy import e, exp, linspace, log, logspace, matmul, ndarray, zeros from numpy.linalg import eigh from numpy.random import pareto, randn, seed from scipy.sparse.linalg import aslinearoperator, eigsh +from tueplots import bundles from curvlinops.outer import OuterProductLinearOperator from curvlinops.papyan2020traces.spectrum import ( @@ -23,6 +26,10 @@ lanczos_approximate_spectrum, ) +# LaTeX is not available in Github actions. +# Therefore, we are turning it off if the script executes on GHA. +USETEX = not getenv("CI") + seed(0) # %% @@ -85,7 +92,7 @@ def create_matrix(dim: int = 2000) -> ndarray: # and using the same hyperparameters as specified by the paper: # spectral density hyperparameters -num_points = 1024 +num_points = 200 ncv = 128 num_repeats = 10 kappa = 3 @@ -118,20 +125,33 @@ def create_matrix(dim: int = 2000) -> ndarray: # and plot it with a histogram (same number of bins as in the paper) of the # exact density: -plt.figure() -plt.xlabel("Eigenvalue") -plt.ylabel("Spectral density") - -left, right = grid[0], grid[-1] -num_bins = 100 -bins = linspace(left, right, num_bins, endpoint=True) -plt.hist(Y_evals, bins=bins, log=True, density=True, label="Exact") +# use `tueplots` to make the plot look pretty +plot_config = bundles.icml2024(column="half", usetex=USETEX) + +with plt.rc_context(plot_config): + plt.figure() + plt.xlabel(r"Eigenvalue $\lambda$") + plt.ylabel(r"Spectral density $\rho(\lambda)$") + + left, right = grid[0], grid[-1] + num_bins = 40 + bins = linspace(left, right, num_bins, endpoint=True) + plt.hist( + Y_evals, + bins=bins, + log=True, + density=True, + label="Exact", + edgecolor="white", + lw=0.5, + ) -plt.plot(grid, density, label="Approximate") -plt.legend() + plt.plot(grid, density, label="Approximate") + plt.legend() -# same ylimits as in the paper -plt.ylim(bottom=1e-5, top=1e1) + # same ylimits as in the paper + plt.ylim(bottom=1e-5, top=1e1) + plt.savefig("toy_spectrum.pdf", bbox_inches="tight") # %% # @@ -153,27 +173,39 @@ def create_matrix(dim: int = 2000) -> ndarray: # Let's try out different values for :code:`kappa`: kappas = [1.1, 3, 10.0] -fig, ax = plt.subplots(ncols=len(kappas), figsize=(12, 3), sharex=True, sharey=True) - cache = LanczosApproximateSpectrumCached(Y_linop, ncv, boundaries) -for idx, kappa in enumerate(kappas): - grid, density = cache.approximate_spectrum( - num_repeats=num_repeats, num_points=num_points, kappa=kappa, margin=margin - ) - - ax[idx].hist(Y_evals, bins=bins, log=True, density=True, label="Exact") - ax[idx].plot(grid, density, label=rf"$\kappa = {kappa}$") - ax[idx].legend() - - ax[idx].set_xlabel("Eigenvalue") - ax[idx].set_ylabel("Spectral density") - ax[idx].set_ylim(bottom=1e-5, top=1e1) +# use `tueplots` to make the plot look pretty +plot_config = bundles.icml2024(column="full", ncols=len(kappas), usetex=USETEX) + +with plt.rc_context(plot_config): + fig, ax = plt.subplots(ncols=len(kappas), sharex=True, sharey=True) + for idx, kappa in enumerate(kappas): + grid, density = cache.approximate_spectrum( + num_repeats=num_repeats, num_points=num_points, kappa=kappa, margin=margin + ) + + ax[idx].hist( + Y_evals, + bins=bins, + log=True, + density=True, + label="Exact", + edgecolor="white", + lw=0.5, + ) + ax[idx].plot(grid, density, label=rf"$\kappa = {kappa}$") + ax[idx].legend() + + ax[idx].set_xlabel(r"Eigenvalue $\lambda$") + if idx == 0: + ax[idx].set_ylabel(r"Spectral density $\rho(\lambda)$") + ax[idx].set_ylim(bottom=1e-5, top=1e1) # %% # -# Wit rank deflation -# ^^^^^^^^^^^^^^^^^^ +# With rank deflation +# ^^^^^^^^^^^^^^^^^^^ # # As you can see in the above plot, the spectrum consists of a bulk and three # outliers. We can project out the three (or in general :code:`k`) outliers to @@ -207,25 +239,37 @@ def create_matrix(dim: int = 2000) -> ndarray: # # Here is the visualization, with outliers marked separately: -plt.figure() -plt.title(f"With rank deflation (top {k})") -plt.xlabel("Eigenvalue") -plt.ylabel("Spectral density") - -plt.hist(Y_evals, bins=bins, log=True, density=True, label="Exact") -plt.plot(grid_no_top, density_no_top, label="Approximate (deflated)") - -plt.plot( - Y_top_evals, - len(Y_top_evals) * [1 / Y_linop.shape[0]], - linestyle="", - marker="o", - label=f"Top {k}", -) +# use `tueplots` to make the plot look pretty +plot_config = bundles.icml2024(column="half", usetex=USETEX) + +with plt.rc_context(plot_config): + plt.figure() + plt.title(f"With rank deflation (top {k})") + plt.xlabel(r"Eigenvalue $\lambda$") + plt.ylabel(r"Spectral density $\rho(\lambda)$") + + plt.hist( + Y_evals, + bins=bins, + log=True, + density=True, + label="Exact", + edgecolor="white", + lw=0.5, + ) + plt.plot(grid_no_top, density_no_top, label="Approximate (deflated)") + + plt.plot( + Y_top_evals, + len(Y_top_evals) * [1 / Y_linop.shape[0]], + linestyle="", + marker="o", + label=f"Top {k}", + ) -# same ylimits as in the paper -plt.ylim(bottom=1e-5, top=1e1) -plt.legend() + # same ylimits as in the paper + plt.ylim(bottom=1e-5, top=1e1) + plt.legend() # %% # @@ -288,7 +332,7 @@ def create_matrix_log_spectrum(dim: int = 500) -> ndarray: # and using the same hyperparameters as specified by the paper: # spectral density hyperparameters -num_points = 1024 +num_points = 200 margin = 0.05 ncv = 256 num_repeats = 10 @@ -319,29 +363,43 @@ def create_matrix_log_spectrum(dim: int = 500) -> ndarray: # # Now we can visualize the results: -plt.figure() -plt.xlabel("Eigenvalue") -plt.ylabel("Spectral density") - -Y_log_abs_evals = log(abs(Y_evals) + epsilon) - -xlimits_no_margin = (Y_log_abs_evals.min(), Y_log_abs_evals.max()) -width_no_margins = xlimits_no_margin[1] - xlimits_no_margin[0] -xlimits = [ - xlimits_no_margin[0] - margin * width_no_margins, - xlimits_no_margin[1] + margin * width_no_margins, -] +# use `tueplots` to make the plot look pretty +plot_config = bundles.icml2024(column="half", usetex=USETEX) + +with plt.rc_context(plot_config): + plt.figure() + plt.xlabel(r"Absolute eigenvalue $\nu = |\lambda| + \epsilon$") + plt.ylabel(r"Spectral density $\rho(\log\nu)$") + + Y_log_abs_evals = log(abs(Y_evals) + epsilon) + + xlimits_no_margin = (Y_log_abs_evals.min(), Y_log_abs_evals.max()) + width_no_margins = xlimits_no_margin[1] - xlimits_no_margin[0] + xlimits = [ + xlimits_no_margin[0] - margin * width_no_margins, + xlimits_no_margin[1] + margin * width_no_margins, + ] + + plt.semilogx() + num_bins = 40 + bins = logspace(*xlimits, num=num_bins, endpoint=True, base=e) + plt.hist( + exp(Y_log_abs_evals), + bins=bins, + log=True, + density=True, + label="Exact", + edgecolor="white", + lw=0.5, + ) -plt.semilogx() -num_bins = 100 -bins = logspace(*xlimits, num=num_bins, endpoint=True, base=e) -plt.hist(exp(Y_log_abs_evals), bins=bins, log=True, density=True, label="Exact") + plt.plot(grid, density, label="Approximate") -plt.plot(grid, density, label="Approximate") + # use same ylimits as in the paper + plt.ylim(bottom=1e-14, top=1e-2) + plt.legend() -# use same ylimits as in the paper -plt.ylim(bottom=1e-14, top=1e-2) -plt.legend() + plt.savefig("toy_log_spectrum.pdf", bbox_inches="tight") # %% # @@ -355,26 +413,36 @@ def create_matrix_log_spectrum(dim: int = 500) -> ndarray: # # Let's try out different values for :code:`kappa`: -plt.close() - kappas = [1.01, 1.1, 3] -fig, ax = plt.subplots(ncols=len(kappas), figsize=(12, 3), sharex=True, sharey=True) - cache = LanczosApproximateLogSpectrumCached(Y_linop, ncv, boundaries) -for idx, kappa in enumerate(kappas): - grid, density = cache.approximate_log_spectrum( - num_repeats=num_repeats, - num_points=num_points, - kappa=kappa, - margin=margin, - epsilon=epsilon, - ) - - ax[idx].hist(exp(Y_log_abs_evals), bins=bins, log=True, density=True, label="Exact") - ax[idx].loglog(grid, density, label=rf"$\kappa = {kappa}$") - ax[idx].legend() - - ax[idx].set_xlabel("Eigenvalue") - ax[idx].set_ylabel("Spectral density") - ax[idx].set_ylim(bottom=1e-14, top=1e-2) +# use `tueplots` to make the plot look pretty +plot_config = bundles.icml2024(column="full", ncols=len(kappas), usetex=USETEX) + +with plt.rc_context(plot_config): + fig, ax = plt.subplots(ncols=len(kappas), sharex=True, sharey=True) + for idx, kappa in enumerate(kappas): + grid, density = cache.approximate_log_spectrum( + num_repeats=num_repeats, + num_points=num_points, + kappa=kappa, + margin=margin, + epsilon=epsilon, + ) + + ax[idx].hist( + exp(Y_log_abs_evals), + bins=bins, + log=True, + density=True, + label="Exact", + edgecolor="white", + lw=0.5, + ) + ax[idx].loglog(grid, density, label=rf"$\kappa = {kappa}$") + ax[idx].legend() + + ax[idx].set_xlabel(r"Absolute eigenvalue $\nu = |\lambda| + \epsilon$") + if idx == 0: + ax[idx].set_ylabel(r"Spectral density $\rho(\log \nu)$") + ax[idx].set_ylim(bottom=1e-14, top=1e-2) diff --git a/docs/examples/basic_usage/example_visual_tour.py b/docs/examples/basic_usage/example_visual_tour.py index 0461d27..5947f54 100644 --- a/docs/examples/basic_usage/example_visual_tour.py +++ b/docs/examples/basic_usage/example_visual_tour.py @@ -16,8 +16,15 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure from torch import nn - -from curvlinops import EFLinearOperator, GGNLinearOperator, HessianLinearOperator +from tueplots import bundles + +from curvlinops import ( + EFLinearOperator, + FisherMCLinearOperator, + GGNLinearOperator, + HessianLinearOperator, + KFACLinearOperator, +) # make deterministic torch.manual_seed(0) @@ -61,6 +68,7 @@ num_params_layer = [ sum(p.numel() for p in child.parameters()) for child in model.children() ] +num_tensors_layer = [len(list(child.parameters())) for child in model.children()] loss_function = nn.CrossEntropyLoss(reduction="mean").to(DEVICE) @@ -82,6 +90,17 @@ ).to_scipy() GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader).to_scipy() EF_linop = EFLinearOperator(model, loss_function, params, dataloader) +Hessian_blocked_linop = HessianLinearOperator( + model, + loss_function, + params, + dataloader, + block_sizes=[s for s in num_tensors_layer if s != 0], +).to_scipy() +F_linop = FisherMCLinearOperator(model, loss_function, params, dataloader) +KFAC_linop = KFACLinearOperator( + model, loss_function, params, dataloader, separate_weight_and_bias=False +) # %% # @@ -92,6 +111,9 @@ Hessian_mat = Hessian_linop @ identity GGN_mat = GGN_linop @ identity EF_mat = EF_linop @ identity +Hessian_blocked_mat = Hessian_blocked_linop @ identity +F_mat = F_linop @ identity +KFAC_mat = KFAC_linop @ identity # %% # Visualization @@ -99,11 +121,17 @@ # # We will show the matrix entries on a shared domain for better comparability. -matrices = [Hessian_mat, GGN_mat, EF_mat] -titles = ["Hessian", "GGN", "Empirical Fisher"] +matrices = [Hessian_mat, GGN_mat, EF_mat, Hessian_blocked_mat, F_mat, KFAC_mat] +titles = [ + "Hessian", + "Generalized Gauss-Newton", + "Empirical Fisher", + "Block-diagonal Hessian", + "Monte-Carlo Fisher", + "KFAC", +] -rows, columns = 1, 3 -img_width = 7 +rows, columns = 2, 3 def plot( @@ -123,21 +151,34 @@ def plot( min_value = min(transform(mat).min() for mat in matrices) max_value = max(transform(mat).max() for mat in matrices) - fig, axes = plt.subplots( - nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width) - ) + fig, axes = plt.subplots(nrows=rows, ncols=columns, sharex=True, sharey=True) + fig.supxlabel("Layer") + fig.supylabel("Layer") for idx, (ax, mat, title) in enumerate(zip(axes.flat, matrices, titles)): ax.set_title(title) img = ax.imshow(transform(mat), vmin=min_value, vmax=max_value) - # layer structure - for pos in numpy.cumsum(num_params_layer): + # layer blocks + boundaries = [0] + numpy.cumsum(num_params_layer).tolist() + for pos in boundaries: if pos not in [0, num_params]: - style = {"color": "w", "lw": 0.5, "ls": "--"} + style = {"color": "w", "lw": 0.5, "ls": "-"} ax.axhline(y=pos - 1, xmin=0, xmax=num_params - 1, **style) ax.axvline(x=pos - 1, ymin=0, ymax=num_params - 1, **style) + # label positions + label_positions = [ + (boundaries[layer_idx] + boundaries[layer_idx + 1]) / 2 + for layer_idx in range(len(boundaries) - 1) + if boundaries[layer_idx] != boundaries[layer_idx + 1] + ] + labels = [str(i + 1) for i in range(len(label_positions))] + ax.set_xticks(label_positions) + ax.set_xticklabels(labels) + ax.set_yticks(label_positions) + ax.set_yticklabels(labels) + # colorbar last = idx == len(matrices) - 1 if last: @@ -148,16 +189,21 @@ def plot( return fig, axes +# use `tueplots` to make the plot look pretty +plot_config = bundles.icml2024(column="full", nrows=1.5 * rows, ncols=columns) + # %% # # We will show their logarithmic absolute value: -def logabs(mat, epsilon=1e-5): - return numpy.log10(numpy.abs(mat) + epsilon) +def logabs(mat, epsilon=1e-6): + return numpy.log10(numpy.clip(numpy.abs(mat), a_min=epsilon, a_max=None)) -plot(logabs, transform_title="Logarithmic absolute entries") +with plt.rc_context(plot_config): + plot(logabs, transform_title="Logarithmic absolute entries") + plt.savefig("curvature_matrices_log_abs.pdf", bbox_inches="tight") # %% # @@ -168,7 +214,8 @@ def unchanged(mat): return mat -plot(unchanged, transform_title="Unaltered matrix entries") +with plt.rc_context(plot_config): + plot(unchanged, transform_title="Unaltered matrix entries") # %% # diff --git a/docs/rtd/conf.py b/docs/rtd/conf.py index bfbc193..53e9768 100644 --- a/docs/rtd/conf.py +++ b/docs/rtd/conf.py @@ -37,7 +37,7 @@ "sphinx.ext.autosectionlabel", "sphinx.ext.intersphinx", "sphinx_gallery.gen_gallery", - 'sphinx.ext.viewcode', # show source code links + "sphinx.ext.viewcode", # show source code links ] # -- Intersphinx config ----------------------------------------------------- @@ -48,6 +48,7 @@ "scipy": ("http://docs.scipy.org/doc/scipy/reference/", None), "numpy": ("http://docs.scipy.org/doc/numpy/", None), "matplotlib": ("https://matplotlib.org/stable/", None), + "curvlinops": ("https://curvlinops.readthedocs.io/en/latest/", None), } # -- Sphinx Gallery config --------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 1164203..4aac13e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ requires-python = ">=3.9" # Dependencies needed to run the tests. test = [ "matplotlib", + "tueplots", "coveralls", "pytest", "pytest-cov", @@ -84,6 +85,7 @@ docs = [ "matplotlib", "sphinx-gallery", "sphinx-rtd-theme", + "tueplots" ] ############################################################################### diff --git a/test/test__torch_base.py b/test/test__torch_base.py index dd4821f..744697a 100644 --- a/test/test__torch_base.py +++ b/test/test__torch_base.py @@ -5,7 +5,9 @@ from pytest import raises from torch import Tensor, zeros -from curvlinops._torch_base import PyTorchLinearOperator +from curvlinops._torch_base import CurvatureLinearOperator, PyTorchLinearOperator +from curvlinops.examples.functorch import functorch_gradient_and_loss +from curvlinops.utils import allclose_report def test_input_formatting(): @@ -80,3 +82,28 @@ def test_preserve_input_format(): X = zeros(26, 6) # matrix in tensor format IdX = Id @ X assert IdX.allclose(X) + + +def test_gradient_and_loss(case): + """Test the gradient and loss computation over a data loader.""" + model, loss_func, params, data, batch_size_fn = case + + linop = CurvatureLinearOperator( + model, + loss_func, + params, + data, + # turn off because this would trigger the un-implemented `matmat` + check_deterministic=False, + batch_size_fn=batch_size_fn, + ) + gradient, loss = linop.gradient_and_loss() + + gradient_functorch, loss_functorch = functorch_gradient_and_loss( + model, loss_func, params, data, input_key="x" + ) + + assert allclose_report(loss, loss_functorch) + assert len(gradient) == len(gradient_functorch) + for g, g_functorch in zip(gradient, gradient_functorch): + assert allclose_report(g, g_functorch) diff --git a/test/test_fisher.py b/test/test_fisher.py index 4fa598e..3bfa332 100644 --- a/test/test_fisher.py +++ b/test/test_fisher.py @@ -1,7 +1,7 @@ """Contains tests for ``curvlinops/fisher.py``.""" from collections.abc import MutableMapping -from contextlib import suppress +from contextlib import redirect_stdout, suppress from numpy import random, zeros_like from pytest import mark, raises @@ -10,11 +10,11 @@ from curvlinops.examples.functorch import functorch_ggn from curvlinops.examples.utils import report_nonclose -MAX_REPEATS_MC_SAMPLES = [(1_000_000, 1), (10_000, 100)] +MAX_REPEATS_MC_SAMPLES = [(10_000, 1), (100, 100)] MAX_REPEATS_MC_SAMPLES_IDS = [ f"max_repeats={n}-mc_samples={m}" for (n, m) in MAX_REPEATS_MC_SAMPLES ] -CHECK_EVERY = 1_000 +CHECK_EVERY = 100 @mark.montecarlo @@ -58,7 +58,8 @@ def test_LinearOperator_matvec_expectation( Gx = G_functorch @ x Fx = zeros_like(x) - atol, rtol = 1e-5, 1e-1 + atol = 5e-3 * max(abs(Gx)) + rtol = 1e-1 for m in range(max_repeats): Fx += F @ x @@ -66,9 +67,8 @@ def test_LinearOperator_matvec_expectation( total_samples = (m + 1) * mc_samples if total_samples % CHECK_EVERY == 0: - with suppress(ValueError): + with redirect_stdout(None), suppress(ValueError): report_nonclose(Fx / (m + 1), Gx, rtol=rtol, atol=atol) - print(f"Converged after {m} iterations") return report_nonclose(Fx / max_repeats, Gx, rtol=rtol, atol=atol) @@ -104,7 +104,8 @@ def test_LinearOperator_matmat_expectation( GX = G_functorch @ X FX = zeros_like(X) - atol, rtol = 1e-5, 1e-1 + atol = 5e-3 * max(abs(GX.flatten())) + rtol = 1.5e-1 for m in range(max_repeats): FX += F @ X @@ -112,9 +113,8 @@ def test_LinearOperator_matmat_expectation( total_samples = (m + 1) * mc_samples if total_samples % CHECK_EVERY == 0: - with suppress(ValueError): + with redirect_stdout(None), suppress(ValueError): report_nonclose(FX / (m + 1), GX, rtol=rtol, atol=atol) - print(f"Converged after {m} iterations") return - report_nonclose(FX, GX, rtol=rtol, atol=atol) + report_nonclose(FX / max_repeats, GX, rtol=rtol, atol=atol) diff --git a/test/test_hessian.py b/test/test_hessian.py index c6efe03..0727a4d 100644 --- a/test/test_hessian.py +++ b/test/test_hessian.py @@ -24,9 +24,9 @@ def test_HessianLinearOperator( Args: case: Tuple of model, loss function, parameters, data, and batch size getter. adjoint: Whether to test the adjoint operator. - is_vec: Whether to test matrix-vector or matrix-matrix multiplication. block_sizes_fn: The function that generates the block sizes used to define block diagonal approximations from the parameters. + is_vec: Whether to test matrix-vector or matrix-matrix multiplication. """ model_func, loss_func, params, data, batch_size_fn = case block_sizes = block_sizes_fn(params) diff --git a/test/test_kfac.py b/test/test_kfac.py index e7f8d61..8b2a7f7 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -16,13 +16,23 @@ from einops import rearrange from einops.layers.torch import Rearrange -from numpy import eye +from numpy import eye, random from numpy.linalg import det, norm, slogdet from pytest import mark, raises, skip from scipy.linalg import block_diag from torch import Tensor, allclose, cat, cuda, device from torch import eye as torch_eye -from torch import isinf, isnan, load, manual_seed, rand, rand_like, randperm, save +from torch import ( + float64, + isinf, + isnan, + load, + manual_seed, + rand, + rand_like, + randperm, + save, +) from torch.nn import ( BCEWithLogitsLoss, CrossEntropyLoss, @@ -1305,3 +1315,35 @@ def test_string_in_enum(fisher_type: str, kfac_approx: str): fisher_type=fisher_type, kfac_approx=kfac_approx, ) + + +@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) +def test_bug_132_dtype_deterministic_checks(dev: device): + """Test whether the vectors used in the deterministic checks have correct data type. + + This bug was reported in https://github.com/f-dangel/curvlinops/issues/132. + + Args: + dev: The device to run the test on. + """ + # make deterministic + manual_seed(0) + random.seed(0) + + # create a toy problem, load everything to float64 + dt = float64 + N = 4 + D_in = 3 + D_out = 2 + + X = rand(N, D_in, dtype=dt, device=dev) + y = rand(N, D_out, dtype=dt, device=dev) + data = [(X, y)] + + model = Linear(D_in, D_out).to(dev, dt) + params = [p for p in model.parameters() if p.requires_grad] + + loss_func = MSELoss().to(dev, dt) + + # run deterministic checks + KFACLinearOperator(model, loss_func, params, data, check_deterministic=True)