Skip to content

Commit

Permalink
Fix tests for FOOF+eigenvalue correction
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Sep 17, 2024
1 parent b309300 commit 725c413
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,14 +511,22 @@ def test_multi_dim_output(

# KFAC for deep linear network with 4d input and output
params = list(model.parameters())
kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type=fisher_type,
correct_eigenvalues=correct_eigenvalues,
)
context = (
raises(ValueError, match="eigenvalues")
if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY
else nullcontext()
) # EKFAC for FOOF is currently not supported
with context:
kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type=fisher_type,
correct_eigenvalues=correct_eigenvalues,
)
if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY:
return
kfac_mat = kfac @ eye(kfac.shape[1])

# KFAC for deep linear network with 4d input and equivalent 2d output
Expand Down Expand Up @@ -598,15 +606,26 @@ def test_expand_setting_scaling(
params = list(model.parameters())

# KFAC with sum reduction
params = list(model.parameters())
loss_func = loss(reduction="sum").to(dev)
kfac_sum = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type=fisher_type,
correct_eigenvalues=correct_eigenvalues,
)

context = (
raises(ValueError, match="eigenvalues")
if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY
else nullcontext()
) # EKFAC for FOOF is currently not supported
with context:
kfac_sum = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type=fisher_type,
correct_eigenvalues=correct_eigenvalues,
)
if correct_eigenvalues and fisher_type == FisherType.FORWARD_ONLY:
return

# FOOF does not scale the gradient covariances, even when using a mean reduction
if fisher_type != FisherType.FORWARD_ONLY:
# Simulate a mean reduction by manually scaling the gradient covariances
Expand Down

0 comments on commit 725c413

Please sign in to comment.