Skip to content

Commit

Permalink
Merge branch 'kfac-inv' of github.com:f-dangel/curvlinops into kfac-inv
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Feb 6, 2024
2 parents 72cab2f + 41a6cf9 commit 49014e9
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
9 changes: 7 additions & 2 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
progressbar: bool = False,
check_deterministic: bool = True,
shape: Optional[Tuple[int, int]] = None,
num_data: Optional[int] = None,
):
"""Linear operator for DNN matrices.
Expand All @@ -64,6 +65,8 @@ def __init__(
safeguard, only turn it off if you know what you are doing.
shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)``
where ``D`` is the total number of parameters
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
Raises:
RuntimeError: If the check for deterministic behavior fails.
Expand All @@ -80,8 +83,10 @@ def __init__(
self._device = self._infer_device(self._params)
self._progressbar = progressbar

self._N_data = sum(
X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data")
self._N_data = (
sum(X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data"))
if num_data is None
else num_data
)

if check_deterministic:
Expand Down
6 changes: 5 additions & 1 deletion curvlinops/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from math import sqrt
from typing import Callable, Iterable, List, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union

from numpy import ndarray
from torch import (
Expand Down Expand Up @@ -121,6 +121,7 @@ def __init__(
check_deterministic: bool = True,
seed: int = 2147483647,
mc_samples: int = 1,
num_data: Optional[int] = None,
):
"""Linear operator for the MC approximation of the Fisher.
Expand Down Expand Up @@ -150,6 +151,8 @@ def __init__(
draw samples at the beginning of each matrix-vector product.
Default: ``2147483647``
mc_samples: Number of samples to use. Default: ``1``.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
Raises:
NotImplementedError: If the loss function differs from ``MSELoss`` or
Expand All @@ -169,6 +172,7 @@ def __init__(
data,
progressbar=progressbar,
check_deterministic=check_deterministic,
num_data=num_data,
)

def _matvec(self, x: ndarray) -> ndarray:
Expand Down
14 changes: 11 additions & 3 deletions curvlinops/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Callable, Iterable, List, Tuple
from typing import Callable, Iterable, List, Optional, Tuple

from backpack.hessianfree.lop import transposed_jacobian_vector_product as vjp
from backpack.hessianfree.rop import jacobian_vector_product as jvp
Expand All @@ -26,6 +26,7 @@ def __init__(
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
num_data: Optional[int] = None,
):
r"""Linear operator for the Jacobian as SciPy linear operator.
Expand All @@ -52,8 +53,10 @@ def __init__(
data: Iterable of batched input-target pairs.
progressbar: Show progress bar.
check_deterministic: Check if model and data are deterministic.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
"""
num_data = sum(t.shape[0] for t, _ in data)
num_data = sum(t.shape[0] for t, _ in data) if num_data is None else num_data
x = next(iter(data))[0].to(self._infer_device(params))
num_outputs = model_func(x).shape[1:].numel()
num_params = sum(p.numel() for p in params)
Expand All @@ -65,6 +68,7 @@ def __init__(
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=(num_data * num_outputs, num_params),
num_data=num_data,
)

def _check_deterministic(self):
Expand Down Expand Up @@ -151,6 +155,7 @@ def __init__(
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
num_data: Optional[int] = None,
):
r"""Linear operator for the transpose Jacobian as SciPy linear operator.
Expand All @@ -177,8 +182,10 @@ def __init__(
data: Iterable of batched input-target pairs.
progressbar: Show progress bar.
check_deterministic: Check if model and data are deterministic.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
"""
num_data = sum(t.shape[0] for t, _ in data)
num_data = sum(t.shape[0] for t, _ in data) if num_data is None else num_data
x = next(iter(data))[0].to(self._infer_device(params))
num_outputs = model_func(x).shape[1:].numel()
num_params = sum(p.numel() for p in params)
Expand All @@ -190,6 +197,7 @@ def __init__(
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=(num_params, num_data * num_outputs),
num_data=num_data,
)

def _check_deterministic(self):
Expand Down
6 changes: 5 additions & 1 deletion curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from functools import partial
from math import sqrt
from typing import Dict, Iterable, List, Set, Tuple, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

from einops import rearrange, reduce
from numpy import ndarray
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
kfac_approx: str = "expand",
loss_average: Union[None, str] = "batch",
separate_weight_and_bias: bool = True,
num_data: Optional[int] = None,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand Down Expand Up @@ -165,6 +166,8 @@ def __init__(
consistently with the loss and the gradient. Default: ``"batch"``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
Raises:
ValueError: If the loss function is not supported.
Expand Down Expand Up @@ -241,6 +244,7 @@ def __init__(
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=shape,
num_data=num_data,
)

def _matvec(self, x: ndarray) -> ndarray:
Expand Down

0 comments on commit 49014e9

Please sign in to comment.