Skip to content

Commit

Permalink
Add jax_test configs for shardy and enable it for pjit_test.py and fi…
Browse files Browse the repository at this point in the history
…x any tests.

Tests fixed include:

- `test_globally_sharded_key_array_8x4_multi_device`
  - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
  - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
  - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
  - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
  - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
  - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

PiperOrigin-RevId: 666777167
  • Loading branch information
bartchr808 authored and jax authors committed Aug 23, 2024
1 parent f54e220 commit 71b7e78
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 20 deletions.
27 changes: 18 additions & 9 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,14 +1119,15 @@ def lower_jaxpr_to_module(
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
if config.use_shardy_partitioner.value:
assert (isinstance(axis_context, sharding_impls.ShardingContext) and
axis_context.mesh_shape is not None)
ctx.module.body.append(
dialects.sdy.MeshOp(
"mesh",
dialects.sdy.MeshAttr.get(
[dialects.sdy.MeshAxisAttr.get(name, size)
for name, size in axis_context.mesh_shape])))
if (isinstance(axis_context, sharding_impls.ShardingContext) and
axis_context.mesh_shape is not None):
sdy_mesh_attr = dialects.sdy.MeshAttr.get(
[dialects.sdy.MeshAxisAttr.get(name, size)
for name, size in axis_context.mesh_shape])
else:
sdy_mesh_attr = dialects.sdy.MeshAttr.get([])

ctx.module.body.append(dialects.sdy.MeshOp("mesh", sdy_mesh_attr))
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
Expand Down Expand Up @@ -1633,7 +1634,15 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
return wrap_with_sharding_op(
if config.use_shardy_partitioner.value:
physical_ndim = core.physical_aval(aval).ndim
s = sharding.SdyArraySharding(
mesh_name='mesh',
dimension_shardings=[sharding.SdyDimSharding(axes=[], is_closed=i >= aval.ndim)
for i in range(physical_ndim)])
return wrap_with_sharding_op(ctx, val, aval, s)
else:
return wrap_with_sharding_op(
ctx, val, aval, xc.HloSharding.replicate().to_proto(),
unspecified_dims=set(range(aval.ndim)))

Expand Down
5 changes: 5 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ jax_test(
"tpu": ["notsan"], # Times out under tsan.
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"gpu_2gpu_shardy",
"tpu_df_2x2_shardy",
"tpu_pf_2x2_shardy",
],
shard_count = {
"cpu": 5,
"gpu": 5,
Expand Down
93 changes: 82 additions & 11 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ def f(inp1, inp2, inp3):

@jtu.run_on_devices('tpu')
def testBufferDonationWithOutputShardingInferenceAndTokens(self):
if config.use_shardy_partitioner.value:
self.skipTest('b/355263220: Shardy does not support callbacks yet.')
mesh = jtu.create_global_mesh((2,), 'x')
s = NamedSharding(mesh, P('x'))

Expand Down Expand Up @@ -453,10 +455,16 @@ def f(x):
check_dtypes=False)

hlo = f.lower(np.ones(shape)).compiler_ir()
# Annotation from with_sharding_constraint
self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo))
# Annotation from pjit
self.assertIn('sharding = "{replicated}"', str(hlo))
if config.use_shardy_partitioner.value:
# Annotation from with_sharding_constraint
self.assertIn('<@mesh, [{"x"}, {"y"}]>', str(hlo))
# Annotation from pjit
self.assertIn('sharding = #sdy.sharding<@mesh, [{}, {}]>}', str(hlo))
else:
# Annotation from with_sharding_constraint
self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo))
# Annotation from pjit
self.assertIn('sharding = "{replicated}"', str(hlo))

