diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 6db00a53671e..c2009d7d24c9 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -990,21 +990,19 @@ def _minus(x: ir.Value) -> ir.Value: def _add(x: ir.Value, y: ir.Value): x_element_type = _element_type(x.type) y_element_type = _element_type(y.type) - if tt_dialect.PointerType.isinstance(y_element_type): - assert not tt_dialect.PointerType.isinstance(x_element_type) - x, y = y, x - x_element_type, y_element_type = y_element_type, x_element_type if tt_dialect.PointerType.isinstance(x_element_type): + assert not tt_dialect.PointerType.isinstance(y_element_type) return tt_dialect.addptr(x.type, x, y) + if tt_dialect.PointerType.isinstance(y_element_type): + return tt_dialect.addptr(y.type, y, x) assert x.type == y.type, (str(x.type), str(y.type)) if isinstance(x_element_type, ir.IntegerType): return arith_dialect.addi(x, y) - elif isinstance(x_element_type, ir.FloatType): + if isinstance(x_element_type, ir.FloatType): return arith_dialect.addf(x, y) - else: - raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}") + raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}") def _sub(x: ir.Value, y: ir.Value) -> ir.Value: