Skip to content

Commit

Permalink
[ADD] Minimal prototype for KFAC (#43)
Browse files Browse the repository at this point in the history
* [ADD] Prototype for KFAC linear operator

* [DOC] Progress on documentation

* [DOC] Describe KFAC and its limitations

* [FIX] Name of fixture

* [FIX] Darglint

* [FIX] Darglint

See terrencepreilly/darglint#53

* [DOC] Show supported layers in error message

* [DOC] Improve correctness
  • Loading branch information
f-dangel authored Oct 30, 2023
1 parent c0f66e1 commit a331b05
Show file tree
Hide file tree
Showing 9 changed files with 479 additions and 3 deletions.
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from curvlinops.hessian import HessianLinearOperator
from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator
from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator
from curvlinops.kfac import KFACLinearOperator
from curvlinops.papyan2020traces.spectrum import (
LanczosApproximateLogSpectrumCached,
LanczosApproximateSpectrumCached,
Expand All @@ -22,6 +23,7 @@
"GGNLinearOperator",
"EFLinearOperator",
"FisherMCLinearOperator",
"KFACLinearOperator",
"JacobianLinearOperator",
"TransposedJacobianLinearOperator",
"CGInverseLinearOperator",
Expand Down
8 changes: 8 additions & 0 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Contains functionality to analyze Hessian & GGN via matrix-free multiplication."""

from typing import Callable, Iterable, List, Optional, Tuple, Union
from warnings import warn

from backpack.utils.convert_parameters import vector_to_parameter_list
from numpy import (
Expand Down Expand Up @@ -254,6 +255,13 @@ def _preprocess(self, x: ndarray) -> List[Tensor]:
Returns:
Vector in list format.
"""
if x.dtype != self.dtype:
warn(
f"Input vector is {x.dtype}, while linear operator is {self.dtype}. "
+ f"Converting to {self.dtype}."
)
x = x.astype(self.dtype)

x_torch = from_numpy(x).to(self._device)
return vector_to_parameter_list(x_torch, self._params)

Expand Down
4 changes: 3 additions & 1 deletion curvlinops/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ def report_nonclose(
else:
for a1, a2 in zip(array1.flatten(), array2.flatten()):
if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan):
print(f"{a1}{a2}")
print(f"{a1}{a2} (ratio {a1 / a2:.5f})")
print(f"Max: {array1.max():.5f}, {array2.max():.5f}")
print(f"Min: {array1.min():.5f}, {array2.min():.5f}")
raise ValueError("Compared arrays don't match.")
Loading

0 comments on commit a331b05

Please sign in to comment.