Skip to content

Commit

Permalink
Added support for lax.cond_p to Pallas Mosaic GPU lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676133222
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 26, 2024
1 parent 0a66e2d commit afb65ae
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 4 deletions.
60 changes: 59 additions & 1 deletion jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.mlir.dialects import memref as memref_dialect
from jax._src.lib.mlir.dialects import scf as scf_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
Expand Down Expand Up @@ -677,6 +678,22 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
)


@register_lowering_rule(lax.select_n_p)
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
if len(cases) != 2:
raise NotImplementedError(
"Mosaic GPU lowering only supports select_n with 2 cases, got"
f" {len(cases)}"
)
pred_aval, *cases_avals = ctx.avals_in
[out_aval] = ctx.avals_out
pred = _ensure_fa(pred, pred_aval.dtype)
cases = _bcast(*cases, *cases_avals, out_aval)
# ``select`` expects the first case to be the true branch, but ``select_n``
# orders the cases in reverse.
return pred.select(*reversed(cases))


@register_lowering_rule(lax.broadcast_in_dim_p)
def _broadcast_in_dim_lowering_rule(
ctx: LoweringRuleContext,
Expand Down Expand Up @@ -712,6 +729,16 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y),
lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y),
lax.div_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x / y),
lax.rem_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x % y),
lax.and_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x & y),
lax.or_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x | y),
lax.xor_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x ^ y),
lax.gt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x > y),
lax.lt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x < y),
lax.ge_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x >= y),
lax.le_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x <= y),
lax.eq_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x == y),
lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y),
})


Expand Down Expand Up @@ -909,13 +936,41 @@ def _scan_lowering_rule(
return for_out


@register_lowering_rule(lax.cond_p)
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
index_aval, *_arg_avals = ctx.avals_in
switch_op = scf_dialect.IndexSwitchOp(
map(mgpu_utils.dtype_to_ir_type, ctx.avals_out),
_as_index(_ensure_ir_value(index, index_aval.dtype)),
ir.DenseI64ArrayAttr.get(range(len(branches) - 1)),
num_caseRegions=len(branches) - 1,
)

# ``RegionSequence`` in MLIR does not support slicing, so the
# auto-generated Python bindings for ``caseRegions`` fail at runtime!
# We convert it to a list to work around that.
regions = list(switch_op.regions)
# Move the default region to the back.
regions = regions[1:] + regions[:1]
for branch, region in zip(branches, regions):
with ir.InsertionPoint(region.blocks.append()):
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args
)
scf_dialect.yield_([
_ensure_ir_value(out, aval.dtype)
for out, aval in zip(outs, ctx.avals_out)
])
return list(switch_op.results)


def _bcast(
x: ir.Value,
y: ir.Value,
x_aval: jax_core.ShapedArray,
y_aval: jax_core.ShapedArray,
out_aval: jax_core.ShapedArray,
) -> ir.Value:
) -> tuple[mgpu.FragmentedArray, mgpu.FragmentedArray]:
if not isinstance(x, mgpu.FragmentedArray):
x_dtype = x_aval.dtype
if x_aval.weak_type:
Expand All @@ -935,6 +990,7 @@ def _bcast(

def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
if isinstance(x, mgpu.FragmentedArray):
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return mgpu.FragmentedArray.splat(
Expand All @@ -944,12 +1000,14 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
)
elif isinstance(x, ir.Value):
if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)):
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype))
raise NotImplementedError(f"Unsupported type: {type(x)}")


def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
if isinstance(x, ir.Value):
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,13 +784,13 @@ def broadcast_minor(self, n):
_registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed
)

def select(self, x, y):
def select(self, on_true, on_false):
if (
not ir.IntegerType.isinstance(self.mlir_dtype)
or ir.IntegerType(self.mlir_dtype).width != 1
):
raise NotImplementedError
return self._pointwise(arith.select, x, y)
return self._pointwise(arith.select, on_true, on_false)

def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
"""Call a function for each value and index."""
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,8 @@ def dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type:


def is_signed(dtype: jax.typing.DTypeLike) -> bool | None:
if jnp.issubdtype(dtype, jnp.integer):
if jnp.issubdtype(dtype, jnp.bool_):
return False
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.issubdtype(dtype, jnp.signedinteger)
return None
21 changes: 21 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,27 @@ def kernel(o_ref):
kernel(), jnp.full([256], 5.0, dtype=jnp.float32)
)

def test_cond(self):

@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
acc = x_ref[...].sum()
jax.lax.cond(
acc % 2 == 0,
lambda: pl.debug_print("acc * 2: {}", acc * 2),
lambda: pl.debug_print("acc: {}", acc),
)
o_ref[...] = jnp.broadcast_to(acc, o_ref.shape)

x = jnp.arange(256)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))

self.assertIn("acc * 2:", output())

@parameterized.parameters(jnp.float16, jnp.float32)
def test_wgmma(self, dtype):
# TensorCores can only fuse transposes of 16-bit values, and RHS
Expand Down

0 comments on commit afb65ae

Please sign in to comment.