def testShardingConstraintWithArray(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
Expand Down Expand Up @@ -484,6 +492,8 @@ def f(x):
self.assertIn("sharding={replicated}", hlo.as_hlo_text())

def testShardingConstraintWithArrayOpSharding(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy doesn't support PositionalSharding")
shape = (8, 8)
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
s = NamedSharding(mesh, P(None))
Expand Down Expand Up @@ -555,8 +565,12 @@ def f(x):
self.assertLen(actual[0]['a'].addressable_shards, 4)

mlir_str = str(f.lower(x).compiler_ir())
self.assertIn("unspecified_dims=[0]", mlir_str)
self.assertIn("unspecified_dims=[1]", mlir_str)
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {"y"}, {}]>', mlir_str)
self.assertIn('<@mesh, [{"x"}, {?}, {}]>', mlir_str)
else:
self.assertIn("unspecified_dims=[0]", mlir_str)
self.assertIn("unspecified_dims=[1]", mlir_str)

@jtu.with_mesh([('x', 2), ('y', 2)])
def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self):
Expand All @@ -575,8 +589,12 @@ def f(x):
x = [{'a': v, 'b': v * 2}, v * 3]

mlir_str = str(f.lower(x).compiler_ir())
self.assertIn("unspecified_dims=[0,1]", mlir_str)
self.assertIn("unspecified_dims=[0,2]", mlir_str)
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', mlir_str)
self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', mlir_str)
else:
self.assertIn("unspecified_dims=[0,1]", mlir_str)
self.assertIn("unspecified_dims=[0,2]", mlir_str)

def testCaching(self):
def f(x):
Expand Down Expand Up @@ -847,6 +865,9 @@ def f_for_pjit(x):
def testOutfeed(self):
if xla_bridge.using_pjrt_c_api():
raise unittest.SkipTest('outfeed not implemented in PJRT C API')
if config.use_shardy_partitioner.value:
self.skipTest(
'b/355263220: outfeed lowering not supported by Shardy')

devices = np.array(jax.local_devices())
nr_devices = len(devices)
Expand Down Expand Up @@ -1280,6 +1301,9 @@ class CustomPartitionerTest(jtu.JaxTestCase):
def skip_if_custom_partitioning_not_supported(self):
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
if config.use_shardy_partitioner.value:
self.skipTest(
'Custom partitioning is not supported with Shardy yet.')

@jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU.
@jtu.with_mesh([('x', 4), ('y', 2)])
Expand Down Expand Up @@ -1564,6 +1588,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
)
def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape,
mesh_axis_names):
if config.use_shardy_partitioner.value:
self.skipTest('Must register auto partitioner for Shardy')
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
Expand All @@ -1580,6 +1606,8 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape,
self.assertArraysEqual(out._value, input_data)

def test_xla_arr_sharding_mismatch(self):
if config.use_shardy_partitioner.value:
self.skipTest('Must register auto partitioner for Shardy')
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (6, 2)
input_data = np.arange(
Expand Down Expand Up @@ -1607,6 +1635,8 @@ def test_xla_arr_sharding_mismatch(self):
compiled(arr)

def test_gda_auto_shardings_len(self):
if config.use_shardy_partitioner.value:
self.skipTest('Must register auto partitioner for Shardy')
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (4, 2)
input_data = np.arange(
Expand All @@ -1627,6 +1657,8 @@ def test_gda_auto_shardings_len(self):
)
def test_jit_arr_partial_auto_sharding_array(
self, mesh_shape, mesh_axis_names, pspec):
if config.use_shardy_partitioner.value:
self.skipTest('Must register auto partitioner for Shardy')
mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
global_input_shape = (8, 4)
input_data = np.arange(
Expand Down Expand Up @@ -1667,6 +1699,8 @@ def test_jit_auto_sharding_partial_tuple_input_shardings(
self, mesh_shape, mesh_axis_names):
if not jtu.test_device_matches(["tpu"]):
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')
if config.use_shardy_partitioner.value:
self.skipTest('Must register auto partitioner for Shardy')

mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
global_input_shape = (8, 4)
Expand Down Expand Up @@ -1838,6 +1872,11 @@ def _checks(out, input_data):
)
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
s2_shape, s3_shape, s4_shape):
if config.use_shardy_partitioner.value:
self.skipTest(
'TODO(b/355263220) Shardy conflict resolution is not complete. Issue '
'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while '
'Shardy gives it fully replicated.')
global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
global_input_shape = (8, 2)

