diff --git a/jax/_src/callback.py b/jax/_src/callback.py index ea5412e4f535..b30e073ab798 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -108,12 +108,27 @@ def _callback(*flat_args): sharding = None axis_context = ctx.module_context.axis_context - if isinstance(axis_context, sharding_impls.ShardingContext): - if len(axis_context.device_assignment) > 1: + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + # If we have fully manual sharding during lowering, that means the JAX + # program has per-device semantics, so we run the callback on each device. + if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names): raise NotImplementedError( "pure_callback is only supported in spmd computations when all mesh" " axes are partitioned manually (no partial automatic sharding)." ) + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + elif isinstance(axis_context, sharding_impls.ShardingContext): + # If we have fully automatic sharding during lowering, that means the JAX + # program has bulk array semantics, so we run the callback with a MAXIMAL + # sharding and hence execute it only once on the full logical value). + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MAXIMAL + sharding.tile_assignment_dimensions = [1] + sharding.tile_assignment_devices = [0] + else: + # When there's no SPMD partitioning going on, don't annotate a sharding. + sharding = None if isinstance(axis_context, sharding_impls.SPMDAxisContext): if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names): raise NotImplementedError( diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 41c3946a21b8..dcfe97765519 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -29,6 +29,7 @@ from jax.sharding import NamedSharding, PartitionSpec, Mesh from jax._src import core from jax._src import ad_util +from jax._src import callback from jax._src import custom_derivatives from jax._src import debugging from jax._src import dispatch @@ -812,6 +813,10 @@ def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs): def _debug_callback_rule(mesh, *in_rep, **_): return [] +@register_rule(callback.pure_callback_p) +def _pure_callback_rule(mesh, *in_rep, result_avals, **_): + return [set()] * len(result_avals) + @register_rule(dispatch.device_put_p) def _device_put_rep_rule(mesh, x, *, src, device): return x diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 34920a3a9eab..8b71c9327c39 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -34,6 +34,7 @@ from jax.experimental import pjit from jax.interpreters import mlir from jax.experimental.maps import xmap +from jax.experimental.shard_map import shard_map from jax.experimental import io_callback import jax.numpy as jnp from jax.sharding import Mesh @@ -699,15 +700,6 @@ def without_xmap_f(x): np.testing.assert_allclose( out, np.sin(np.arange(jax.local_device_count())) ) - - if jax.local_device_count() > 1: - with self.assertRaisesRegex( - NotImplementedError, 'when all mesh axes are partitioned manually' - ): - pjit.pjit(without_xmap_f, in_shardings=spec, out_shardings=spec)( - inp - ) - finally: jtu.restore_spmd_manual_lowering_flag() jtu.restore_spmd_lowering_flag() @@ -875,6 +867,57 @@ def g(x): x = np.arange(6, dtype=np.int32).reshape((3, 2)) np.testing.assert_allclose(g(x), x) + def test_can_shard_pure_callback_maximally(self): + if xla_bridge.get_backend().runtime_type == 'stream_executor': + raise unittest.SkipTest( + 'Host callback not supported for runtime type: stream_executor.' + ) + + mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) + + spec = jax.sharding.PartitionSpec('x') + sharding = jax.sharding.NamedSharding(mesh, spec) + + def func(x): + return x + np.arange(x.shape[0], dtype=x.dtype) + + def f(x): + return jax.pure_callback(func, x, x) + + inp = jnp.arange(float(jax.local_device_count())) + out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp) + jax.block_until_ready(out) + np.testing.assert_allclose( + out, np.arange(jax.local_device_count()) * 2 + ) + + def test_can_shard_pure_callback_manually(self): + if xla_bridge.get_backend().runtime_type == 'stream_executor': + raise unittest.SkipTest( + 'Host callback not supported for runtime type: stream_executor.' + ) + + mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) + + spec = jax.sharding.PartitionSpec('x') + sharding = jax.sharding.NamedSharding(mesh, spec) + + def func(x): + return x + np.arange(x.shape[0], dtype=x.dtype) + + def f(x): + return jax.pure_callback(func, x, x) + f = shard_map(f, mesh=mesh, in_specs=(spec,), out_specs=spec) + + inp = jnp.arange(float(jax.local_device_count() * 2)) + out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp) + y = np.tile(np.arange(2, dtype=inp.dtype), jax.local_device_count()) + jax.block_until_ready(out) + np.testing.assert_allclose( + out, inp + y + ) + + class IOPythonCallbackTest(jtu.JaxTestCase): def setUp(self):