diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 3c33702b63e3..4b3f47e6f5c1 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest import jax from jax import random -from jax._src.lib import xla_extension from jax._src import config from jax._src import test_util as jtu from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr @@ -208,10 +207,8 @@ def sin(x_ref, o_ref): np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") - # Catches https://github.com/google/jax/issues/18361 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), @@ -229,9 +226,8 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(out, out_ref) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),