Skip to content

Commit

Permalink
Fix test failure when shardy is not enabled.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679601133
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Sep 27, 2024
1 parent 26632fd commit e865184
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from jax._src import core
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.ad_checkpoint import saved_residuals
from jax._src.mesh import AbstractMesh
Expand Down Expand Up @@ -2636,6 +2637,8 @@ def fwd(a):


@jtu.with_config(jax_use_shardy_partitioner=True)
# TODO(phawkins): enable this test unconditionally once shardy is the default.
@unittest.skipIf(sdy is None, "shardy is not enabled")
class SdyIntegrationTest(jtu.JaxTestCase):
# Verify we can lower to a `ManualComputationOp`.
def test_shardy_collective_permute(self):
Expand Down

0 comments on commit e865184

Please sign in to comment.