From 8d3e44c71a632cbff93dd8a4aaa7ac34716f796f Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 31 Oct 2023 21:19:42 +0100 Subject: [PATCH] Add test for one datum KFAC EF exactness --- test/conftest.py | 14 ++++++++++++++ test/test_kfac.py | 20 ++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/test/conftest.py b/test/conftest.py index 10f0617..9097d4c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -86,3 +86,17 @@ def kfac_expand_exact_one_datum_case( """ case = request.param yield initialize_case(case) + + +@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES) +def kfac_ef_exact_one_datum_case( + request, +) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]: + """Prepare a test case with one datum for which KFAC with empirical gradients equals the EF. + + Yields: + A neural network, the mean-squared error function, a list of parameters, and + a data set. + """ + case = request.param + yield initialize_case(case) diff --git a/test/test_kfac.py b/test/test_kfac.py index b384646..3dc04aa 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -10,6 +10,7 @@ from curvlinops.examples.utils import report_nonclose from curvlinops.ggn import GGNLinearOperator +from curvlinops.gradient_moments import EFLinearOperator from curvlinops.kfac import KFACLinearOperator @@ -68,3 +69,22 @@ def test_kfac_one_datum( rtol = {"sum": 3e-2, "mean": 3e-2}[loss_func.reduction] report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) + + +def test_kfac_ef_one_datum( + kfac_ef_exact_one_datum_case: Tuple[ + Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] + ] +): + model, loss_func, params, data = kfac_ef_exact_one_datum_case + + ef_blocks = [] # list of per-parameter EFs + for param in params: + ef = EFLinearOperator(model, loss_func, [param], data) + ef_blocks.append(ef @ eye(ef.shape[1])) + ef = block_diag(*ef_blocks) + + kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="empirical") + kfac_mat = kfac @ eye(kfac.shape[1]) + + report_nonclose(ef, kfac_mat)