Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to see what we're still missing #23937

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
zip, unsafe_zip = util.safe_zip, zip

partial = functools.partial
SMEM = gpu_core.SMEM

_smem_estimators = {}

Expand Down Expand Up @@ -330,6 +331,45 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
start_indices, [grid_mapping.num_inputs]
)

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
should_discharge = []
accs = []
for aval in scratch_avals:
match aval:
case gpu_core.WGMMAAbstractAccumulatorRef():
scratch_buffers_template.append(None)
should_discharge.append(True)
accs.append(
mgpu.WGMMAAccumulator.zero(
*aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype)
)
)
case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM:
scratch_buffers_template.append(next(smem_scratch_it))
should_discharge.append(False)
case _:
raise NotImplementedError(
f"Unsupported scratch operand type: {aval}"
)
assert not jaxpr.outvars
if any(should_discharge):
# User-visible WGMMA APIs use the effectful accumulator references, but we
# can't lower that directly to Mosaic GPU that uses pure dataflow for
# accumulators. So we have to discharge the effects first.
assert not jaxpr.constvars
should_discharge = (
[False] * len(grid_mapping.block_mappings)
+ should_discharge
+ [False] * len(extra_barriers)
)
with grid_mapping.trace_env():
lowered_jaxpr, _ = discharge.discharge_state(
jaxpr, (), should_discharge=should_discharge
)
else:
lowered_jaxpr = jaxpr

# Precompute the total number of bytes transferred from GMEM to SMEM,
# so that we can do a single arrive instruction for all of the inputs.
in_transfer_bytes = 0
Expand Down Expand Up @@ -414,8 +454,8 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

@mgpu.fori(_as_index(num_steps), ())
def _(step, _):
@mgpu.fori(_as_index(num_steps), accs)
def _(step, accs):
slot = arith_dialect.remui(step, _as_index(num_stages))
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
Expand All @@ -427,9 +467,18 @@ def _(step, _):
else buffers_gmem[idx]
for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem))
]
args.extend(scratch_buffers_smem)
accs_it = iter(accs)
scratch_buffers = [
b if b is not None else next(accs_it)
for b in scratch_buffers_template
]
args.extend(scratch_buffers)
# TODO(apaszke): This assumes barriers come after buffers in scratch args,
# but that's not necessarily true.
args.extend(extra_barriers)
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args)
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
)
mgpu.commit_shared()

with mgpu.single_thread():
Expand All @@ -445,20 +494,22 @@ def _(step, _):
fetch(idx, next_step, slot)
barriers[slot].arrive_expect_tx(in_transfer_bytes)

return ()
return list(new_accs)

launch_ctx.await_async_copy(0)

scratch_avals = [
var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops]
]
local_spaces = (gpu_core.SMEM, gpu_core.REGS)
if not all(
isinstance(aval, pallas_core.AbstractMemoryRef)
and aval.memory_space is gpu_core.SMEM
and aval.memory_space in local_spaces
for aval in scratch_avals
):
raise TypeError(
f"All scratch operands must be in SMEM, but got: {scratch_avals}"
"All scratch operands must be SMEM references or accumulators (ACC),"
f" but got: {scratch_avals}"
)
extra_barriers = [
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
Expand All @@ -469,6 +520,7 @@ def _(step, _):
jax.ShapeDtypeStruct(aval.shape, aval.dtype)
for aval in scratch_avals
if not isinstance(aval.dtype, gpu_core.BarrierType)
and aval.memory_space == gpu_core.SMEM
]
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
if smem_scratch_bytes is None:
Expand Down
9 changes: 1 addition & 8 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import lowering
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp

async_copy_p = jax_core.Primitive("async_copy")
async_copy_p.multiple_results = True
Expand Down Expand Up @@ -120,7 +119,7 @@ class _WGMMAPipelineEffect(effects.Effect):
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
"""Asynchronous warp group matmul.

The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
Expand All @@ -137,12 +136,6 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

rhs_transpose = (
(jnp.dtype(b.dtype).itemsize == 2)
if rhs_transpose is None
else rhs_transpose
)

ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape
Expand Down
56 changes: 56 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,62 @@ def kernel(a_ref, b_ref):
)(a)
np.testing.assert_array_equal(b, np.ones_like(a))

def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
# TODO(apaszke): Make the grid and tile sizes larger
# grid_m, grid_k, grid_n = 132, 10, 4
# TODO(apaszke): Increasing grid_k causes th test to fail.
# It seems like our pipelining implementation has a number of races.
grid_m, grid_k, grid_n = 2, 1, 2
# tile_m = tile_n = 128
tile_m = tile_n = 64
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races
# TODO(apaszke): Only store in the last step. It doesn't work because we
# don't have partial discharge for control flow.
# is_last_step = pl.program_id(2) == grid_k - 1
# @pl.when(is_last_step)
# def _epilogue():
# pl.debug_print("{}", acc_ref[...])
# TODO(apaszke): This is an untiled store! It's slow!!
o_ref[...] = acc_ref[...]

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)

res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda m, n, k: (m, k),
transforms=plgpu.TilingTransform((64, elems_128b)),
swizzle=128,
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda m, n, k: (k, n),
transforms=plgpu.TilingTransform((elems_128b, elems_128b)),
swizzle=128,
),
],
out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
grid=(grid_m, grid_n, grid_k),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "parallel", "sequential"],
num_stages=2,
),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)


if __name__ == "__main__":
absltest.main()