diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 3259ee634e1f..f2a4229223bd 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -580,6 +580,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: return False return all( aval.dtype == jnp.dtype(arg_type) + or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind) for aval, arg_type in zip(avals, self.arg_types) )