Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Replicate base class, but inherit from PyTorch linear operator #142

Merged
merged 9 commits into from
Nov 4, 2024
350 changes: 348 additions & 2 deletions curvlinops/_torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from __future__ import annotations

from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, MutableMapping, Optional, Tuple, Union

import numpy
from scipy.sparse.linalg import LinearOperator
from torch import Size, Tensor, as_tensor, cat, device, dtype
from torch import Size, Tensor, as_tensor, cat, device, dtype, rand, tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module, Parameter
from tqdm import tqdm

from curvlinops.utils import allclose_report


class PyTorchLinearOperator:
Expand Down Expand Up @@ -347,3 +352,344 @@ def f_scipy(X: numpy.ndarray) -> numpy.ndarray:
return AX_torch.detach().cpu().numpy().astype(X_dtype)

return f_scipy


class CurvatureLinearOperator(PyTorchLinearOperator):
"""Base class for PyTorch linear operators of deep learning curvature matrices.

To implement a new curvature linear operator, subclass this class and implement
the ``_matmat_batch`` and ``_adjoint`` methods.

Attributes:
SUPPORTS_BLOCKS: Whether the linear operator supports multiplication with
a block-diagonal approximation rather than the full matrix.
Default: ``False``.
"""

SUPPORTS_BLOCKS: bool = False

def __init__(
self,
model_func: Callable[[Union[Tensor, MutableMapping]], Tensor],
loss_func: Union[Callable[[Tensor, Tensor], Tensor], None],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
in_shape: Optional[List[Tuple[int, ...]]] = None,
out_shape: Optional[List[Tuple[int, ...]]] = None,
num_data: Optional[int] = None,
block_sizes: Optional[List[int]] = None,
batch_size_fn: Optional[Callable[[Union[MutableMapping, Tensor]], int]] = None,
):
"""Linear operator for curvature matrices of empirical risks.

Note:
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch
input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the
mini-batch labels y.

Args:
model_func: A function that maps the mini-batch input X to predictions.
Could be a PyTorch module representing a neural network.
loss_func: Loss function criterion. Maps predictions and mini-batch labels
to a scalar value. If ``None``, there is no loss function and the
represented matrix is independent of the loss function.
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. Note that ``X``
could be a ``dict`` or ``UserDict``; this is useful for custom models.
In this case, you must (i) specify the ``batch_size_fn`` argument, and
(ii) take care of preprocessing like ``X.to(device)`` inside of your
``model.forward()`` function.
progressbar: Show a progressbar during matrix-multiplication.
Default: ``False``.
check_deterministic: Probe that model and data are deterministic, i.e.
that the data does not use ``drop_last`` or data augmentation. Also, the
model's forward pass could depend on the order in which mini-batches
are presented (BatchNorm, Dropout). Default: ``True``. This is a
safeguard, only turn it off if you know what you are doing.
in_shape: Shapes of the linear operator's input tensor product space.
If ``None``, will use the shapes of ``params``.
out_shape: Shapes of the linear operator's output tensor product space.
If ``None``, will use the shapes of ``params``.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
block_sizes: This argument will be ignored if the linear operator does not
support blocks. List of integers indicating the number of
``nn.Parameter``s forming a block. Entries must sum to ``len(params)``.
For instance ``[len(params)]`` considers the full matrix, while
``[1, 1, ...]`` corresponds to a block diagonal approximation where
each parameter forms its own block.
batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this
needs to be specified. The intended behavior is to consume the first
entry of the iterates from ``data`` and return their batch size.

Raises:
RuntimeError: If the check for deterministic behavior fails.
ValueError: If ``block_sizes`` is specified but the linear operator does not
support blocks.
ValueError: If the sum of blocks does not equal the number of parameters.
ValueError: If any block size is not positive.
ValueError: If ``X`` is not a tensor and ``batch_size_fn`` is not specified.
"""
if isinstance(next(iter(data))[0], MutableMapping) and batch_size_fn is None:
raise ValueError(
"When using dict-like custom data, `batch_size_fn` is required."
)

in_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape
out_shape = [tuple(p.shape) for p in params] if in_shape is None else in_shape
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(in_shape, out_shape)

self._params = params
if block_sizes is not None:
if not self.SUPPORTS_BLOCKS:
raise ValueError(
"Block sizes were specified but operator does not support blocking."
)
if sum(block_sizes) != len(params):
raise ValueError("Sum of blocks must equal the number of parameters.")
if any(s <= 0 for s in block_sizes):
raise ValueError("Block sizes must be positive.")
self._block_sizes = [len(params)] if block_sizes is None else block_sizes

self._model_func = model_func
self._loss_func = loss_func
self._data = data
self._device = self._infer_device()
self._progressbar = progressbar
self._batch_size_fn = (
(lambda X: X.shape[0]) if batch_size_fn is None else batch_size_fn
)

self._N_data = (
sum(
self._batch_size_fn(X)
for (X, _) in self._loop_over_data(desc="_N_data")
)
if num_data is None
else num_data
)

if check_deterministic:
old_device = self._device
self.to_device(device("cpu"))
try:
self._check_deterministic()
except RuntimeError as e:
raise e
finally:
self.to_device(old_device)

def _matmat(self, M: List[Tensor]) -> List[Tensor]:
"""Matrix-matrix multiplication.

Args:
M: Matrix for multiplication in tensor list format. Assume the linear
operator's input tensor product space consists of shapes ``[*N1],
[*N2], ...``. Then, ``M`` is a list of tensors with shapes
``[*N1, K], [*N2, K], ...`` with ``K`` the number of columns.

Returns:
Matrix-multiplication result ``mat @ M`` in tensor list format.
Has same format as the input matrix, but lives in the linear operator's
output tensor product space.
"""
AM = [zeros_like(m) for m in M]

