Skip to content

Commit

Permalink
[ADD] Implement model Jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 18, 2023
1 parent 783e644 commit c68e3ee
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 12 deletions.
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from curvlinops.gradient_moments import EFLinearOperator
from curvlinops.hessian import HessianLinearOperator
from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator
from curvlinops.jacobian import JacobianLinearOperator
from curvlinops.papyan2020traces.spectrum import (
LanczosApproximateLogSpectrumCached,
LanczosApproximateSpectrumCached,
Expand All @@ -18,6 +19,7 @@
"GGNLinearOperator",
"EFLinearOperator",
"FisherMCLinearOperator",
"JacobianLinearOperator",
"CGInverseLinearOperator",
"NeumannInverseLinearOperator",
"SubmatrixLinearOperator",
Expand Down
66 changes: 54 additions & 12 deletions curvlinops/examples/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from functorch import grad, hessian, jvp, make_functional, vmap
from torch import Tensor, cat, einsum
from torch.func import jacrev
from torch.nn import Module


Expand Down Expand Up @@ -55,9 +56,7 @@ def functorch_hessian(
model_fn, _ = make_functional(model_func)
loss_fn, loss_fn_params = make_functional(loss_func)

# concatenate batches
X, y = list(zip(*list(data)))
X, y = cat(X), cat(y)
X, y = _concatenate_batches(data)

def loss(X: Tensor, y: Tensor, params: Tuple[Tensor]) -> Tensor:
"""Compute the loss given a mini-batch and the neural network parameters.
Expand Down Expand Up @@ -100,9 +99,7 @@ def functorch_ggn(
model_fn, _ = make_functional(model_func)
loss_fn, loss_fn_params = make_functional(loss_func)

# concatenate batches
X, y = list(zip(*list(data)))
X, y = cat(X), cat(y)
X, y = _concatenate_batches(data)

def linearized_model(
anchor: Tuple[Tensor], params: Tuple[Tensor], X: Tensor
Expand Down Expand Up @@ -167,9 +164,7 @@ def functorch_gradient(
model_fn, _ = make_functional(model_func)
loss_fn, loss_fn_params = make_functional(loss_func)

# concatenate batches
X, y = list(zip(*list(data)))
X, y = cat(X), cat(y)
X, y = _concatenate_batches(data)

def loss(X: Tensor, y: Tensor, params: Tuple[Tensor]) -> Tensor:
"""Compute the loss given a mini-batch and the neural network parameters.
Expand Down Expand Up @@ -213,9 +208,7 @@ def functorch_empirical_fisher(
model_fn, _ = make_functional(model_func)
loss_fn, loss_fn_params = make_functional(loss_func)

# concatenate batches
X, y = list(zip(*list(data)))
X, y = cat(X), cat(y)
X, y = _concatenate_batches(data)

# compute batched gradients
def loss_n(X_n: Tensor, y_n: Tensor, params: List[Tensor]) -> Tensor:
Expand Down Expand Up @@ -244,3 +237,52 @@ def loss_n(X_n: Tensor, y_n: Tensor, params: List[Tensor]) -> Tensor:
raise ValueError("Cannot detect reduction method from loss function.")

return 1 / normalization * einsum("ni,nj->ij", batch_grad, batch_grad)


def functorch_jacobian(
model_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Tensor, Tensor]],
) -> Tensor:
"""Compute the Jacobian with functorch.
Args:
model_func: A function that maps the mini-batch input X to predictions.
Could be a PyTorch module representing a neural network.
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
Returns:
Matrix containing the Jacobian. Has shape ``[N * C, D]`` where ``D`` is the
total number of parameters, ``N`` the total number of data points, and ``C``
the model's output space dimension.
"""
model_fn, _ = make_functional(model_func)
X, _ = _concatenate_batches(data)

def model_fn_params_only(params: Tuple[Tensor]) -> Tensor:
return model_fn(params, X)

# concatenate over flattened parameters and flattened outputs
jac = jacrev(model_fn_params_only)(params)
jac = [j.flatten(start_dim=-p.dim()) for j, p in zip(jac, params)]
jac = cat(jac, dim=-1).flatten(end_dim=-2)

return jac


def _concatenate_batches(
data: Iterable[Tuple[Tensor, Tensor]]
) -> Tuple[Tensor, Tensor]:
"""Concatenate all batches in the dataset along the batch dimension.
Args:
data: A dataloader or iterable of batches.
Returns:
Concatenated model inputs.
Concatenated targets.
"""
X, y = list(zip(*list(data)))
return cat(X), cat(y)
200 changes: 200 additions & 0 deletions curvlinops/jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Implements linear operators for per-sample Jacobians."""

from typing import Callable, Iterable, List, Tuple

from backpack.hessianfree.rop import jacobian_vector_product as jvp
from backpack.utils.convert_parameters import vector_to_parameter_list
from numpy import allclose, column_stack, float32, ndarray
from numpy.random import rand
from scipy.sparse.linalg import LinearOperator
from torch import Tensor, cat
from torch import device as torch_device
from torch import from_numpy, no_grad
from torch.nn import Module, Parameter
from tqdm import tqdm

from curvlinops._base import _LinearOperator


class JacobianLinearOperator(LinearOperator):
"""Linear operator for the Jacobian.
Can be used with SciPy.
"""

def __init__(
self,
model_func: Callable[[Tensor], Tensor],
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
):
r"""Linear operator for the Jacobian as SciPy linear operator.
Consider a model :math:`f(\mathbf{x}, \mathbf{\theta}): \mathbb{R}^M
\times \mathbb{R}^D \to \mathbb{R}^C` with parameters
:math:`\mathbf{\theta}` and input :math:`\mathbf{x}`. Assume we are
given a data set :math:`\mathcal{D} = \{ (\mathbf{x}_n, \mathbf{y}_n)
\}_{n=1}^N` of input-target pairs via batches. The model's Jacobian
:math:`\mathbf{J}_\mathbf{\theta}\mathbf{f}` is an :math:`NC \times D`
with elements
.. math::
\left[
\mathbf{J}_\mathbf{\theta}\mathbf{f}
\right]_{(n,c), d}
=
\frac{\partial f(\mathbf{x}_n, \mathbf{\theta})}{\partial \theta_d}\,.
Note that the data must be supplied in deterministic order.
Args:
model_func: Neural network function.
params: Neural network parameters.
data: Iterable of batched input-target pairs.
progressbar: Show progress bar.
check_deterministic: Check if model and data are deterministic.
"""
num_data = sum(t.shape[0] for t, _ in data)
x = next(iter(data))[0]
num_outputs = model_func(x).shape[1:].numel()
num_params = sum(p.numel() for p in params)
super().__init__(shape=(num_data * num_outputs, num_params), dtype=float32)

self._params = params
self._model_func = model_func
self._data = data
self._device = _LinearOperator._infer_device(self._params)
self._progressbar = progressbar

if check_deterministic:
old_device = self._device
self.to_device(torch_device("cpu"))
try:
self._check_deterministic()
except RuntimeError as e:
raise e
finally:
self.to_device(old_device)

def _check_deterministic(self):
"""Verify that the linear operator is deterministic.
- Checks that the data is loaded in a deterministic fashion (e.g. shuffling).
- Checks that the model is deterministic (e.g. dropout).
- Checks that matrix-vector multiplication with a single random vector is
deterministic.
Note:
Deterministic checks are performed on CPU. We noticed that even when it
passes on CPU, it can fail on GPU; probably due to non-deterministic
operations.
Raises:
RuntimeError: If the linear operator is not deterministic.
"""
print("Performing deterministic checks")

pred1, y1 = self.predictions_and_targets()
pred1, y1 = pred1.cpu().numpy(), y1.cpu().numpy()
pred2, y2 = self.predictions_and_targets()
pred2, y2 = pred2.cpu().numpy(), y2.cpu().numpy()

rtol, atol = 5e-5, 1e-6

if not allclose(y1, y2, rtol=rtol, atol=atol):
_LinearOperator.print_nonclose(y1, y2, rtol=rtol, atol=atol)
raise RuntimeError(
"Data is not loaded in a deterministic fashion."
+ " Make sure shuffling is turned off."
)
if not allclose(pred1, pred2, rtol=rtol, atol=atol):
_LinearOperator.print_nonclose(pred1, pred2, rtol=rtol, atol=atol)
raise RuntimeError(
"Model predictions are not deterministic."
+ " Make sure dropout and batch normalization are in eval mode."
)

v = rand(self.shape[1]).astype(self.dtype)
mat_v1 = self @ v
mat_v2 = self @ v
if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol):
_LinearOperator.print_nonclose(mat_v1, mat_v2, rtol, atol)
raise RuntimeError("Check for deterministic matvec failed.")

def to_device(self, device: torch_device):
"""Load linear operator to a device (inplace).
Args:
device: Target device.
"""
self._device = device

if isinstance(self._model_func, Module):
self._model_func = self._model_func.to(self._device)
self._params = [p.to(device) for p in self._params]

def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]:
"""Yield batches of the data set, loaded to the correct device.
Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)

