From 725c413ace484ae1aae1620a82f85748e7e97fc6 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 17 Sep 2024 00:09:57 -0400 Subject: [PATCH] Fix tests for FOOF+eigenvalue correction --- test/test_kfac.py | 51 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index 98426d6..02861b5 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -511,14 +511,22 @@ def test_multi_dim_output( # KFAC for deep linear network with 4d input and output params = list(model.parameters()) - kfac = KFACLinearOperator( - model, - loss_func, - params, - data, - fisher_type=fisher_type, - correct_eigenvalues=correct_eigenvalues, - ) + context = ( + raises(ValueError, match="eigenvalues") + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + kfac = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY: + return kfac_mat = kfac @ eye(kfac.shape[1]) # KFAC for deep linear network with 4d input and equivalent 2d output @@ -598,15 +606,26 @@ def test_expand_setting_scaling( params = list(model.parameters()) # KFAC with sum reduction + params = list(model.parameters()) loss_func = loss(reduction="sum").to(dev) - kfac_sum = KFACLinearOperator( - model, - loss_func, - params, - data, - fisher_type=fisher_type, - correct_eigenvalues=correct_eigenvalues, - ) + + context = ( + raises(ValueError, match="eigenvalues") + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY + else nullcontext() + ) # EKFAC for FOOF is currently not supported + with context: + kfac_sum = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type=fisher_type, + correct_eigenvalues=correct_eigenvalues, + ) + if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY: + return + # FOOF does not scale the gradient covariances, even when using a mean reduction if fisher_type != FisherType.FORWARD_ONLY: # Simulate a mean reduction by manually scaling the gradient covariances