for X, y in self._loop_over_data(desc="_matmat"):
normalization_factor = self._get_normalization_factor(X, y)
for AM_current, current in zip(AM, self._matmat_batch(X, y, M)):
AM_current.add_(current, alpha=normalization_factor)

return AM

def _matmat_batch(
self, X: Union[MutableMapping, Tensor], y: Tensor, M: List[Tensor]
) -> Tuple[Tensor]:
"""Apply the mini-batch matrix to a vector.

Args:
X: Input to the DNN.
y: Ground truth.
M: Matrix in list format (same shape as trainable model parameters with
additional trailing dimension of size number of columns).

Returns: # noqa: D402
Result of matrix-multiplication in list format.
f-dangel marked this conversation as resolved.
Show resolved Hide resolved

Raises:
NotImplementedError: Must be implemented by descendants.
"""
raise NotImplementedError

def _loop_over_data(
self, desc: Optional[str] = None, add_device_to_desc: bool = True
) -> Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]]:
"""Yield batches of the data set, loaded to the correct device.

Args:
desc: Description for the progress bar. Will be ignored if progressbar is
disabled.
add_device_to_desc: Whether to add the device to the description.
Default: ``True``.

Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)
f-dangel marked this conversation as resolved.
Show resolved Hide resolved

if self._progressbar:
desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}"
if add_device_to_desc:
desc = f"{desc} (on {str(self._device)})"
data_iter = tqdm(data_iter, desc=desc)

for X, y in data_iter:
# Assume everything is handled by the model
# if `X` is a custom data format
if isinstance(X, Tensor):
X = X.to(self._device)
y = y.to(self._device)
yield (X, y)

def _get_normalization_factor(
self, X: Union[MutableMapping, Tensor], y: Tensor
) -> float:
"""Return the correction factor for correct normalization over the data set.

Args:
X: Input to the DNN.
y: Ground truth.

Returns:
Normalization factor
"""
return {"sum": 1.0, "mean": self._batch_size_fn(X) / self._N_data}[
self._loss_func.reduction
]

###############################################################################
# DETERMINISTIC CHECKS #
###############################################################################

def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]:
"""Evaluate the gradient and loss on the data.

(Not really part of the LinearOperator interface.)

Returns:
Gradient and loss on the data set.

Raises:
ValueError: If there is no loss function.
"""
if self._loss_func is None:
raise ValueError("No loss function specified.")

total_loss = tensor([0.0], device=self._device, dtype=self._infer_dtype())
total_grad = [zeros_like(p) for p in self._params]

for X, y in self._loop_over_data(desc="gradient_and_loss"):
loss = self._loss_func(self._model_func(X), y)
normalization_factor = self._get_normalization_factor(X, y)

for grad_param, current in zip(total_grad, grad(loss, self._params)):
grad_param.add_(current, alpha=normalization_factor)
total_loss.add_(loss.detach(), alpha=normalization_factor)

return total_grad, total_loss

def to_device(self, device: device):
"""Load linear operator to a device (inplace).

Args:
device: Target device.
"""
self._device = device

if isinstance(self._model_func, Module):
self._model_func = self._model_func.to(self._device)
self._params = [p.to(device) for p in self._params]

if isinstance(self._loss_func, Module):
self._loss_func = self._loss_func.to(self._device)

def _check_deterministic(self):
"""Check that the linear operator is deterministic.

Non-deterministic behavior is detected if:

- Two independent applications of matvec onto the same vector yield different
results
- Two independent loss/gradient computations yield different results

Note:
Deterministic checks should be performed on CPU. We noticed that even when
it passes on CPU, it can fail on GPU; probably due to non-deterministic
operations.

# TODO This can be impractical if the CPU is less powerful than the GPU.
# Also, it would be desirable to confirm deterministic behavior on the compute
# device that will be used for matvecs. Try refactoring by using device-agnostic
# tolerances. Then remove the ``to_device`` method.

Raises:
RuntimeError: If non-deterministic behavior is detected.
"""
v = rand(self.shape[1], device=self._device, dtype=self._infer_dtype())
Av1 = self @ v
Av2 = self @ v

rtol, atol = 5e-5, 1e-6
if not allclose_report(Av1, Av2, rtol=rtol, atol=atol):
raise RuntimeError("Check for deterministic matvec failed.")

if self._loss_func is None:
return

# only carried out if there is a loss function
grad1, loss1 = self.gradient_and_loss()
grad2, loss2 = self.gradient_and_loss()

if not allclose_report(loss1, loss2, rtol=rtol, atol=atol):
raise RuntimeError("Check for deterministic loss failed.")

if len(grad1) != len(grad2) or any(
not allclose_report(g1, g2, atol=atol, rtol=rtol)
for g1, g2 in zip(grad1, grad2)
):
raise RuntimeError("Check for deterministic gradient failed.")

###############################################################################
# SCIPY EXPORT #
###############################################################################

def _infer_device(self) -> device:
"""Infer the device onto which to load NumPy vectors for the matrix multiply.

Returns:
Inferred device.

Raises:
RuntimeError: If the device cannot be inferred.
"""
devices = {p.device for p in self._params}
if len(devices) != 1:
raise RuntimeError(f"Could not infer device. Parameters live on {devices}.")
return devices.pop()

def _infer_dtype(self) -> dtype:
"""Infer the data type to which to load NumPy vectors for the matrix multiply.

Returns:
Inferred data type.

Raises:
RuntimeError: If the data type cannot be inferred.
"""
dtypes = {p.dtype for p in self._params}
if len(dtypes) != 1:
raise RuntimeError(f"Could not infer data type. Parameters have {dtypes}.")
return dtypes.pop()
Loading