Skip to content

Commit

Permalink
Merge pull request #23853 from zhenying-liu:remat-scan
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679365040
  • Loading branch information
Google-ML-Automation committed Sep 27, 2024
2 parents 9f4e8d0 + adaf54a commit 5a1549c
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TransferToMemoryKind, PartitionSpec as P)
from jax.experimental.compute_on import compute_on
from jax.experimental.shard_map import shard_map
from jax._src.lib import xla_extension_version
import numpy as np

config.parse_flags_with_absl()
Expand Down Expand Up @@ -1567,8 +1568,8 @@ def g(ys, _):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

def test_remat_scan_layout_change_offloadable(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Remat scan does not work on GPU backend.")
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
self.skipTest("Requires xla_extension_version >= 289")
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
Expand Down Expand Up @@ -1602,6 +1603,10 @@ def g(ys, _):
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)")

compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
Expand Down

0 comments on commit 5a1549c

Please sign in to comment.