Skip to content

Commit

Permalink
[Pallas:TPU] Fix lowering of convert_element_type(int32) -> bool.
Browse files Browse the repository at this point in the history
We need to add a condition on vector type since both operands of arith::CmpIOp must have same type.

PiperOrigin-RevId: 679243567
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
1 parent 6f7ad64 commit 771c241
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
15 changes: 12 additions & 3 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,8 @@ def _convert_element_type_lowering_rule(
del weak_type
del sharding
out_aval = ctx.avals_out[0]
old_dtype = ctx.avals_in[0].dtype
in_aval = ctx.avals_in[0]
old_dtype = in_aval.dtype
out_type = aval_to_ir_type(out_aval)

if old_dtype == new_dtype:
Expand Down Expand Up @@ -1680,8 +1681,16 @@ def _convert_element_type_lowering_rule(
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
const_type = _dtype_to_ir_type(old_dtype)
const_zero = ir.IntegerAttr.get(const_type, 0)
const_zero = arith.ConstantOp(const_type, const_zero)
return arith.CmpIOp(predicate, x, const_zero).result
if in_aval.shape:
in_type = aval_to_ir_type(in_aval, is_kernel_boundary=False)
vector_zeros = arith.ConstantOp(
in_type,
ir.DenseElementsAttr.get_splat(in_type, const_zero),
)
return arith.CmpIOp(predicate, x, vector_zeros).result
return arith.CmpIOp(
predicate, x, arith.ConstantOp(const_type, const_zero)
).result
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
multiple_results=False)(ctx, x)

Expand Down
2 changes: 1 addition & 1 deletion tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def test_cast(self, from_dtype, to_dtype, data):
self.skipTest("Not supported: bad canonicalization")
if from_dtype == "bool" and to_dtype in {"int16", "int8"}:
self.skipTest("Not supported: cannot extend to sub-32 bit types")
if from_dtype in {"int32", "bfloat16", "float32"} and to_dtype == "bool":
if from_dtype in {"bfloat16", "float32"} and to_dtype == "bool":
self.skipTest("Not supported: unsupported relayout")
if from_dtype == "bool" and to_dtype in {"int32", "bfloat16", "float32"}:
self.skipTest("Not supported: unsupported relayout")
Expand Down

0 comments on commit 771c241

Please sign in to comment.