Skip to content

Commit

Permalink
add fraction of mismatches to equivalence checker
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Sep 12, 2024
1 parent 4334a19 commit 23b6932
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 6 additions & 1 deletion roicat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4811,7 +4811,12 @@ def _checker(
at = np.abs(true)
r_diff = diff / at if np.all(at != 0) else np.inf
r_diff_mean, r_diff_max, any_nan = np.nanmean(r_diff), np.nanmax(r_diff), np.any(np.isnan(r_diff))
reason = f"Equivalence: Relative difference: mean={r_diff_mean}, max={r_diff_max}, any_nan={any_nan}"
## fraction of mismatches
n_elements = np.prod(test.shape)
n_mismatches = np.sum(diff > 0)
frac_mismatches = n_mismatches / n_elements
## Use scientific notation and round to 3 decimal places
reason = f"Equivalence: Relative difference: mean={r_diff_mean:.3e}, max={r_diff_max:.3e}, any_nan={any_nan}, n_elements={n_elements}, n_mismatches={n_mismatches}, frac_mismatches={frac_mismatches:.3e}"
else:
reason = f"Values are not numpy numeric types. types: {test.dtype}, {true.dtype}"
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_pipeline_tracking_simple(dir_data_test):
checker = helpers.Equivalence_checker(
kwargs_allclose={'rtol': 1e-5, 'equal_nan': True},
assert_mode=False,
verbose=2,
verbose=1,
)
checks = checker(test=run_data, true=run_data_true)
fails = [key for key, val in helpers.flatten_dict(checks).items() if val[0]==False]
Expand Down

0 comments on commit 23b6932

Please sign in to comment.