Skip to content

Commit

Permalink
[Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to s…
Browse files Browse the repository at this point in the history
…ee what we're still missing

I annotated a number of issues in the test. To make the test run I also needed to add support
for the accumulator reference allocation and discharge in the main lowering part. Ideally,
we'd defer it all to run_scoped, but run_scoped can't allocate barriers...

PiperOrigin-RevId: 679076014
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent b6d668e commit c37e09f
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 46 deletions.
2 changes: 2 additions & 0 deletions jax/_src/pallas/mosaic_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem
Expand Down
14 changes: 5 additions & 9 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,8 @@ class GPUBlockMapping(pallas_core.BlockMapping):

@dataclasses.dataclass
class GPUBlockSpec(pallas_core.BlockSpec):
# TODO(justinfu): Replace tiling a list of transforms.
tiling: tuple[int, ...] | None = None
transpose_permutation: tuple[int, ...] | None = None
swizzle: int | None = None
transforms: MemoryRefTransform | tuple[MemoryRefTransform, ...] = ()
swizzle: int | None = None # TODO: apaszke - Swizzle is also a transform.

def to_block_mapping(
self,
Expand All @@ -155,11 +153,9 @@ def to_block_mapping(
grid=grid,
mapped_dims=mapped_dims,
)
transforms: tuple[pallas_core.MemoryRefTransform, ...] = ()
if self.tiling is not None:
transforms += (TilingTransform(self.tiling),)
if self.transpose_permutation is not None:
transforms += (TransposeTransform(self.transpose_permutation),)
transforms = self.transforms
if not isinstance(transforms, tuple):
transforms = (transforms,)
return GPUBlockMapping(
block_shape=bm.block_shape,
block_aval=bm.block_aval,
Expand Down
81 changes: 55 additions & 26 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
zip, unsafe_zip = util.safe_zip, zip

partial = functools.partial
SMEM = gpu_core.SMEM

_smem_estimators = {}

Expand Down Expand Up @@ -330,6 +331,42 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
start_indices, [grid_mapping.num_inputs]
)

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
should_discharge = []
accs = []
for aval in scratch_avals:
match aval:
case gpu_core.WGMMAAbstractAccumulatorRef():
scratch_buffers_template.append(None)
should_discharge.append(True)
accs.append(
mgpu.WGMMAAccumulator.zero(
*aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype)
)
)
case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM:
scratch_buffers_template.append(next(smem_scratch_it))
should_discharge.append(False)
case _:
raise NotImplementedError(
f"Unsupported scratch operand type: {aval}"
)
assert not jaxpr.outvars
if any(should_discharge):
assert not jaxpr.constvars
should_discharge = (
[False] * len(grid_mapping.block_mappings)
+ should_discharge
+ [False] * len(extra_barriers)
)
with grid_mapping.trace_env():
lowered_jaxpr, _ = discharge.discharge_state(
jaxpr, (), should_discharge=should_discharge
)
else:
lowered_jaxpr = jaxpr

# Precompute the total number of bytes transferred from GMEM to SMEM,
# so that we can do a single arrive instruction for all of the inputs.
in_transfer_bytes = 0
Expand Down Expand Up @@ -404,26 +441,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
if any(
b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis]
for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem)
if b_smem
):
raise ValueError(
"Array dimensions along the sequential axis must be divisible by"
" the corresponding block dimensions."
)
num_steps, *rest = {
b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis]
for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem)
if b_smem
}
if rest:
raise ValueError(
"Array dimensions along the sequential axis must produce the same"
" number of steps when devided by the corresponding block"
" dimensions."
)
num_steps = grid_mapping.grid[sequential_axis]
else:
num_steps = 1

Expand All @@ -433,8 +451,8 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

@mgpu.fori(_as_index(num_steps), ())
def _(step, _):
@mgpu.fori(_as_index(num_steps), accs)
def _(step, accs):
slot = arith_dialect.remui(step, _as_index(num_stages))
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
Expand All @@ -446,9 +464,18 @@ def _(step, _):
else buffers_gmem[idx]
for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem))
]
args.extend(scratch_buffers_smem)
accs_it = iter(accs)
scratch_buffers = [
b if b is not None else next(accs_it)
for b in scratch_buffers_template
]
args.extend(scratch_buffers)
# TODO(apaszke): This assumes barriers come after buffers in scratch args,
# but that's not necessarily true.
args.extend(extra_barriers)
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args)
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
)
mgpu.commit_shared()

