Skip to content

Commit

Permalink
#sdy add JAX Shardy support for shard_map.
Browse files Browse the repository at this point in the history
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
    jnp.arange(8 * 8).reshape((8, 8)),
    jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(
    shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
  axis_size = lax.psum(1, 'x')
  perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
  return lax.ppermute(a, 'x', perm=perm)

print(jax.jit(fwd).lower(a).as_text())
```

prints:

```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=8]>
  func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
  func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
      %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
      sdy.return %1 : tensor<1x8xi32>
    } : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
}
```

PiperOrigin-RevId: 679165100
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 26, 2024
1 parent 7b53c2f commit e62a50c
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 15 deletions.
64 changes: 64 additions & 0 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
special, control_flow, ann)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo, sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists, split_list, subs_list2)
Expand Down Expand Up @@ -643,9 +645,71 @@ def _rule_missing(prim: core.Primitive, *_, **__):

# Lowering

def _shardy_shard_map_sharding(
ctx: mlir.LoweringRuleContext, mesh, names, aval_in
) -> ir.Attribute:
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
ns = sharding_impls.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
return ns._to_sdy_sharding(aval_in.ndim).build()


def _shard_map_lowering_shardy(
ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto):
in_avals_ = [v.aval for v in jaxpr.invars]
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
# Nested `ManualComputationOp`s cannot refer to axes that are already
# manual. So figure out what axes are free thus far and get the new axis
# context.
free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes
new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto)
else:
new_axis_context = sharding_impls.SPMDAxisContext(
mesh, frozenset(mesh.axis_names) - auto)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
args = (*ctx.dim_var_values, *in_nodes)

manual_axes = sub_ctx.axis_context.manual_axes
mesh_shape = mesh.shape
manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes])
if manual_axes_size == 1:
# No need for a `ManualComputationOp` if all manual axes are size 1.
out_nodes, _ = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args,
dim_var_values=ctx.dim_var_values)
return out_nodes

in_shardings = sdy.TensorShardingPerValueAttr.get(map(
partial(_shardy_shard_map_sharding, ctx, mesh),
in_names, ctx.avals_in))
out_shardings = sdy.TensorShardingPerValueAttr.get(map(
partial(_shardy_shard_map_sharding, ctx, mesh),
out_names, ctx.avals_out))
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
manual_computation_op = sdy.ManualComputationOp(
output_types, args, in_shardings, out_shardings,
sdy.ManualAxesAttr.get(
ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes])))
block = ir.Block.create_at_start(
manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_))
with ir.InsertionPoint(block), core.extend_axis_env_nd(
tuple(mesh.shape.items())):
out_nodes_, _ = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments,
dim_var_values=ctx.dim_var_values)
sdy.ReturnOp([ir.Value(x) for x in out_nodes_])

return manual_computation_op.results


def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
del check_rep, rewrite
if config.use_shardy_partitioner.value:
return _shard_map_lowering_shardy(
ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto)
in_avals_ = [v.aval for v in jaxpr.invars]
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
Expand Down
5 changes: 5 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,11 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "shard_map_test",
srcs = ["shard_map_test.py"],
enable_configs = [
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v4_2x2_shardy",
],
shard_count = {
"cpu": 50,
"gpu": 10,
Expand Down
78 changes: 63 additions & 15 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,10 @@ def test_shmap_abstract_mesh_errors(self):
@parameterized.parameters([True, False])
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def test_debug_print_jit(self, jit):
if config.use_shardy_partitioner.value:
self.skipTest(
'TODO(b/364547005): debug prints not supported by Shardy yet'
)
mesh = Mesh(jax.devices(), ('i',))

@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
Expand Down Expand Up @@ -1229,13 +1233,18 @@ def foo(x):
return x

hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
self.assertIn("call @shmap_body", hlo_str)
self.assertIn("call @shmap_body_0", hlo_str)
self.assertIn("%arg0: tensor<1xf32>", hlo_str)
self.assertIn("\"[None]\"", hlo_str)
self.assertIn("%arg1: tensor<1xf32>", hlo_str)
self.assertIn("\"[('i',)]\"", hlo_str)
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str)
if config.use_shardy_partitioner.value:
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
else:
self.assertIn('call @shmap_body', hlo_str)
self.assertIn('call @shmap_body_0', hlo_str)
self.assertIn('%arg0: tensor<1xf32>', hlo_str)
self.assertIn('"[None]"', hlo_str)
self.assertIn('%arg1: tensor<1xf32>', hlo_str)
self.assertIn('"[(\'i\',)]"', hlo_str)
self.assertIn(
'-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str
)

def test_rewrite_process_call(self):
def f(x):
Expand Down Expand Up @@ -1759,10 +1768,18 @@ def f(x):

v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertIn(
'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}',
f.lower(v).as_text('hlo'),
)
if config.use_shardy_partitioner.value:
self.assertIn(
'in_shardings=[<@mesh, [{"i"}, {}]>] out_shardings=[<@mesh, [{"i"},'
' {}]>] manual_axes={"i"}',
f.lower(v).as_text(),
)
else:
self.assertIn(
'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual,'
' replicated}}',
f.lower(v).as_text('hlo'),
)
self.assertAllClose(v*v, f(v), check_dtypes=False)

def test_sharded_prng_with_abstract_mesh(self):
Expand Down Expand Up @@ -1909,6 +1926,11 @@ def f():
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))

def test_partial_auto_of_pjit_different_mesh(self):
if config.use_shardy_partitioner.value:
self.skipTest(
'Shardy requires the mesh axis names to be the same across '
'the entire computation.'
)
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l'))

Expand Down Expand Up @@ -1977,10 +1999,14 @@ def f(x):
xs = jnp.arange(16.)

ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs)
self.assertIn(
'{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}',
ir.as_text()
)
if config.use_shardy_partitioner.value:
self.assertIn(
'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text()
)
else:
self.assertIn(
"{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text()
)

def test_vmap_spmd_axis_name_error(self):
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
Expand Down Expand Up @@ -2609,5 +2635,27 @@ def fwd(a):
self.assertEqual(c.addressable_data(0).shape, (4, 2))


@jtu.with_config(jax_use_shardy_partitioner=True)
class SdyIntegrationTest(jtu.JaxTestCase):
# Verify we can lower to a `ManualComputationOp`.
def test_shardy_collective_permute(self):
mesh = jtu.create_mesh((2,), ('x',))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)),
)

@jax.jit
@partial(
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
axis_size = lax.psum(1, 'x')
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
return lax.ppermute(a, 'x', perm=perm)

self.assertIn('sdy.manual_computation', jax.jit(fwd).lower(a).as_text())


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit e62a50c

Please sign in to comment.