Expand Down Expand Up @@ -2400,6 +2439,10 @@ def test_device_put_sharding_prng(self):
self.assertTrue(jax.dtypes.issubdtype(a.dtype, jax.dtypes.prng_key))
self.assertEqual(a.sharding, out_p.sharding)

if config.use_shardy_partitioner.value:
# OpSharding is not supported in shardy.
return

op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [8]
Expand Down Expand Up @@ -3405,6 +3448,8 @@ def g(x):
jtu.check_grads(g, (arr,), order=2)

def test_pjit_out_sharding_preserved(self):
if config.use_shardy_partitioner.value:
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
Expand Down Expand Up @@ -3483,6 +3528,8 @@ def test_list_in_pspec(self):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))

def test_sharding_preserved_trivial(self):
if config.use_shardy_partitioner.value:
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
Expand Down Expand Up @@ -3535,6 +3582,8 @@ def test_sharding_on_output_with_vmap(self):
self.assertEqual(count[0], 1)

def test_jit_mul_sum_sharding_preserved(self):
if config.use_shardy_partitioner.value:
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
Expand Down Expand Up @@ -3608,6 +3657,8 @@ def test_none_out_sharding(self):
self.assertEqual(out2.sharding.spec, P())

def test_sharding_preserved_apply_primitive(self):
if config.use_shardy_partitioner.value:
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))

Expand Down Expand Up @@ -3848,6 +3899,9 @@ def f():
f() # doesn't crash

def test_lowering_cache_hit_different_devices(self):
if config.use_shardy_partitioner.value:
self.skipTest('b/358322664: different axis names results in '
'a cache miss with Shardy.')
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')

Expand Down Expand Up @@ -3945,7 +3999,10 @@ def make_keys(seeds):
self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None)))

lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1]', lowered_text)
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
else:
self.assertIn('unspecified_dims=[0,1]', lowered_text)

def test_prng_sharding_propagation_with_nested_jit(self):
input_shape = (8, 2)
Expand All @@ -3971,7 +4028,10 @@ def f():
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None)))

lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1]', lowered_text)
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text)
else:
self.assertIn('unspecified_dims=[0,1]', lowered_text)

def test_partial_sharded_prng_key_inp(self):
input_shape = (8, 2, 2)
Expand All @@ -3995,7 +4055,10 @@ def make_keys(seeds):
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x')))

lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {?}, {}]>', lowered_text)
else:
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)

def test_jit_partially_specified_shardings(self):

Expand Down Expand Up @@ -4048,6 +4111,8 @@ def f(*args):
f(inps) # doesn't crash

def test_spmd_preserves_input_sharding_vmap_grad(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy doesn't support PositionalSharding")
# https://github.com/google/jax/issues/20710
n_devices = jax.device_count()
sharding = PositionalSharding(jax.devices())
Expand Down Expand Up @@ -4211,6 +4276,9 @@ def f(x):
self.assertArraysEqual(out2, np.arange(8) * 2)

def test_device_put_efficient_reshard_single_host(self):
if config.use_shardy_partitioner.value:
self.skipTest(
'_different_device_order_reshard is creating a GSPMDSharding')
if jax.device_count() < 4:
self.skipTest('Requires >= 4 devices')

Expand All @@ -4235,6 +4303,9 @@ def test_device_put_efficient_reshard_single_host(self):
("8_384", (8, 384)),
)
def test_device_put_efficient_reshard_complex_mesh(self, shape):
if config.use_shardy_partitioner.value:
self.skipTest(
'_different_device_order_reshard is creating a GSPMDSharding')
if jax.device_count() < 8:
self.skipTest('Requires >= 8 devices')

Expand Down

0 comments on commit 71b7e78

Please sign in to comment.