diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 0c1155ddf1ab..0ddbd4b50bd0 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -37,6 +37,7 @@ from jax._src import core from jax._src import prng from jax._src import test_util as jtu +from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals from jax._src.mesh import AbstractMesh @@ -2636,6 +2637,8 @@ def fwd(a): @jtu.with_config(jax_use_shardy_partitioner=True) +# TODO(phawkins): enable this test unconditionally once shardy is the default. +@unittest.skipIf(sdy is None, "shardy is not enabled") class SdyIntegrationTest(jtu.JaxTestCase): # Verify we can lower to a `ManualComputationOp`. def test_shardy_collective_permute(self):