From b3615a9a6aa6371287e22f43b41c2a4252b45c6b Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 27 Aug 2024 22:55:48 +0100 Subject: [PATCH] Test if the new test cases fails --- tests/pallas/ops_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index cf247ac3f6a4..27f3c6083c34 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -739,6 +739,18 @@ def kernel(x_ref, o_ref): x = jnp.array([0.42, 2.4]).astype(dtype) np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + def test_abs_weak_type(self): + # see https://github.com/google/jax/issues/23191 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.abs(x_ref[...]) + + x = jnp.array([-3.2], dtype=jnp.float32) + x.aval.weak_type = True # manually set `weak_type` to `True` + np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6) + @parameterized.parameters( ("float32", "int32"), ("float64", "int32"),