From 6fe78b07446fda85889c593bea98b296c7f70ad7 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 25 Sep 2024 15:44:14 +0200 Subject: [PATCH] Fixed an error in the tests. It seems JAX has updated the `make_jaxpr()` function and now that thing caches itself. This is now accounted for. --- tests/unit_tests/test_caching.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_caching.py b/tests/unit_tests/test_caching.py index 60be130..85e5b3b 100644 --- a/tests/unit_tests/test_caching.py +++ b/tests/unit_tests/test_caching.py @@ -402,9 +402,15 @@ def wrapped(a: np.ndarray) -> np.ndarray: res_f = lower_f.compile()(array_f) assert res_c is not res_f - assert np.allclose(res_f, res_c) assert lower_f is not lower_c - assert lower_cnt[0] == 2 + assert np.allclose(res_f, res_c) + + # In previous versions JAX did not cached the result of the tracing, + # but in newer version the tracing itself is also cached + if lower_c._jaxpr is lower_f._jaxpr: + assert lower_cnt[0] == 1 + else: + assert lower_cnt[0] == 2 def test_caching_jax_numpy_array() -> None: