diff --git a/tests/unit_tests/test_caching.py b/tests/unit_tests/test_caching.py index 60be130..85e5b3b 100644 --- a/tests/unit_tests/test_caching.py +++ b/tests/unit_tests/test_caching.py @@ -402,9 +402,15 @@ def wrapped(a: np.ndarray) -> np.ndarray: res_f = lower_f.compile()(array_f) assert res_c is not res_f - assert np.allclose(res_f, res_c) assert lower_f is not lower_c - assert lower_cnt[0] == 2 + assert np.allclose(res_f, res_c) + + # In previous versions JAX did not cached the result of the tracing, + # but in newer version the tracing itself is also cached + if lower_c._jaxpr is lower_f._jaxpr: + assert lower_cnt[0] == 1 + else: + assert lower_cnt[0] == 2 def test_caching_jax_numpy_array() -> None: