Skip to content

Commit

Permalink
[jax:pallas] Minor cleanup in Triton add lowering.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668411109
  • Loading branch information
chr1sj0nes authored and jax authors committed Aug 28, 2024
1 parent f0a7266 commit 7d38718
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,21 +990,19 @@ def _minus(x: ir.Value) -> ir.Value:
def _add(x: ir.Value, y: ir.Value):
x_element_type = _element_type(x.type)
y_element_type = _element_type(y.type)
if tt_dialect.PointerType.isinstance(y_element_type):
assert not tt_dialect.PointerType.isinstance(x_element_type)
x, y = y, x
x_element_type, y_element_type = y_element_type, x_element_type

if tt_dialect.PointerType.isinstance(x_element_type):
assert not tt_dialect.PointerType.isinstance(y_element_type)
return tt_dialect.addptr(x.type, x, y)
if tt_dialect.PointerType.isinstance(y_element_type):
return tt_dialect.addptr(y.type, y, x)

assert x.type == y.type, (str(x.type), str(y.type))
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.addi(x, y)
elif isinstance(x_element_type, ir.FloatType):
if isinstance(x_element_type, ir.FloatType):
return arith_dialect.addf(x, y)
else:
raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}")
raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}")


def _sub(x: ir.Value, y: ir.Value) -> ir.Value:
Expand Down

0 comments on commit 7d38718

Please sign in to comment.