From adaf54a4bbe10ce05edcfeb29039c6948444c641 Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Mon, 23 Sep 2024 12:54:32 -0700 Subject: [PATCH] enable the activation offloading test --- tests/memories_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 3e0f444a1e66..63b21e2d3e6d 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1567,8 +1567,6 @@ def g(ys, _): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Remat scan does not work on GPU backend.") mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -1602,6 +1600,10 @@ def g(ys, _): self.assertIn('S(5)', compiled_text) self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)") compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: