Skip to content

Commit

Permalink
Improved extern selection in Pallas GPU
Browse files Browse the repository at this point in the history
Previously,

* weakly typed avals matched the wrong externs;
* this was addressed by #23193, which disallowed weakly typed avals entirely.

Here we check if a weakly typed aval can be casted to the extern input dtype
when selecting an extern.

PiperOrigin-RevId: 669378582
  • Loading branch information
superbobry authored and jax authors committed Aug 30, 2024
1 parent f8a4662 commit fb7fa2a
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
return False
return all(
aval.dtype == jnp.dtype(arg_type)
or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind)
for aval, arg_type in zip(avals, self.arg_types)
)

Expand Down

0 comments on commit fb7fa2a

Please sign in to comment.