From e86518479cf898161e4ba3763fb605f106b5a7a9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 27 Sep 2024 08:31:49 -0700 Subject: [PATCH] Fix test failure when shardy is not enabled. PiperOrigin-RevId: 679601133 --- tests/shard_map_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 0c1155ddf1ab..0ddbd4b50bd0 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -37,6 +37,7 @@ from jax._src import core from jax._src import prng from jax._src import test_util as jtu +from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals from jax._src.mesh import AbstractMesh @@ -2636,6 +2637,8 @@ def fwd(a): @jtu.with_config(jax_use_shardy_partitioner=True) +# TODO(phawkins): enable this test unconditionally once shardy is the default. +@unittest.skipIf(sdy is None, "shardy is not enabled") class SdyIntegrationTest(jtu.JaxTestCase): # Verify we can lower to a `ManualComputationOp`. def test_shardy_collective_permute(self):