Skip to content

Commit

Permalink
[REF] Implement Jacobian via base class, remove print statements
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 18, 2023
1 parent 91424cf commit a4d668a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 146 deletions.
87 changes: 57 additions & 30 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains functionality to analyze Hessian & GGN via matrix-free multiplication."""

from typing import Callable, Iterable, List, Tuple
from typing import Callable, Iterable, List, Optional, Tuple, Union

from backpack.utils.convert_parameters import vector_to_parameter_list
from numpy import (
Expand All @@ -14,31 +14,31 @@
)
from numpy.random import rand
from scipy.sparse.linalg import LinearOperator
from torch import Tensor
from torch import Tensor, cat
from torch import device as torch_device
from torch import from_numpy, tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module, Parameter
from torch.nn.utils import parameters_to_vector
from tqdm import tqdm


class _LinearOperator(LinearOperator):
"""Base class for linear operators of DNN curvature matrices.
"""Base class for linear operators of DNN matrices.
Can be used with SciPy.
"""

def __init__(
self,
model_func: Callable[[Tensor], Tensor],
loss_func: Callable[[Tensor, Tensor], Tensor],
loss_func: Union[Callable[[Tensor, Tensor], Tensor], None],
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
shape: Optional[Tuple[int, int]] = None,
):
"""Linear operator for DNN curvature matrices.
"""Linear operator for DNN matrices.
Note:
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch
Expand All @@ -49,10 +49,13 @@ def __init__(
model_func: A function that maps the mini-batch input X to predictions.
Could be a PyTorch module representing a neural network.
loss_func: Loss function criterion. Maps predictions and mini-batch labels
to a scalar value.
to a scalar value. If ``None``, there is no loss function and the
represented matrix is independent of the loss function.
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)``
where ``D`` is the total number of parameters
progressbar: Show a progressbar during matrix-multiplication.
Default: ``False``.
check_deterministic: Probe that model and data are deterministic, i.e.
Expand All @@ -64,8 +67,10 @@ def __init__(
Raises:
RuntimeError: If the check for deterministic behavior fails.
"""
dim = sum(p.numel() for p in params)
super().__init__(shape=(dim, dim), dtype=float32)
if shape is None:
dim = sum(p.numel() for p in params)
shape = (dim, dim)
super().__init__(shape=shape, dtype=float32)

self._params = params
self._model_func = model_func
Expand All @@ -74,7 +79,7 @@ def __init__(
self._device = self._infer_device(self._params)
self._progressbar = progressbar

self._N_data = sum(X.shape[0] for (X, _) in self._loop_over_data())
self._num_data = sum(X.shape[0] for (X, _) in self._loop_over_data())

if check_deterministic:
old_device = self._device
Expand Down Expand Up @@ -129,22 +134,37 @@ def _check_deterministic(self):
- Two independent loss/gradient computations yield different results
Note:
Deterministic checks are performed on CPU. We noticed that even when it
passes on CPU, it can fail on GPU; probably due to non-deterministic
Deterministic checks should be performed on CPU. We noticed that even when
it passes on CPU, it can fail on GPU; probably due to non-deterministic
operations.
Raises:
RuntimeError: If non-deterministic behavior is detected.
"""
print("Performing deterministic checks")
v = rand(self.shape[1]).astype(self.dtype)
mat_v1 = self @ v
mat_v2 = self @ v

rtol, atol = 5e-5, 1e-6
if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol):
self.print_nonclose(mat_v1, mat_v2, rtol, atol)
raise RuntimeError("Check for deterministic matvec failed.")

if self._loss_func is None:
return

# only carried out if there is a loss function
grad1, loss1 = self.gradient_and_loss()
grad1, loss1 = parameters_to_vector(grad1).cpu().numpy(), loss1.cpu().numpy()
grad1, loss1 = (
self.flatten_and_concatenate(grad1).cpu().numpy(),
loss1.cpu().numpy(),
)

grad2, loss2 = self.gradient_and_loss()
grad2, loss2 = parameters_to_vector(grad2).cpu().numpy(), loss2.cpu().numpy()

rtol, atol = 5e-5, 1e-6
grad2, loss2 = (
self.flatten_and_concatenate(grad2).cpu().numpy(),
loss2.cpu().numpy(),
)

