diff --git a/tests/memories_test.py b/tests/memories_test.py index 3e0f444a1e66..6959aa7535b8 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -35,6 +35,7 @@ TransferToMemoryKind, PartitionSpec as P) from jax.experimental.compute_on import compute_on from jax.experimental.shard_map import shard_map +from jax._src.lib import xla_extension_version import numpy as np config.parse_flags_with_absl() @@ -1567,8 +1568,8 @@ def g(ys, _): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Remat scan does not work on GPU backend.") + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: + self.skipTest("Requires xla_extension_version >= 289") mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -1602,6 +1603,10 @@ def g(ys, _): self.assertIn('S(5)', compiled_text) self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)") compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: