diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2707c83..a49757d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,11 +29,11 @@ jobs: python -m pip install --upgrade pip make install-test - name: Run test - if: contains('refs/heads/master refs/heads/development', github.ref) + if: contains('refs/heads/main', github.ref) run: | make test - name: Run test-light - if: contains('refs/heads/master refs/heads/development', github.ref) != 1 + if: contains('refs/heads/main', github.ref) != 1 run: | make test-light diff --git a/changelog.md b/changelog.md index d644aaa..4065df1 100644 --- a/changelog.md +++ b/changelog.md @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update Github action versions and cache `pip` ([PR](https://github.com/f-dangel/curvlinops/pull/129)) +- Re-activate Monte-Carlo tests, refactor, and reduce their run time + ([PR](https://github.com/f-dangel/curvlinops/pull/131)) + ## [2.0.0] - 2024-08-15 This major release is almost fully backward compatible with the `1.x.y` release diff --git a/curvlinops/examples/utils.py b/curvlinops/examples/utils.py index e22fe4a..08241c2 100644 --- a/curvlinops/examples/utils.py +++ b/curvlinops/examples/utils.py @@ -31,9 +31,13 @@ def report_nonclose( if allclose(array1, array2, rtol=rtol, atol=atol, equal_nan=equal_nan): print("Compared arrays match.") else: + nonclose_entries = 0 for a1, a2 in zip(array1.flatten(), array2.flatten()): if not isclose(a1, a2, atol=atol, rtol=rtol, equal_nan=equal_nan): print(f"{a1} ≠ {a2} (ratio {a1 / a2:.5f})") + nonclose_entries += 1 print(f"Max: {array1.max():.5f}, {array2.max():.5f}") print(f"Min: {array1.min():.5f}, {array2.min():.5f}") + print(f"Nonclose entries: {nonclose_entries} / {array1.size}") + print(f"rtol = {rtol}, atol= {atol}") raise ValueError("Compared arrays don't match.") diff --git a/test/test_fisher.py b/test/test_fisher.py index 4fa598e..3bfa332 100644 --- a/test/test_fisher.py +++ b/test/test_fisher.py @@ -1,7 +1,7 @@ """Contains tests for ``curvlinops/fisher.py``.""" from collections.abc import MutableMapping -from contextlib import suppress +from contextlib import redirect_stdout, suppress from numpy import random, zeros_like from pytest import mark, raises @@ -10,11 +10,11 @@ from curvlinops.examples.functorch import functorch_ggn from curvlinops.examples.utils import report_nonclose -MAX_REPEATS_MC_SAMPLES = [(1_000_000, 1), (10_000, 100)] +MAX_REPEATS_MC_SAMPLES = [(10_000, 1), (100, 100)] MAX_REPEATS_MC_SAMPLES_IDS = [ f"max_repeats={n}-mc_samples={m}" for (n, m) in MAX_REPEATS_MC_SAMPLES ] -CHECK_EVERY = 1_000 +CHECK_EVERY = 100 @mark.montecarlo @@ -58,7 +58,8 @@ def test_LinearOperator_matvec_expectation( Gx = G_functorch @ x Fx = zeros_like(x) - atol, rtol = 1e-5, 1e-1 + atol = 5e-3 * max(abs(Gx)) + rtol = 1e-1 for m in range(max_repeats): Fx += F @ x @@ -66,9 +67,8 @@ def test_LinearOperator_matvec_expectation( total_samples = (m + 1) * mc_samples if total_samples % CHECK_EVERY == 0: - with suppress(ValueError): + with redirect_stdout(None), suppress(ValueError): report_nonclose(Fx / (m + 1), Gx, rtol=rtol, atol=atol) - print(f"Converged after {m} iterations") return report_nonclose(Fx / max_repeats, Gx, rtol=rtol, atol=atol) @@ -104,7 +104,8 @@ def test_LinearOperator_matmat_expectation( GX = G_functorch @ X FX = zeros_like(X) - atol, rtol = 1e-5, 1e-1 + atol = 5e-3 * max(abs(GX.flatten())) + rtol = 1.5e-1 for m in range(max_repeats): FX += F @ X @@ -112,9 +113,8 @@ def test_LinearOperator_matmat_expectation( total_samples = (m + 1) * mc_samples if total_samples % CHECK_EVERY == 0: - with suppress(ValueError): + with redirect_stdout(None), suppress(ValueError): report_nonclose(FX / (m + 1), GX, rtol=rtol, atol=atol) - print(f"Converged after {m} iterations") return - report_nonclose(FX, GX, rtol=rtol, atol=atol) + report_nonclose(FX / max_repeats, GX, rtol=rtol, atol=atol)