Skip to content

Commit

Permalink
Merge pull request #23193 from ayaka14732:pallas-gpu-type-conversion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668670205
  • Loading branch information
jax authors committed Aug 29, 2024
2 parents 57a9b18 + f479350 commit 48a9159
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
if len(avals) != len(self.arg_types):
return False
return all(
aval.weak_type or aval.dtype.name == arg_type
aval.dtype.name == arg_type
for aval, arg_type in zip(avals, self.arg_types)
)

Expand Down
11 changes: 11 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,17 @@ def kernel(x_ref, o_ref):
x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/google/jax/issues/23191
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.abs(x_ref[...])

x = jnp.broadcast_to(-3.2, (4, 4)) # sets `weak_type` to `True`
np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6)

@parameterized.parameters(
("float32", "int32"),
("float64", "int32"),
Expand Down

0 comments on commit 48a9159

Please sign in to comment.