From f3d4bd92ecb9ab683e138f466c423aa5d6619e6c Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 17 Jul 2023 15:30:45 -0400 Subject: [PATCH 01/10] [ADD] Implement and test `_adjoint` --- curvlinops/fisher.py | 12 ++++++++++++ curvlinops/ggn.py | 12 ++++++++++++ curvlinops/gradient_moments.py | 12 ++++++++++++ curvlinops/hessian.py | 12 ++++++++++++ curvlinops/outer.py | 12 ++++++++++++ test/cases.py | 2 ++ test/conftest.py | 7 ++++++- test/test_fisher.py | 18 ++++++++++++------ test/test_ggn.py | 24 ++++++++++++++---------- test/test_gradient_moments.py | 24 ++++++++++++++---------- test/test_hessian.py | 24 ++++++++++++++---------- 11 files changed, 122 insertions(+), 37 deletions(-) diff --git a/curvlinops/fisher.py b/curvlinops/fisher.py index adfad2d..86264f9 100644 --- a/curvlinops/fisher.py +++ b/curvlinops/fisher.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of the (approximate) Fisher.""" +from __future__ import annotations + from math import sqrt from typing import Callable, Iterable, List, Tuple, Union @@ -270,3 +272,13 @@ def sample_grad_output(self, output: Tensor) -> Tensor: else: raise NotImplementedError(f"Supported losses: {self.supported_losses}") + + def _adjoint(self) -> FisherMCLinearOperator: + """Return the linear operator representing the adjoint. + + The Fisher MC-approximation is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/ggn.py b/curvlinops/ggn.py index 4b542fc..8e901bd 100644 --- a/curvlinops/ggn.py +++ b/curvlinops/ggn.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of the GGN.""" +from __future__ import annotations + from typing import List, Tuple from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist @@ -54,3 +56,13 @@ def _matvec_batch( output = self._model_func(X) loss = self._loss_func(output, y) return ggn_vector_product_from_plist(loss, output, self._params, x_list) + + def _adjoint(self) -> GGNLinearOperator: + """Return the linear operator representing the adjoint. + + The GGN is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/gradient_moments.py b/curvlinops/gradient_moments.py index 2b25d72..7dbd0ea 100644 --- a/curvlinops/gradient_moments.py +++ b/curvlinops/gradient_moments.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of gradient moment matrices.""" +from __future__ import annotations + from typing import List, Tuple from torch import Tensor, autograd, einsum, zeros_like @@ -76,3 +78,13 @@ def _matvec_batch( raise ValueError("Loss must have reduction 'mean' or 'sum'.") return tuple(r / normalization for r in result_list) + + def _adjoint(self) -> EFLinearOperator: + """Return the linear operator representing the adjoint. + + The empirical Fisher is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/hessian.py b/curvlinops/hessian.py index 6a932c6..de1b44e 100644 --- a/curvlinops/hessian.py +++ b/curvlinops/hessian.py @@ -1,5 +1,7 @@ """Contains LinearOperator implementation of the Hessian.""" +from __future__ import annotations + from typing import List, Tuple from backpack.hessianfree.hvp import hessian_vector_product @@ -45,3 +47,13 @@ def _matvec_batch( """ loss = self._loss_func(self._model_func(X), y) return hessian_vector_product(loss, self._params, x_list) + + def _adjoint(self) -> HessianLinearOperator: + """Return the linear operator representing the adjoint. + + The Hessian is real symmetric, and hence self-adjoint. + + Returns: + Self. + """ + return self diff --git a/curvlinops/outer.py b/curvlinops/outer.py index eec6a3e..a87c676 100644 --- a/curvlinops/outer.py +++ b/curvlinops/outer.py @@ -1,5 +1,7 @@ """Utility linear operators.""" +from __future__ import annotations + from numpy import einsum, einsum_path, ndarray, ones from scipy.sparse.linalg import LinearOperator @@ -42,6 +44,16 @@ def _matvec(self, x: ndarray) -> ndarray: """ return einsum(self._equation, *self._operands, x, optimize=self._path) + def _adjoint(self) -> OuterProductLinearOperator: + """Return the linear operator representing the adjoint. + + An outer product is self-adjoint. + + Returns: + Self. + """ + return self + class Projector(OuterProductLinearOperator): """Linear operator for the projector onto the orthonormal basis ``{ aᵢ }``.""" diff --git a/test/cases.py b/test/cases.py index bc475fd..071b332 100644 --- a/test/cases.py +++ b/test/cases.py @@ -171,3 +171,5 @@ def data(): for device in DEVICES: case_with_device = {**case, "device": device} NON_DETERMINISTIC_CASES.append(case_with_device) + +ADJOINT_CASES = [False, True] diff --git a/test/conftest.py b/test/conftest.py index 5ab5b2b..408074e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,6 @@ """Contains pytest fixtures that are visible by other files.""" -from test.cases import CASES, NON_DETERMINISTIC_CASES +from test.cases import ADJOINT_CASES, CASES, NON_DETERMINISTIC_CASES from typing import Callable, Dict, Iterable, List, Tuple from numpy import random @@ -51,3 +51,8 @@ def non_deterministic_case( ]: case = request.param yield initialize_case(case) + + +@fixture(params=ADJOINT_CASES) +def adjoint(request) -> bool: + return request.param diff --git a/test/test_fisher.py b/test/test_fisher.py index 7fafe09..c0f4e97 100644 --- a/test/test_fisher.py +++ b/test/test_fisher.py @@ -20,11 +20,15 @@ @mark.parametrize( "max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS ) -def test_LinearOperator_matvec_expectation(case, max_repeats: int, mc_samples: int): +def test_LinearOperator_matvec_expectation( + case, adjoint: bool, max_repeats: int, mc_samples: int +): F = FisherMCLinearOperator(*case, mc_samples=mc_samples) - x = random.rand(F.shape[1]).astype(F.dtype) - G_functorch = functorch_ggn(*case).detach().cpu().numpy() + if adjoint: + F, G_functorch = F.adjoint(), G_functorch.conj().T + + x = random.rand(F.shape[1]).astype(F.dtype) Gx = G_functorch @ x Fx = zeros_like(x) @@ -49,12 +53,14 @@ def test_LinearOperator_matvec_expectation(case, max_repeats: int, mc_samples: i "max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS ) def test_LinearOperator_matmat_expectation( - case, max_repeats: int, mc_samples: int, num_vecs: int = 2 + case, adjoint: bool, max_repeats: int, mc_samples: int, num_vecs: int = 2 ): F = FisherMCLinearOperator(*case, mc_samples=mc_samples) - X = random.rand(F.shape[1], num_vecs).astype(F.dtype) - G_functorch = functorch_ggn(*case).detach().cpu().numpy() + if adjoint: + F, G_functorch = F.adjoint(), G_functorch.conj().T + + X = random.rand(F.shape[1], num_vecs).astype(F.dtype) GX = G_functorch @ X FX = zeros_like(X) diff --git a/test/test_ggn.py b/test/test_ggn.py index 72721f2..3d58f3d 100644 --- a/test/test_ggn.py +++ b/test/test_ggn.py @@ -7,25 +7,29 @@ from curvlinops.examples.utils import report_nonclose -def test_GGNLinearOperator_matvec(case): +def test_GGNLinearOperator_matvec(case, adjoint: bool): model_func, loss_func, params, data = case - GGN = GGNLinearOperator(model_func, loss_func, params, data) - GGN_functorch = ( + op = GGNLinearOperator(model_func, loss_func, params, data) + op_functorch = ( functorch_ggn(model_func, loss_func, params, data).detach().cpu().numpy() ) + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - x = random.rand(GGN.shape[1]) - report_nonclose(GGN @ x, GGN_functorch @ x) + x = random.rand(op.shape[1]) + report_nonclose(op @ x, op_functorch @ x) -def test_GGNLinearOperator_matmat(case, num_vecs: int = 3): +def test_GGNLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): model_func, loss_func, params, data = case - GGN = GGNLinearOperator(model_func, loss_func, params, data) - GGN_functorch = ( + op = GGNLinearOperator(model_func, loss_func, params, data) + op_functorch = ( functorch_ggn(model_func, loss_func, params, data).detach().cpu().numpy() ) + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - X = random.rand(GGN.shape[1], num_vecs) - report_nonclose(GGN @ X, GGN_functorch @ X) + X = random.rand(op.shape[1], num_vecs) + report_nonclose(op @ X, op_functorch @ X) diff --git a/test/test_gradient_moments.py b/test/test_gradient_moments.py index f123372..0c97da0 100644 --- a/test/test_gradient_moments.py +++ b/test/test_gradient_moments.py @@ -7,17 +7,21 @@ from curvlinops.examples.utils import report_nonclose -def test_EFLinearOperator_matvec(case): - EF = EFLinearOperator(*case) - EF_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() +def test_EFLinearOperator_matvec(case, adjoint: bool): + op = EFLinearOperator(*case) + op_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - x = random.rand(EF.shape[1]).astype(EF.dtype) - report_nonclose(EF @ x, EF_functorch @ x) + x = random.rand(op.shape[1]).astype(op.dtype) + report_nonclose(op @ x, op_functorch @ x) -def test_EFLinearOperator_matmat(case, num_vecs: int = 3): - EF = EFLinearOperator(*case) - EF_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() +def test_EFLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): + op = EFLinearOperator(*case) + op_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - X = random.rand(EF.shape[1], num_vecs).astype(EF.dtype) - report_nonclose(EF @ X, EF_functorch @ X, atol=1e-7, rtol=1e-4) + X = random.rand(op.shape[1], num_vecs).astype(op.dtype) + report_nonclose(op @ X, op_functorch @ X, atol=1e-7, rtol=1e-4) diff --git a/test/test_hessian.py b/test/test_hessian.py index 91bb00e..5012b69 100644 --- a/test/test_hessian.py +++ b/test/test_hessian.py @@ -7,17 +7,21 @@ from curvlinops.examples.utils import report_nonclose -def test_HessianLinearOperator_matvec(case): - H = HessianLinearOperator(*case) - H_functorch = functorch_hessian(*case).detach().cpu().numpy() +def test_HessianLinearOperator_matvec(case, adjoint: bool): + op = HessianLinearOperator(*case) + op_functorch = functorch_hessian(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - x = random.rand(H.shape[1]) - report_nonclose(H @ x, H_functorch @ x) + x = random.rand(op.shape[1]) + report_nonclose(op @ x, op_functorch @ x) -def test_HessianLinearOperator_matmat(case, num_vecs: int = 3): - H = HessianLinearOperator(*case) - H_functorch = functorch_hessian(*case).detach().cpu().numpy() +def test_HessianLinearOperator_matmat(case, adjoint: bool, num_vecs: int = 3): + op = HessianLinearOperator(*case) + op_functorch = functorch_hessian(*case).detach().cpu().numpy() + if adjoint: + op, op_functorch = op.adjoint(), op_functorch.conj().T - X = random.rand(H.shape[1], num_vecs) - report_nonclose(H @ X, H_functorch @ X, atol=1e-6, rtol=5e-4) + X = random.rand(op.shape[1], num_vecs) + report_nonclose(op @ X, op_functorch @ X, atol=1e-6, rtol=5e-4) From 783e6444428ccdb7d903d145071dda1cd0ef1f57 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 09:34:19 -0400 Subject: [PATCH 02/10] [FIX] Use column dimension to create random vector --- curvlinops/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/curvlinops/_base.py b/curvlinops/_base.py index 102faee..f97bc00 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -154,7 +154,7 @@ def _check_deterministic(self): self.print_nonclose(grad1, grad2, rtol, atol) raise RuntimeError("Check for deterministic gradient failed.") - v = rand(self.shape[0]).astype(self.dtype) + v = rand(self.shape[1]).astype(self.dtype) mat_v1 = self @ v mat_v2 = self @ v From c68e3eee0bae2b754b8f221c0c662bdd7578d58d Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 09:36:45 -0400 Subject: [PATCH 03/10] [ADD] Implement model Jacobian --- curvlinops/__init__.py | 2 + curvlinops/examples/functorch.py | 66 ++++++++-- curvlinops/jacobian.py | 200 +++++++++++++++++++++++++++++++ docs/rtd/linops.rst | 6 + test/test_jacobian.py | 27 +++++ 5 files changed, 289 insertions(+), 12 deletions(-) create mode 100644 curvlinops/jacobian.py create mode 100644 test/test_jacobian.py diff --git a/curvlinops/__init__.py b/curvlinops/__init__.py index 7f7bd36..cccfca4 100644 --- a/curvlinops/__init__.py +++ b/curvlinops/__init__.py @@ -5,6 +5,7 @@ from curvlinops.gradient_moments import EFLinearOperator from curvlinops.hessian import HessianLinearOperator from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator +from curvlinops.jacobian import JacobianLinearOperator from curvlinops.papyan2020traces.spectrum import ( LanczosApproximateLogSpectrumCached, LanczosApproximateSpectrumCached, @@ -18,6 +19,7 @@ "GGNLinearOperator", "EFLinearOperator", "FisherMCLinearOperator", + "JacobianLinearOperator", "CGInverseLinearOperator", "NeumannInverseLinearOperator", "SubmatrixLinearOperator", diff --git a/curvlinops/examples/functorch.py b/curvlinops/examples/functorch.py index 3e38854..36d43dd 100644 --- a/curvlinops/examples/functorch.py +++ b/curvlinops/examples/functorch.py @@ -5,6 +5,7 @@ from functorch import grad, hessian, jvp, make_functional, vmap from torch import Tensor, cat, einsum +from torch.func import jacrev from torch.nn import Module @@ -55,9 +56,7 @@ def functorch_hessian( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) def loss(X: Tensor, y: Tensor, params: Tuple[Tensor]) -> Tensor: """Compute the loss given a mini-batch and the neural network parameters. @@ -100,9 +99,7 @@ def functorch_ggn( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) def linearized_model( anchor: Tuple[Tensor], params: Tuple[Tensor], X: Tensor @@ -167,9 +164,7 @@ def functorch_gradient( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) def loss(X: Tensor, y: Tensor, params: Tuple[Tensor]) -> Tensor: """Compute the loss given a mini-batch and the neural network parameters. @@ -213,9 +208,7 @@ def functorch_empirical_fisher( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) # compute batched gradients def loss_n(X_n: Tensor, y_n: Tensor, params: List[Tensor]) -> Tensor: @@ -244,3 +237,52 @@ def loss_n(X_n: Tensor, y_n: Tensor, params: List[Tensor]) -> Tensor: raise ValueError("Cannot detect reduction method from loss function.") return 1 / normalization * einsum("ni,nj->ij", batch_grad, batch_grad) + + +def functorch_jacobian( + model_func: Module, + params: List[Tensor], + data: Iterable[Tuple[Tensor, Tensor]], +) -> Tensor: + """Compute the Jacobian with functorch. + + Args: + model_func: A function that maps the mini-batch input X to predictions. + Could be a PyTorch module representing a neural network. + params: List of differentiable parameters used by the prediction function. + data: Source from which mini-batches can be drawn, for instance a list of + mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. + + Returns: + Matrix containing the Jacobian. Has shape ``[N * C, D]`` where ``D`` is the + total number of parameters, ``N`` the total number of data points, and ``C`` + the model's output space dimension. + """ + model_fn, _ = make_functional(model_func) + X, _ = _concatenate_batches(data) + + def model_fn_params_only(params: Tuple[Tensor]) -> Tensor: + return model_fn(params, X) + + # concatenate over flattened parameters and flattened outputs + jac = jacrev(model_fn_params_only)(params) + jac = [j.flatten(start_dim=-p.dim()) for j, p in zip(jac, params)] + jac = cat(jac, dim=-1).flatten(end_dim=-2) + + return jac + + +def _concatenate_batches( + data: Iterable[Tuple[Tensor, Tensor]] +) -> Tuple[Tensor, Tensor]: + """Concatenate all batches in the dataset along the batch dimension. + + Args: + data: A dataloader or iterable of batches. + + Returns: + Concatenated model inputs. + Concatenated targets. + """ + X, y = list(zip(*list(data))) + return cat(X), cat(y) diff --git a/curvlinops/jacobian.py b/curvlinops/jacobian.py new file mode 100644 index 0000000..da490a4 --- /dev/null +++ b/curvlinops/jacobian.py @@ -0,0 +1,200 @@ +"""Implements linear operators for per-sample Jacobians.""" + +from typing import Callable, Iterable, List, Tuple + +from backpack.hessianfree.rop import jacobian_vector_product as jvp +from backpack.utils.convert_parameters import vector_to_parameter_list +from numpy import allclose, column_stack, float32, ndarray +from numpy.random import rand +from scipy.sparse.linalg import LinearOperator +from torch import Tensor, cat +from torch import device as torch_device +from torch import from_numpy, no_grad +from torch.nn import Module, Parameter +from tqdm import tqdm + +from curvlinops._base import _LinearOperator + + +class JacobianLinearOperator(LinearOperator): + """Linear operator for the Jacobian. + + Can be used with SciPy. + """ + + def __init__( + self, + model_func: Callable[[Tensor], Tensor], + params: List[Parameter], + data: Iterable[Tuple[Tensor, Tensor]], + progressbar: bool = False, + check_deterministic: bool = True, + ): + r"""Linear operator for the Jacobian as SciPy linear operator. + + Consider a model :math:`f(\mathbf{x}, \mathbf{\theta}): \mathbb{R}^M + \times \mathbb{R}^D \to \mathbb{R}^C` with parameters + :math:`\mathbf{\theta}` and input :math:`\mathbf{x}`. Assume we are + given a data set :math:`\mathcal{D} = \{ (\mathbf{x}_n, \mathbf{y}_n) + \}_{n=1}^N` of input-target pairs via batches. The model's Jacobian + :math:`\mathbf{J}_\mathbf{\theta}\mathbf{f}` is an :math:`NC \times D` + with elements + + .. math:: + \left[ + \mathbf{J}_\mathbf{\theta}\mathbf{f} + \right]_{(n,c), d} + = + \frac{\partial f(\mathbf{x}_n, \mathbf{\theta})}{\partial \theta_d}\,. + + Note that the data must be supplied in deterministic order. + + Args: + model_func: Neural network function. + params: Neural network parameters. + data: Iterable of batched input-target pairs. + progressbar: Show progress bar. + check_deterministic: Check if model and data are deterministic. + """ + num_data = sum(t.shape[0] for t, _ in data) + x = next(iter(data))[0] + num_outputs = model_func(x).shape[1:].numel() + num_params = sum(p.numel() for p in params) + super().__init__(shape=(num_data * num_outputs, num_params), dtype=float32) + + self._params = params + self._model_func = model_func + self._data = data + self._device = _LinearOperator._infer_device(self._params) + self._progressbar = progressbar + + if check_deterministic: + old_device = self._device + self.to_device(torch_device("cpu")) + try: + self._check_deterministic() + except RuntimeError as e: + raise e + finally: + self.to_device(old_device) + + def _check_deterministic(self): + """Verify that the linear operator is deterministic. + + - Checks that the data is loaded in a deterministic fashion (e.g. shuffling). + - Checks that the model is deterministic (e.g. dropout). + - Checks that matrix-vector multiplication with a single random vector is + deterministic. + + Note: + Deterministic checks are performed on CPU. We noticed that even when it + passes on CPU, it can fail on GPU; probably due to non-deterministic + operations. + + Raises: + RuntimeError: If the linear operator is not deterministic. + """ + print("Performing deterministic checks") + + pred1, y1 = self.predictions_and_targets() + pred1, y1 = pred1.cpu().numpy(), y1.cpu().numpy() + pred2, y2 = self.predictions_and_targets() + pred2, y2 = pred2.cpu().numpy(), y2.cpu().numpy() + + rtol, atol = 5e-5, 1e-6 + + if not allclose(y1, y2, rtol=rtol, atol=atol): + _LinearOperator.print_nonclose(y1, y2, rtol=rtol, atol=atol) + raise RuntimeError( + "Data is not loaded in a deterministic fashion." + + " Make sure shuffling is turned off." + ) + if not allclose(pred1, pred2, rtol=rtol, atol=atol): + _LinearOperator.print_nonclose(pred1, pred2, rtol=rtol, atol=atol) + raise RuntimeError( + "Model predictions are not deterministic." + + " Make sure dropout and batch normalization are in eval mode." + ) + + v = rand(self.shape[1]).astype(self.dtype) + mat_v1 = self @ v + mat_v2 = self @ v + if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): + _LinearOperator.print_nonclose(mat_v1, mat_v2, rtol, atol) + raise RuntimeError("Check for deterministic matvec failed.") + + def to_device(self, device: torch_device): + """Load linear operator to a device (inplace). + + Args: + device: Target device. + """ + self._device = device + + if isinstance(self._model_func, Module): + self._model_func = self._model_func.to(self._device) + self._params = [p.to(device) for p in self._params] + + def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]: + """Yield batches of the data set, loaded to the correct device. + + Yields: + Mini-batches ``(X, y)``. + """ + data_iter = iter(self._data) + + if self._progressbar: + data_iter = tqdm(data_iter, desc="matvec") + + for X, y in data_iter: + X, y = X.to(self._device), y.to(self._device) + yield (X, y) + + def predictions_and_targets(self) -> Tuple[Tensor, Tensor]: + """Return the batch-concatenated model predictions and labels. + + Returns: + Batch-concatenated model predictions of shape ``[N, *]`` where ``*`` + denotes the model's output shape (for instance ``* = C``). + Batch-concatenated labels of shape ``[N, *]``, where ``*`` denotes + the dimension of a label. + """ + total_pred, total_y = [], [] + + with no_grad(): + for X, y in self._loop_over_data(): + total_pred.append(self._model_func(X)) + total_y.append(y) + assert total_pred and total_y + + return cat(total_pred), cat(total_y) + + def _matvec(self, x: ndarray) -> ndarray: + """Loop over all batches in the data and apply the matrix to vector x. + + Args: + x: Vector for multiplication. Has shape ``[D]``. + + Returns: + Matrix-multiplication result ``mat @ x``. + """ + x_list = vector_to_parameter_list(from_numpy(x).to(self._device), self._params) + out_list = [ + jvp(self._model_func(X), self._params, x_list, retain_graph=False)[ + 0 + ].flatten(start_dim=1) + for X, _ in self._loop_over_data() + ] + + return cat(out_list).cpu().numpy() + + def _matmat(self, X: ndarray) -> ndarray: + """Matrix-matrix multiplication. + + Args: + X: Matrix for multiplication. + + Returns: + Matrix-multiplication result ``mat @ X``. + """ + return column_stack([self @ col for col in X.T]) diff --git a/docs/rtd/linops.rst b/docs/rtd/linops.rst index c630140..fb903e0 100644 --- a/docs/rtd/linops.rst +++ b/docs/rtd/linops.rst @@ -26,6 +26,12 @@ Uncentered gradient covariance (empirical Fisher) .. autoclass:: curvlinops.EFLinearOperator :members: __init__ +Jacobians +--------- + +.. autoclass:: curvlinops.JacobianLinearOperator + :members: __init__ + Inverses -------- diff --git a/test/test_jacobian.py b/test/test_jacobian.py new file mode 100644 index 0000000..ecec1ae --- /dev/null +++ b/test/test_jacobian.py @@ -0,0 +1,27 @@ +"""Contains tests for ``curvlinops/jacobian``.""" + +from numpy import random + +from curvlinops import JacobianLinearOperator +from curvlinops.examples.functorch import functorch_jacobian +from curvlinops.examples.utils import report_nonclose + + +def test_JacobianLinearOperator_matvec(case): + model_func, _, params, data = case + + op = JacobianLinearOperator(model_func, params, data) + op_functorch = functorch_jacobian(model_func, params, data).detach().cpu().numpy() + + x = random.rand(op.shape[1]) + report_nonclose(op @ x, op_functorch @ x) + + +def test_JacobianLinearOperator_matmat(case, num_vecs: int = 3): + model_func, _, params, data = case + + op = JacobianLinearOperator(model_func, params, data) + op_functorch = functorch_jacobian(model_func, params, data).detach().cpu().numpy() + + X = random.rand(op.shape[1], num_vecs) + report_nonclose(op @ X, op_functorch @ X) From 69789b3ae0dbe48d3fa81e32b2feb13f5db3e54a Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 09:38:51 -0400 Subject: [PATCH 04/10] [DOC] Mention Jacobian in README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index eec1ac0..0111d52 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ for deep learning matrices, such as - the Fisher/generalized Gauss-Newton (GGN) - the Monte-Carlo approximated Fisher - the uncentered gradient covariance (aka empirical Fisher) +- the output-parameter Jacobian of a neural net Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU. The library supports defining these matrices not only on a mini-batch, but From ff9fdd43ee80712f1c736517d126bedb28f8ca03 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 09:45:12 -0400 Subject: [PATCH 05/10] [REQ] Deprecate python 3.7 to use functorch from inside torch --- .github/workflows/lint-black.yaml | 4 ++-- .github/workflows/lint-darglint.yaml | 4 ++-- .github/workflows/lint-flake8.yaml | 4 ++-- .github/workflows/lint-isort.yaml | 4 ++-- .github/workflows/lint-pydocstyle.yaml | 4 ++-- .github/workflows/test.yaml | 2 +- README.md | 2 +- black.toml | 2 +- setup.cfg | 4 +++- 9 files changed, 16 insertions(+), 14 deletions(-) diff --git a/.github/workflows/lint-black.yaml b/.github/workflows/lint-black.yaml index 60830db..296b889 100644 --- a/.github/workflows/lint-black.yaml +++ b/.github/workflows/lint-black.yaml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-darglint.yaml b/.github/workflows/lint-darglint.yaml index 7efb71e..5ba0035 100644 --- a/.github/workflows/lint-darglint.yaml +++ b/.github/workflows/lint-darglint.yaml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-flake8.yaml b/.github/workflows/lint-flake8.yaml index 6d3ed55..dbfa163 100644 --- a/.github/workflows/lint-flake8.yaml +++ b/.github/workflows/lint-flake8.yaml @@ -15,10 +15,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-isort.yaml b/.github/workflows/lint-isort.yaml index d4a2b65..2b1344f 100644 --- a/.github/workflows/lint-isort.yaml +++ b/.github/workflows/lint-isort.yaml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-pydocstyle.yaml b/.github/workflows/lint-pydocstyle.yaml index 53397eb..2383898 100644 --- a/.github/workflows/lint-pydocstyle.yaml +++ b/.github/workflows/lint-pydocstyle.yaml @@ -17,10 +17,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 69a4e48..be64df7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,7 +18,7 @@ jobs: USING_COVERAGE: '3.8' strategy: matrix: - python-version: ["3.7", "3.8"] + python-version: ["3.8"] steps: - uses: actions/checkout@v1 - uses: actions/setup-python@v1 diff --git a/README.md b/README.md index 0111d52..450c9f1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Logo scipy linear operators of deep learning matrices in PyTorch [![Python -3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/) +3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/) ![tests](https://github.com/f-dangel/curvature-linear-operators/actions/workflows/test.yaml/badge.svg) [![Coveralls](https://coveralls.io/repos/github/f-dangel/curvlinops/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/curvlinops) diff --git a/black.toml b/black.toml index 911dcf1..94efa1b 100644 --- a/black.toml +++ b/black.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 88 -target-version = ['py36', 'py37', 'py38'] +target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' exclude = ''' ( diff --git a/setup.cfg b/setup.cfg index 905bb11..38958e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,8 @@ classifiers = Operating System :: OS Independent Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 [options] zip_safe = False @@ -37,7 +39,7 @@ install_requires = scipy>=1.7.1,<2.0.0 tqdm>=4.61.0,<5.0.0 # Require a specific Python version, e.g. Python 2.7 or >= 3.4 -python_requires = >=3.7 +python_requires = >=3.8 ############################################################################### # Development dependencies # From 08ca0b89cd0948a8bc4494081c176b7be39643eb Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 09:46:12 -0400 Subject: [PATCH 06/10] [DOC] Use python 3.8 to build the docs --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 0807394..4d7c441 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,7 +7,7 @@ sphinx: configuration: docs/rtd/conf.py python: - version: 3.7 + version: 3.8 install: - method: pip path: . From 8fba1159f3c41716d4a8dc5b56d353fba1c43a3b Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 09:51:15 -0400 Subject: [PATCH 07/10] [FIX] darglint --- curvlinops/jacobian.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/curvlinops/jacobian.py b/curvlinops/jacobian.py index da490a4..bb0e905 100644 --- a/curvlinops/jacobian.py +++ b/curvlinops/jacobian.py @@ -55,6 +55,9 @@ def __init__( data: Iterable of batched input-target pairs. progressbar: Show progress bar. check_deterministic: Check if model and data are deterministic. + + Raises: + RuntimeError: If deterministic checks are enables and fail. """ num_data = sum(t.shape[0] for t, _ in data) x = next(iter(data))[0] From 96c805e32ff33cc795127c8daaf7a87ef0db1b63 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 09:51:22 -0400 Subject: [PATCH 08/10] [REQ] Use torch>=2 for built-in functorch --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 38958e0..14c81aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,7 @@ setup_requires = # Dependencies of the project (semicolon/line-separated): install_requires = backpack-for-pytorch>=1.5.0,<2.0.0 + torch>=2.0 scipy>=1.7.1,<2.0.0 tqdm>=4.61.0,<5.0.0 # Require a specific Python version, e.g. Python 2.7 or >= 3.4 From a4d668a8228276529e44ec4cf90b31e73d87b071 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 15:14:12 -0400 Subject: [PATCH 09/10] [REF] Implement Jacobian via base class, remove print statements --- curvlinops/_base.py | 87 +++++++++++++++-------- curvlinops/jacobian.py | 152 ++++++++++------------------------------- 2 files changed, 93 insertions(+), 146 deletions(-) diff --git a/curvlinops/_base.py b/curvlinops/_base.py index f97bc00..7f332ee 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -1,6 +1,6 @@ """Contains functionality to analyze Hessian & GGN via matrix-free multiplication.""" -from typing import Callable, Iterable, List, Tuple +from typing import Callable, Iterable, List, Optional, Tuple, Union from backpack.utils.convert_parameters import vector_to_parameter_list from numpy import ( @@ -14,17 +14,16 @@ ) from numpy.random import rand from scipy.sparse.linalg import LinearOperator -from torch import Tensor +from torch import Tensor, cat from torch import device as torch_device from torch import from_numpy, tensor, zeros_like from torch.autograd import grad from torch.nn import Module, Parameter -from torch.nn.utils import parameters_to_vector from tqdm import tqdm class _LinearOperator(LinearOperator): - """Base class for linear operators of DNN curvature matrices. + """Base class for linear operators of DNN matrices. Can be used with SciPy. """ @@ -32,13 +31,14 @@ class _LinearOperator(LinearOperator): def __init__( self, model_func: Callable[[Tensor], Tensor], - loss_func: Callable[[Tensor, Tensor], Tensor], + loss_func: Union[Callable[[Tensor, Tensor], Tensor], None], params: List[Parameter], data: Iterable[Tuple[Tensor, Tensor]], progressbar: bool = False, check_deterministic: bool = True, + shape: Optional[Tuple[int, int]] = None, ): - """Linear operator for DNN curvature matrices. + """Linear operator for DNN matrices. Note: f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch @@ -49,10 +49,13 @@ def __init__( model_func: A function that maps the mini-batch input X to predictions. Could be a PyTorch module representing a neural network. loss_func: Loss function criterion. Maps predictions and mini-batch labels - to a scalar value. + to a scalar value. If ``None``, there is no loss function and the + represented matrix is independent of the loss function. params: List of differentiable parameters used by the prediction function. data: Source from which mini-batches can be drawn, for instance a list of mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. + shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)`` + where ``D`` is the total number of parameters progressbar: Show a progressbar during matrix-multiplication. Default: ``False``. check_deterministic: Probe that model and data are deterministic, i.e. @@ -64,8 +67,10 @@ def __init__( Raises: RuntimeError: If the check for deterministic behavior fails. """ - dim = sum(p.numel() for p in params) - super().__init__(shape=(dim, dim), dtype=float32) + if shape is None: + dim = sum(p.numel() for p in params) + shape = (dim, dim) + super().__init__(shape=shape, dtype=float32) self._params = params self._model_func = model_func @@ -74,7 +79,7 @@ def __init__( self._device = self._infer_device(self._params) self._progressbar = progressbar - self._N_data = sum(X.shape[0] for (X, _) in self._loop_over_data()) + self._num_data = sum(X.shape[0] for (X, _) in self._loop_over_data()) if check_deterministic: old_device = self._device @@ -129,22 +134,37 @@ def _check_deterministic(self): - Two independent loss/gradient computations yield different results Note: - Deterministic checks are performed on CPU. We noticed that even when it - passes on CPU, it can fail on GPU; probably due to non-deterministic + Deterministic checks should be performed on CPU. We noticed that even when + it passes on CPU, it can fail on GPU; probably due to non-deterministic operations. Raises: RuntimeError: If non-deterministic behavior is detected. """ - print("Performing deterministic checks") + v = rand(self.shape[1]).astype(self.dtype) + mat_v1 = self @ v + mat_v2 = self @ v + + rtol, atol = 5e-5, 1e-6 + if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): + self.print_nonclose(mat_v1, mat_v2, rtol, atol) + raise RuntimeError("Check for deterministic matvec failed.") + if self._loss_func is None: + return + + # only carried out if there is a loss function grad1, loss1 = self.gradient_and_loss() - grad1, loss1 = parameters_to_vector(grad1).cpu().numpy(), loss1.cpu().numpy() + grad1, loss1 = ( + self.flatten_and_concatenate(grad1).cpu().numpy(), + loss1.cpu().numpy(), + ) grad2, loss2 = self.gradient_and_loss() - grad2, loss2 = parameters_to_vector(grad2).cpu().numpy(), loss2.cpu().numpy() - - rtol, atol = 5e-5, 1e-6 + grad2, loss2 = ( + self.flatten_and_concatenate(grad2).cpu().numpy(), + loss2.cpu().numpy(), + ) if not allclose(loss1, loss2, rtol=rtol, atol=atol): self.print_nonclose(loss1, loss2, rtol, atol) @@ -154,16 +174,6 @@ def _check_deterministic(self): self.print_nonclose(grad1, grad2, rtol, atol) raise RuntimeError("Check for deterministic gradient failed.") - v = rand(self.shape[1]).astype(self.dtype) - mat_v1 = self @ v - mat_v2 = self @ v - - if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): - self.print_nonclose(mat_v1, mat_v2, rtol, atol) - raise RuntimeError("Check for deterministic matvec failed.") - - print("Deterministic checks passed") - @staticmethod def print_nonclose(array1: ndarray, array2: ndarray, rtol: float, atol: float): """Check if the two arrays are element-wise equal within a tolerance and print @@ -245,8 +255,7 @@ def _preprocess(self, x: ndarray) -> List[Tensor]: x_torch = from_numpy(x).to(self._device) return vector_to_parameter_list(x_torch, self._params) - @staticmethod - def _postprocess(x_list: List[Tensor]) -> ndarray: + def _postprocess(self, x_list: List[Tensor]) -> ndarray: """Convert torch list format to flat numpy array. Args: @@ -255,7 +264,7 @@ def _postprocess(x_list: List[Tensor]) -> ndarray: Returns: Flat vector. """ - return parameters_to_vector([x.contiguous() for x in x_list]).cpu().numpy() + return self.flatten_and_concatenate(x_list).cpu().numpy() def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]: """Yield batches of the data set, loaded to the correct device. @@ -279,7 +288,13 @@ def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]: Returns: Gradient and loss on the data set. + + Raises: + ValueError: If there is no loss function. """ + if self._loss_func is None: + raise ValueError("No loss function specified.") + total_loss = tensor([0.0], device=self._device) total_grad = [zeros_like(p) for p in self._params] @@ -317,3 +332,15 @@ def _get_normalization_factor(self, X: Tensor, y: Tensor) -> float: return X.shape[0] / self._N_data else: raise ValueError("Loss must have reduction 'mean' or 'sum'.") + + @staticmethod + def flatten_and_concatenate(tensors: List[Tensor]) -> Tensor: + """Flatten then concatenate all tensors in a list. + + Args: + tensors: List of tensors. + + Returns: + Concatenated flattened tensors. + """ + return cat([t.flatten() for t in tensors]) diff --git a/curvlinops/jacobian.py b/curvlinops/jacobian.py index bb0e905..58b18e6 100644 --- a/curvlinops/jacobian.py +++ b/curvlinops/jacobian.py @@ -3,20 +3,14 @@ from typing import Callable, Iterable, List, Tuple from backpack.hessianfree.rop import jacobian_vector_product as jvp -from backpack.utils.convert_parameters import vector_to_parameter_list -from numpy import allclose, column_stack, float32, ndarray -from numpy.random import rand -from scipy.sparse.linalg import LinearOperator -from torch import Tensor, cat -from torch import device as torch_device -from torch import from_numpy, no_grad -from torch.nn import Module, Parameter -from tqdm import tqdm +from numpy import allclose, ndarray +from torch import Tensor, no_grad +from torch.nn import Parameter from curvlinops._base import _LinearOperator -class JacobianLinearOperator(LinearOperator): +class JacobianLinearOperator(_LinearOperator): """Linear operator for the Jacobian. Can be used with SciPy. @@ -63,114 +57,51 @@ def __init__( x = next(iter(data))[0] num_outputs = model_func(x).shape[1:].numel() num_params = sum(p.numel() for p in params) - super().__init__(shape=(num_data * num_outputs, num_params), dtype=float32) - - self._params = params - self._model_func = model_func - self._data = data - self._device = _LinearOperator._infer_device(self._params) - self._progressbar = progressbar - - if check_deterministic: - old_device = self._device - self.to_device(torch_device("cpu")) - try: - self._check_deterministic() - except RuntimeError as e: - raise e - finally: - self.to_device(old_device) + super().__init__( + model_func, + None, + params, + data, + progressbar=progressbar, + check_deterministic=check_deterministic, + shape=(num_data * num_outputs, num_params), + ) def _check_deterministic(self): """Verify that the linear operator is deterministic. - - Checks that the data is loaded in a deterministic fashion (e.g. shuffling). - - Checks that the model is deterministic (e.g. dropout). - - Checks that matrix-vector multiplication with a single random vector is - deterministic. + In addition to the checks from the base class, checks that the model + predictions and data are always the same (loaded in the same order, and + only deterministic operations in the network. Note: - Deterministic checks are performed on CPU. We noticed that even when it - passes on CPU, it can fail on GPU; probably due to non-deterministic + Deterministic checks should be performed on CPU. We noticed that even when + it passes on CPU, it can fail on GPU; probably due to non-deterministic operations. Raises: RuntimeError: If the linear operator is not deterministic. """ - print("Performing deterministic checks") - - pred1, y1 = self.predictions_and_targets() - pred1, y1 = pred1.cpu().numpy(), y1.cpu().numpy() - pred2, y2 = self.predictions_and_targets() - pred2, y2 = pred2.cpu().numpy(), y2.cpu().numpy() + super()._check_deterministic() rtol, atol = 5e-5, 1e-6 - if not allclose(y1, y2, rtol=rtol, atol=atol): - _LinearOperator.print_nonclose(y1, y2, rtol=rtol, atol=atol) - raise RuntimeError( - "Data is not loaded in a deterministic fashion." - + " Make sure shuffling is turned off." - ) - if not allclose(pred1, pred2, rtol=rtol, atol=atol): - _LinearOperator.print_nonclose(pred1, pred2, rtol=rtol, atol=atol) - raise RuntimeError( - "Model predictions are not deterministic." - + " Make sure dropout and batch normalization are in eval mode." - ) - - v = rand(self.shape[1]).astype(self.dtype) - mat_v1 = self @ v - mat_v2 = self @ v - if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): - _LinearOperator.print_nonclose(mat_v1, mat_v2, rtol, atol) - raise RuntimeError("Check for deterministic matvec failed.") - - def to_device(self, device: torch_device): - """Load linear operator to a device (inplace). - - Args: - device: Target device. - """ - self._device = device - - if isinstance(self._model_func, Module): - self._model_func = self._model_func.to(self._device) - self._params = [p.to(device) for p in self._params] - - def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]: - """Yield batches of the data set, loaded to the correct device. - - Yields: - Mini-batches ``(X, y)``. - """ - data_iter = iter(self._data) - - if self._progressbar: - data_iter = tqdm(data_iter, desc="matvec") - - for X, y in data_iter: - X, y = X.to(self._device), y.to(self._device) - yield (X, y) - - def predictions_and_targets(self) -> Tuple[Tensor, Tensor]: - """Return the batch-concatenated model predictions and labels. - - Returns: - Batch-concatenated model predictions of shape ``[N, *]`` where ``*`` - denotes the model's output shape (for instance ``* = C``). - Batch-concatenated labels of shape ``[N, *]``, where ``*`` denotes - the dimension of a label. - """ - total_pred, total_y = [], [] - with no_grad(): - for X, y in self._loop_over_data(): - total_pred.append(self._model_func(X)) - total_y.append(y) - assert total_pred and total_y - - return cat(total_pred), cat(total_y) + for (X1, y1), (X2, y2) in zip( + self._loop_over_data(), self._loop_over_data() + ): + pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy() + pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy() + X1, X2 = X1.cpu().numpy(), X2.cpu().numpy() + + if not allclose(X1, X2) or not allclose(y1, y2): + self.print_nonclose(X1, X2, rtol=rtol, atol=atol) + self.print_nonclose(y1, y2, rtol=rtol, atol=atol) + raise RuntimeError("Non-deterministic data loading detected.") + + if not allclose(pred1, pred2): + self.print_nonclose(pred1, pred2, rtol=rtol, atol=atol) + raise RuntimeError("Non-deterministic model detected.") def _matvec(self, x: ndarray) -> ndarray: """Loop over all batches in the data and apply the matrix to vector x. @@ -181,7 +112,7 @@ def _matvec(self, x: ndarray) -> ndarray: Returns: Matrix-multiplication result ``mat @ x``. """ - x_list = vector_to_parameter_list(from_numpy(x).to(self._device), self._params) + x_list = self._preprocess(x) out_list = [ jvp(self._model_func(X), self._params, x_list, retain_graph=False)[ 0 @@ -189,15 +120,4 @@ def _matvec(self, x: ndarray) -> ndarray: for X, _ in self._loop_over_data() ] - return cat(out_list).cpu().numpy() - - def _matmat(self, X: ndarray) -> ndarray: - """Matrix-matrix multiplication. - - Args: - X: Matrix for multiplication. - - Returns: - Matrix-multiplication result ``mat @ X``. - """ - return column_stack([self @ col for col in X.T]) + return self._postprocess(out_list) From 98009c1dc5b8ee4c630bbe994ccbc7c22eef63cb Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 16:42:38 -0400 Subject: [PATCH 10/10] [FIX] Documentation and rename --- curvlinops/_base.py | 2 +- curvlinops/jacobian.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/curvlinops/_base.py b/curvlinops/_base.py index 7f332ee..e655eb9 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -79,7 +79,7 @@ def __init__( self._device = self._infer_device(self._params) self._progressbar = progressbar - self._num_data = sum(X.shape[0] for (X, _) in self._loop_over_data()) + self._N_data = sum(X.shape[0] for (X, _) in self._loop_over_data()) if check_deterministic: old_device = self._device diff --git a/curvlinops/jacobian.py b/curvlinops/jacobian.py index 58b18e6..65d4614 100644 --- a/curvlinops/jacobian.py +++ b/curvlinops/jacobian.py @@ -49,9 +49,6 @@ def __init__( data: Iterable of batched input-target pairs. progressbar: Show progress bar. check_deterministic: Check if model and data are deterministic. - - Raises: - RuntimeError: If deterministic checks are enables and fail. """ num_data = sum(t.shape[0] for t, _ in data) x = next(iter(data))[0]