Skip to content

Commit

Permalink
[ADD] Output-parameter Jacobian of an NN, deprecate python 3.7 (#32)
Browse files Browse the repository at this point in the history
* [ADD] Implement and test `_adjoint`

* [FIX] Use column dimension to create random vector

* [ADD] Implement model Jacobian

* [DOC] Mention Jacobian in README

* [REQ] Deprecate python 3.7 to use functorch from inside torch

* [DOC] Use python 3.8 to build the docs

* [FIX] darglint

* [REQ] Use torch>=2 for built-in functorch

* [REF] Implement Jacobian via base class, remove print statements

* [FIX] Documentation and rename

---------

Co-authored-by: Felix Dangel <felix.dangel@vectorinstitute.ai>
  • Loading branch information
f-dangel and f-dangel authored Jul 18, 2023
1 parent eac0711 commit 2ed6c6c
Show file tree
Hide file tree
Showing 16 changed files with 284 additions and 56 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint-black.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lint-darglint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lint-flake8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lint-isort.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lint-pydocstyle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
USING_COVERAGE: '3.8'
strategy:
matrix:
python-version: ["3.7", "3.8"]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sphinx:
configuration: docs/rtd/conf.py

python:
version: 3.7
version: 3.8
install:
- method: pip
path: .
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# <img alt="Logo" src="./docs/rtd/assets/logo.svg" height="90"> scipy linear operators of deep learning matrices in PyTorch

[![Python
3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)
3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/)
![tests](https://github.com/f-dangel/curvature-linear-operators/actions/workflows/test.yaml/badge.svg)
[![Coveralls](https://coveralls.io/repos/github/f-dangel/curvlinops/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/curvlinops)

Expand All @@ -13,6 +13,7 @@ for deep learning matrices, such as
- the Fisher/generalized Gauss-Newton (GGN)
- the Monte-Carlo approximated Fisher
- the uncentered gradient covariance (aka empirical Fisher)
- the output-parameter Jacobian of a neural net

Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU.
The library supports defining these matrices not only on a mini-batch, but
Expand Down
2 changes: 1 addition & 1 deletion black.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.black]
line-length = 88
target-version = ['py36', 'py37', 'py38']
target-version = ['py38', 'py39', 'py310', 'py311']
include = '\.pyi?$'
exclude = '''
(
Expand Down
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from curvlinops.gradient_moments import EFLinearOperator
from curvlinops.hessian import HessianLinearOperator
from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator
from curvlinops.jacobian import JacobianLinearOperator
from curvlinops.papyan2020traces.spectrum import (
LanczosApproximateLogSpectrumCached,
LanczosApproximateSpectrumCached,
Expand All @@ -18,6 +19,7 @@
"GGNLinearOperator",
"EFLinearOperator",
"FisherMCLinearOperator",
"JacobianLinearOperator",
"CGInverseLinearOperator",
"NeumannInverseLinearOperator",
"SubmatrixLinearOperator",
Expand Down
85 changes: 56 additions & 29 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains functionality to analyze Hessian & GGN via matrix-free multiplication."""

from typing import Callable, Iterable, List, Tuple
from typing import Callable, Iterable, List, Optional, Tuple, Union

from backpack.utils.convert_parameters import vector_to_parameter_list
from numpy import (
Expand All @@ -14,31 +14,31 @@
)
from numpy.random import rand
from scipy.sparse.linalg import LinearOperator
from torch import Tensor
from torch import Tensor, cat
from torch import device as torch_device
from torch import from_numpy, tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module, Parameter
from torch.nn.utils import parameters_to_vector
from tqdm import tqdm


class _LinearOperator(LinearOperator):
"""Base class for linear operators of DNN curvature matrices.
"""Base class for linear operators of DNN matrices.
Can be used with SciPy.
"""

def __init__(
self,
model_func: Callable[[Tensor], Tensor],
loss_func: Callable[[Tensor, Tensor], Tensor],
loss_func: Union[Callable[[Tensor, Tensor], Tensor], None],
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
shape: Optional[Tuple[int, int]] = None,
):
"""Linear operator for DNN curvature matrices.
"""Linear operator for DNN matrices.
Note:
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch
Expand All @@ -49,10 +49,13 @@ def __init__(
model_func: A function that maps the mini-batch input X to predictions.
Could be a PyTorch module representing a neural network.
loss_func: Loss function criterion. Maps predictions and mini-batch labels
to a scalar value.
to a scalar value. If ``None``, there is no loss function and the
represented matrix is independent of the loss function.
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)``
where ``D`` is the total number of parameters
progressbar: Show a progressbar during matrix-multiplication.
Default: ``False``.
check_deterministic: Probe that model and data are deterministic, i.e.
Expand All @@ -64,8 +67,10 @@ def __init__(
Raises:
RuntimeError: If the check for deterministic behavior fails.
"""
dim = sum(p.numel() for p in params)
super().__init__(shape=(dim, dim), dtype=float32)
if shape is None:
dim = sum(p.numel() for p in params)
shape = (dim, dim)
super().__init__(shape=shape, dtype=float32)

self._params = params
self._model_func = model_func
Expand Down Expand Up @@ -129,22 +134,37 @@ def _check_deterministic(self):
- Two independent loss/gradient computations yield different results
Note:
Deterministic checks are performed on CPU. We noticed that even when it
passes on CPU, it can fail on GPU; probably due to non-deterministic
Deterministic checks should be performed on CPU. We noticed that even when
it passes on CPU, it can fail on GPU; probably due to non-deterministic
operations.
Raises:
RuntimeError: If non-deterministic behavior is detected.
"""
print("Performing deterministic checks")
v = rand(self.shape[1]).astype(self.dtype)
mat_v1 = self @ v
mat_v2 = self @ v

rtol, atol = 5e-5, 1e-6
if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol):
self.print_nonclose(mat_v1, mat_v2, rtol, atol)
raise RuntimeError("Check for deterministic matvec failed.")

if self._loss_func is None:
return

# only carried out if there is a loss function
grad1, loss1 = self.gradient_and_loss()
grad1, loss1 = parameters_to_vector(grad1).cpu().numpy(), loss1.cpu().numpy()
grad1, loss1 = (
self.flatten_and_concatenate(grad1).cpu().numpy(),
loss1.cpu().numpy(),
)

grad2, loss2 = self.gradient_and_loss()
grad2, loss2 = parameters_to_vector(grad2).cpu().numpy(), loss2.cpu().numpy()

rtol, atol = 5e-5, 1e-6
grad2, loss2 = (
self.flatten_and_concatenate(grad2).cpu().numpy(),
loss2.cpu().numpy(),
)

if not allclose(loss1, loss2, rtol=rtol, atol=atol):
self.print_nonclose(loss1, loss2, rtol, atol)
Expand All @@ -154,16 +174,6 @@ def _check_deterministic(self):
self.print_nonclose(grad1, grad2, rtol, atol)
raise RuntimeError("Check for deterministic gradient failed.")

v = rand(self.shape[0]).astype(self.dtype)
mat_v1 = self @ v
mat_v2 = self @ v

if not allclose(mat_v1, mat_v2, rtol=rtol, atol=atol):
self.print_nonclose(mat_v1, mat_v2, rtol, atol)
raise RuntimeError("Check for deterministic matvec failed.")

print("Deterministic checks passed")

@staticmethod
def print_nonclose(array1: ndarray, array2: ndarray, rtol: float, atol: float):
"""Check if the two arrays are element-wise equal within a tolerance and print
Expand Down Expand Up @@ -245,8 +255,7 @@ def _preprocess(self, x: ndarray) -> List[Tensor]:
x_torch = from_numpy(x).to(self._device)
return vector_to_parameter_list(x_torch, self._params)

@staticmethod
def _postprocess(x_list: List[Tensor]) -> ndarray:
def _postprocess(self, x_list: List[Tensor]) -> ndarray:
"""Convert torch list format to flat numpy array.
Args:
Expand All @@ -255,7 +264,7 @@ def _postprocess(x_list: List[Tensor]) -> ndarray:
Returns:
Flat vector.
"""
return parameters_to_vector([x.contiguous() for x in x_list]).cpu().numpy()
return self.flatten_and_concatenate(x_list).cpu().numpy()

def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]:
"""Yield batches of the data set, loaded to the correct device.
Expand All @@ -279,7 +288,13 @@ def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]:
Returns:
Gradient and loss on the data set.
Raises:
ValueError: If there is no loss function.
"""
if self._loss_func is None:
raise ValueError("No loss function specified.")

total_loss = tensor([0.0], device=self._device)
total_grad = [zeros_like(p) for p in self._params]

Expand Down Expand Up @@ -317,3 +332,15 @@ def _get_normalization_factor(self, X: Tensor, y: Tensor) -> float:
return X.shape[0] / self._N_data
else:
raise ValueError("Loss must have reduction 'mean' or 'sum'.")

@staticmethod
def flatten_and_concatenate(tensors: List[Tensor]) -> Tensor:
"""Flatten then concatenate all tensors in a list.
Args:
tensors: List of tensors.
Returns:
Concatenated flattened tensors.
"""
return cat([t.flatten() for t in tensors])
Loading

0 comments on commit 2ed6c6c

Please sign in to comment.