if not allclose(loss1, loss2, rtol=rtol, atol=atol):
self.print_nonclose(loss1, loss2, rtol, atol)
Expand All @@ -154,16 +174,6 @@ def _check_deterministic(self):
self.print_nonclose(grad1, grad2, rtol, atol)
raise RuntimeError("Check for deterministic gradient failed.")

v = rand(self.shape[1]).astype(self.dtype)
mat_v1 = self @ v
mat_v2 = self @ v

if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol):
self.print_nonclose(mat_v1, mat_v2, rtol, atol)
raise RuntimeError("Check for deterministic matvec failed.")

print("Deterministic checks passed")

@staticmethod
def print_nonclose(array1: ndarray, array2: ndarray, rtol: float, atol: float):
"""Check if the two arrays are element-wise equal within a tolerance and print
Expand Down Expand Up @@ -245,8 +255,7 @@ def _preprocess(self, x: ndarray) -> List[Tensor]:
x_torch = from_numpy(x).to(self._device)
return vector_to_parameter_list(x_torch, self._params)

@staticmethod
def _postprocess(x_list: List[Tensor]) -> ndarray:
def _postprocess(self, x_list: List[Tensor]) -> ndarray:
"""Convert torch list format to flat numpy array.
Args:
Expand All @@ -255,7 +264,7 @@ def _postprocess(x_list: List[Tensor]) -> ndarray:
Returns:
Flat vector.
"""
return parameters_to_vector([x.contiguous() for x in x_list]).cpu().numpy()
return self.flatten_and_concatenate(x_list).cpu().numpy()

def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]:
"""Yield batches of the data set, loaded to the correct device.
Expand All @@ -279,7 +288,13 @@ def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]:
Returns:
Gradient and loss on the data set.
Raises:
ValueError: If there is no loss function.
"""
if self._loss_func is None:
raise ValueError("No loss function specified.")

total_loss = tensor([0.0], device=self._device)
total_grad = [zeros_like(p) for p in self._params]

Expand Down Expand Up @@ -317,3 +332,15 @@ def _get_normalization_factor(self, X: Tensor, y: Tensor) -> float:
return X.shape[0] / self._N_data
else:
raise ValueError("Loss must have reduction 'mean' or 'sum'.")

@staticmethod
def flatten_and_concatenate(tensors: List[Tensor]) -> Tensor:
"""Flatten then concatenate all tensors in a list.
Args:
tensors: List of tensors.
Returns:
Concatenated flattened tensors.
"""
return cat([t.flatten() for t in tensors])
152 changes: 36 additions & 116 deletions curvlinops/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,14 @@
from typing import Callable, Iterable, List, Tuple

from backpack.hessianfree.rop import jacobian_vector_product as jvp
from backpack.utils.convert_parameters import vector_to_parameter_list
from numpy import allclose, column_stack, float32, ndarray
from numpy.random import rand
from scipy.sparse.linalg import LinearOperator
from torch import Tensor, cat
from torch import device as torch_device
from torch import from_numpy, no_grad
from torch.nn import Module, Parameter
from tqdm import tqdm
from numpy import allclose, ndarray
from torch import Tensor, no_grad
from torch.nn import Parameter

from curvlinops._base import _LinearOperator


class JacobianLinearOperator(LinearOperator):
class JacobianLinearOperator(_LinearOperator):
"""Linear operator for the Jacobian.
Can be used with SciPy.
Expand Down Expand Up @@ -63,114 +57,51 @@ def __init__(
x = next(iter(data))[0]
num_outputs = model_func(x).shape[1:].numel()
num_params = sum(p.numel() for p in params)
super().__init__(shape=(num_data * num_outputs, num_params), dtype=float32)

self._params = params
self._model_func = model_func
self._data = data
self._device = _LinearOperator._infer_device(self._params)
self._progressbar = progressbar

if check_deterministic:
old_device = self._device
self.to_device(torch_device("cpu"))
try:
self._check_deterministic()
except RuntimeError as e:
raise e
finally:
self.to_device(old_device)
super().__init__(
model_func,
None,
params,
data,
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=(num_data * num_outputs, num_params),
)

def _check_deterministic(self):
"""Verify that the linear operator is deterministic.
- Checks that the data is loaded in a deterministic fashion (e.g. shuffling).
- Checks that the model is deterministic (e.g. dropout).
- Checks that matrix-vector multiplication with a single random vector is
deterministic.
In addition to the checks from the base class, checks that the model
predictions and data are always the same (loaded in the same order, and
only deterministic operations in the network.
Note:
Deterministic checks are performed on CPU. We noticed that even when it
passes on CPU, it can fail on GPU; probably due to non-deterministic
Deterministic checks should be performed on CPU. We noticed that even when
it passes on CPU, it can fail on GPU; probably due to non-deterministic
operations.
Raises:
RuntimeError: If the linear operator is not deterministic.
"""
print("Performing deterministic checks")

pred1, y1 = self.predictions_and_targets()
pred1, y1 = pred1.cpu().numpy(), y1.cpu().numpy()
pred2, y2 = self.predictions_and_targets()
pred2, y2 = pred2.cpu().numpy(), y2.cpu().numpy()
super()._check_deterministic()

rtol, atol = 5e-5, 1e-6

if not allclose(y1, y2, rtol=rtol, atol=atol):
_LinearOperator.print_nonclose(y1, y2, rtol=rtol, atol=atol)
raise RuntimeError(
"Data is not loaded in a deterministic fashion."
+ " Make sure shuffling is turned off."
)
if not allclose(pred1, pred2, rtol=rtol, atol=atol):
_LinearOperator.print_nonclose(pred1, pred2, rtol=rtol, atol=atol)
raise RuntimeError(
"Model predictions are not deterministic."
+ " Make sure dropout and batch normalization are in eval mode."
)

v = rand(self.shape[1]).astype(self.dtype)
mat_v1 = self @ v
mat_v2 = self @ v
if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol):
_LinearOperator.print_nonclose(mat_v1, mat_v2, rtol, atol)
raise RuntimeError("Check for deterministic matvec failed.")

def to_device(self, device: torch_device):
"""Load linear operator to a device (inplace).
Args:
device: Target device.
"""
self._device = device

if isinstance(self._model_func, Module):
self._model_func = self._model_func.to(self._device)
self._params = [p.to(device) for p in self._params]

def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]:
"""Yield batches of the data set, loaded to the correct device.
Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)

if self._progressbar:
data_iter = tqdm(data_iter, desc="matvec")

for X, y in data_iter:
X, y = X.to(self._device), y.to(self._device)
yield (X, y)

def predictions_and_targets(self) -> Tuple[Tensor, Tensor]:
"""Return the batch-concatenated model predictions and labels.
Returns:
Batch-concatenated model predictions of shape ``[N, *]`` where ``*``
denotes the model's output shape (for instance ``* = C``).
Batch-concatenated labels of shape ``[N, *]``, where ``*`` denotes
the dimension of a label.
"""
total_pred, total_y = [], []

with no_grad():
for X, y in self._loop_over_data():
total_pred.append(self._model_func(X))
total_y.append(y)
assert total_pred and total_y

return cat(total_pred), cat(total_y)
for (X1, y1), (X2, y2) in zip(
self._loop_over_data(), self._loop_over_data()
):
pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy()
pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy()
X1, X2 = X1.cpu().numpy(), X2.cpu().numpy()

if not allclose(X1, X2) or not allclose(y1, y2):
self.print_nonclose(X1, X2, rtol=rtol, atol=atol)
self.print_nonclose(y1, y2, rtol=rtol, atol=atol)
raise RuntimeError("Non-deterministic data loading detected.")

if not allclose(pred1, pred2):
self.print_nonclose(pred1, pred2, rtol=rtol, atol=atol)
raise RuntimeError("Non-deterministic model detected.")

def _matvec(self, x: ndarray) -> ndarray:
"""Loop over all batches in the data and apply the matrix to vector x.
Expand All @@ -181,23 +112,12 @@ def _matvec(self, x: ndarray) -> ndarray:
Returns:
Matrix-multiplication result ``mat @ x``.
"""
x_list = vector_to_parameter_list(from_numpy(x).to(self._device), self._params)
x_list = self._preprocess(x)
out_list = [
jvp(self._model_func(X), self._params, x_list, retain_graph=False)[
0
].flatten(start_dim=1)
for X, _ in self._loop_over_data()
]

return cat(out_list).cpu().numpy()

def _matmat(self, X: ndarray) -> ndarray:
"""Matrix-matrix multiplication.
Args:
X: Matrix for multiplication.
Returns:
Matrix-multiplication result ``mat @ X``.
"""
return column_stack([self @ col for col in X.T])
return self._postprocess(out_list)

0 comments on commit a4d668a

Please sign in to comment.