diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index bdb22442c6b7..0d0ac41d11e3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -201,6 +201,11 @@ def _eval_index_map( return tuple(result) +def _uses_arguments(cjaxpr: jax_core.ClosedJaxpr) -> list[bool]: + jaxpr = cjaxpr.jaxpr + return pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))[1] + + def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -270,8 +275,13 @@ def lower_jaxpr_to_module( ) [sequential_axis] = sequential_axes num_steps = grid_mapping.grid[sequential_axis] + out_sequential_invariant = [ + not _uses_arguments(bm.index_map_jaxpr)[sequential_axis] + for bm in grid_mapping.block_mappings_output + ] else: num_steps = 1 + out_sequential_invariant = [True] * len(grid_mapping.out_shapes) in_in_smem, out_in_smem = util.split_list( [ @@ -429,36 +439,42 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: ) def store( - idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value - ) -> ir.Value: + idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value | None + ) -> ir.Value | None: if not out_in_smem[idx]: return _as_index(-1) - # We have to do some work to make sure that consecutive stores are not - # going to be writing to the same location, or else we'll end up with - # multiple concurrent writes and a racy program. - # TODO(apaszke,slebedev): In most cases output index maps depend only on - # parallel grid axes and in that case we can simply move the store to - # happen after the loop. - # TODO(apaszke,slebedev): This still diverges significantly from the TPU - # semantics in that it will move on to the next SMEM output slice even if - # it's not storing the previous one. store_slice = gmem_slice(step, out_block_mappings[idx]) - strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset() - base_offset = _as_index(0) - for stride, slc in zip(strides, store_slice): - base_offset = arith_dialect.addi( - base_offset, arith_dialect.muli(slc.base, _as_index(stride)) + if out_sequential_invariant[idx]: + assert prev_base_offset is None + do_store = None # Lack of predicate defaults to True. + base_offset = None + else: + assert prev_base_offset is not None + # We have to do some work to make sure that consecutive stores are not + # going to be writing to the same location, or else we'll end up with + # multiple concurrent writes and a racy program. + # TODO(apaszke,slebedev): In most cases output index maps depend only on + # parallel grid axes and in that case we can simply move the store to + # happen after the loop. + # TODO(apaszke,slebedev): This still diverges significantly from the TPU + # semantics in that it will move on to the next SMEM output slice even if + # it's not storing the previous one. + strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset() + base_offset = _as_index(0) + for stride, slc in zip(strides, store_slice): + base_offset = arith_dialect.addi( + base_offset, arith_dialect.muli(slc.base, _as_index(stride)) + ) + base_offset_changed = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset + ) + is_last_step = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1) + ) + do_store = arith_dialect.andi( + is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) ) - base_offset_changed = arith_dialect.cmpi( - arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset - ) - is_last_step = arith_dialect.cmpi( - arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1) - ) - do_store = arith_dialect.andi( - is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) - ) # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), @@ -475,7 +491,7 @@ def store( for idx in range(grid_mapping.num_inputs): fetch(idx, _as_index(slot), _as_index(slot)) - last_store_offsets = [_as_index(-1)] * grid_mapping.num_outputs + last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant] @mgpu.fori(_as_index(num_steps), (accs, last_store_offsets)) def _(step, carry): accs, last_store_offsets = carry @@ -510,8 +526,11 @@ def _(step, carry): mgpu.commit_shared() new_store_offsets = [] for idx in range(grid_mapping.num_outputs): + last_offset = last_store_offsets[idx] new_store_offsets.append( - store(idx, step, slot, last_store_offsets[idx]) + store(idx, step, slot, last_offset) + if not out_sequential_invariant[idx] + else last_offset # Only store if the output can depend on the step. ) next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps)) @@ -526,6 +545,13 @@ def _(step, carry): return list(new_accs), new_store_offsets + # Outputs invariant to the sequential axis are never written from inside the + # loop. This is the only place where we store them. + last_slot = _as_index((num_steps - 1) % max_concurrent_steps) + for idx in range(grid_mapping.num_outputs): + if out_sequential_invariant[idx]: + store(idx, _as_index(0), last_slot, None) + launch_ctx.await_async_copy(0) scratch_avals = [