if self._progressbar:
data_iter = tqdm(data_iter, desc="matvec")

for X, y in data_iter:
X, y = X.to(self._device), y.to(self._device)
yield (X, y)

def predictions_and_targets(self) -> Tuple[Tensor, Tensor]:
"""Return the batch-concatenated model predictions and labels.
Returns:
Batch-concatenated model predictions of shape ``[N, *]`` where ``*``
denotes the model's output shape (for instance ``* = C``).
Batch-concatenated labels of shape ``[N, *]``, where ``*`` denotes
the dimension of a label.
"""
total_pred, total_y = [], []

with no_grad():
for X, y in self._loop_over_data():
total_pred.append(self._model_func(X))
total_y.append(y)
assert total_pred and total_y

return cat(total_pred), cat(total_y)

def _matvec(self, x: ndarray) -> ndarray:
"""Loop over all batches in the data and apply the matrix to vector x.
Args:
x: Vector for multiplication. Has shape ``[D]``.
Returns:
Matrix-multiplication result ``mat @ x``.
"""
x_list = vector_to_parameter_list(from_numpy(x).to(self._device), self._params)
out_list = [
jvp(self._model_func(X), self._params, x_list, retain_graph=False)[
0
].flatten(start_dim=1)
for X, _ in self._loop_over_data()
]

