diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 7bba9f01bec9..91a34a530fd2 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -707,17 +707,11 @@ def run(interpret=False): for value in values ) def test_sign(self, dtype, value): - if ( - not jax.config.x64_enabled - and dtype in (jnp.uint64, jnp.int64, jnp.float64) - ): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") - if ( - jtu.test_device_matches(["tpu"]) - and dtype in (jnp.uint16, jnp.int16, jnp.bfloat16, jnp.float16) - ): - self.skipTest("16-bit types are not supported on TPU") + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("Only 32-bit select supported") @functools.partial( self.pallas_call, @@ -753,37 +747,6 @@ def kernel(x_ref, o_ref): expected = lax.erf_inv(x) np.testing.assert_array_equal(out, expected) - -class OpsInterpretTest(OpsTest): - INTERPRET = True - - def test_debug_print(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - grid=1, - ) - def kernel(x_ref, o_ref): - jax.debug.print("x = {}", x_ref) - - x = jnp.array([4.2, 2.4]).astype(jnp.float32) - with jtu.capture_stdout() as output: - jax.block_until_ready(kernel(x)) - jax.effects_barrier() - - self.assertIn("x = [4.2 2.4]", output()) - - -class OpsExtraTest(PallasBaseTest): - """These are additional ops tests that have not been ported to TPU yet.""" - # TODO: fix these for TPU and merge with OpsTest. - - def setUp(self): - super().setUp() - if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - # TODO: most tests fail on TPU in non-interpret mode - self.skipTest("On TPU the test works only in interpret mode") - ELEMENTWISE_OPS = [ ( [jnp.abs, jnp.negative], @@ -811,6 +774,17 @@ def setUp(self): for fn, dtype in itertools.product(*args) ) def test_elementwise(self, fn, dtype): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + + if jtu.test_device_matches(["tpu"]) and fn in ( + jnp.acosh, jnp.asin, jnp.atanh, jnp.cbrt, jnp.cos, jnp.tan, + ): + self.skipTest(f"{fn.__name__} not implemented for TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1 ) @@ -841,6 +815,9 @@ def kernel(x_ref, o_ref): ("float64", "float64"), ) def test_pow(self, x_dtype, y_dtype): + if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), x_dtype), grid=1 ) @@ -867,8 +844,12 @@ def kernel(x_ref, o_ref): @parameterized.parameters("float32", "float64") def test_nextafter(self, dtype): - if jtu.test_device_matches(["tpu"]) and dtype == "float64": - self.skipTest("float64 disabled on TPU.") + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented: nextafter") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 ) @@ -901,6 +882,13 @@ def test_comparison(self, fn, dtype): if jtu.test_device_matches(["gpu"]) and dtype == "bool": self.skipTest("Not implemented on GPU.") + if jtu.test_device_matches(["tpu"]) and dtype == "float16": + self.skipTest("float16 is not supported on TPU") + + # TODO(ayx): skipped due to https://github.com/jax-ml/jax/issues/23972 + if jtu.test_device_matches(["tpu"]) and dtype == "uint32": + self.skipTest("not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), grid=1) @@ -979,6 +967,9 @@ def kernel(x_ref, y_ref, o_ref): for fn, dtype in itertools.product(*args) ) def test_binary(self, f, dtype): + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1 ) @@ -986,7 +977,7 @@ def kernel(x_ref, y_ref, o_ref): o_ref[...] = f(x_ref[...], y_ref[...]) x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) - if (f == jnp.bitwise_left_shift): + if f == jnp.bitwise_left_shift: y = jnp.array([3, 1, 4, 5, 2, 2, 2, 4]).astype(dtype) else: y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) @@ -999,6 +990,9 @@ def kernel(x_ref, y_ref, o_ref): ((8, 16, 2), jnp.int8, 1), ) def test_broadcasted_iota(self, shape, dtype, dimension): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Only 32-bit integer iota supported") + f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) @functools.partial( @@ -1011,8 +1005,12 @@ def kernel(o_ref): @parameterized.parameters("float16", "bfloat16", "float32") def test_approx_tanh(self, dtype): + if jtu.test_device_matches(["tpu"]): + self.skipTest("The test is for GPU only") + if self.INTERPRET: self.skipTest("approx_tanh is not supported in interpret mode") + if (dtype == "bfloat16" and not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") @@ -1034,6 +1032,9 @@ def kernel(x_ref, o_ref): ) def test_elementwise_inline_asm(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented: elementwise_inline_asm_p") + if self.INTERPRET: self.skipTest( "elementwise_inline_asm is not supported in interpret mode" @@ -1127,6 +1128,9 @@ def kernel(x_ref, o_ref): ((64,), (32, 2)), ) def test_reshape(self, in_shape, out_shape): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), @@ -1156,6 +1160,10 @@ def f(x_ref, o_ref): # fmt: on ) def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): + # Unsupported implicit dim change: from "32,{0,0},(2,128),-1" to none + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), @@ -1182,6 +1190,10 @@ def kernel(o_ref): ) def test_where_broadcasting(self): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx), @@ -1207,6 +1219,10 @@ def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): ((), (2, 2), ()), ) def test_broadcast_in_dim(self, in_shape, out_shape, dims): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), @@ -1227,6 +1243,12 @@ def f(x_ref, o_ref): trans_y=[False, True], ) def test_dot(self, size, dtype, trans_x, trans_y): + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented: Transposed LHS") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((size, size), dtype), @@ -1249,6 +1271,9 @@ def dot(x_ref, y_ref, o_ref): block_size=[1, 2, 32, 64, 128], ) def test_masked_load_store(self, size, block_size): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented") + @functools.partial( self.pallas_call, out_shape=(jax.ShapeDtypeStruct((size,), floatx)), @@ -1290,15 +1315,18 @@ def test_strided_load(self): # Reproducer from https://github.com/jax-ml/jax/issues/20895. @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), ) def kernel(x_ref, o_ref): o_ref[...] = x_ref[::4] - x = jnp.arange(16, dtype=jnp.float32) + x = jnp.arange(64, dtype=jnp.float32).reshape((16, 4)) np.testing.assert_array_equal(kernel(x), x[::4]) def test_broadcasted_load_store(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Unimplemented primitive: broadcast_to") + m, n = 16, 32 @functools.partial( @@ -1320,6 +1348,10 @@ def load(x_ref, o_ref): ((16, 32), (16, 16)), ) def test_invalid_broadcasted_load(self, x_shape, mask_shape): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + if self.INTERPRET: self.skipTest("No broadcasting checks in pl.load in interpret mode") @@ -1342,6 +1374,10 @@ def kernel(x_ref, mask_ref, o_ref): self.fail("Expected exception due to invalid broadcasting") def test_swap(self): + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24023 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + m, n = 16, 32 @functools.partial( @@ -1421,6 +1457,10 @@ def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), ) def test_scalar_atomic(self, op, value, numpy_op): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), value.dtype), @@ -1452,6 +1492,9 @@ def atomic_kernel(x_ref, _, o_ref): @parameterized.parameters((0,), (1,)) def test_array_atomic_add(self, axis): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Unimplemented primitive: broadcast_to") + m, n = 32, 8 if axis == 0: grid = m @@ -1489,6 +1532,10 @@ def reduce(x_ref, _, y_ref): (2, 1, 1), ) def test_atomic_cas(self, init_value, cmp, new_value): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): self.skipTest("Not supported on GPU in 64-bit mode") @@ -1507,6 +1554,10 @@ def swap(_, lock_ref, out_ref): @parameterized.parameters(1, 2, 3, 4, 8) def test_atomic_counter(self, num_threads): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + if self.INTERPRET: self.skipTest("While loop not supported in interpret mode.") @@ -1532,6 +1583,10 @@ def _cond(_): @parameterized.parameters(False, True) def test_reduce_only_dim(self, use_store): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + m = 32 x = random.normal(random.key(0), (m,), dtype=jnp.float32) out_shape = jax.ShapeDtypeStruct((), x.dtype) @@ -1573,9 +1628,10 @@ def reduce(x_ref, y_ref): if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): - m, n = 32, 8 + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") - if not jax.config.x64_enabled and dtype in ("float64", "int64", "uint64"): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects @@ -1587,6 +1643,12 @@ def test_array_reduce(self, op, dtype, axis): ): self.skipTest("Not supported on GPU in 64-bit mode") + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + + m, n = 32, 8 + def make_x(key): if jnp.issubdtype(dtype, jnp.integer): return random.permutation( @@ -1623,6 +1685,9 @@ def reduce(x_ref, y_ref): dtype=["float16", "float32", "int32", "uint32"], ) def test_cumsum(self, dtype, axis): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + m, n = 32, 8 out_dtype = dtype @@ -1649,9 +1714,25 @@ def reduce(x_ref, y_ref): np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) -class OpsExtraInterpretTest(OpsExtraTest): +class OpsInterpretTest(OpsTest): INTERPRET = True + def test_debug_print(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + ) + def kernel(x_ref, o_ref): + jax.debug.print("x = {}", x_ref) + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("x = [4.2 2.4]", output()) + class PallasPrimitivesTest(PallasBaseTest):