Skip to content

Commit

Permalink
Exposed sequential iteration index via pl.program_id in Pallas Mosa…
Browse files Browse the repository at this point in the history
…ic GPU

PiperOrigin-RevId: 679261898
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 26, 2024
1 parent 6f7ad64 commit 27d2fac
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
32 changes: 22 additions & 10 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
class ModuleContext:
name: str
grid_mapping: pallas_core.GridMapping
program_ids: Sequence[ir.Value]
approx_math: bool
runtime_smem: ir.Value # ir.MemRefType
smem_used_bytes: int = 0
Expand Down Expand Up @@ -261,7 +262,6 @@ def lower_jaxpr_to_module(
raise NotImplementedError(
"Only <=3D grids are supported in Mosaic GPU lowering."
)
# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
Expand Down Expand Up @@ -340,14 +340,16 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
for i in range(len(grid_mapping.grid))
]
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
name_and_src_info.name, grid_mapping, (), approx_math, runtime_smem
)

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
should_discharge = []
accs = []
for aval in scratch_avals:
if isinstance(aval.dtype, gpu_core.BarrierType):
continue
match aval:
case gpu_core.WGMMAAbstractAccumulatorRef():
scratch_buffers_template.append(None)
Expand Down Expand Up @@ -397,7 +399,6 @@ def gmem_slice(
step: ir.Value,
block_mapping: pallas_core.BlockMapping,
) -> Sequence[mgpu.DynamicSlice]:
assert len(sequential_axes) == 1
program_ids = [step if i is None else i for i in program_ids_template]
idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping)
return tuple(
Expand Down Expand Up @@ -428,7 +429,7 @@ def store(
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value
) -> ir.Value:
if not out_in_smem[idx]:
return
return _as_index(-1)

# We have to do some work to make sure that consecutive stores are not
# going to be writing to the same location, or else we'll end up with
Expand Down Expand Up @@ -498,9 +499,19 @@ def _(step, carry):
# TODO(apaszke): This assumes barriers come after buffers in scratch args,
# but that's not necessarily true.
args.extend(extra_barriers)
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
)
program_ids = [
arith_dialect.index_cast(ir.IntegerType.get_signless(32), step)
if i is None
else i
for i in program_ids_template
]
step_module_ctx = dataclasses.replace(module_ctx, program_ids=program_ids)
with pallas_core.grid_env(
map(pallas_core.GridAxis, program_ids, grid_mapping.grid)
):
new_accs = lower_jaxpr_to_mosaic_gpu(
step_module_ctx, launch_ctx, lowered_jaxpr, args
)

# TODO(apaszke): Elide this if we're not going to perform any stores
mgpu.commit_shared()
Expand Down Expand Up @@ -638,9 +649,10 @@ def write_env(var: jax_core.Var, val):

@register_lowering_rule(primitives.program_id_p)
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
# TODO(apaszke): Sequential axis should be handled specially!!
del ctx # Unused.
return _program_id(axis)
try:
return ctx.module_ctx.program_ids[axis]
except IndexError:
return _i32_constant(0)


def _program_id(axis: int) -> ir.Value:
Expand Down
19 changes: 19 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,25 @@ def kernel(x_ref, o_ref):
x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)

def test_add_one_grid_pipelined_program_id(self):

@functools.partial(
pl.pallas_call,
out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=2,
),
grid=(1, 4),
)
def kernel(o_ref):
o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape)

out = kernel()
for i in range(4):
np.testing.assert_array_equal(out[:, i * 16 : (i + 1) * 16], i)

def test_add_one_with_async_copy_smem_to_gmem(self):
@functools.partial(
pl.pallas_call,
Expand Down

0 comments on commit 27d2fac

Please sign in to comment.