diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0211b549a49e..3e5c7403ed6d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -55,6 +55,7 @@ zip, unsafe_zip = util.safe_zip, zip partial = functools.partial +SMEM = gpu_core.SMEM _smem_estimators = {} @@ -330,6 +331,45 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): start_indices, [grid_mapping.num_inputs] ) + smem_scratch_it = iter(scratch_buffers_smem) + scratch_buffers_template = [] + should_discharge = [] + accs = [] + for aval in scratch_avals: + match aval: + case gpu_core.WGMMAAbstractAccumulatorRef(): + scratch_buffers_template.append(None) + should_discharge.append(True) + accs.append( + mgpu.WGMMAAccumulator.zero( + *aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype) + ) + ) + case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM: + scratch_buffers_template.append(next(smem_scratch_it)) + should_discharge.append(False) + case _: + raise NotImplementedError( + f"Unsupported scratch operand type: {aval}" + ) + assert not jaxpr.outvars + if any(should_discharge): + # User-visible WGMMA APIs use the effectful accumulator references, but we + # can't lower that directly to Mosaic GPU that uses pure dataflow for + # accumulators. So we have to discharge the effects first. + assert not jaxpr.constvars + should_discharge = ( + [False] * len(grid_mapping.block_mappings) + + should_discharge + + [False] * len(extra_barriers) + ) + with grid_mapping.trace_env(): + lowered_jaxpr, _ = discharge.discharge_state( + jaxpr, (), should_discharge=should_discharge + ) + else: + lowered_jaxpr = jaxpr + # Precompute the total number of bytes transferred from GMEM to SMEM, # so that we can do a single arrive instruction for all of the inputs. in_transfer_bytes = 0 @@ -414,8 +454,8 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: for idx in range(grid_mapping.num_inputs): fetch(idx, _as_index(slot), _as_index(slot)) - @mgpu.fori(_as_index(num_steps), ()) - def _(step, _): + @mgpu.fori(_as_index(num_steps), accs) + def _(step, accs): slot = arith_dialect.remui(step, _as_index(num_stages)) if grid_mapping.num_inputs: # Only wait if async copies were issued. @@ -427,9 +467,18 @@ def _(step, _): else buffers_gmem[idx] for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem)) ] - args.extend(scratch_buffers_smem) + accs_it = iter(accs) + scratch_buffers = [ + b if b is not None else next(accs_it) + for b in scratch_buffers_template + ] + args.extend(scratch_buffers) + # TODO(apaszke): This assumes barriers come after buffers in scratch args, + # but that's not necessarily true. args.extend(extra_barriers) - _ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args) + new_accs = lower_jaxpr_to_mosaic_gpu( + module_ctx, launch_ctx, lowered_jaxpr, args + ) mgpu.commit_shared() with mgpu.single_thread(): @@ -445,20 +494,22 @@ def _(step, _): fetch(idx, next_step, slot) barriers[slot].arrive_expect_tx(in_transfer_bytes) - return () + return list(new_accs) launch_ctx.await_async_copy(0) scratch_avals = [ var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops] ] + local_spaces = (gpu_core.SMEM, gpu_core.REGS) if not all( isinstance(aval, pallas_core.AbstractMemoryRef) - and aval.memory_space is gpu_core.SMEM + and aval.memory_space in local_spaces for aval in scratch_avals ): raise TypeError( - f"All scratch operands must be in SMEM, but got: {scratch_avals}" + "All scratch operands must be SMEM references or accumulators (ACC)," + f" but got: {scratch_avals}" ) extra_barriers = [ mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) @@ -469,6 +520,7 @@ def _(step, _): jax.ShapeDtypeStruct(aval.shape, aval.dtype) for aval in scratch_avals if not isinstance(aval.dtype, gpu_core.BarrierType) + and aval.memory_space == gpu_core.SMEM ] smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") if smem_scratch_bytes is None: diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 5295507ff2fa..dcec631e389b 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -25,7 +25,6 @@ from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering import jax.experimental.mosaic.gpu as mgpu -import jax.numpy as jnp async_copy_p = jax_core.Primitive("async_copy") async_copy_p.multiple_results = True @@ -120,7 +119,7 @@ class _WGMMAPipelineEffect(effects.Effect): wgmma_ref_p = jax_core.Primitive("wgmma_ref") wgmma_ref_p.multiple_results = True -def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): +def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128): """Asynchronous warp group matmul. The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires @@ -137,12 +136,6 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") - rhs_transpose = ( - (jnp.dtype(b.dtype).itemsize == 2) - if rhs_transpose is None - else rhs_transpose - ) - ma, ka, tma, tka = a.shape kb, nb, tkb, tnb = b.shape mc, nc = acc.shape diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 5abf4e2e363e..7c2bb12b4607 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -437,6 +437,62 @@ def kernel(a_ref, b_ref): )(a) np.testing.assert_array_equal(b, np.ones_like(a)) + def test_realistic_matmul(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + # TODO(apaszke): Make the grid and tile sizes larger + # grid_m, grid_k, grid_n = 132, 10, 4 + # TODO(apaszke): Increasing grid_k causes th test to fail. + # It seems like our pipelining implementation has a number of races. + grid_m, grid_k, grid_n = 2, 1, 2 + # tile_m = tile_n = 128 + tile_m = tile_n = 64 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + def kernel(a_ref, b_ref, o_ref, acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref) + plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races + # TODO(apaszke): Only store in the last step. It doesn't work because we + # don't have partial discharge for control flow. + # is_last_step = pl.program_id(2) == grid_k - 1 + # @pl.when(is_last_step) + # def _epilogue(): + # pl.debug_print("{}", acc_ref[...]) + # TODO(apaszke): This is an untiled store! It's slow!! + o_ref[...] = acc_ref[...] + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + transforms=plgpu.TilingTransform((64, elems_128b)), + swizzle=128, + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + transforms=plgpu.TilingTransform((elems_128b, elems_128b)), + swizzle=128, + ), + ], + out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], + grid=(grid_m, grid_n, grid_k), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "parallel", "sequential"], + num_stages=2, + ), + )(a, b) + np.testing.assert_allclose(res, a @ b, rtol=1e-3) + if __name__ == "__main__": absltest.main()