diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 902624192dcf..cc19d96aa194 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -34,6 +34,7 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils @@ -677,6 +678,22 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): ) +@register_lowering_rule(lax.select_n_p) +def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): + if len(cases) != 2: + raise NotImplementedError( + "Mosaic GPU lowering only supports select_n with 2 cases, got" + f" {len(cases)}" + ) + pred_aval, *cases_avals = ctx.avals_in + [out_aval] = ctx.avals_out + pred = _ensure_fa(pred, pred_aval.dtype) + cases = _bcast(*cases, *cases_avals, out_aval) + # ``select`` expects the first case to be the true branch, but ``select_n`` + # orders the cases in reverse. + return pred.select(*reversed(cases)) + + @register_lowering_rule(lax.broadcast_in_dim_p) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, @@ -712,6 +729,16 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), lax.div_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x / y), + lax.rem_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x % y), + lax.and_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x & y), + lax.or_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x | y), + lax.xor_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x ^ y), + lax.gt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x > y), + lax.lt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x < y), + lax.ge_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x >= y), + lax.le_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x <= y), + lax.eq_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x == y), + lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y), }) @@ -909,13 +936,41 @@ def _scan_lowering_rule( return for_out +@register_lowering_rule(lax.cond_p) +def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): + index_aval, *_arg_avals = ctx.avals_in + switch_op = scf_dialect.IndexSwitchOp( + map(mgpu_utils.dtype_to_ir_type, ctx.avals_out), + _as_index(_ensure_ir_value(index, index_aval.dtype)), + ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), + num_caseRegions=len(branches) - 1, + ) + + # ``RegionSequence`` in MLIR does not support slicing, so the + # auto-generated Python bindings for ``caseRegions`` fail at runtime! + # We convert it to a list to work around that. + regions = list(switch_op.regions) + # Move the default region to the back. + regions = regions[1:] + regions[:1] + for branch, region in zip(branches, regions): + with ir.InsertionPoint(region.blocks.append()): + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args + ) + scf_dialect.yield_([ + _ensure_ir_value(out, aval.dtype) + for out, aval in zip(outs, ctx.avals_out) + ]) + return list(switch_op.results) + + def _bcast( x: ir.Value, y: ir.Value, x_aval: jax_core.ShapedArray, y_aval: jax_core.ShapedArray, out_aval: jax_core.ShapedArray, -) -> ir.Value: +) -> tuple[mgpu.FragmentedArray, mgpu.FragmentedArray]: if not isinstance(x, mgpu.FragmentedArray): x_dtype = x_aval.dtype if x_aval.weak_type: @@ -935,6 +990,7 @@ def _bcast( def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: if isinstance(x, mgpu.FragmentedArray): + assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype) return x elif isinstance(x, (np.number, np.ndarray, int, float)): return mgpu.FragmentedArray.splat( @@ -944,12 +1000,14 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: ) elif isinstance(x, ir.Value): if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)): + assert x.type == mgpu_utils.dtype_to_ir_type(dtype) return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype)) raise NotImplementedError(f"Unsupported type: {type(x)}") def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: if isinstance(x, ir.Value): + assert x.type == mgpu_utils.dtype_to_ir_type(dtype) return x elif isinstance(x, (np.number, np.ndarray, int, float)): return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ea45a7dcb7b9..ae6c40b9416d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -784,13 +784,13 @@ def broadcast_minor(self, n): _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed ) - def select(self, x, y): + def select(self, on_true, on_false): if ( not ir.IntegerType.isinstance(self.mlir_dtype) or ir.IntegerType(self.mlir_dtype).width != 1 ): raise NotImplementedError - return self._pointwise(arith.select, x, y) + return self._pointwise(arith.select, on_true, on_false) def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): """Call a function for each value and index.""" diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 4701a72490c8..b5c22734b0d4 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -951,6 +951,8 @@ def dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type: def is_signed(dtype: jax.typing.DTypeLike) -> bool | None: - if jnp.issubdtype(dtype, jnp.integer): + if jnp.issubdtype(dtype, jnp.bool_): + return False + elif jnp.issubdtype(dtype, jnp.integer): return jnp.issubdtype(dtype, jnp.signedinteger) return None diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 3018e0383188..267ddd0c97d5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -376,6 +376,27 @@ def kernel(o_ref): kernel(), jnp.full([256], 5.0, dtype=jnp.float32) ) + def test_cond(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + acc = x_ref[...].sum() + jax.lax.cond( + acc % 2 == 0, + lambda: pl.debug_print("acc * 2: {}", acc * 2), + lambda: pl.debug_print("acc: {}", acc), + ) + o_ref[...] = jnp.broadcast_to(acc, o_ref.shape) + + x = jnp.arange(256) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn("acc * 2:", output()) + @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): # TensorCores can only fuse transposes of 16-bit values, and RHS