Skip to content

Commit

Permalink
[Pallas/MGPU] Skip output transfers when they don't depend on sequeni…
Browse files Browse the repository at this point in the history
…tal dims

Note that thanks to the previous revisiting-related checks we weren't doing the
transfers anyway, but this way we can also avoid having to pay for the checks.

PiperOrigin-RevId: 679129331
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent 8599dbc commit 1c7626d
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 106 deletions.
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ class GPUCompilerParams(pallas_core.CompilerParams):
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
executed sequentially.
num_stages: The number of pipline stages in the kernel. Defaults to 1,
meaning no pipelining is done.
max_concurrent_steps: The maximum number of sequential stages that are
active concurrently. Defaults to 1.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
approx_math: bool = False
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1
max_concurrent_steps: int = 1


class GPUMemorySpace(enum.Enum):
Expand Down
206 changes: 128 additions & 78 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name
def _eval_index_map(
module_ctx: ModuleContext,
launch_ctx: mgpu.LaunchContext,
idx: ir.Value,
idx: Sequence[ir.Value],
block_mapping: pallas_core.BlockMapping,
) -> Sequence[ir.Value]:
block_indices = lower_jaxpr_to_mosaic_gpu(
Expand All @@ -200,6 +200,11 @@ def _eval_index_map(
return tuple(result)


def _uses_arguments(cjaxpr: jax_core.ClosedJaxpr) -> list[bool]:
jaxpr = cjaxpr.jaxpr
return pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))[1]


def lower_jaxpr_to_module(
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
Expand Down Expand Up @@ -237,13 +242,10 @@ def lower_jaxpr_to_module(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
)

grid = grid_mapping.grid
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
block = (128, 1, 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
num_stages = params.get("num_stages", 1)
max_concurrent_steps = params.get("max_concurrent_steps", 1)
dimension_semantics = params.get("dimension_semantics")
if dimension_semantics is None:
dimension_semantics = ["parallel"] * len(grid_mapping.grid)
Expand All @@ -255,8 +257,30 @@ def lower_jaxpr_to_module(
sequential_axes = tuple(
i for i, s in enumerate(dimension_semantics) if s == "sequential"
)
assert all(grid[axis] for axis in sequential_axes)
assert all(block[axis] == 1 for axis in sequential_axes)

grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes]
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
else:
raise NotImplementedError(
"Only <=3D grids are supported in Mosaic GPU lowering."
)
# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
raise NotImplementedError(
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
num_steps = grid_mapping.grid[sequential_axis]
out_sequential_invariant = [
not _uses_arguments(bm.index_map_jaxpr)[sequential_axis]
for bm in grid_mapping.block_mappings_output
]
else:
num_steps = 1
out_sequential_invariant = [True] * len(grid_mapping.out_shapes)

in_in_smem, out_in_smem = util.split_list(
[
Expand All @@ -267,21 +291,21 @@ def lower_jaxpr_to_module(
)

in_structs_gmem = [*grid_mapping.in_shapes]
in_block_shapes = [
bm.block_shape
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
in_structs_smem = [
jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype)
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
)
if in_smem
else None
for bm, in_smem in zip(
block_mappings[: grid_mapping.num_inputs], in_in_smem
)
]
in_gmem_transforms = [
cast(gpu_core.MemoryRefTransform, bm.transforms)

cast(gpu_core.MemoryRefTransform, bm.transforms)
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_swizzles = map(
Expand All @@ -293,7 +317,7 @@ def lower_jaxpr_to_module(
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
out_structs_smem = [
jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype)
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
if in_smem
else None
for bm, in_smem, s in zip(
Expand All @@ -319,17 +343,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
)
barriers, *extra_barriers = barriers

parallel_count = it.count()
program_ids_template = [
_program_id(next(parallel_count)) if i not in sequential_axes else None
for i in range(len(grid_mapping.grid))
]
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
)
program_ids = map(_program_id, range(len(grid_mapping.grid)))
start_indices = map(
partial(_eval_index_map, module_ctx, launch_ctx, program_ids),
block_mappings,
)
in_start_indices, out_start_indices = util.split_list(
start_indices, [grid_mapping.num_inputs]
)

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
Expand Down Expand Up @@ -382,22 +403,18 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
)

def gmem_slice(
start_indices: Sequence[ir.Value],
step: ir.Value,
shape: Sequence[int],
block_mapping: pallas_core.BlockMapping,
) -> Sequence[mgpu.DynamicSlice]:
assert len(sequential_axes) == 1
program_ids = [step if i is None else i for i in program_ids_template]
idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping)
return tuple(
mgpu.ds(
arith_dialect.addi(
start_index, arith_dialect.muli(step, _as_index(dim))
)
if axis in sequential_axes
else start_index,
dim,
)
for axis, (start_index, dim) in enumerate(zip(start_indices, shape))
mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape)
)

is_memory_thread = mgpu.single_thread_predicate(per_block=True)

def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
if not in_in_smem[idx]:
return
Expand All @@ -407,59 +424,79 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
launch_ctx.async_copy(
src_ref=in_buffers_gmem[idx],
dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot),
gmem_slice=gmem_slice(
in_start_indices[idx],
step,
in_block_shapes[idx],
),
gmem_slice=gmem_slice(step, in_block_mappings[idx]),
barrier=barriers[slot],
gmem_transform=tuple(gmem_transforms),
swizzle=in_swizzles[idx],
arrive=False, # The caller must do ``arrive_expect_tx`` manually!
uniform=False,
predicate=is_memory_thread,
)

def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
def store(
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value | None
) -> ir.Value | None:
if not out_in_smem[idx]:
return

store_slice = gmem_slice(step, out_block_mappings[idx])
if out_sequential_invariant[idx]:
assert prev_base_offset is None
do_store = None # Lack of predicate defaults to True.
base_offset = None
else:
assert prev_base_offset is not None
# We have to do some work to make sure that consecutive stores are not
# going to be writing to the same location, or else we'll end up with
# multiple concurrent writes and a racy program.
# TODO(apaszke,slebedev): In most cases output index maps depend only on
# parallel grid axes and in that case we can simply move the store to
# happen after the loop.
# TODO(apaszke,slebedev): This still diverges significantly from the TPU
# semantics in that it will move on to the next SMEM output slice even if
# it's not storing the previous one.
strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset()
base_offset = _as_index(0)
for stride, slc in zip(strides, store_slice):
base_offset = arith_dialect.addi(
base_offset, arith_dialect.muli(slc.base, _as_index(stride))
)
base_offset_changed = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset
)
is_last_step = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1)
)
do_store = arith_dialect.andi(
is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step)
)
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
dst_ref=out_buffers_gmem[idx],
gmem_slice=gmem_slice(
out_start_indices[idx],
step,
ir.MemRefType(out_buffers_smem[idx].type).shape[1:],
),
gmem_slice=store_slice,
swizzle=None,
uniform=False,
predicate=do_store,
)

# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
raise NotImplementedError(
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
num_steps = grid_mapping.grid[sequential_axis]
else:
num_steps = 1

with mgpu.single_thread():
for slot in range(min(num_stages, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes)
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

@mgpu.fori(_as_index(num_steps), accs)
def _(step, accs):
slot = arith_dialect.remui(step, _as_index(num_stages))
return base_offset

for slot in range(min(max_concurrent_steps, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread)
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant]
@mgpu.fori(_as_index(num_steps), (accs, last_store_offsets))
def _(step, carry):
accs, last_store_offsets = carry
slot = arith_dialect.remui(step, _as_index(max_concurrent_steps))
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
barriers[slot].wait()
# We need to make sure the output copy is complete before the kernel starts
# writing to the output window.
launch_ctx.await_async_copy(max_concurrent_steps - 1, await_read_only=True)

args = [
mgpu.memref_slice(buffers_smem[idx], slot)
Expand All @@ -479,22 +516,34 @@ def _(step, accs):
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
)
mgpu.commit_shared()

with mgpu.single_thread():
for idx in range(grid_mapping.num_outputs):
store(idx, step, slot)
# TODO(apaszke): Elide this if we're not going to perform any stores
mgpu.commit_shared()
new_store_offsets = []
for idx in range(grid_mapping.num_outputs):
last_offset = last_store_offsets[idx]
new_store_offsets.append(
store(idx, step, slot, last_offset)
if not out_sequential_invariant[idx]
else last_offset # Only store if the output can depend on the step.
)

next_step = arith_dialect.addi(step, _as_index(num_stages))
next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps))
next_step_in_bounds = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps)
)
with mgpu.when(next_step_in_bounds), mgpu.single_thread():
next_slot = slot # (x + y) % y == x % y
with mgpu.when(next_step_in_bounds):
barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread)
for idx in range(grid_mapping.num_inputs):
fetch(idx, next_step, slot)
barriers[slot].arrive_expect_tx(in_transfer_bytes)
fetch(idx, next_step, next_slot)

return list(new_accs), new_store_offsets

return list(new_accs)
last_slot = _as_index((num_steps - 1) % max_concurrent_steps)
for idx in range(grid_mapping.num_outputs):
if out_sequential_invariant[idx]:
store(idx, _as_index(0), last_slot, None)

launch_ctx.await_async_copy(0)

Expand Down Expand Up @@ -540,7 +589,7 @@ def _(step, accs):
(*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
(
mgpu.Barrier(arrival_count=1, num_barriers=num_stages),
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
*extra_barriers,
),
),
Expand Down Expand Up @@ -612,6 +661,7 @@ def write_env(var: jax_core.Var, val):

@register_lowering_rule(primitives.program_id_p)
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
# TODO(apaszke): Sequential axis should be handled specially!!
del ctx # Unused.
return _program_id(axis)

Expand Down
1 change: 1 addition & 0 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
memref_unfold,
memref_unsqueeze,
single_thread,
single_thread_predicate,
thread_idx,
tile_shape,
warp_idx,
Expand Down
10 changes: 7 additions & 3 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def async_copy(
arrive: bool | None = None,
uniform: bool = True,
collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None,
predicate: ir.Value | None = None,
):
index = ir.IndexType.get()
i16 = ir.IntegerType.get_signless(16)
Expand Down Expand Up @@ -503,14 +504,17 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
barrier_ptr = barrier.get_ptr()
with uniform_ctx():
if arrive:
nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes)
nvvm.mbarrier_arrive_expect_tx_shared(
barrier_ptr, transfer_bytes, predicate=predicate
)
nvvm.cp_async_bulk_tensor_shared_cluster_global(
smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask,
smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [],
multicast_mask=multicast_mask, predicate=predicate
)
else:
with uniform_ctx():
nvvm.cp_async_bulk_tensor_global_shared_cta(
tma_desc, smem_ptr, rev_dyn_base_indices
tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate
)
nvvm.cp_async_bulk_commit_group()

Expand Down
Loading

0 comments on commit 1c7626d

Please sign in to comment.