From 0e61b6eb8a9eff265a963b8ad430731f24d9ceae Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 27 Sep 2024 13:49:10 -0400 Subject: [PATCH] Add cache tests --- examples/ffi/tests/attrs_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/ffi/tests/attrs_test.py b/examples/ffi/tests/attrs_test.py index ca8b5b0a5f2d..62f084de960a 100644 --- a/examples/ffi/tests/attrs_test.py +++ b/examples/ffi/tests/attrs_test.py @@ -28,6 +28,16 @@ def test_array_attr(self): self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum()) self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum()) + def test_array_attr_jit_cache(self): + jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,)) + with jtu.count_jit_and_pmap_lowerings() as count: + jit_array_attr(5) + self.assertEqual(count[0], 1) # compiles once the first time + with jtu.count_jit_and_pmap_lowerings() as count: + jit_array_attr(5) + self.assertEqual(count[0], 0) # cache hit + self.assertNotIn("_HashableByObjectId", jit_array_attr.lower(5).as_text()) + def test_dictionary_attr(self): secret, count = attrs.dictionary_attr(secret=5) self.assertEqual(secret, 5)