From 5ac209f62113cfadf6cc0c22b569756bb79103b8 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Tue, 10 Sep 2024 09:26:01 -0700 Subject: [PATCH] #sdy add JAX Shardy support for shard_map. For example the following JAX program: ```py devices = np.array(jax.devices()[:8]) mesh = Mesh(devices, axis_names=('x')) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) print(jax.jit(fwd).lower(a).as_text()) ``` prints: ```cpp module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["x"=8]> func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) { %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32> sdy.return %1 : tensor<1x8xi32> } : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } } ``` PiperOrigin-RevId: 672993069 --- jax/experimental/shard_map.py | 64 ++++++++++++++++++++++++++++ tests/BUILD | 5 +++ tests/shard_map_test.py | 78 ++++++++++++++++++++++++++++------- 3 files changed, 132 insertions(+), 15 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 10d4874d7329..0dace1977dc0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -51,6 +51,8 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo, sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, as_hashable_function, memoize, partition_list, merge_lists, split_list, subs_list2) @@ -643,9 +645,71 @@ def _rule_missing(prim: core.Primitive, *_, **__): # Lowering +def _shardy_shard_map_sharding( + ctx: mlir.LoweringRuleContext, mesh, names, aval_in + ) -> ir.Attribute: + axes = {name: i for i, ns in names.items() for name in ns} + ns = _make_scoped_manual_sharding(ctx, mesh, axes) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + return ns._to_sdy_sharding(aval_in.ndim).build() + + +def _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto): + in_avals_ = [v.aval for v in jaxpr.invars] + if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): + # Nested `ManualComputationOp`s cannot refer to axes that are already + # manual. So figure out what axes are free thus far and get the new axis + # context. + free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes + new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto) + else: + new_axis_context = sharding_impls.SPMDAxisContext( + mesh, frozenset(mesh.axis_names) - auto) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + args = (*ctx.dim_var_values, *in_nodes) + + manual_axes = sub_ctx.axis_context.manual_axes + mesh_shape = mesh.shape + manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes]) + if manual_axes_size == 1: + # No need for a `ManualComputationOp` if all manual axes are size 1. + out_nodes, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args, + dim_var_values=ctx.dim_var_values) + return out_nodes + + in_shardings = sdy.TensorShardingPerValueAttr.get(map( + partial(_shardy_shard_map_sharding, ctx, mesh), + in_names, ctx.avals_in)) + out_shardings = sdy.TensorShardingPerValueAttr.get(map( + partial(_shardy_shard_map_sharding, ctx, mesh), + out_names, ctx.avals_out)) + output_types = map(mlir.aval_to_ir_type, ctx.avals_out) + manual_computation_op = sdy.ManualComputationOp( + output_types, args, in_shardings, out_shardings, + sdy.ManualAxesAttr.get( + ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) + block = ir.Block.create_at_start( + manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) + with ir.InsertionPoint(block), core.extend_axis_env_nd( + tuple(mesh.shape.items())): + out_nodes_, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, + dim_var_values=ctx.dim_var_values) + sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) + + return manual_computation_op.results + + def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): del check_rep, rewrite + if config.use_shardy_partitioner.value: + return _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, diff --git a/tests/BUILD b/tests/BUILD index 46345b6475d9..e580fb0ae363 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1346,6 +1346,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], + enable_configs = [ + "gpu_2gpu_shardy", + "tpu_v3_2x2_shardy", + "tpu_v4_2x2_shardy", + ], shard_count = { "cpu": 50, "gpu": 10, diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 397f2d94c7f7..0c1155ddf1ab 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -848,6 +848,10 @@ def test_shmap_abstract_mesh_errors(self): @parameterized.parameters([True, False]) @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_debug_print_jit(self, jit): + if config.use_shardy_partitioner.value: + self.skipTest( + 'TODO(b/364547005): debug prints not supported by Shardy yet' + ) mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) @@ -1229,13 +1233,18 @@ def foo(x): return x hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo')) - self.assertIn("call @shmap_body", hlo_str) - self.assertIn("call @shmap_body_0", hlo_str) - self.assertIn("%arg0: tensor<1xf32>", hlo_str) - self.assertIn("\"[None]\"", hlo_str) - self.assertIn("%arg1: tensor<1xf32>", hlo_str) - self.assertIn("\"[('i',)]\"", hlo_str) - self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str) + if config.use_shardy_partitioner.value: + self.assertEqual(2, hlo_str.count('sdy.manual_computation')) + else: + self.assertIn('call @shmap_body', hlo_str) + self.assertIn('call @shmap_body_0', hlo_str) + self.assertIn('%arg0: tensor<1xf32>', hlo_str) + self.assertIn('"[None]"', hlo_str) + self.assertIn('%arg1: tensor<1xf32>', hlo_str) + self.assertIn('"[(\'i\',)]"', hlo_str) + self.assertIn( + '-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str + ) def test_rewrite_process_call(self): def f(x): @@ -1759,10 +1768,18 @@ def f(x): v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertIn( - 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}', - f.lower(v).as_text('hlo'), - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'in_shardings=[<@mesh, [{"i"}, {}]>] out_shardings=[<@mesh, [{"i"},' + ' {}]>] manual_axes={"i"}', + f.lower(v).as_text(), + ) + else: + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual,' + ' replicated}}', + f.lower(v).as_text('hlo'), + ) self.assertAllClose(v*v, f(v), check_dtypes=False) def test_sharded_prng_with_abstract_mesh(self): @@ -1909,6 +1926,11 @@ def f(): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_partial_auto_of_pjit_different_mesh(self): + if config.use_shardy_partitioner.value: + self.skipTest( + 'Shardy requires the mesh axis names to be the same across ' + 'the entire computation.' + ) mesh = jtu.create_mesh((2, 2), ('i', 'j')) mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l')) @@ -1977,10 +1999,14 @@ def f(x): xs = jnp.arange(16.) ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) - self.assertIn( - '{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}', - ir.as_text() - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text() + ) + else: + self.assertIn( + "{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text() + ) def test_vmap_spmd_axis_name_error(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @@ -2609,5 +2635,27 @@ def fwd(a): self.assertEqual(c.addressable_data(0).shape, (4, 2)) +@jtu.with_config(jax_use_shardy_partitioner=True) +class SdyIntegrationTest(jtu.JaxTestCase): + # Verify we can lower to a `ManualComputationOp`. + def test_shardy_collective_permute(self): + mesh = jtu.create_mesh((2,), ('x',)) + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + + @jax.jit + @partial( + shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + axis_size = lax.psum(1, 'x') + perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] + return lax.ppermute(a, 'x', perm=perm) + + self.assertIn('sdy.manual_computation', jax.jit(fwd).lower(a).as_text()) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())