diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 92cd385be2a3..a64d4cc9f9b1 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -404,6 +404,8 @@ def gmem_slice( mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape) ) + is_memory_thread = mgpu.single_thread_predicate(per_block=True) + def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: if not in_in_smem[idx]: return @@ -419,36 +421,67 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: swizzle=in_swizzles[idx], arrive=False, # The caller must do ``arrive_expect_tx`` manually! uniform=False, + predicate=is_memory_thread, ) - def store(idx: int, step: ir.Value, slot: ir.Value) -> None: + def store( + idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value + ) -> ir.Value: if not out_in_smem[idx]: return + # We have to do some work to make sure that consecutive stores are not + # going to be writing to the same location, or else we'll end up with + # multiple concurrent writes and a racy program. + # TODO(apaszke,slebedev): In most cases output index maps depend only on + # parallel grid axes and in that case we can simply move the store to + # happen after the loop. + # TODO(apaszke,slebedev): This still diverges significantly from the TPU + # semantics in that it will move on to the next SMEM output slice even if + # it's not storing the previous one. + store_slice = gmem_slice(step, out_block_mappings[idx]) + strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset() + base_offset = _as_index(0) + for stride, slc in zip(strides, store_slice): + base_offset = arith_dialect.addi( + base_offset, arith_dialect.muli(slc.base, _as_index(stride)) + ) + base_offset_changed = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset + ) + is_last_step = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1) + ) + do_store = arith_dialect.andi( + is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) + ) # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), dst_ref=out_buffers_gmem[idx], - gmem_slice=gmem_slice(step, out_block_mappings[idx]), + gmem_slice=store_slice, swizzle=None, uniform=False, + predicate=do_store, ) + return base_offset - with mgpu.single_thread(): - for slot in range(min(max_concurrent_steps, num_steps)): - barriers[slot].arrive_expect_tx(in_transfer_bytes) - for idx in range(grid_mapping.num_inputs): - fetch(idx, _as_index(slot), _as_index(slot)) + for slot in range(min(max_concurrent_steps, num_steps)): + barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) + for idx in range(grid_mapping.num_inputs): + fetch(idx, _as_index(slot), _as_index(slot)) - @mgpu.fori(_as_index(num_steps), accs) - def _(step, accs): + last_store_offsets = [_as_index(-1)] * grid_mapping.num_outputs + @mgpu.fori(_as_index(num_steps), (accs, last_store_offsets)) + def _(step, carry): + accs, last_store_offsets = carry slot = arith_dialect.remui(step, _as_index(max_concurrent_steps)) if grid_mapping.num_inputs: # Only wait if async copies were issued. barriers[slot].wait() # We need to make sure the output copy is complete before the kernel starts # writing to the output window. - launch_ctx.await_async_copy(max_concurrent_steps - 1) + launch_ctx.await_async_copy(max_concurrent_steps - 1, await_read_only=True) args = [ mgpu.memref_slice(buffers_smem[idx], slot) @@ -468,23 +501,26 @@ def _(step, accs): new_accs = lower_jaxpr_to_mosaic_gpu( module_ctx, launch_ctx, lowered_jaxpr, args ) - mgpu.commit_shared() - with mgpu.single_thread(): - for idx in range(grid_mapping.num_outputs): - store(idx, step, slot) + # TODO(apaszke): Elide this if we're not going to perform any stores + mgpu.commit_shared() + new_store_offsets = [] + for idx in range(grid_mapping.num_outputs): + new_store_offsets.append( + store(idx, step, slot, last_store_offsets[idx]) + ) next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps)) next_step_in_bounds = arith_dialect.cmpi( arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) ) next_slot = slot # (x + y) % y == x % y - with mgpu.when(next_step_in_bounds), mgpu.single_thread(): + with mgpu.when(next_step_in_bounds): + barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) for idx in range(grid_mapping.num_inputs): fetch(idx, next_step, next_slot) - barriers[slot].arrive_expect_tx(in_transfer_bytes) - return list(new_accs) + return list(new_accs), new_store_offsets launch_ctx.await_async_copy(0) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 8057f97a7dfc..f5944c862480 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -51,6 +51,7 @@ memref_unfold, memref_unsqueeze, single_thread, + single_thread_predicate, thread_idx, tile_shape, warp_idx, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 9a03afeb4a8c..22a996efd64a 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -347,6 +347,7 @@ def async_copy( arrive: bool | None = None, uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, + predicate: ir.Value | None = None, ): index = ir.IndexType.get() i16 = ir.IntegerType.get_signless(16) @@ -503,14 +504,17 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): barrier_ptr = barrier.get_ptr() with uniform_ctx(): if arrive: - nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes) + nvvm.mbarrier_arrive_expect_tx_shared( + barrier_ptr, transfer_bytes, predicate=predicate + ) nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask, + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], + multicast_mask=multicast_mask, predicate=predicate ) else: with uniform_ctx(): nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices + tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate ) nvvm.cp_async_bulk_commit_group() diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index b5c22734b0d4..a59ddbea5565 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -229,6 +229,15 @@ class ThreadSubset(enum.IntEnum): _ONCE_PER: ThreadSubset | None = None +def single_thread_predicate(per_block=True): + warp = warp_idx() + if not per_block: + warp = arith.remui(warp, c(4, warp.type)) + first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + return arith.andi(first_warp, elected) + + @contextlib.contextmanager def single_thread(per_block=True): """Runs the context only from a single thread. @@ -244,16 +253,10 @@ def single_thread(per_block=True): yield return - warp = warp_idx() - if not per_block: - warp = arith.remui(warp, c(4, warp.type)) - first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) - elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) - should_run = arith.andi(first_warp, elected) - if_op = scf.IfOp(should_run) prev_scope = _ONCE_PER _ONCE_PER = scope try: + if_op = scf.IfOp(single_thread_predicate(per_block)) with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) @@ -610,14 +613,15 @@ def arrive(self): i64 = ir.IntegerType.get_signless(64) nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) - def arrive_expect_tx(self, bytes: int | ir.Value): + def arrive_expect_tx( + self, bytes: int | ir.Value, predicate: ir.Value | None = None + ): if isinstance(bytes, int): bytes = c(bytes, ir.IntegerType.get_signless(32)) elif ir.IndexType.isinstance(bytes.type): i32 = ir.IntegerType.get_signless(32) bytes = arith.index_cast(i32, bytes) - - nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes) + nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes, predicate=predicate) def get_ptr(self): ptr = ir.Type.parse("!llvm.ptr<3>")