From c37e09f6c6d1b46c1f96bcd1ec15e53d23ab8401 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 03:28:23 -0700 Subject: [PATCH] [Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to see what we're still missing I annotated a number of issues in the test. To make the test run I also needed to add support for the accumulator reference allocation and discharge in the main lowering part. Ideally, we'd defer it all to run_scoped, but run_scoped can't allocate barriers... PiperOrigin-RevId: 679076014 --- jax/_src/pallas/mosaic_gpu/__init__.py | 2 + jax/_src/pallas/mosaic_gpu/core.py | 14 ++-- jax/_src/pallas/mosaic_gpu/lowering.py | 81 ++++++++++++++++-------- jax/_src/pallas/mosaic_gpu/primitives.py | 8 +-- tests/pallas/mosaic_gpu_test.py | 70 ++++++++++++++++++-- 5 files changed, 129 insertions(+), 46 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index bbada82ace82..187a84478c65 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -18,6 +18,8 @@ from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import TilingTransform +from jax._src.pallas.mosaic_gpu.core import TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index df633619e6ee..0b79121c34c1 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -132,10 +132,8 @@ class GPUBlockMapping(pallas_core.BlockMapping): @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): - # TODO(justinfu): Replace tiling a list of transforms. - tiling: tuple[int, ...] | None = None - transpose_permutation: tuple[int, ...] | None = None - swizzle: int | None = None + transforms: MemoryRefTransform | tuple[MemoryRefTransform, ...] = () + swizzle: int | None = None # TODO: apaszke - Swizzle is also a transform. def to_block_mapping( self, @@ -155,11 +153,9 @@ def to_block_mapping( grid=grid, mapped_dims=mapped_dims, ) - transforms: tuple[pallas_core.MemoryRefTransform, ...] = () - if self.tiling is not None: - transforms += (TilingTransform(self.tiling),) - if self.transpose_permutation is not None: - transforms += (TransposeTransform(self.transpose_permutation),) + transforms = self.transforms + if not isinstance(transforms, tuple): + transforms = (transforms,) return GPUBlockMapping( block_shape=bm.block_shape, block_aval=bm.block_aval, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0c6d61fa9c96..f7eb48ae9a02 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -55,6 +55,7 @@ zip, unsafe_zip = util.safe_zip, zip partial = functools.partial +SMEM = gpu_core.SMEM _smem_estimators = {} @@ -330,6 +331,42 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): start_indices, [grid_mapping.num_inputs] ) + smem_scratch_it = iter(scratch_buffers_smem) + scratch_buffers_template = [] + should_discharge = [] + accs = [] + for aval in scratch_avals: + match aval: + case gpu_core.WGMMAAbstractAccumulatorRef(): + scratch_buffers_template.append(None) + should_discharge.append(True) + accs.append( + mgpu.WGMMAAccumulator.zero( + *aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype) + ) + ) + case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM: + scratch_buffers_template.append(next(smem_scratch_it)) + should_discharge.append(False) + case _: + raise NotImplementedError( + f"Unsupported scratch operand type: {aval}" + ) + assert not jaxpr.outvars + if any(should_discharge): + assert not jaxpr.constvars + should_discharge = ( + [False] * len(grid_mapping.block_mappings) + + should_discharge + + [False] * len(extra_barriers) + ) + with grid_mapping.trace_env(): + lowered_jaxpr, _ = discharge.discharge_state( + jaxpr, (), should_discharge=should_discharge + ) + else: + lowered_jaxpr = jaxpr + # Precompute the total number of bytes transferred from GMEM to SMEM, # so that we can do a single arrive instruction for all of the inputs. in_transfer_bytes = 0 @@ -404,26 +441,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: "Multiple sequential axes are not supported in Mosaic GPU lowering." ) [sequential_axis] = sequential_axes - if any( - b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis] - for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) - if b_smem - ): - raise ValueError( - "Array dimensions along the sequential axis must be divisible by" - " the corresponding block dimensions." - ) - num_steps, *rest = { - b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis] - for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) - if b_smem - } - if rest: - raise ValueError( - "Array dimensions along the sequential axis must produce the same" - " number of steps when devided by the corresponding block" - " dimensions." - ) + num_steps = grid_mapping.grid[sequential_axis] else: num_steps = 1 @@ -433,8 +451,8 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: for idx in range(grid_mapping.num_inputs): fetch(idx, _as_index(slot), _as_index(slot)) - @mgpu.fori(_as_index(num_steps), ()) - def _(step, _): + @mgpu.fori(_as_index(num_steps), accs) + def _(step, accs): slot = arith_dialect.remui(step, _as_index(num_stages)) if grid_mapping.num_inputs: # Only wait if async copies were issued. @@ -446,9 +464,18 @@ def _(step, _): else buffers_gmem[idx] for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem)) ] - args.extend(scratch_buffers_smem) + accs_it = iter(accs) + scratch_buffers = [ + b if b is not None else next(accs_it) + for b in scratch_buffers_template + ] + args.extend(scratch_buffers) + # TODO(apaszke): This assumes barriers come after buffers in scratch args, + # but that's not necessarily true. args.extend(extra_barriers) - _ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args) + new_accs = lower_jaxpr_to_mosaic_gpu( + module_ctx, launch_ctx, lowered_jaxpr, args + ) mgpu.commit_shared() with mgpu.single_thread(): @@ -464,16 +491,17 @@ def _(step, _): fetch(idx, next_step, slot) barriers[slot].arrive_expect_tx(in_transfer_bytes) - return () + return list(new_accs) launch_ctx.await_async_copy(0) scratch_avals = [ var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops] ] + local_spaces = (gpu_core.SMEM, gpu_core.REGS) if not all( isinstance(aval, pallas_core.AbstractMemoryRef) - and aval.memory_space is gpu_core.SMEM + and aval.memory_space in local_spaces for aval in scratch_avals ): raise TypeError( @@ -488,6 +516,7 @@ def _(step, _): jax.ShapeDtypeStruct(aval.shape, aval.dtype) for aval in scratch_avals if not isinstance(aval.dtype, gpu_core.BarrierType) + and aval.memory_space == gpu_core.SMEM ] smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") if smem_scratch_bytes is None: diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 5295507ff2fa..0d1215ddd9f2 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -120,7 +120,7 @@ class _WGMMAPipelineEffect(effects.Effect): wgmma_ref_p = jax_core.Primitive("wgmma_ref") wgmma_ref_p.multiple_results = True -def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): +def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128): """Asynchronous warp group matmul. The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires @@ -137,12 +137,6 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") - rhs_transpose = ( - (jnp.dtype(b.dtype).itemsize == 2) - if rhs_transpose is None - else rhs_transpose - ) - ma, ka, tma, tka = a.shape kb, nb, tkb, tnb = b.shape mc, nc = acc.shape diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 431dba133ce3..7c2bb12b4607 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -326,11 +326,15 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): + @functools.partial( pl.pallas_call, in_specs=[ plgpu.GPUBlockSpec( - (128, 64), lambda *i: i, tiling=(64, 64), swizzle=128 + (128, 64), + lambda *i: i, + transforms=plgpu.TilingTransform((64, 64)), + swizzle=128, ), ], out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)), @@ -390,20 +394,22 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype) + rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) + if rhs_transpose: + rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) res = pl.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( (64, 128), lambda i, j: (i, j), - tiling=(64, elems_128b), + transforms=plgpu.TilingTransform((64, elems_128b)), swizzle=128, ), plgpu.GPUBlockSpec( (128, 128), lambda *i: i, - transpose_permutation=(1, 0, 2, 3) if rhs_transpose else None, - tiling=(elems_128b, elems_128b), + transforms=rhs_transforms, swizzle=128, ), ], @@ -431,6 +437,62 @@ def kernel(a_ref, b_ref): )(a) np.testing.assert_array_equal(b, np.ones_like(a)) + def test_realistic_matmul(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + # TODO(apaszke): Make the grid and tile sizes larger + # grid_m, grid_k, grid_n = 132, 10, 4 + # TODO(apaszke): Increasing grid_k causes th test to fail. + # It seems like our pipelining implementation has a number of races. + grid_m, grid_k, grid_n = 2, 1, 2 + # tile_m = tile_n = 128 + tile_m = tile_n = 64 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + def kernel(a_ref, b_ref, o_ref, acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref) + plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races + # TODO(apaszke): Only store in the last step. It doesn't work because we + # don't have partial discharge for control flow. + # is_last_step = pl.program_id(2) == grid_k - 1 + # @pl.when(is_last_step) + # def _epilogue(): + # pl.debug_print("{}", acc_ref[...]) + # TODO(apaszke): This is an untiled store! It's slow!! + o_ref[...] = acc_ref[...] + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + transforms=plgpu.TilingTransform((64, elems_128b)), + swizzle=128, + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + transforms=plgpu.TilingTransform((elems_128b, elems_128b)), + swizzle=128, + ), + ], + out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], + grid=(grid_m, grid_n, grid_k), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "parallel", "sequential"], + num_stages=2, + ), + )(a, b) + np.testing.assert_allclose(res, a @ b, rtol=1e-3) + if __name__ == "__main__": absltest.main()