From a4d668a8228276529e44ec4cf90b31e73d87b071 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 18 Jul 2023 15:14:12 -0400 Subject: [PATCH] [REF] Implement Jacobian via base class, remove print statements --- curvlinops/_base.py | 87 +++++++++++++++-------- curvlinops/jacobian.py | 152 ++++++++++------------------------------- 2 files changed, 93 insertions(+), 146 deletions(-) diff --git a/curvlinops/_base.py b/curvlinops/_base.py index f97bc00..7f332ee 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -1,6 +1,6 @@ """Contains functionality to analyze Hessian & GGN via matrix-free multiplication.""" -from typing import Callable, Iterable, List, Tuple +from typing import Callable, Iterable, List, Optional, Tuple, Union from backpack.utils.convert_parameters import vector_to_parameter_list from numpy import ( @@ -14,17 +14,16 @@ ) from numpy.random import rand from scipy.sparse.linalg import LinearOperator -from torch import Tensor +from torch import Tensor, cat from torch import device as torch_device from torch import from_numpy, tensor, zeros_like from torch.autograd import grad from torch.nn import Module, Parameter -from torch.nn.utils import parameters_to_vector from tqdm import tqdm class _LinearOperator(LinearOperator): - """Base class for linear operators of DNN curvature matrices. + """Base class for linear operators of DNN matrices. Can be used with SciPy. """ @@ -32,13 +31,14 @@ class _LinearOperator(LinearOperator): def __init__( self, model_func: Callable[[Tensor], Tensor], - loss_func: Callable[[Tensor, Tensor], Tensor], + loss_func: Union[Callable[[Tensor, Tensor], Tensor], None], params: List[Parameter], data: Iterable[Tuple[Tensor, Tensor]], progressbar: bool = False, check_deterministic: bool = True, + shape: Optional[Tuple[int, int]] = None, ): - """Linear operator for DNN curvature matrices. + """Linear operator for DNN matrices. Note: f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch @@ -49,10 +49,13 @@ def __init__( model_func: A function that maps the mini-batch input X to predictions. Could be a PyTorch module representing a neural network. loss_func: Loss function criterion. Maps predictions and mini-batch labels - to a scalar value. + to a scalar value. If ``None``, there is no loss function and the + represented matrix is independent of the loss function. params: List of differentiable parameters used by the prediction function. data: Source from which mini-batches can be drawn, for instance a list of mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. + shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)`` + where ``D`` is the total number of parameters progressbar: Show a progressbar during matrix-multiplication. Default: ``False``. check_deterministic: Probe that model and data are deterministic, i.e. @@ -64,8 +67,10 @@ def __init__( Raises: RuntimeError: If the check for deterministic behavior fails. """ - dim = sum(p.numel() for p in params) - super().__init__(shape=(dim, dim), dtype=float32) + if shape is None: + dim = sum(p.numel() for p in params) + shape = (dim, dim) + super().__init__(shape=shape, dtype=float32) self._params = params self._model_func = model_func @@ -74,7 +79,7 @@ def __init__( self._device = self._infer_device(self._params) self._progressbar = progressbar - self._N_data = sum(X.shape[0] for (X, _) in self._loop_over_data()) + self._num_data = sum(X.shape[0] for (X, _) in self._loop_over_data()) if check_deterministic: old_device = self._device @@ -129,22 +134,37 @@ def _check_deterministic(self): - Two independent loss/gradient computations yield different results Note: - Deterministic checks are performed on CPU. We noticed that even when it - passes on CPU, it can fail on GPU; probably due to non-deterministic + Deterministic checks should be performed on CPU. We noticed that even when + it passes on CPU, it can fail on GPU; probably due to non-deterministic operations. Raises: RuntimeError: If non-deterministic behavior is detected. """ - print("Performing deterministic checks") + v = rand(self.shape[1]).astype(self.dtype) + mat_v1 = self @ v + mat_v2 = self @ v + + rtol, atol = 5e-5, 1e-6 + if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): + self.print_nonclose(mat_v1, mat_v2, rtol, atol) + raise RuntimeError("Check for deterministic matvec failed.") + if self._loss_func is None: + return + + # only carried out if there is a loss function grad1, loss1 = self.gradient_and_loss() - grad1, loss1 = parameters_to_vector(grad1).cpu().numpy(), loss1.cpu().numpy() + grad1, loss1 = ( + self.flatten_and_concatenate(grad1).cpu().numpy(), + loss1.cpu().numpy(), + ) grad2, loss2 = self.gradient_and_loss() - grad2, loss2 = parameters_to_vector(grad2).cpu().numpy(), loss2.cpu().numpy() - - rtol, atol = 5e-5, 1e-6 + grad2, loss2 = ( + self.flatten_and_concatenate(grad2).cpu().numpy(), + loss2.cpu().numpy(), + ) if not allclose(loss1, loss2, rtol=rtol, atol=atol): self.print_nonclose(loss1, loss2, rtol, atol) @@ -154,16 +174,6 @@ def _check_deterministic(self): self.print_nonclose(grad1, grad2, rtol, atol) raise RuntimeError("Check for deterministic gradient failed.") - v = rand(self.shape[1]).astype(self.dtype) - mat_v1 = self @ v - mat_v2 = self @ v - - if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): - self.print_nonclose(mat_v1, mat_v2, rtol, atol) - raise RuntimeError("Check for deterministic matvec failed.") - - print("Deterministic checks passed") - @staticmethod def print_nonclose(array1: ndarray, array2: ndarray, rtol: float, atol: float): """Check if the two arrays are element-wise equal within a tolerance and print @@ -245,8 +255,7 @@ def _preprocess(self, x: ndarray) -> List[Tensor]: x_torch = from_numpy(x).to(self._device) return vector_to_parameter_list(x_torch, self._params) - @staticmethod - def _postprocess(x_list: List[Tensor]) -> ndarray: + def _postprocess(self, x_list: List[Tensor]) -> ndarray: """Convert torch list format to flat numpy array. Args: @@ -255,7 +264,7 @@ def _postprocess(x_list: List[Tensor]) -> ndarray: Returns: Flat vector. """ - return parameters_to_vector([x.contiguous() for x in x_list]).cpu().numpy() + return self.flatten_and_concatenate(x_list).cpu().numpy() def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]: """Yield batches of the data set, loaded to the correct device. @@ -279,7 +288,13 @@ def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]: Returns: Gradient and loss on the data set. + + Raises: + ValueError: If there is no loss function. """ + if self._loss_func is None: + raise ValueError("No loss function specified.") + total_loss = tensor([0.0], device=self._device) total_grad = [zeros_like(p) for p in self._params] @@ -317,3 +332,15 @@ def _get_normalization_factor(self, X: Tensor, y: Tensor) -> float: return X.shape[0] / self._N_data else: raise ValueError("Loss must have reduction 'mean' or 'sum'.") + + @staticmethod + def flatten_and_concatenate(tensors: List[Tensor]) -> Tensor: + """Flatten then concatenate all tensors in a list. + + Args: + tensors: List of tensors. + + Returns: + Concatenated flattened tensors. + """ + return cat([t.flatten() for t in tensors]) diff --git a/curvlinops/jacobian.py b/curvlinops/jacobian.py index bb0e905..58b18e6 100644 --- a/curvlinops/jacobian.py +++ b/curvlinops/jacobian.py @@ -3,20 +3,14 @@ from typing import Callable, Iterable, List, Tuple from backpack.hessianfree.rop import jacobian_vector_product as jvp -from backpack.utils.convert_parameters import vector_to_parameter_list -from numpy import allclose, column_stack, float32, ndarray -from numpy.random import rand -from scipy.sparse.linalg import LinearOperator -from torch import Tensor, cat -from torch import device as torch_device -from torch import from_numpy, no_grad -from torch.nn import Module, Parameter -from tqdm import tqdm +from numpy import allclose, ndarray +from torch import Tensor, no_grad +from torch.nn import Parameter from curvlinops._base import _LinearOperator -class JacobianLinearOperator(LinearOperator): +class JacobianLinearOperator(_LinearOperator): """Linear operator for the Jacobian. Can be used with SciPy. @@ -63,114 +57,51 @@ def __init__( x = next(iter(data))[0] num_outputs = model_func(x).shape[1:].numel() num_params = sum(p.numel() for p in params) - super().__init__(shape=(num_data * num_outputs, num_params), dtype=float32) - - self._params = params - self._model_func = model_func - self._data = data - self._device = _LinearOperator._infer_device(self._params) - self._progressbar = progressbar - - if check_deterministic: - old_device = self._device - self.to_device(torch_device("cpu")) - try: - self._check_deterministic() - except RuntimeError as e: - raise e - finally: - self.to_device(old_device) + super().__init__( + model_func, + None, + params, + data, + progressbar=progressbar, + check_deterministic=check_deterministic, + shape=(num_data * num_outputs, num_params), + ) def _check_deterministic(self): """Verify that the linear operator is deterministic. - - Checks that the data is loaded in a deterministic fashion (e.g. shuffling). - - Checks that the model is deterministic (e.g. dropout). - - Checks that matrix-vector multiplication with a single random vector is - deterministic. + In addition to the checks from the base class, checks that the model + predictions and data are always the same (loaded in the same order, and + only deterministic operations in the network. Note: - Deterministic checks are performed on CPU. We noticed that even when it - passes on CPU, it can fail on GPU; probably due to non-deterministic + Deterministic checks should be performed on CPU. We noticed that even when + it passes on CPU, it can fail on GPU; probably due to non-deterministic operations. Raises: RuntimeError: If the linear operator is not deterministic. """ - print("Performing deterministic checks") - - pred1, y1 = self.predictions_and_targets() - pred1, y1 = pred1.cpu().numpy(), y1.cpu().numpy() - pred2, y2 = self.predictions_and_targets() - pred2, y2 = pred2.cpu().numpy(), y2.cpu().numpy() + super()._check_deterministic() rtol, atol = 5e-5, 1e-6 - if not allclose(y1, y2, rtol=rtol, atol=atol): - _LinearOperator.print_nonclose(y1, y2, rtol=rtol, atol=atol) - raise RuntimeError( - "Data is not loaded in a deterministic fashion." - + " Make sure shuffling is turned off." - ) - if not allclose(pred1, pred2, rtol=rtol, atol=atol): - _LinearOperator.print_nonclose(pred1, pred2, rtol=rtol, atol=atol) - raise RuntimeError( - "Model predictions are not deterministic." - + " Make sure dropout and batch normalization are in eval mode." - ) - - v = rand(self.shape[1]).astype(self.dtype) - mat_v1 = self @ v - mat_v2 = self @ v - if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol): - _LinearOperator.print_nonclose(mat_v1, mat_v2, rtol, atol) - raise RuntimeError("Check for deterministic matvec failed.") - - def to_device(self, device: torch_device): - """Load linear operator to a device (inplace). - - Args: - device: Target device. - """ - self._device = device - - if isinstance(self._model_func, Module): - self._model_func = self._model_func.to(self._device) - self._params = [p.to(device) for p in self._params] - - def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]: - """Yield batches of the data set, loaded to the correct device. - - Yields: - Mini-batches ``(X, y)``. - """ - data_iter = iter(self._data) - - if self._progressbar: - data_iter = tqdm(data_iter, desc="matvec") - - for X, y in data_iter: - X, y = X.to(self._device), y.to(self._device) - yield (X, y) - - def predictions_and_targets(self) -> Tuple[Tensor, Tensor]: - """Return the batch-concatenated model predictions and labels. - - Returns: - Batch-concatenated model predictions of shape ``[N, *]`` where ``*`` - denotes the model's output shape (for instance ``* = C``). - Batch-concatenated labels of shape ``[N, *]``, where ``*`` denotes - the dimension of a label. - """ - total_pred, total_y = [], [] - with no_grad(): - for X, y in self._loop_over_data(): - total_pred.append(self._model_func(X)) - total_y.append(y) - assert total_pred and total_y - - return cat(total_pred), cat(total_y) + for (X1, y1), (X2, y2) in zip( + self._loop_over_data(), self._loop_over_data() + ): + pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy() + pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy() + X1, X2 = X1.cpu().numpy(), X2.cpu().numpy() + + if not allclose(X1, X2) or not allclose(y1, y2): + self.print_nonclose(X1, X2, rtol=rtol, atol=atol) + self.print_nonclose(y1, y2, rtol=rtol, atol=atol) + raise RuntimeError("Non-deterministic data loading detected.") + + if not allclose(pred1, pred2): + self.print_nonclose(pred1, pred2, rtol=rtol, atol=atol) + raise RuntimeError("Non-deterministic model detected.") def _matvec(self, x: ndarray) -> ndarray: """Loop over all batches in the data and apply the matrix to vector x. @@ -181,7 +112,7 @@ def _matvec(self, x: ndarray) -> ndarray: Returns: Matrix-multiplication result ``mat @ x``. """ - x_list = vector_to_parameter_list(from_numpy(x).to(self._device), self._params) + x_list = self._preprocess(x) out_list = [ jvp(self._model_func(X), self._params, x_list, retain_graph=False)[ 0 @@ -189,15 +120,4 @@ def _matvec(self, x: ndarray) -> ndarray: for X, _ in self._loop_over_data() ] - return cat(out_list).cpu().numpy() - - def _matmat(self, X: ndarray) -> ndarray: - """Matrix-matrix multiplication. - - Args: - X: Matrix for multiplication. - - Returns: - Matrix-multiplication result ``mat @ X``. - """ - return column_stack([self @ col for col in X.T]) + return self._postprocess(out_list)