Skip to content

Commit

Permalink
Skip test_ragged_copy_on_host if xla_extension_version < 290
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683326972
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 7, 2024
1 parent 28bbbf8 commit ce2b497
Showing 1 changed file with 46 additions and 43 deletions.
89 changes: 46 additions & 43 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jax import lax
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_extension_version
from jax._src import config
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
Expand Down Expand Up @@ -651,6 +652,51 @@ def f():
out = f()
self._check_device_put_addressable_shards(out, np_inp * 2, s_dev, 'device')

@jtu.run_on_devices('tpu')
def test_ragged_copy_on_host(self):
if xla_extension_version < 290:
self.skipTest('Requires xla_extension_version >= 290')
mesh = jtu.create_mesh((2,), ('x'))
sharding = jax.sharding.NamedSharding(mesh, P(('x')))
cpu_sharding = sharding.with_memory_kind('pinned_host')

num_pages = 512 * 1024
page_size = 1024

x = jnp.full((num_pages, page_size), 1, dtype=jnp.bfloat16, device=sharding)

def write(x):
return x.at[16 * 1024:].set(0)
x = shard_map(write, mesh, P(('x'),), P(('x')))(x)

chunk_size = 8
def inner(state):
idx, x, output = state
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
chunk_host = jax.device_put(chunk, TransferToMemoryKind('pinned_host'))
output = jax.lax.dynamic_update_slice_in_dim(
output, chunk_host, idx * chunk_size, axis=0)
return (idx + 1, x, output)

def cond(state):
idx, x, _ = state
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
return (idx * chunk_size < x.shape[0]) & jnp.any(chunk > 0)

def foo(x):
output = jnp.zeros_like(x, device=cpu_sharding)
_, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output))
return cpu_x

fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')),
check_rep=False),
out_shardings=cpu_sharding)
y = fn(x)
jax.block_until_ready(y)
compiled_text = fn.lower(x).compile().as_text()
if compiled_text is not None:
self.assertIn('custom_call_target="AllocateBuffer"', compiled_text)


class ComputeOffload(jtu.BufferDonationTestCase):

Expand Down Expand Up @@ -1603,49 +1649,6 @@ def g(ys, _):
if compiled_stats is not None:
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

def test_ragged_copy_on_host(self):
mesh = jtu.create_mesh((2,), ('x'))
sharding = jax.sharding.NamedSharding(mesh, P(('x')))
cpu_sharding = sharding.with_memory_kind('pinned_host')

num_pages = 512 * 1024
page_size = 1024

x = jnp.full((num_pages, page_size), 1, dtype=jnp.bfloat16, device=sharding)

def write(x):
return x.at[16 * 1024:].set(0)
x = shard_map(write, mesh, P(('x'),), P(('x')))(x)

chunk_size = 8
def inner(state):
idx, x, output = state
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
chunk_host = jax.device_put(chunk, TransferToMemoryKind('pinned_host'))
output = jax.lax.dynamic_update_slice_in_dim(output, chunk_host, idx * chunk_size, axis=0)
return (idx + 1, x, output)

def cond(state):
idx, x, _ = state
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
return (idx * chunk_size < x.shape[0]) & jnp.any(chunk > 0)

def foo(x):
output = jnp.zeros_like(x, device=cpu_sharding)
_, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output))
return cpu_x

fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')),
check_rep=False),
out_shardings=cpu_sharding)
y = fn(x)
jax.block_until_ready(y)
compiled_f = fn.lower(x).compile()
compiled_text = compiled_f.as_text()
if compiled_text is not None:
allocate_buffer_on_host = 'custom_call_target="AllocateBuffer"' in compiled_text
self.assertEqual(allocate_buffer_on_host, True)

def test_remat_checkpoint_dots_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")
Expand Down

0 comments on commit ce2b497

Please sign in to comment.