Skip to content

Commit

Permalink
Tweak tolerances for KFAC-MC tests on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 27, 2024
1 parent fc8aa1e commit 1f44ee3
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_kfac_mc(
).to_scipy()
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
atol = {"sum": 5e-1, "mean": 1e-2}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_kfac_mc_weight_sharing(
).to_scipy()
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
atol = {"sum": 5e-1, "mean": 1e-2}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)
Expand Down Expand Up @@ -371,10 +371,10 @@ def test_kfac_inplace_activations(dev: device):
dev: The device to run the test on.
"""
manual_seed(0)
model = Sequential(Linear(6, 3), ReLU(inplace=True), Linear(3, 2)).to(dev)
model = Sequential(Linear(4, 3), ReLU(inplace=True), Linear(3, 2)).to(dev)
loss_func = MSELoss().to(dev)
batch_size = 1
data = [(rand(batch_size, 6), regression_targets((batch_size, 2)))]
data = [(rand(batch_size, 4), regression_targets((batch_size, 2)))]
params = list(model.parameters())

# 1) compare KFAC and GGN
Expand All @@ -385,8 +385,8 @@ def test_kfac_inplace_activations(dev: device):
).to_scipy()
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 2e-3}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]
atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 4e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)

Expand Down

0 comments on commit 1f44ee3

Please sign in to comment.