Skip to content

Commit

Permalink
Add test for one datum KFAC EF exactness
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Oct 31, 2023
1 parent b3d7463 commit 8d3e44c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,17 @@ def kfac_expand_exact_one_datum_case(
"""
case = request.param
yield initialize_case(case)


@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES)
def kfac_ef_exact_one_datum_case(
request,
) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]:
"""Prepare a test case with one datum for which KFAC with empirical gradients equals the EF.
Yields:
A neural network, the mean-squared error function, a list of parameters, and
a data set.
"""
case = request.param
yield initialize_case(case)
20 changes: 20 additions & 0 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from curvlinops.examples.utils import report_nonclose
from curvlinops.ggn import GGNLinearOperator
from curvlinops.gradient_moments import EFLinearOperator
from curvlinops.kfac import KFACLinearOperator


Expand Down Expand Up @@ -68,3 +69,22 @@ def test_kfac_one_datum(
rtol = {"sum": 3e-2, "mean": 3e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)


def test_kfac_ef_one_datum(
kfac_ef_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_ef_exact_one_datum_case

ef_blocks = [] # list of per-parameter EFs
for param in params:
ef = EFLinearOperator(model, loss_func, [param], data)
ef_blocks.append(ef @ eye(ef.shape[1]))
ef = block_diag(*ef_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="empirical")
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ef, kfac_mat)

0 comments on commit 8d3e44c

Please sign in to comment.