From f4793501168e2f25c584f751750f0196cfa03b2a Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 29 Aug 2024 00:50:37 +0100 Subject: [PATCH] Disable implicit type conversion during type matching --- jax/_src/pallas/triton/lowering.py | 2 +- tests/pallas/ops_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 6db00a53671e..e4e2dc0791ad 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -564,7 +564,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: if len(avals) != len(self.arg_types): return False return all( - aval.weak_type or aval.dtype.name == arg_type + aval.dtype.name == arg_type for aval, arg_type in zip(avals, self.arg_types) ) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index cf247ac3f6a4..ae099abc0eb3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -739,6 +739,17 @@ def kernel(x_ref, o_ref): x = jnp.array([0.42, 2.4]).astype(dtype) np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + def test_abs_weak_type(self): + # see https://github.com/google/jax/issues/23191 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.abs(x_ref[...]) + + x = jnp.broadcast_to(-3.2, (4, 4)) # sets `weak_type` to `True` + np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6) + @parameterized.parameters( ("float32", "int32"), ("float64", "int32"),