From 3084adaa72768e3f3339e17dacd07fc7a861be0c Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 23 Sep 2024 10:38:31 -0400 Subject: [PATCH] [CI | FIX] Improve and run Fisher-MC tests on main --- .github/workflows/test.yaml | 4 ++-- curvlinops/examples/utils.py | 4 ++++ test/test_fisher.py | 20 ++++++++++---------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2707c83d..a49757d7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,11 +29,11 @@ jobs: python -m pip install --upgrade pip make install-test - name: Run test - if: contains('refs/heads/master refs/heads/development', github.ref) + if: contains('refs/heads/main', github.ref) run: | make test - name: Run test-light - if: contains('refs/heads/master refs/heads/development', github.ref) != 1 + if: contains('refs/heads/main', github.ref) != 1 run: | make test-light diff --git a/curvlinops/examples/utils.py b/curvlinops/examples/utils.py index e22fe4af..08241c2a 100644 --- a/curvlinops/examples/utils.py +++ b/curvlinops/examples/utils.py @@ -31,9 +31,13 @@ def report_nonclose( if allclose(array1, array2, rtol=rtol, atol=atol, equal_nan=equal_nan): print("Compared arrays match.") else: + nonclose_entries = 0 for a1, a2 in zip(array1.flatten(), array2.flatten()): if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan): print(f"{a1} ≠ {a2} (ratio {a1 / a2:.5f})") + nonclose_entries += 1 print(f"Max: {array1.max():.5f}, {array2.max():.5f}") print(f"Min: {array1.min():.5f}, {array2.min():.5f}") + print(f"Nonclose entries: {nonclose_entries} / {array1.size}") + print(f"rtol = {rtol}, atol= {atol}") raise ValueError("Compared arrays don't match.") diff --git a/test/test_fisher.py b/test/test_fisher.py index 4fa598e4..3bfa3326 100644 --- a/test/test_fisher.py +++ b/test/test_fisher.py @@ -1,7 +1,7 @@ """Contains tests for ``curvlinops/fisher.py``.""" from collections.abc import MutableMapping -from contextlib import suppress +from contextlib import redirect_stdout, suppress from numpy import random, zeros_like from pytest import mark, raises @@ -10,11 +10,11 @@ from curvlinops.examples.functorch import functorch_ggn from curvlinops.examples.utils import report_nonclose -MAX_REPEATS_MC_SAMPLES = [(1_000_000, 1), (10_000, 100)] +MAX_REPEATS_MC_SAMPLES = [(10_000, 1), (100, 100)] MAX_REPEATS_MC_SAMPLES_IDS = [ f"max_repeats={n}-mc_samples={m}" for (n, m) in MAX_REPEATS_MC_SAMPLES ] -CHECK_EVERY = 1_000 +CHECK_EVERY = 100 @mark.montecarlo @@ -58,7 +58,8 @@ def test_LinearOperator_matvec_expectation( Gx = G_functorch @ x Fx = zeros_like(x) - atol, rtol = 1e-5, 1e-1 + atol = 5e-3 * max(abs(Gx)) + rtol = 1e-1 for m in range(max_repeats): Fx += F @ x @@ -66,9 +67,8 @@ def test_LinearOperator_matvec_expectation( total_samples = (m + 1) * mc_samples if total_samples % CHECK_EVERY == 0: - with suppress(ValueError): + with redirect_stdout(None), suppress(ValueError): report_nonclose(Fx / (m + 1), Gx, rtol=rtol, atol=atol) - print(f"Converged after {m} iterations") return report_nonclose(Fx / max_repeats, Gx, rtol=rtol, atol=atol) @@ -104,7 +104,8 @@ def test_LinearOperator_matmat_expectation( GX = G_functorch @ X FX = zeros_like(X) - atol, rtol = 1e-5, 1e-1 + atol = 5e-3 * max(abs(GX.flatten())) + rtol = 1.5e-1 for m in range(max_repeats): FX += F @ X @@ -112,9 +113,8 @@ def test_LinearOperator_matmat_expectation( total_samples = (m + 1) * mc_samples if total_samples % CHECK_EVERY == 0: - with suppress(ValueError): + with redirect_stdout(None), suppress(ValueError): report_nonclose(FX / (m + 1), GX, rtol=rtol, atol=atol) - print(f"Converged after {m} iterations") return - report_nonclose(FX, GX, rtol=rtol, atol=atol) + report_nonclose(FX / max_repeats, GX, rtol=rtol, atol=atol)