Skip to content

Commit

Permalink
[Pallas/MGPU] Skip output transfers when they don't depend on sequeni…
Browse files Browse the repository at this point in the history
…tal dims

Note that thanks to the previous revisiting-related checks we weren't doing the
transfers anyway, but this way we can also avoid having to pay for the checks.

PiperOrigin-RevId: 679516275
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 27, 2024
1 parent afaf8b8 commit 5740ab3
Showing 1 changed file with 53 additions and 27 deletions.
80 changes: 53 additions & 27 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def _eval_index_map(
return tuple(result)


def _uses_arguments(cjaxpr: jax_core.ClosedJaxpr) -> list[bool]:
jaxpr = cjaxpr.jaxpr
return pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))[1]


def lower_jaxpr_to_module(
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
Expand Down Expand Up @@ -270,8 +275,13 @@ def lower_jaxpr_to_module(
)
[sequential_axis] = sequential_axes
num_steps = grid_mapping.grid[sequential_axis]
out_sequential_invariant = [
not _uses_arguments(bm.index_map_jaxpr)[sequential_axis]
for bm in grid_mapping.block_mappings_output
]
else:
num_steps = 1
out_sequential_invariant = [True] * len(grid_mapping.out_shapes)

in_in_smem, out_in_smem = util.split_list(
[
Expand Down Expand Up @@ -429,36 +439,42 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
)

def store(
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value
) -> ir.Value:
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value | None
) -> ir.Value | None:
if not out_in_smem[idx]:
return _as_index(-1)

# We have to do some work to make sure that consecutive stores are not
# going to be writing to the same location, or else we'll end up with
# multiple concurrent writes and a racy program.
# TODO(apaszke,slebedev): In most cases output index maps depend only on
# parallel grid axes and in that case we can simply move the store to
# happen after the loop.
# TODO(apaszke,slebedev): This still diverges significantly from the TPU
# semantics in that it will move on to the next SMEM output slice even if
# it's not storing the previous one.
store_slice = gmem_slice(step, out_block_mappings[idx])
strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset()
base_offset = _as_index(0)
for stride, slc in zip(strides, store_slice):
base_offset = arith_dialect.addi(
base_offset, arith_dialect.muli(slc.base, _as_index(stride))
if out_sequential_invariant[idx]:
assert prev_base_offset is None
do_store = None # Lack of predicate defaults to True.
base_offset = None
else:
assert prev_base_offset is not None
# We have to do some work to make sure that consecutive stores are not
# going to be writing to the same location, or else we'll end up with
# multiple concurrent writes and a racy program.
# TODO(apaszke,slebedev): In most cases output index maps depend only on
# parallel grid axes and in that case we can simply move the store to
# happen after the loop.
# TODO(apaszke,slebedev): This still diverges significantly from the TPU
# semantics in that it will move on to the next SMEM output slice even if
# it's not storing the previous one.
strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset()
base_offset = _as_index(0)
for stride, slc in zip(strides, store_slice):
base_offset = arith_dialect.addi(
base_offset, arith_dialect.muli(slc.base, _as_index(stride))
)
base_offset_changed = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset
)
is_last_step = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1)
)
do_store = arith_dialect.andi(
is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step)
)
base_offset_changed = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset
)
is_last_step = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1)
)
do_store = arith_dialect.andi(
is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step)
)
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
Expand All @@ -475,7 +491,7 @@ def store(
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

last_store_offsets = [_as_index(-1)] * grid_mapping.num_outputs
last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant]
@mgpu.fori(_as_index(num_steps), (accs, last_store_offsets))
def _(step, carry):
accs, last_store_offsets = carry
Expand Down Expand Up @@ -510,8 +526,11 @@ def _(step, carry):
mgpu.commit_shared()
new_store_offsets = []
for idx in range(grid_mapping.num_outputs):
last_offset = last_store_offsets[idx]
new_store_offsets.append(
store(idx, step, slot, last_store_offsets[idx])
store(idx, step, slot, last_offset)
if not out_sequential_invariant[idx]
else last_offset # Only store if the output can depend on the step.
)

next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps))
Expand All @@ -526,6 +545,13 @@ def _(step, carry):

return list(new_accs), new_store_offsets

# Outputs invariant to the sequential axis are never written from inside the
# loop. This is the only place where we store them.
last_slot = _as_index((num_steps - 1) % max_concurrent_steps)
for idx in range(grid_mapping.num_outputs):
if out_sequential_invariant[idx]:
store(idx, _as_index(0), last_slot, None)

launch_ctx.await_async_copy(0)

scratch_avals = [
Expand Down

0 comments on commit 5740ab3

Please sign in to comment.