diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bc1c00948943..75353197ae2b 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1119,14 +1119,15 @@ def lower_jaxpr_to_module( # XLA computation preserves the module name. attrs = ctx.module.operation.attributes if config.use_shardy_partitioner.value: - assert (isinstance(axis_context, sharding_impls.ShardingContext) and - axis_context.mesh_shape is not None) - ctx.module.body.append( - dialects.sdy.MeshOp( - "mesh", - dialects.sdy.MeshAttr.get( - [dialects.sdy.MeshAxisAttr.get(name, size) - for name, size in axis_context.mesh_shape]))) + if (isinstance(axis_context, sharding_impls.ShardingContext) and + axis_context.mesh_shape is not None): + sdy_mesh_attr = dialects.sdy.MeshAttr.get( + [dialects.sdy.MeshAxisAttr.get(name, size) + for name, size in axis_context.mesh_shape]) + else: + sdy_mesh_attr = dialects.sdy.MeshAttr.get([]) + + ctx.module.body.append(dialects.sdy.MeshOp("mesh", sdy_mesh_attr)) module_name = _module_name_regex.sub("_", module_name) attrs["sym_name"] = ir.StringAttr.get(module_name) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) @@ -1633,7 +1634,15 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. - return wrap_with_sharding_op( + if config.use_shardy_partitioner.value: + physical_ndim = core.physical_aval(aval).ndim + s = sharding.SdyArraySharding( + mesh_name='mesh', + dimension_shardings=[sharding.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) + for i in range(physical_ndim)]) + return wrap_with_sharding_op(ctx, val, aval, s) + else: + return wrap_with_sharding_op( ctx, val, aval, xc.HloSharding.replicate().to_proto(), unspecified_dims=set(range(aval.ndim))) diff --git a/tests/BUILD b/tests/BUILD index 14d1d409c2ce..eab1d11287e2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -234,6 +234,11 @@ jax_test( "tpu": ["notsan"], # Times out under tsan. "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + enable_configs = [ + "gpu_2gpu_shardy", + "tpu_df_2x2_shardy", + "tpu_pf_2x2_shardy", + ], shard_count = { "cpu": 5, "gpu": 5, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3af5dfe4cd37..392a25f32612 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -402,6 +402,8 @@ def f(inp1, inp2, inp3): @jtu.run_on_devices('tpu') def testBufferDonationWithOutputShardingInferenceAndTokens(self): + if config.use_shardy_partitioner.value: + self.skipTest('b/355263220: Shardy does not support callbacks yet.') mesh = jtu.create_global_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) @@ -453,10 +455,16 @@ def f(x): check_dtypes=False) hlo = f.lower(np.ones(shape)).compiler_ir() - # Annotation from with_sharding_constraint - self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo)) - # Annotation from pjit - self.assertIn('sharding = "{replicated}"', str(hlo)) + if config.use_shardy_partitioner.value: + # Annotation from with_sharding_constraint + self.assertIn('<@mesh, [{"x"}, {"y"}]>', str(hlo)) + # Annotation from pjit + self.assertIn('sharding = #sdy.sharding<@mesh, [{}, {}]>}', str(hlo)) + else: + # Annotation from with_sharding_constraint + self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo)) + # Annotation from pjit + self.assertIn('sharding = "{replicated}"', str(hlo)) def testShardingConstraintWithArray(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) @@ -484,6 +492,8 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintWithArrayOpSharding(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support PositionalSharding") shape = (8, 8) mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @@ -555,8 +565,12 @@ def f(x): self.assertLen(actual[0]['a'].addressable_shards, 4) mlir_str = str(f.lower(x).compiler_ir()) - self.assertIn("unspecified_dims=[0]", mlir_str) - self.assertIn("unspecified_dims=[1]", mlir_str) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {"y"}, {}]>', mlir_str) + self.assertIn('<@mesh, [{"x"}, {?}, {}]>', mlir_str) + else: + self.assertIn("unspecified_dims=[0]", mlir_str) + self.assertIn("unspecified_dims=[1]", mlir_str) @jtu.with_mesh([('x', 2), ('y', 2)]) def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self): @@ -575,8 +589,12 @@ def f(x): x = [{'a': v, 'b': v * 2}, v * 3] mlir_str = str(f.lower(x).compiler_ir()) - self.assertIn("unspecified_dims=[0,1]", mlir_str) - self.assertIn("unspecified_dims=[0,2]", mlir_str) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', mlir_str) + self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', mlir_str) + else: + self.assertIn("unspecified_dims=[0,1]", mlir_str) + self.assertIn("unspecified_dims=[0,2]", mlir_str) def testCaching(self): def f(x): @@ -847,6 +865,9 @@ def f_for_pjit(x): def testOutfeed(self): if xla_bridge.using_pjrt_c_api(): raise unittest.SkipTest('outfeed not implemented in PJRT C API') + if config.use_shardy_partitioner.value: + self.skipTest( + 'b/355263220: outfeed lowering not supported by Shardy') devices = np.array(jax.local_devices()) nr_devices = len(devices) @@ -1280,6 +1301,9 @@ class CustomPartitionerTest(jtu.JaxTestCase): def skip_if_custom_partitioning_not_supported(self): if jtu.is_cloud_tpu(): raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + if config.use_shardy_partitioner.value: + self.skipTest( + 'Custom partitioning is not supported with Shardy yet.') @jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU. @jtu.with_mesh([('x', 4), ('y', 2)]) @@ -1564,6 +1588,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase): ) def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, mesh_axis_names): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1580,6 +1606,8 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, self.assertArraysEqual(out._value, input_data) def test_xla_arr_sharding_mismatch(self): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (6, 2) input_data = np.arange( @@ -1607,6 +1635,8 @@ def test_xla_arr_sharding_mismatch(self): compiled(arr) def test_gda_auto_shardings_len(self): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( @@ -1627,6 +1657,8 @@ def test_gda_auto_shardings_len(self): ) def test_jit_arr_partial_auto_sharding_array( self, mesh_shape, mesh_axis_names, pspec): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( @@ -1667,6 +1699,8 @@ def test_jit_auto_sharding_partial_tuple_input_shardings( self, mesh_shape, mesh_axis_names): if not jtu.test_device_matches(["tpu"]): self.skipTest('Parameters are tupled only on TPU if >2000 parameters') + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) @@ -1838,6 +1872,11 @@ def _checks(out, input_data): ) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): + if config.use_shardy_partitioner.value: + self.skipTest( + 'TODO(b/355263220) Shardy conflict resolution is not complete. Issue ' + 'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while ' + 'Shardy gives it fully replicated.') global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) @@ -2400,6 +2439,10 @@ def test_device_put_sharding_prng(self): self.assertTrue(jax.dtypes.issubdtype(a.dtype, jax.dtypes.prng_key)) self.assertEqual(a.sharding, out_p.sharding) + if config.use_shardy_partitioner.value: + # OpSharding is not supported in shardy. + return + op = xc.OpSharding() op.type = xc.OpSharding.Type.OTHER op.tile_assignment_dimensions = [8] @@ -3405,6 +3448,8 @@ def g(x): jtu.check_grads(g, (arr,), order=2) def test_pjit_out_sharding_preserved(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3483,6 +3528,8 @@ def test_list_in_pspec(self): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) def test_sharding_preserved_trivial(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3535,6 +3582,8 @@ def test_sharding_on_output_with_vmap(self): self.assertEqual(count[0], 1) def test_jit_mul_sum_sharding_preserved(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3608,6 +3657,8 @@ def test_none_out_sharding(self): self.assertEqual(out2.sharding.spec, P()) def test_sharding_preserved_apply_primitive(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3848,6 +3899,9 @@ def f(): f() # doesn't crash def test_lowering_cache_hit_different_devices(self): + if config.use_shardy_partitioner.value: + self.skipTest('b/358322664: different axis names results in ' + 'a cache miss with Shardy.') if jax.device_count() < 4: self.skipTest('Requires >=4 devices') @@ -3945,7 +3999,10 @@ def make_keys(seeds): self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1]', lowered_text) def test_prng_sharding_propagation_with_nested_jit(self): input_shape = (8, 2) @@ -3971,7 +4028,10 @@ def f(): self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1]', lowered_text) def test_partial_sharded_prng_key_inp(self): input_shape = (8, 2, 2) @@ -3995,7 +4055,10 @@ def make_keys(seeds): self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1,2]', lowered_text) def test_jit_partially_specified_shardings(self): @@ -4048,6 +4111,8 @@ def f(*args): f(inps) # doesn't crash def test_spmd_preserves_input_sharding_vmap_grad(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support PositionalSharding") # https://github.com/google/jax/issues/20710 n_devices = jax.device_count() sharding = PositionalSharding(jax.devices()) @@ -4211,6 +4276,9 @@ def f(x): self.assertArraysEqual(out2, np.arange(8) * 2) def test_device_put_efficient_reshard_single_host(self): + if config.use_shardy_partitioner.value: + self.skipTest( + '_different_device_order_reshard is creating a GSPMDSharding') if jax.device_count() < 4: self.skipTest('Requires >= 4 devices') @@ -4235,6 +4303,9 @@ def test_device_put_efficient_reshard_single_host(self): ("8_384", (8, 384)), ) def test_device_put_efficient_reshard_complex_mesh(self, shape): + if config.use_shardy_partitioner.value: + self.skipTest( + '_different_device_order_reshard is creating a GSPMDSharding') if jax.device_count() < 8: self.skipTest('Requires >= 8 devices')