Skip to content

Commit

Permalink
[Pallas/MGPU] Implement block spec evaluation correctly
Browse files Browse the repository at this point in the history
The preivous implementation made some surprising assumptions about the contents
of the block specs and wasn't correct in general. The new implementation handles
all the cases and seems to be sufficient to finally run the matmul example with
multiple k steps while producing correct results (it's also shorter!).

PiperOrigin-RevId: 679175212
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent a3284bd commit 076287f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 61 deletions.
92 changes: 38 additions & 54 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name
def _eval_index_map(
module_ctx: ModuleContext,
launch_ctx: mgpu.LaunchContext,
idx: ir.Value,
idx: Sequence[ir.Value],
block_mapping: pallas_core.BlockMapping,
) -> Sequence[ir.Value]:
block_indices = lower_jaxpr_to_mosaic_gpu(
Expand Down Expand Up @@ -238,10 +238,7 @@ def lower_jaxpr_to_module(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
)

grid = grid_mapping.grid
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
block = (128, 1, 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
max_concurrent_steps = params.get("max_concurrent_steps", 1)
Expand All @@ -256,8 +253,25 @@ def lower_jaxpr_to_module(
sequential_axes = tuple(
i for i, s in enumerate(dimension_semantics) if s == "sequential"
)
assert all(grid[axis] for axis in sequential_axes)
assert all(block[axis] == 1 for axis in sequential_axes)

grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes]
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
else:
raise NotImplementedError(
"Only <=3D grids are supported in Mosaic GPU lowering."
)
# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
raise NotImplementedError(
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
num_steps = grid_mapping.grid[sequential_axis]
else:
num_steps = 1

in_in_smem, out_in_smem = util.split_list(
[
Expand All @@ -268,10 +282,9 @@ def lower_jaxpr_to_module(
)

in_structs_gmem = [*grid_mapping.in_shapes]
in_block_shapes = [
bm.block_shape
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
in_structs_smem = [
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
Expand All @@ -283,8 +296,7 @@ def lower_jaxpr_to_module(
)
]
in_gmem_transforms = [
cast(gpu_core.MemoryRefTransform, bm.transforms)

cast(gpu_core.MemoryRefTransform, bm.transforms)
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_swizzles = map(
Expand Down Expand Up @@ -322,17 +334,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
)
barriers, *extra_barriers = barriers

parallel_count = it.count()
program_ids_template = [
_program_id(next(parallel_count)) if i not in sequential_axes else None
for i in range(len(grid_mapping.grid))
]
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
)
program_ids = map(_program_id, range(len(grid_mapping.grid)))
start_indices = map(
partial(_eval_index_map, module_ctx, launch_ctx, program_ids),
block_mappings,
)
in_start_indices, out_start_indices = util.split_list(
start_indices, [grid_mapping.num_inputs]
)

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
Expand Down Expand Up @@ -385,20 +394,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
)

def gmem_slice(
start_indices: Sequence[ir.Value],
step: ir.Value,
shape: Sequence[int],
block_mapping: pallas_core.BlockMapping,
) -> Sequence[mgpu.DynamicSlice]:
assert len(sequential_axes) == 1
program_ids = [step if i is None else i for i in program_ids_template]
idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping)
return tuple(
mgpu.ds(
arith_dialect.addi(
start_index, arith_dialect.muli(step, _as_index(dim))
)
if axis in sequential_axes
else start_index,
dim,
)
for axis, (start_index, dim) in enumerate(zip(start_indices, shape))
mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape)
)

def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
Expand All @@ -410,11 +413,7 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
launch_ctx.async_copy(
src_ref=in_buffers_gmem[idx],
dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot),
gmem_slice=gmem_slice(
in_start_indices[idx],
step,
in_block_shapes[idx],
),
gmem_slice=gmem_slice(step, in_block_mappings[idx]),
barrier=barriers[slot],
gmem_transform=tuple(gmem_transforms),
swizzle=in_swizzles[idx],
Expand All @@ -430,27 +429,11 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
dst_ref=out_buffers_gmem[idx],
gmem_slice=gmem_slice(
out_start_indices[idx],
step,
ir.MemRefType(out_buffers_smem[idx].type).shape[1:],
),
gmem_slice=gmem_slice(step, out_block_mappings[idx]),
swizzle=None,
uniform=False,
)

# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
raise NotImplementedError(
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
num_steps = grid_mapping.grid[sequential_axis]
else:
num_steps = 1

with mgpu.single_thread():
for slot in range(min(max_concurrent_steps, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes)
Expand Down Expand Up @@ -619,6 +602,7 @@ def write_env(var: jax_core.Var, val):

@register_lowering_rule(primitives.program_id_p)
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
# TODO(apaszke): Sequential axis should be handled specially!!
del ctx # Unused.
return _program_id(axis)

Expand Down
9 changes: 2 additions & 7 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,8 @@ def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
# TODO(apaszke): Make the grid and tile sizes larger
# grid_m, grid_k, grid_n = 132, 10, 4
# TODO(apaszke): Increasing grid_k causes th test to fail.
# It seems like our pipelining implementation has a number of races.
grid_m, grid_k, grid_n = 2, 1, 2
# tile_m = tile_n = 128
tile_m = tile_n = 64
grid_m, grid_k, grid_n = 132, 10, 4
tile_m = tile_n = 128
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
Expand Down

0 comments on commit 076287f

Please sign in to comment.