Skip to content

Commit

Permalink
[Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to s…
Browse files Browse the repository at this point in the history
…ee what we're still missing

I annotated a number of issues in the test. To make the test run I also needed to add support
for the accumulator reference allocation and discharge in the main lowering part. Ideally,
we'd defer it all to run_scoped, but run_scoped can't allocate barriers...

PiperOrigin-RevId: 679076014
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent 5788773 commit 832cd74
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 15 deletions.
66 changes: 59 additions & 7 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
zip, unsafe_zip = util.safe_zip, zip

partial = functools.partial
SMEM = gpu_core.SMEM

_smem_estimators = {}

Expand Down Expand Up @@ -330,6 +331,45 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
start_indices, [grid_mapping.num_inputs]
)

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
should_discharge = []
accs = []
for aval in scratch_avals:
match aval:
case gpu_core.WGMMAAbstractAccumulatorRef():
scratch_buffers_template.append(None)
should_discharge.append(True)
accs.append(
mgpu.WGMMAAccumulator.zero(
*aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype)
)
)
case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM:
scratch_buffers_template.append(next(smem_scratch_it))
should_discharge.append(False)
case _:
raise NotImplementedError(
f"Unsupported scratch operand type: {aval}"
)
assert not jaxpr.outvars
if any(should_discharge):
# User-visible WGMMA APIs use the effectful accumulator references, but we
# can't lower that directly to Mosaic GPU that uses pure dataflow for
# accumulators. So we have to discharge the effects first.
assert not jaxpr.constvars
should_discharge = (
[False] * len(grid_mapping.block_mappings)
+ should_discharge
+ [False] * len(extra_barriers)
)
with grid_mapping.trace_env():
lowered_jaxpr, _ = discharge.discharge_state(
jaxpr, (), should_discharge=should_discharge
)
else:
lowered_jaxpr = jaxpr

# Precompute the total number of bytes transferred from GMEM to SMEM,
# so that we can do a single arrive instruction for all of the inputs.
in_transfer_bytes = 0
Expand Down Expand Up @@ -414,8 +454,8 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

@mgpu.fori(_as_index(num_steps), ())
def _(step, _):
@mgpu.fori(_as_index(num_steps), accs)
def _(step, accs):
slot = arith_dialect.remui(step, _as_index(num_stages))
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
Expand All @@ -427,9 +467,18 @@ def _(step, _):
else buffers_gmem[idx]
for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem))
]
args.extend(scratch_buffers_smem)
accs_it = iter(accs)
scratch_buffers = [
b if b is not None else next(accs_it)
for b in scratch_buffers_template
]
args.extend(scratch_buffers)
# TODO(apaszke): This assumes barriers come after buffers in scratch args,
# but that's not necessarily true.
args.extend(extra_barriers)
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args)
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
)
mgpu.commit_shared()

with mgpu.single_thread():
Expand All @@ -445,20 +494,22 @@ def _(step, _):
fetch(idx, next_step, slot)
barriers[slot].arrive_expect_tx(in_transfer_bytes)

return ()
return list(new_accs)

launch_ctx.await_async_copy(0)

scratch_avals = [
var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops]
]
local_spaces = (gpu_core.SMEM, gpu_core.REGS)
if not all(
isinstance(aval, pallas_core.AbstractMemoryRef)
and aval.memory_space is gpu_core.SMEM
and aval.memory_space in local_spaces
for aval in scratch_avals
):
raise TypeError(
f"All scratch operands must be in SMEM, but got: {scratch_avals}"
"All scratch operands must be SMEM references or accumulators (ACC),"
f" but got: {scratch_avals}"
)
extra_barriers = [
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
Expand All @@ -469,6 +520,7 @@ def _(step, _):
jax.ShapeDtypeStruct(aval.shape, aval.dtype)
for aval in scratch_avals
if not isinstance(aval.dtype, gpu_core.BarrierType)
and aval.memory_space == gpu_core.SMEM
]
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
if smem_scratch_bytes is None:
Expand Down
9 changes: 1 addition & 8 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import lowering
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp

async_copy_p = jax_core.Primitive("async_copy")
async_copy_p.multiple_results = True
Expand Down Expand Up @@ -120,7 +119,7 @@ class _WGMMAPipelineEffect(effects.Effect):
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
"""Asynchronous warp group matmul.
The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
Expand All @@ -137,12 +136,6 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

rhs_transpose = (
(jnp.dtype(b.dtype).itemsize == 2)
if rhs_transpose is None
else rhs_transpose
)

ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape
Expand Down
56 changes: 56 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,62 @@ def kernel(a_ref, b_ref):
)(a)
np.testing.assert_array_equal(b, np.ones_like(a))

def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
# TODO(apaszke): Make the grid and tile sizes larger
# grid_m, grid_k, grid_n = 132, 10, 4
# TODO(apaszke): Increasing grid_k causes th test to fail.
# It seems like our pipelining implementation has a number of races.
grid_m, grid_k, grid_n = 2, 1, 2
# tile_m = tile_n = 128
tile_m = tile_n = 64
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races
# TODO(apaszke): Only store in the last step. It doesn't work because we
# don't have partial discharge for control flow.
# is_last_step = pl.program_id(2) == grid_k - 1
# @pl.when(is_last_step)
# def _epilogue():
# pl.debug_print("{}", acc_ref[...])
# TODO(apaszke): This is an untiled store! It's slow!!
o_ref[...] = acc_ref[...]

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)

res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda m, n, k: (m, k),
transforms=plgpu.TilingTransform((64, elems_128b)),
swizzle=128,
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda m, n, k: (k, n),
transforms=plgpu.TilingTransform((elems_128b, elems_128b)),
swizzle=128,
),
],
out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
grid=(grid_m, grid_n, grid_k),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "parallel", "sequential"],
num_stages=2,
),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)


if __name__ == "__main__":
absltest.main()

0 comments on commit 832cd74

Please sign in to comment.