Skip to content

Commit

Permalink
[Pallas/MGPU] Skip outgoing TMA when the output is being revisited
Browse files Browse the repository at this point in the history
Otherwise we end up with programs that race on writes to the same GMEM location.

PiperOrigin-RevId: 679189227
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent 076287f commit dd2ee8c
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 30 deletions.
70 changes: 53 additions & 17 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def gmem_slice(
mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape)
)

is_memory_thread = mgpu.single_thread_predicate(per_block=True)

def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
if not in_in_smem[idx]:
return
Expand All @@ -419,36 +421,67 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
swizzle=in_swizzles[idx],
arrive=False, # The caller must do ``arrive_expect_tx`` manually!
uniform=False,
predicate=is_memory_thread,
)

def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
def store(
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value
) -> ir.Value:
if not out_in_smem[idx]:
return

# We have to do some work to make sure that consecutive stores are not
# going to be writing to the same location, or else we'll end up with
# multiple concurrent writes and a racy program.
# TODO(apaszke,slebedev): In most cases output index maps depend only on
# parallel grid axes and in that case we can simply move the store to
# happen after the loop.
# TODO(apaszke,slebedev): This still diverges significantly from the TPU
# semantics in that it will move on to the next SMEM output slice even if
# it's not storing the previous one.
store_slice = gmem_slice(step, out_block_mappings[idx])
strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset()
base_offset = _as_index(0)
for stride, slc in zip(strides, store_slice):
base_offset = arith_dialect.addi(
base_offset, arith_dialect.muli(slc.base, _as_index(stride))
)
base_offset_changed = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset
)
is_last_step = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1)
)
do_store = arith_dialect.andi(
is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step)
)
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
dst_ref=out_buffers_gmem[idx],
gmem_slice=gmem_slice(step, out_block_mappings[idx]),
gmem_slice=store_slice,
swizzle=None,
uniform=False,
predicate=do_store,
)
return base_offset

with mgpu.single_thread():
for slot in range(min(max_concurrent_steps, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes)
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))
for slot in range(min(max_concurrent_steps, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread)
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

@mgpu.fori(_as_index(num_steps), accs)
def _(step, accs):
last_store_offsets = [_as_index(-1)] * grid_mapping.num_outputs
@mgpu.fori(_as_index(num_steps), (accs, last_store_offsets))
def _(step, carry):
accs, last_store_offsets = carry
slot = arith_dialect.remui(step, _as_index(max_concurrent_steps))
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
barriers[slot].wait()
# We need to make sure the output copy is complete before the kernel starts
# writing to the output window.
launch_ctx.await_async_copy(max_concurrent_steps - 1)
launch_ctx.await_async_copy(max_concurrent_steps - 1, await_read_only=True)

args = [
mgpu.memref_slice(buffers_smem[idx], slot)
Expand All @@ -468,23 +501,26 @@ def _(step, accs):
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
)
mgpu.commit_shared()

with mgpu.single_thread():
for idx in range(grid_mapping.num_outputs):
store(idx, step, slot)
# TODO(apaszke): Elide this if we're not going to perform any stores
mgpu.commit_shared()
new_store_offsets = []
for idx in range(grid_mapping.num_outputs):
new_store_offsets.append(
store(idx, step, slot, last_store_offsets[idx])
)

next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps))
next_step_in_bounds = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps)
)
next_slot = slot # (x + y) % y == x % y
with mgpu.when(next_step_in_bounds), mgpu.single_thread():
with mgpu.when(next_step_in_bounds):
barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread)
for idx in range(grid_mapping.num_inputs):
fetch(idx, next_step, next_slot)
barriers[slot].arrive_expect_tx(in_transfer_bytes)

return list(new_accs)
return list(new_accs), new_store_offsets

launch_ctx.await_async_copy(0)

Expand Down
1 change: 1 addition & 0 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
memref_unfold,
memref_unsqueeze,
single_thread,
single_thread_predicate,
thread_idx,
tile_shape,
warp_idx,
Expand Down
10 changes: 7 additions & 3 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def async_copy(
arrive: bool | None = None,
uniform: bool = True,
collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None,
predicate: ir.Value | None = None,
):
index = ir.IndexType.get()
i16 = ir.IntegerType.get_signless(16)
Expand Down Expand Up @@ -503,14 +504,17 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
barrier_ptr = barrier.get_ptr()
with uniform_ctx():
if arrive:
nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes)
nvvm.mbarrier_arrive_expect_tx_shared(
barrier_ptr, transfer_bytes, predicate=predicate
)
nvvm.cp_async_bulk_tensor_shared_cluster_global(
smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask,
smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [],
multicast_mask=multicast_mask, predicate=predicate
)
else:
with uniform_ctx():
nvvm.cp_async_bulk_tensor_global_shared_cta(
tma_desc, smem_ptr, rev_dyn_base_indices
tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate
)
nvvm.cp_async_bulk_commit_group()

Expand Down
24 changes: 14 additions & 10 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ class ThreadSubset(enum.IntEnum):
_ONCE_PER: ThreadSubset | None = None


def single_thread_predicate(per_block=True):
warp = warp_idx()
if not per_block:
warp = arith.remui(warp, c(4, warp.type))
first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type))
elected = nvvm.elect_sync(ir.IntegerType.get_signless(1))
return arith.andi(first_warp, elected)


@contextlib.contextmanager
def single_thread(per_block=True):
"""Runs the context only from a single thread.
Expand All @@ -244,16 +253,10 @@ def single_thread(per_block=True):
yield
return

warp = warp_idx()
if not per_block:
warp = arith.remui(warp, c(4, warp.type))
first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type))
elected = nvvm.elect_sync(ir.IntegerType.get_signless(1))
should_run = arith.andi(first_warp, elected)
if_op = scf.IfOp(should_run)
prev_scope = _ONCE_PER
_ONCE_PER = scope
try:
if_op = scf.IfOp(single_thread_predicate(per_block))
with ir.InsertionPoint(if_op.then_block):
yield
scf.YieldOp([])
Expand Down Expand Up @@ -610,14 +613,15 @@ def arrive(self):
i64 = ir.IntegerType.get_signless(64)
nvvm.mbarrier_arrive_shared(i64, self.get_ptr())

def arrive_expect_tx(self, bytes: int | ir.Value):
def arrive_expect_tx(
self, bytes: int | ir.Value, predicate: ir.Value | None = None
):
if isinstance(bytes, int):
bytes = c(bytes, ir.IntegerType.get_signless(32))
elif ir.IndexType.isinstance(bytes.type):
i32 = ir.IntegerType.get_signless(32)
bytes = arith.index_cast(i32, bytes)

nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes)
nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes, predicate=predicate)

def get_ptr(self):
ptr = ir.Type.parse("!llvm.ptr<3>")
Expand Down

0 comments on commit dd2ee8c

Please sign in to comment.