Skip to content

Commit

Permalink
Fixed an error in the tests.
Browse files Browse the repository at this point in the history
It seems JAX has updated the `make_jaxpr()` function and now that thing caches itself.
This is now accounted for.
  • Loading branch information
philip-paul-mueller committed Sep 25, 2024
1 parent e667854 commit 6fe78b0
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/unit_tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,15 @@ def wrapped(a: np.ndarray) -> np.ndarray:
res_f = lower_f.compile()(array_f)

assert res_c is not res_f
assert np.allclose(res_f, res_c)
assert lower_f is not lower_c
assert lower_cnt[0] == 2
assert np.allclose(res_f, res_c)

# In previous versions JAX did not cached the result of the tracing,
# but in newer version the tracing itself is also cached
if lower_c._jaxpr is lower_f._jaxpr:
assert lower_cnt[0] == 1
else:
assert lower_cnt[0] == 2


def test_caching_jax_numpy_array() -> None:
Expand Down

0 comments on commit 6fe78b0

Please sign in to comment.