Skip to content

Commit

Permalink
Add cache tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Sep 27, 2024
1 parent e170da3 commit 0e61b6e
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions examples/ffi/tests/attrs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def test_array_attr(self):
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())

def test_array_attr_jit_cache(self):
jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,))
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 1) # compiles once the first time
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 0) # cache hit
self.assertNotIn("_HashableByObjectId", jit_array_attr.lower(5).as_text())

def test_dictionary_attr(self):
secret, count = attrs.dictionary_attr(secret=5)
self.assertEqual(secret, 5)
Expand Down

0 comments on commit 0e61b6e

Please sign in to comment.