diff --git a/.github/workflows/lint-black.yaml b/.github/workflows/lint-black.yaml index 60830db..296b889 100644 --- a/.github/workflows/lint-black.yaml +++ b/.github/workflows/lint-black.yaml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-darglint.yaml b/.github/workflows/lint-darglint.yaml index 7efb71e..5ba0035 100644 --- a/.github/workflows/lint-darglint.yaml +++ b/.github/workflows/lint-darglint.yaml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-flake8.yaml b/.github/workflows/lint-flake8.yaml index 6d3ed55..dbfa163 100644 --- a/.github/workflows/lint-flake8.yaml +++ b/.github/workflows/lint-flake8.yaml @@ -15,10 +15,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-isort.yaml b/.github/workflows/lint-isort.yaml index d4a2b65..2b1344f 100644 --- a/.github/workflows/lint-isort.yaml +++ b/.github/workflows/lint-isort.yaml @@ -16,10 +16,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint-pydocstyle.yaml b/.github/workflows/lint-pydocstyle.yaml index 53397eb..2383898 100644 --- a/.github/workflows/lint-pydocstyle.yaml +++ b/.github/workflows/lint-pydocstyle.yaml @@ -17,10 +17,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 69a4e48..be64df7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,7 +18,7 @@ jobs: USING_COVERAGE: '3.8' strategy: matrix: - python-version: ["3.7", "3.8"] + python-version: ["3.8"] steps: - uses: actions/checkout@v1 - uses: actions/setup-python@v1 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 0807394..4d7c441 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,7 +7,7 @@ sphinx: configuration: docs/rtd/conf.py python: - version: 3.7 + version: 3.8 install: - method: pip path: . diff --git a/README.md b/README.md index eec1ac0..450c9f1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Logo scipy linear operators of deep learning matrices in PyTorch [![Python -3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/) +3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/) ![tests](https://github.com/f-dangel/curvature-linear-operators/actions/workflows/test.yaml/badge.svg) [![Coveralls](https://coveralls.io/repos/github/f-dangel/curvlinops/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/curvlinops) @@ -13,6 +13,7 @@ for deep learning matrices, such as - the Fisher/generalized Gauss-Newton (GGN) - the Monte-Carlo approximated Fisher - the uncentered gradient covariance (aka empirical Fisher) +- the output-parameter Jacobian of a neural net Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU. The library supports defining these matrices not only on a mini-batch, but diff --git a/black.toml b/black.toml index 911dcf1..94efa1b 100644 --- a/black.toml +++ b/black.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 88 -target-version = ['py36', 'py37', 'py38'] +target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' exclude = ''' ( diff --git a/curvlinops/__init__.py b/curvlinops/__init__.py index 7f7bd36..cccfca4 100644 --- a/curvlinops/__init__.py +++ b/curvlinops/__init__.py @@ -5,6 +5,7 @@ from curvlinops.gradient_moments import EFLinearOperator from curvlinops.hessian import HessianLinearOperator from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator +from curvlinops.jacobian import JacobianLinearOperator from curvlinops.papyan2020traces.spectrum import ( LanczosApproximateLogSpectrumCached, LanczosApproximateSpectrumCached, @@ -18,6 +19,7 @@ "GGNLinearOperator", "EFLinearOperator", "FisherMCLinearOperator", + "JacobianLinearOperator", "CGInverseLinearOperator", "NeumannInverseLinearOperator", "SubmatrixLinearOperator", diff --git a/curvlinops/_base.py b/curvlinops/_base.py index 102faee..e655eb9 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -1,6 +1,6 @@ """Contains functionality to analyze Hessian & GGN via matrix-free multiplication.""" -from typing import Callable, Iterable, List, Tuple +from typing import Callable, Iterable, List, Optional, Tuple, Union from backpack.utils.convert_parameters import vector_to_parameter_list from numpy import ( @@ -14,17 +14,16 @@ ) from numpy.random import rand from scipy.sparse.linalg import LinearOperator -from torch import Tensor +from torch import Tensor, cat from torch import device as torch_device from torch import from_numpy, tensor, zeros_like from torch.autograd import grad from torch.nn import Module, Parameter -from torch.nn.utils import parameters_to_vector from tqdm import tqdm class _LinearOperator(LinearOperator): - """Base class for linear operators of DNN curvature matrices. + """Base class for linear operators of DNN matrices. Can be used with SciPy. """ @@ -32,13 +31,14 @@ class _LinearOperator(LinearOperator): def __init__( self, model_func: Callable[[Tensor], Tensor], - loss_func: Callable[[Tensor, Tensor], Tensor], + loss_func: Union[Callable[[Tensor, Tensor], Tensor], None], params: List[Parameter], data: Iterable[Tuple[Tensor, Tensor]], progressbar: bool = False, check_deterministic: bool = True, + shape: Optional[Tuple[int, int]] = None, ): - """Linear operator for DNN curvature matrices. + """Linear operator for DNN matrices. Note: f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch @@ -49,10 +49,13 @@ def __init__( model_func: A function that maps the mini-batch input X to predictions. Could be a PyTorch module representing a neural network. loss_func: Loss function criterion. Maps predictions and mini-batch labels - to a scalar value. + to a scalar value. If ``None``, there is no loss function and the + represented matrix is independent of the loss function. params: List of differentiable parameters used by the prediction function. data: Source from which mini-batches can be drawn, for instance a list of mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. + shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)`` + where ``D`` is the total number of parameters progressbar: Show a progressbar during matrix-multiplication. Default: ``False``. check_deterministic: Probe that model and data are deterministic, i.e. @@ -64,8 +67,10 @@ def __init__( Raises: RuntimeError: If the check for deterministic behavior fails. """ - dim = sum(p.numel() for p in params) - super().__init__(shape=(dim, dim), dtype=float32) + if shape is None: + dim = sum(p.numel() for p in params) + shape = (dim, dim) + super().__init__(shape=shape, dtype=float32) self._params = params self._model_func = model_func @@ -129,22 +134,37 @@ def _check_deterministic(self): - Two independent loss/gradient computations yield different results Note: - Deterministic checks are performed on CPU. We noticed that even when it - passes on CPU, it can fail on GPU; probably due to non-deterministic + Deterministic checks should be performed on CPU. We noticed that even when + it passes on CPU, it can fail on GPU; probably due to non-deterministic operations. Raises: RuntimeError: If non-deterministic behavior is detected. """ - print("Performing deterministic checks") + v = rand(self.shape[1]).astype(self.dtype) + mat_v1 = self @ v + mat_v2 = self @ v + + rtol, atol = 5e-5, 1e-6 + if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): + self.print_nonclose(mat_v1, mat_v2, rtol, atol) + raise RuntimeError("Check for deterministic matvec failed.") + + if self._loss_func is None: + return + # only carried out if there is a loss function grad1, loss1 = self.gradient_and_loss() - grad1, loss1 = parameters_to_vector(grad1).cpu().numpy(), loss1.cpu().numpy() + grad1, loss1 = ( + self.flatten_and_concatenate(grad1).cpu().numpy(), + loss1.cpu().numpy(), + ) grad2, loss2 = self.gradient_and_loss() - grad2, loss2 = parameters_to_vector(grad2).cpu().numpy(), loss2.cpu().numpy() - - rtol, atol = 5e-5, 1e-6 + grad2, loss2 = ( + self.flatten_and_concatenate(grad2).cpu().numpy(), + loss2.cpu().numpy(), + ) if not allclose(loss1, loss2, rtol=rtol, atol=atol): self.print_nonclose(loss1, loss2, rtol, atol) @@ -154,16 +174,6 @@ def _check_deterministic(self): self.print_nonclose(grad1, grad2, rtol, atol) raise RuntimeError("Check for deterministic gradient failed.") - v = rand(self.shape[0]).astype(self.dtype) - mat_v1 = self @ v - mat_v2 = self @ v - - if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): - self.print_nonclose(mat_v1, mat_v2, rtol, atol) - raise RuntimeError("Check for deterministic matvec failed.") - - print("Deterministic checks passed") - @staticmethod def print_nonclose(array1: ndarray, array2: ndarray, rtol: float, atol: float): """Check if the two arrays are element-wise equal within a tolerance and print @@ -245,8 +255,7 @@ def _preprocess(self, x: ndarray) -> List[Tensor]: x_torch = from_numpy(x).to(self._device) return vector_to_parameter_list(x_torch, self._params) - @staticmethod - def _postprocess(x_list: List[Tensor]) -> ndarray: + def _postprocess(self, x_list: List[Tensor]) -> ndarray: """Convert torch list format to flat numpy array. Args: @@ -255,7 +264,7 @@ def _postprocess(x_list: List[Tensor]) -> ndarray: Returns: Flat vector. """ - return parameters_to_vector([x.contiguous() for x in x_list]).cpu().numpy() + return self.flatten_and_concatenate(x_list).cpu().numpy() def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]: """Yield batches of the data set, loaded to the correct device. @@ -279,7 +288,13 @@ def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]: Returns: Gradient and loss on the data set. + + Raises: + ValueError: If there is no loss function. """ + if self._loss_func is None: + raise ValueError("No loss function specified.") + total_loss = tensor([0.0], device=self._device) total_grad = [zeros_like(p) for p in self._params] @@ -317,3 +332,15 @@ def _get_normalization_factor(self, X: Tensor, y: Tensor) -> float: return X.shape[0] / self._N_data else: raise ValueError("Loss must have reduction 'mean' or 'sum'.") + + @staticmethod + def flatten_and_concatenate(tensors: List[Tensor]) -> Tensor: + """Flatten then concatenate all tensors in a list. + + Args: + tensors: List of tensors. + + Returns: + Concatenated flattened tensors. + """ + return cat([t.flatten() for t in tensors]) diff --git a/curvlinops/examples/functorch.py b/curvlinops/examples/functorch.py index 3e38854..36d43dd 100644 --- a/curvlinops/examples/functorch.py +++ b/curvlinops/examples/functorch.py @@ -5,6 +5,7 @@ from functorch import grad, hessian, jvp, make_functional, vmap from torch import Tensor, cat, einsum +from torch.func import jacrev from torch.nn import Module @@ -55,9 +56,7 @@ def functorch_hessian( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) def loss(X: Tensor, y: Tensor, params: Tuple[Tensor]) -> Tensor: """Compute the loss given a mini-batch and the neural network parameters. @@ -100,9 +99,7 @@ def functorch_ggn( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) def linearized_model( anchor: Tuple[Tensor], params: Tuple[Tensor], X: Tensor @@ -167,9 +164,7 @@ def functorch_gradient( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) def loss(X: Tensor, y: Tensor, params: Tuple[Tensor]) -> Tensor: """Compute the loss given a mini-batch and the neural network parameters. @@ -213,9 +208,7 @@ def functorch_empirical_fisher( model_fn, _ = make_functional(model_func) loss_fn, loss_fn_params = make_functional(loss_func) - # concatenate batches - X, y = list(zip(*list(data))) - X, y = cat(X), cat(y) + X, y = _concatenate_batches(data) # compute batched gradients def loss_n(X_n: Tensor, y_n: Tensor, params: List[Tensor]) -> Tensor: @@ -244,3 +237,52 @@ def loss_n(X_n: Tensor, y_n: Tensor, params: List[Tensor]) -> Tensor: raise ValueError("Cannot detect reduction method from loss function.") return 1 / normalization * einsum("ni,nj->ij", batch_grad, batch_grad) + + +def functorch_jacobian( + model_func: Module, + params: List[Tensor], + data: Iterable[Tuple[Tensor, Tensor]], +) -> Tensor: + """Compute the Jacobian with functorch. + + Args: + model_func: A function that maps the mini-batch input X to predictions. + Could be a PyTorch module representing a neural network. + params: List of differentiable parameters used by the prediction function. + data: Source from which mini-batches can be drawn, for instance a list of + mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. + + Returns: + Matrix containing the Jacobian. Has shape ``[N * C, D]`` where ``D`` is the + total number of parameters, ``N`` the total number of data points, and ``C`` + the model's output space dimension. + """ + model_fn, _ = make_functional(model_func) + X, _ = _concatenate_batches(data) + + def model_fn_params_only(params: Tuple[Tensor]) -> Tensor: + return model_fn(params, X) + + # concatenate over flattened parameters and flattened outputs + jac = jacrev(model_fn_params_only)(params) + jac = [j.flatten(start_dim=-p.dim()) for j, p in zip(jac, params)] + jac = cat(jac, dim=-1).flatten(end_dim=-2) + + return jac + + +def _concatenate_batches( + data: Iterable[Tuple[Tensor, Tensor]] +) -> Tuple[Tensor, Tensor]: + """Concatenate all batches in the dataset along the batch dimension. + + Args: + data: A dataloader or iterable of batches. + + Returns: + Concatenated model inputs. + Concatenated targets. + """ + X, y = list(zip(*list(data))) + return cat(X), cat(y) diff --git a/curvlinops/jacobian.py b/curvlinops/jacobian.py new file mode 100644 index 0000000..65d4614 --- /dev/null +++ b/curvlinops/jacobian.py @@ -0,0 +1,120 @@ +"""Implements linear operators for per-sample Jacobians.""" + +from typing import Callable, Iterable, List, Tuple + +from backpack.hessianfree.rop import jacobian_vector_product as jvp +from numpy import allclose, ndarray +from torch import Tensor, no_grad +from torch.nn import Parameter + +from curvlinops._base import _LinearOperator + + +class JacobianLinearOperator(_LinearOperator): + """Linear operator for the Jacobian. + + Can be used with SciPy. + """ + + def __init__( + self, + model_func: Callable[[Tensor], Tensor], + params: List[Parameter], + data: Iterable[Tuple[Tensor, Tensor]], + progressbar: bool = False, + check_deterministic: bool = True, + ): + r"""Linear operator for the Jacobian as SciPy linear operator. + + Consider a model :math:`f(\mathbf{x}, \mathbf{\theta}): \mathbb{R}^M + \times \mathbb{R}^D \to \mathbb{R}^C` with parameters + :math:`\mathbf{\theta}` and input :math:`\mathbf{x}`. Assume we are + given a data set :math:`\mathcal{D} = \{ (\mathbf{x}_n, \mathbf{y}_n) + \}_{n=1}^N` of input-target pairs via batches. The model's Jacobian + :math:`\mathbf{J}_\mathbf{\theta}\mathbf{f}` is an :math:`NC \times D` + with elements + + .. math:: + \left[ + \mathbf{J}_\mathbf{\theta}\mathbf{f} + \right]_{(n,c), d} + = + \frac{\partial f(\mathbf{x}_n, \mathbf{\theta})}{\partial \theta_d}\,. + + Note that the data must be supplied in deterministic order. + + Args: + model_func: Neural network function. + params: Neural network parameters. + data: Iterable of batched input-target pairs. + progressbar: Show progress bar. + check_deterministic: Check if model and data are deterministic. + """ + num_data = sum(t.shape[0] for t, _ in data) + x = next(iter(data))[0] + num_outputs = model_func(x).shape[1:].numel() + num_params = sum(p.numel() for p in params) + super().__init__( + model_func, + None, + params, + data, + progressbar=progressbar, + check_deterministic=check_deterministic, + shape=(num_data * num_outputs, num_params), + ) + + def _check_deterministic(self): + """Verify that the linear operator is deterministic. + + In addition to the checks from the base class, checks that the model + predictions and data are always the same (loaded in the same order, and + only deterministic operations in the network. + + Note: + Deterministic checks should be performed on CPU. We noticed that even when + it passes on CPU, it can fail on GPU; probably due to non-deterministic + operations. + + Raises: + RuntimeError: If the linear operator is not deterministic. + """ + super()._check_deterministic() + + rtol, atol = 5e-5, 1e-6 + + with no_grad(): + for (X1, y1), (X2, y2) in zip( + self._loop_over_data(), self._loop_over_data() + ): + pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy() + pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy() + X1, X2 = X1.cpu().numpy(), X2.cpu().numpy() + + if not allclose(X1, X2) or not allclose(y1, y2): + self.print_nonclose(X1, X2, rtol=rtol, atol=atol) + self.print_nonclose(y1, y2, rtol=rtol, atol=atol) + raise RuntimeError("Non-deterministic data loading detected.") + + if not allclose(pred1, pred2): + self.print_nonclose(pred1, pred2, rtol=rtol, atol=atol) + raise RuntimeError("Non-deterministic model detected.") + + def _matvec(self, x: ndarray) -> ndarray: + """Loop over all batches in the data and apply the matrix to vector x. + + Args: + x: Vector for multiplication. Has shape ``[D]``. + + Returns: + Matrix-multiplication result ``mat @ x``. + """ + x_list = self._preprocess(x) + out_list = [ + jvp(self._model_func(X), self._params, x_list, retain_graph=False)[ + 0 + ].flatten(start_dim=1) + for X, _ in self._loop_over_data() + ] + + return self._postprocess(out_list) diff --git a/docs/rtd/linops.rst b/docs/rtd/linops.rst index c630140..fb903e0 100644 --- a/docs/rtd/linops.rst +++ b/docs/rtd/linops.rst @@ -26,6 +26,12 @@ Uncentered gradient covariance (empirical Fisher) .. autoclass:: curvlinops.EFLinearOperator :members: __init__ +Jacobians +--------- + +.. autoclass:: curvlinops.JacobianLinearOperator + :members: __init__ + Inverses -------- diff --git a/setup.cfg b/setup.cfg index 905bb11..14c81aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,8 @@ classifiers = Operating System :: OS Independent Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 [options] zip_safe = False @@ -34,10 +36,11 @@ setup_requires = # Dependencies of the project (semicolon/line-separated): install_requires = backpack-for-pytorch>=1.5.0,<2.0.0 + torch>=2.0 scipy>=1.7.1,<2.0.0 tqdm>=4.61.0,<5.0.0 # Require a specific Python version, e.g. Python 2.7 or >= 3.4 -python_requires = >=3.7 +python_requires = >=3.8 ############################################################################### # Development dependencies # diff --git a/test/test_jacobian.py b/test/test_jacobian.py new file mode 100644 index 0000000..ecec1ae --- /dev/null +++ b/test/test_jacobian.py @@ -0,0 +1,27 @@ +"""Contains tests for ``curvlinops/jacobian``.""" + +from numpy import random + +from curvlinops import JacobianLinearOperator +from curvlinops.examples.functorch import functorch_jacobian +from curvlinops.examples.utils import report_nonclose + + +def test_JacobianLinearOperator_matvec(case): + model_func, _, params, data = case + + op = JacobianLinearOperator(model_func, params, data) + op_functorch = functorch_jacobian(model_func, params, data).detach().cpu().numpy() + + x = random.rand(op.shape[1]) + report_nonclose(op @ x, op_functorch @ x) + + +def test_JacobianLinearOperator_matmat(case, num_vecs: int = 3): + model_func, _, params, data = case + + op = JacobianLinearOperator(model_func, params, data) + op_functorch = functorch_jacobian(model_func, params, data).detach().cpu().numpy() + + X = random.rand(op.shape[1], num_vecs) + report_nonclose(op @ X, op_functorch @ X)