return cat(out_list).cpu().numpy()

def _matmat(self, X: ndarray) -> ndarray:
"""Matrix-matrix multiplication.
Args:
X: Matrix for multiplication.
Returns:
Matrix-multiplication result ``mat @ X``.
"""
return column_stack([self @ col for col in X.T])
6 changes: 6 additions & 0 deletions docs/rtd/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ Uncentered gradient covariance (empirical Fisher)
.. autoclass:: curvlinops.EFLinearOperator
:members: __init__

Jacobians
---------

.. autoclass:: curvlinops.JacobianLinearOperator
:members: __init__

Inverses
--------

Expand Down
27 changes: 27 additions & 0 deletions test/test_jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Contains tests for ``curvlinops/jacobian``."""

from numpy import random

from curvlinops import JacobianLinearOperator
from curvlinops.examples.functorch import functorch_jacobian
from curvlinops.examples.utils import report_nonclose


def test_JacobianLinearOperator_matvec(case):
model_func, _, params, data = case

op = JacobianLinearOperator(model_func, params, data)
op_functorch = functorch_jacobian(model_func, params, data).detach().cpu().numpy()

x = random.rand(op.shape[1])
report_nonclose(op @ x, op_functorch @ x)


def test_JacobianLinearOperator_matmat(case, num_vecs: int = 3):
model_func, _, params, data = case

op = JacobianLinearOperator(model_func, params, data)
op_functorch = functorch_jacobian(model_func, params, data).detach().cpu().numpy()

X = random.rand(op.shape[1], num_vecs)
report_nonclose(op @ X, op_functorch @ X)

0 comments on commit c68e3ee

Please sign in to comment.