From 23b693200b18df1c17dfcaf0750d5262175450e5 Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Thu, 12 Sep 2024 18:26:06 -0400 Subject: [PATCH] add fraction of mismatches to equivalence checker --- roicat/helpers.py | 7 ++++++- tests/test_integration.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/roicat/helpers.py b/roicat/helpers.py index c0227267..91f2d3d6 100644 --- a/roicat/helpers.py +++ b/roicat/helpers.py @@ -4811,7 +4811,12 @@ def _checker( at = np.abs(true) r_diff = diff / at if np.all(at != 0) else np.inf r_diff_mean, r_diff_max, any_nan = np.nanmean(r_diff), np.nanmax(r_diff), np.any(np.isnan(r_diff)) - reason = f"Equivalence: Relative difference: mean={r_diff_mean}, max={r_diff_max}, any_nan={any_nan}" + ## fraction of mismatches + n_elements = np.prod(test.shape) + n_mismatches = np.sum(diff > 0) + frac_mismatches = n_mismatches / n_elements + ## Use scientific notation and round to 3 decimal places + reason = f"Equivalence: Relative difference: mean={r_diff_mean:.3e}, max={r_diff_max:.3e}, any_nan={any_nan}, n_elements={n_elements}, n_mismatches={n_mismatches}, frac_mismatches={frac_mismatches:.3e}" else: reason = f"Values are not numpy numeric types. types: {test.dtype}, {true.dtype}" else: diff --git a/tests/test_integration.py b/tests/test_integration.py index 5768f856..a5345344 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -94,7 +94,7 @@ def test_pipeline_tracking_simple(dir_data_test): checker = helpers.Equivalence_checker( kwargs_allclose={'rtol': 1e-5, 'equal_nan': True}, assert_mode=False, - verbose=2, + verbose=1, ) checks = checker(test=run_data, true=run_data_true) fails = [key for key, val in helpers.flatten_dict(checks).items() if val[0]==False]