with mgpu.single_thread():
Expand All @@ -464,16 +491,17 @@ def _(step, _):
fetch(idx, next_step, slot)
barriers[slot].arrive_expect_tx(in_transfer_bytes)

return ()
return list(new_accs)

launch_ctx.await_async_copy(0)

scratch_avals = [
var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops]
]
local_spaces = (gpu_core.SMEM, gpu_core.REGS)
if not all(
isinstance(aval, pallas_core.AbstractMemoryRef)
and aval.memory_space is gpu_core.SMEM
and aval.memory_space in local_spaces
for aval in scratch_avals
):
raise TypeError(
Expand All @@ -488,6 +516,7 @@ def _(step, _):
jax.ShapeDtypeStruct(aval.shape, aval.dtype)
for aval in scratch_avals
if not isinstance(aval.dtype, gpu_core.BarrierType)
and aval.memory_space == gpu_core.SMEM
]
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
if smem_scratch_bytes is None:
Expand Down
8 changes: 1 addition & 7 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class _WGMMAPipelineEffect(effects.Effect):
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
"""Asynchronous warp group matmul.
The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
Expand All @@ -137,12 +137,6 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

rhs_transpose = (
(jnp.dtype(b.dtype).itemsize == 2)
if rhs_transpose is None
else rhs_transpose
)

ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape
Expand Down
70 changes: 66 additions & 4 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,15 @@ def kernel(o_ref):
)

def test_swizzled_blockspec_shapes(self):

@functools.partial(
pl.pallas_call,
in_specs=[
plgpu.GPUBlockSpec(
(128, 64), lambda *i: i, tiling=(64, 64), swizzle=128
(128, 64),
lambda *i: i,
transforms=plgpu.TilingTransform((64, 64)),
swizzle=128,
),
],
out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)),
Expand Down Expand Up @@ -390,20 +394,22 @@ def scope(acc_ref):
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype)

rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),)
if rhs_transpose:
rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(64, 128),
lambda i, j: (i, j),
tiling=(64, elems_128b),
transforms=plgpu.TilingTransform((64, elems_128b)),
swizzle=128,
),
plgpu.GPUBlockSpec(
(128, 128),
lambda *i: i,
transpose_permutation=(1, 0, 2, 3) if rhs_transpose else None,
tiling=(elems_128b, elems_128b),
transforms=rhs_transforms,
swizzle=128,
),
],
Expand Down Expand Up @@ -431,6 +437,62 @@ def kernel(a_ref, b_ref):
)(a)
np.testing.assert_array_equal(b, np.ones_like(a))

def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
# TODO(apaszke): Make the grid and tile sizes larger
# grid_m, grid_k, grid_n = 132, 10, 4
# TODO(apaszke): Increasing grid_k causes th test to fail.
# It seems like our pipelining implementation has a number of races.
grid_m, grid_k, grid_n = 2, 1, 2
# tile_m = tile_n = 128
tile_m = tile_n = 64
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races
# TODO(apaszke): Only store in the last step. It doesn't work because we
# don't have partial discharge for control flow.
# is_last_step = pl.program_id(2) == grid_k - 1
# @pl.when(is_last_step)
# def _epilogue():
# pl.debug_print("{}", acc_ref[...])
# TODO(apaszke): This is an untiled store! It's slow!!
o_ref[...] = acc_ref[...]

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)

res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda m, n, k: (m, k),
transforms=plgpu.TilingTransform((64, elems_128b)),
swizzle=128,
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda m, n, k: (k, n),
transforms=plgpu.TilingTransform((elems_128b, elems_128b)),
swizzle=128,
),
],
out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
grid=(grid_m, grid_n, grid_k),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "parallel", "sequential"],
num_stages=2,
),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)


if __name__ == "__main__":
absltest.main()

0 comments on commit c37e09f

Please sign in to comment.