Skip to content

Commit

Permalink
Merge pull request #15956 from sharadmv:pure-callback-maximal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 531304370
  • Loading branch information
jax authors committed May 11, 2023
2 parents 0037ab6 + 61f2267 commit d8c487b
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
19 changes: 17 additions & 2 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,27 @@ def _callback(*flat_args):

sharding = None
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
if len(axis_context.device_assignment) > 1:
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
# If we have fully manual sharding during lowering, that means the JAX
# program has per-device semantics, so we run the callback on each device.
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(
"pure_callback is only supported in spmd computations when all mesh"
" axes are partitioned manually (no partial automatic sharding)."
)
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MANUAL
elif isinstance(axis_context, sharding_impls.ShardingContext):
# If we have fully automatic sharding during lowering, that means the JAX
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
# When there's no SPMD partitioning going on, don't annotate a sharding.
sharding = None
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(
Expand Down
5 changes: 5 additions & 0 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax.sharding import NamedSharding, PartitionSpec, Mesh
from jax._src import core
from jax._src import ad_util
from jax._src import callback
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dispatch
Expand Down Expand Up @@ -812,6 +813,10 @@ def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
def _debug_callback_rule(mesh, *in_rep, **_):
return []

@register_rule(callback.pure_callback_p)
def _pure_callback_rule(mesh, *in_rep, result_avals, **_):
return [set()] * len(result_avals)

@register_rule(dispatch.device_put_p)
def _device_put_rep_rule(mesh, x, *, src, device):
return x
Expand Down
61 changes: 52 additions & 9 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
from jax.experimental import io_callback
import jax.numpy as jnp
from jax.sharding import Mesh
Expand Down Expand Up @@ -699,15 +700,6 @@ def without_xmap_f(x):
np.testing.assert_allclose(
out, np.sin(np.arange(jax.local_device_count()))
)

if jax.local_device_count() > 1:
with self.assertRaisesRegex(
NotImplementedError, 'when all mesh axes are partitioned manually'
):
pjit.pjit(without_xmap_f, in_shardings=spec, out_shardings=spec)(
inp
)

finally:
jtu.restore_spmd_manual_lowering_flag()
jtu.restore_spmd_lowering_flag()
Expand Down Expand Up @@ -875,6 +867,57 @@ def g(x):
x = np.arange(6, dtype=np.int32).reshape((3, 2))
np.testing.assert_allclose(g(x), x)

def test_can_shard_pure_callback_maximally(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest(
'Host callback not supported for runtime type: stream_executor.'
)

mesh = Mesh(np.array(jax.devices()), axis_names=('x',))

spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)

def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)

def f(x):
return jax.pure_callback(func, x, x)

inp = jnp.arange(float(jax.local_device_count()))
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
jax.block_until_ready(out)
np.testing.assert_allclose(
out, np.arange(jax.local_device_count()) * 2
)

def test_can_shard_pure_callback_manually(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest(
'Host callback not supported for runtime type: stream_executor.'
)

mesh = Mesh(np.array(jax.devices()), axis_names=('x',))

spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)

def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)

def f(x):
return jax.pure_callback(func, x, x)
f = shard_map(f, mesh=mesh, in_specs=(spec,), out_specs=spec)

inp = jnp.arange(float(jax.local_device_count() * 2))
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
y = np.tile(np.arange(2, dtype=inp.dtype), jax.local_device_count())
jax.block_until_ready(out)
np.testing.assert_allclose(
out, inp + y
)


class IOPythonCallbackTest(jtu.JaxTestCase):

def setUp(self):
Expand Down

0 comments on commit d8c487b

Please sign in to comment.