From 41f06f8ac802cb532471311a6131a08418bddb34 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 28 Aug 2024 08:53:30 -0700 Subject: [PATCH] [pallas] Disable two very slow tests in pallas_vmap_test on CPU. These take over a minute each, causing timeouts in CI. PiperOrigin-RevId: 668473770 --- tests/pallas/pallas_vmap_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 3c33702b63e3..4b3f47e6f5c1 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest import jax from jax import random -from jax._src.lib import xla_extension from jax._src import config from jax._src import test_util as jtu from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr @@ -208,10 +207,8 @@ def sin(x_ref, o_ref): np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") - # Catches https://github.com/google/jax/issues/18361 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), @@ -229,9 +226,8 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(out, out_ref) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),