diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index cf247ac3f6a4..ae099abc0eb3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -739,6 +739,17 @@ def kernel(x_ref, o_ref): x = jnp.array([0.42, 2.4]).astype(dtype) np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + def test_abs_weak_type(self): + # see https://github.com/google/jax/issues/23191 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.abs(x_ref[...]) + + x = jnp.broadcast_to(-3.2, (4, 4)) # sets `weak_type` to `True` + np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6) + @parameterized.parameters( ("float32", "int32"), ("float64", "int32"),