From 076287fb5cf7d107bcfe6e1a0e471b71a0c06e5a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 09:14:38 -0700 Subject: [PATCH] [Pallas/MGPU] Implement block spec evaluation correctly The preivous implementation made some surprising assumptions about the contents of the block specs and wasn't correct in general. The new implementation handles all the cases and seems to be sufficient to finally run the matmul example with multiple k steps while producing correct results (it's also shorter!). PiperOrigin-RevId: 679175212 --- jax/_src/pallas/mosaic_gpu/lowering.py | 92 +++++++++++--------------- tests/pallas/mosaic_gpu_test.py | 9 +-- 2 files changed, 40 insertions(+), 61 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index cc19d96aa194..92cd385be2a3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -185,7 +185,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name def _eval_index_map( module_ctx: ModuleContext, launch_ctx: mgpu.LaunchContext, - idx: ir.Value, + idx: Sequence[ir.Value], block_mapping: pallas_core.BlockMapping, ) -> Sequence[ir.Value]: block_indices = lower_jaxpr_to_mosaic_gpu( @@ -238,10 +238,7 @@ def lower_jaxpr_to_module( jaxpr, [True] * len(jaxpr.outvars), instantiate=True ) - grid = grid_mapping.grid - if len(grid) < 3: - grid += (1,) * (3 - len(grid)) - block = (128,) + (1,) * (len(grid) - 1) + block = (128, 1, 1) params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) max_concurrent_steps = params.get("max_concurrent_steps", 1) @@ -256,8 +253,25 @@ def lower_jaxpr_to_module( sequential_axes = tuple( i for i, s in enumerate(dimension_semantics) if s == "sequential" ) - assert all(grid[axis] for axis in sequential_axes) - assert all(block[axis] == 1 for axis in sequential_axes) + + grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes] + if len(grid) < 3: + grid += (1,) * (3 - len(grid)) + else: + raise NotImplementedError( + "Only <=3D grids are supported in Mosaic GPU lowering." + ) + # Compute the number of steps along each sequential axis. + if sequential_axes: + # TODO(slebedev): Support multiple sequential axes. + if len(sequential_axes) > 1: + raise NotImplementedError( + "Multiple sequential axes are not supported in Mosaic GPU lowering." + ) + [sequential_axis] = sequential_axes + num_steps = grid_mapping.grid[sequential_axis] + else: + num_steps = 1 in_in_smem, out_in_smem = util.split_list( [ @@ -268,10 +282,9 @@ def lower_jaxpr_to_module( ) in_structs_gmem = [*grid_mapping.in_shapes] - in_block_shapes = [ - bm.block_shape - for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] - ] + in_block_mappings, out_block_mappings = util.split_list( + block_mappings, [grid_mapping.num_inputs] + ) in_structs_smem = [ jax.ShapeDtypeStruct( [max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype @@ -283,8 +296,7 @@ def lower_jaxpr_to_module( ) ] in_gmem_transforms = [ - cast(gpu_core.MemoryRefTransform, bm.transforms) - + cast(gpu_core.MemoryRefTransform, bm.transforms) for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] ] in_swizzles = map( @@ -322,17 +334,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) barriers, *extra_barriers = barriers + parallel_count = it.count() + program_ids_template = [ + _program_id(next(parallel_count)) if i not in sequential_axes else None + for i in range(len(grid_mapping.grid)) + ] module_ctx = ModuleContext( name_and_src_info.name, grid_mapping, approx_math, runtime_smem ) - program_ids = map(_program_id, range(len(grid_mapping.grid))) - start_indices = map( - partial(_eval_index_map, module_ctx, launch_ctx, program_ids), - block_mappings, - ) - in_start_indices, out_start_indices = util.split_list( - start_indices, [grid_mapping.num_inputs] - ) smem_scratch_it = iter(scratch_buffers_smem) scratch_buffers_template = [] @@ -385,20 +394,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) def gmem_slice( - start_indices: Sequence[ir.Value], step: ir.Value, - shape: Sequence[int], + block_mapping: pallas_core.BlockMapping, ) -> Sequence[mgpu.DynamicSlice]: + assert len(sequential_axes) == 1 + program_ids = [step if i is None else i for i in program_ids_template] + idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping) return tuple( - mgpu.ds( - arith_dialect.addi( - start_index, arith_dialect.muli(step, _as_index(dim)) - ) - if axis in sequential_axes - else start_index, - dim, - ) - for axis, (start_index, dim) in enumerate(zip(start_indices, shape)) + mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape) ) def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: @@ -410,11 +413,7 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: launch_ctx.async_copy( src_ref=in_buffers_gmem[idx], dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot), - gmem_slice=gmem_slice( - in_start_indices[idx], - step, - in_block_shapes[idx], - ), + gmem_slice=gmem_slice(step, in_block_mappings[idx]), barrier=barriers[slot], gmem_transform=tuple(gmem_transforms), swizzle=in_swizzles[idx], @@ -430,27 +429,11 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), dst_ref=out_buffers_gmem[idx], - gmem_slice=gmem_slice( - out_start_indices[idx], - step, - ir.MemRefType(out_buffers_smem[idx].type).shape[1:], - ), + gmem_slice=gmem_slice(step, out_block_mappings[idx]), swizzle=None, uniform=False, ) - # Compute the number of steps along each sequential axis. - if sequential_axes: - # TODO(slebedev): Support multiple sequential axes. - if len(sequential_axes) > 1: - raise NotImplementedError( - "Multiple sequential axes are not supported in Mosaic GPU lowering." - ) - [sequential_axis] = sequential_axes - num_steps = grid_mapping.grid[sequential_axis] - else: - num_steps = 1 - with mgpu.single_thread(): for slot in range(min(max_concurrent_steps, num_steps)): barriers[slot].arrive_expect_tx(in_transfer_bytes) @@ -619,6 +602,7 @@ def write_env(var: jax_core.Var, val): @register_lowering_rule(primitives.program_id_p) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): + # TODO(apaszke): Sequential axis should be handled specially!! del ctx # Unused. return _program_id(axis) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 267ddd0c97d5..b35658ed4845 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -462,13 +462,8 @@ def test_realistic_matmul(self): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize - # TODO(apaszke): Make the grid and tile sizes larger - # grid_m, grid_k, grid_n = 132, 10, 4 - # TODO(apaszke): Increasing grid_k causes th test to fail. - # It seems like our pipelining implementation has a number of races. - grid_m, grid_k, grid_n = 2, 1, 2 - # tile_m = tile_n = 128 - tile_m = tile_n = 64 + grid_m, grid_k, grid_n = 132, 10, 4 + tile_m = tile_n = 128 tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref):