Skip to content

Commit

Permalink
[ADD] Option to specify number of data points
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Feb 4, 2024
1 parent 9caa6bb commit 4d1c85b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
6 changes: 5 additions & 1 deletion curvlinops/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from math import sqrt
from typing import Callable, Iterable, List, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union

from numpy import ndarray
from torch import (
Expand Down Expand Up @@ -120,6 +120,7 @@ def __init__(
check_deterministic: bool = True,
seed: int = 2147483647,
mc_samples: int = 1,
num_data: Optional[int] = None,
):
"""Linear operator for the MC approximation of the Fisher.
Expand Down Expand Up @@ -149,6 +150,8 @@ def __init__(
draw samples at the beginning of each matrix-vector product.
Default: ``2147483647``
mc_samples: Number of samples to use. Default: ``1``.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
Raises:
NotImplementedError: If the loss function differs from ``MSELoss`` or
Expand All @@ -168,6 +171,7 @@ def __init__(
data,
progressbar=progressbar,
check_deterministic=check_deterministic,
num_data=num_data,
)

def _matvec(self, x: ndarray) -> ndarray:
Expand Down
14 changes: 11 additions & 3 deletions curvlinops/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Callable, Iterable, List, Tuple
from typing import Callable, Iterable, List, Optional, Tuple

from backpack.hessianfree.lop import transposed_jacobian_vector_product as vjp
from backpack.hessianfree.rop import jacobian_vector_product as jvp
Expand All @@ -26,6 +26,7 @@ def __init__(
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
num_data: Optional[int] = None,
):
r"""Linear operator for the Jacobian as SciPy linear operator.
Expand All @@ -52,8 +53,10 @@ def __init__(
data: Iterable of batched input-target pairs.
progressbar: Show progress bar.
check_deterministic: Check if model and data are deterministic.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
"""
num_data = sum(t.shape[0] for t, _ in data)
num_data = sum(t.shape[0] for t, _ in data) if num_data is None else num_data
x = next(iter(data))[0].to(self._infer_device(params))
num_outputs = model_func(x).shape[1:].numel()
num_params = sum(p.numel() for p in params)
Expand All @@ -65,6 +68,7 @@ def __init__(
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=(num_data * num_outputs, num_params),
num_data=num_data,
)

def _check_deterministic(self):
Expand Down Expand Up @@ -151,6 +155,7 @@ def __init__(
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
num_data: Optional[int] = None,
):
r"""Linear operator for the transpose Jacobian as SciPy linear operator.
Expand All @@ -177,8 +182,10 @@ def __init__(
data: Iterable of batched input-target pairs.
progressbar: Show progress bar.
check_deterministic: Check if model and data are deterministic.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
"""
num_data = sum(t.shape[0] for t, _ in data)
num_data = sum(t.shape[0] for t, _ in data) if num_data is None else num_data
x = next(iter(data))[0].to(self._infer_device(params))
num_outputs = model_func(x).shape[1:].numel()
num_params = sum(p.numel() for p in params)
Expand All @@ -190,6 +197,7 @@ def __init__(
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=(num_params, num_data * num_outputs),
num_data=num_data,
)

def _check_deterministic(self):
Expand Down
6 changes: 5 additions & 1 deletion curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from functools import partial
from math import sqrt
from typing import Dict, Iterable, List, Set, Tuple, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

from einops import rearrange, reduce
from numpy import ndarray
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
kfac_approx: str = "expand",
loss_average: Union[None, str] = "batch",
separate_weight_and_bias: bool = True,
num_data: Optional[int] = None,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand Down Expand Up @@ -165,6 +166,8 @@ def __init__(
consistently with the loss and the gradient. Default: ``"batch"``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
Raises:
ValueError: If the loss function is not supported.
Expand Down Expand Up @@ -241,6 +244,7 @@ def __init__(
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=shape,
num_data=num_data,
)

def _matvec(self, x: ndarray) -> ndarray:
Expand Down

0 comments on commit 4d1c85b

Please sign in to comment.