Skip to content

Commit

Permalink
[pallas] Disable two very slow tests in pallas_vmap_test on CPU.
Browse files Browse the repository at this point in the history
These take over a minute each, causing timeouts in CI.

PiperOrigin-RevId: 668473770
  • Loading branch information
hawkinsp authored and jax authors committed Aug 28, 2024
1 parent f0a7266 commit 41f06f8
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions tests/pallas/pallas_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from absl.testing import absltest
import jax
from jax import random
from jax._src.lib import xla_extension
from jax._src import config
from jax._src import test_util as jtu
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
Expand Down Expand Up @@ -208,10 +207,8 @@ def sin(x_ref, o_ref):
np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.skip_on_devices("cpu") # Test is very slow on CPU
def test_small_large_vmap(self):
if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]):
self.skipTest("Test is very slow under TSAN")

# Catches https://github.com/google/jax/issues/18361
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
Expand All @@ -229,9 +226,8 @@ def add_one(x_ref, o_ref):

np.testing.assert_allclose(out, out_ref)

@jtu.skip_on_devices("cpu") # Test is very slow on CPU
def test_small_small_large_vmap(self):
if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]):
self.skipTest("Test is very slow under TSAN")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
Expand Down

0 comments on commit 41f06f8

Please sign in to comment.