Skip to content

Commit

Permalink
[CI | FIX] Improve and run Fisher-MC tests on main
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 23, 2024
1 parent e30f713 commit 3084ada
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ jobs:
python -m pip install --upgrade pip
make install-test
- name: Run test
if: contains('refs/heads/master refs/heads/development', github.ref)
if: contains('refs/heads/main', github.ref)
run: |
make test
- name: Run test-light
if: contains('refs/heads/master refs/heads/development', github.ref) != 1
if: contains('refs/heads/main', github.ref) != 1
run: |
make test-light
Expand Down
4 changes: 4 additions & 0 deletions curvlinops/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def report_nonclose(
if allclose(array1, array2, rtol=rtol, atol=atol, equal_nan=equal_nan):
print("Compared arrays match.")
else:
nonclose_entries = 0
for a1, a2 in zip(array1.flatten(), array2.flatten()):
if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan):
print(f"{a1}{a2} (ratio {a1 / a2:.5f})")
nonclose_entries += 1
print(f"Max: {array1.max():.5f}, {array2.max():.5f}")
print(f"Min: {array1.min():.5f}, {array2.min():.5f}")
print(f"Nonclose entries: {nonclose_entries} / {array1.size}")
print(f"rtol = {rtol}, atol= {atol}")
raise ValueError("Compared arrays don't match.")
20 changes: 10 additions & 10 deletions test/test_fisher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Contains tests for ``curvlinops/fisher.py``."""

from collections.abc import MutableMapping
from contextlib import suppress
from contextlib import redirect_stdout, suppress

from numpy import random, zeros_like
from pytest import mark, raises
Expand All @@ -10,11 +10,11 @@
from curvlinops.examples.functorch import functorch_ggn
from curvlinops.examples.utils import report_nonclose

MAX_REPEATS_MC_SAMPLES = [(1_000_000, 1), (10_000, 100)]
MAX_REPEATS_MC_SAMPLES = [(10_000, 1), (100, 100)]
MAX_REPEATS_MC_SAMPLES_IDS = [
f"max_repeats={n}-mc_samples={m}" for (n, m) in MAX_REPEATS_MC_SAMPLES
]
CHECK_EVERY = 1_000
CHECK_EVERY = 100


@mark.montecarlo
Expand Down Expand Up @@ -58,17 +58,17 @@ def test_LinearOperator_matvec_expectation(
Gx = G_functorch @ x

Fx = zeros_like(x)
atol, rtol = 1e-5, 1e-1
atol = 5e-3 * max(abs(Gx))
rtol = 1e-1

for m in range(max_repeats):
Fx += F @ x
F._seed += 1

total_samples = (m + 1) * mc_samples
if total_samples % CHECK_EVERY == 0:
with suppress(ValueError):
with redirect_stdout(None), suppress(ValueError):
report_nonclose(Fx / (m + 1), Gx, rtol=rtol, atol=atol)
print(f"Converged after {m} iterations")
return

report_nonclose(Fx / max_repeats, Gx, rtol=rtol, atol=atol)
Expand Down Expand Up @@ -104,17 +104,17 @@ def test_LinearOperator_matmat_expectation(
GX = G_functorch @ X

FX = zeros_like(X)
atol, rtol = 1e-5, 1e-1
atol = 5e-3 * max(abs(GX.flatten()))
rtol = 1.5e-1

for m in range(max_repeats):
FX += F @ X
F._seed += 1

total_samples = (m + 1) * mc_samples
if total_samples % CHECK_EVERY == 0:
with suppress(ValueError):
with redirect_stdout(None), suppress(ValueError):
report_nonclose(FX / (m + 1), GX, rtol=rtol, atol=atol)
print(f"Converged after {m} iterations")
return

report_nonclose(FX, GX, rtol=rtol, atol=atol)
report_nonclose(FX / max_repeats, GX, rtol=rtol, atol=atol)

0 comments on commit 3084ada

Please sign in to comment.