Skip to content

Commit

Permalink
[Pallas/MGPU] Fix a race in the pipelining code
Browse files Browse the repository at this point in the history
We never checked if the output windows are done writing before we reused them.
Also, rename num_stages to max_concurrent_steps since we always only have 2 stages,
but might be running multiple iterations at a time.

Also fix the test for this that has been passing for reasons that I don't understand
(it didn't even write to all entries in the output??).

PiperOrigin-RevId: 679102765
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent 5788773 commit 86bb409
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 30 deletions.
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ class GPUCompilerParams(pallas_core.CompilerParams):
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
executed sequentially.
num_stages: The number of pipline stages in the kernel. Defaults to 1,
meaning no pipelining is done.
max_concurrent_steps: The maximum number of sequential stages that are
active concurrently. Defaults to 1.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
approx_math: bool = False
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1
max_concurrent_steps: int = 1


class GPUMemorySpace(enum.Enum):
Expand Down
88 changes: 73 additions & 15 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
zip, unsafe_zip = util.safe_zip, zip

partial = functools.partial
SMEM = gpu_core.SMEM

_smem_estimators = {}

Expand Down Expand Up @@ -242,7 +243,7 @@ def lower_jaxpr_to_module(
block = (128,) + (1,) * (len(grid) - 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
num_stages = params.get("num_stages", 1)
max_concurrent_steps = params.get("max_concurrent_steps", 1)
dimension_semantics = params.get("dimension_semantics")
if dimension_semantics is None:
dimension_semantics = ["parallel"] * len(grid_mapping.grid)
Expand Down Expand Up @@ -271,7 +272,9 @@ def lower_jaxpr_to_module(
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_structs_smem = [
jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype)
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
)
if in_smem
else None
for bm, in_smem in zip(
Expand All @@ -292,7 +295,7 @@ def lower_jaxpr_to_module(
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
out_structs_smem = [
jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype)
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
if in_smem
else None
for bm, in_smem, s in zip(
Expand Down Expand Up @@ -330,6 +333,45 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
start_indices, [grid_mapping.num_inputs]
)

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
should_discharge = []
accs = []
for aval in scratch_avals:
match aval:
case gpu_core.WGMMAAbstractAccumulatorRef():
scratch_buffers_template.append(None)
should_discharge.append(True)
accs.append(
mgpu.WGMMAAccumulator.zero(
*aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype)
)
)
case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM:
scratch_buffers_template.append(next(smem_scratch_it))
should_discharge.append(False)
case _:
raise NotImplementedError(
f"Unsupported scratch operand type: {aval}"
)
assert not jaxpr.outvars
if any(should_discharge):
# User-visible WGMMA APIs use the effectful accumulator references, but we
# can't lower that directly to Mosaic GPU that uses pure dataflow for
# accumulators. So we have to discharge the effects first.
assert not jaxpr.constvars
should_discharge = (
[False] * len(grid_mapping.block_mappings)
+ should_discharge
+ [False] * len(extra_barriers)
)
with grid_mapping.trace_env():
lowered_jaxpr, _ = discharge.discharge_state(
jaxpr, (), should_discharge=should_discharge
)
else:
lowered_jaxpr = jaxpr

# Precompute the total number of bytes transferred from GMEM to SMEM,
# so that we can do a single arrive instruction for all of the inputs.
in_transfer_bytes = 0
Expand Down Expand Up @@ -409,56 +451,71 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
num_steps = 1

with mgpu.single_thread():
for slot in range(min(num_stages, num_steps)):
for slot in range(min(max_concurrent_steps, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes)
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

@mgpu.fori(_as_index(num_steps), ())
def _(step, _):
slot = arith_dialect.remui(step, _as_index(num_stages))
@mgpu.fori(_as_index(num_steps), accs)
def _(step, accs):
slot = arith_dialect.remui(step, _as_index(max_concurrent_steps))
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
barriers[slot].wait()
# We need to make sure the output copy is complete before the kernel starts
# writing to the output window.
launch_ctx.await_async_copy(max_concurrent_steps - 1)

args = [
mgpu.memref_slice(buffers_smem[idx], slot)
if in_smem
else buffers_gmem[idx]
for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem))
]
args.extend(scratch_buffers_smem)
accs_it = iter(accs)
scratch_buffers = [
b if b is not None else next(accs_it)
for b in scratch_buffers_template
]
args.extend(scratch_buffers)
# TODO(apaszke): This assumes barriers come after buffers in scratch args,
# but that's not necessarily true.
args.extend(extra_barriers)
_ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args)
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
)
mgpu.commit_shared()

with mgpu.single_thread():
for idx in range(grid_mapping.num_outputs):
store(idx, step, slot)

next_step = arith_dialect.addi(step, _as_index(num_stages))
next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps))
next_step_in_bounds = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps)
)
next_slot = slot # (x + y) % y == x % y
with mgpu.when(next_step_in_bounds), mgpu.single_thread():
for idx in range(grid_mapping.num_inputs):
fetch(idx, next_step, slot)
fetch(idx, next_step, next_slot)
barriers[slot].arrive_expect_tx(in_transfer_bytes)

return ()
return list(new_accs)

launch_ctx.await_async_copy(0)

scratch_avals = [
var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops]
]
local_spaces = (gpu_core.SMEM, gpu_core.REGS)
if not all(
isinstance(aval, pallas_core.AbstractMemoryRef)
and aval.memory_space is gpu_core.SMEM
and aval.memory_space in local_spaces
for aval in scratch_avals
):
raise TypeError(
f"All scratch operands must be in SMEM, but got: {scratch_avals}"
"All scratch operands must be SMEM references or accumulators (ACC),"
f" but got: {scratch_avals}"
)
extra_barriers = [
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
Expand All @@ -469,6 +526,7 @@ def _(step, _):
jax.ShapeDtypeStruct(aval.shape, aval.dtype)
for aval in scratch_avals
if not isinstance(aval.dtype, gpu_core.BarrierType)
and aval.memory_space == gpu_core.SMEM
]
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
if smem_scratch_bytes is None:
Expand All @@ -488,7 +546,7 @@ def _(step, _):
(*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
(
mgpu.Barrier(arrival_count=1, num_barriers=num_stages),
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
*extra_barriers,
),
),
Expand Down
9 changes: 1 addition & 8 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import lowering
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp

async_copy_p = jax_core.Primitive("async_copy")
async_copy_p.multiple_results = True
Expand Down Expand Up @@ -120,7 +119,7 @@ class _WGMMAPipelineEffect(effects.Effect):
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
"""Asynchronous warp group matmul.
The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
Expand All @@ -137,12 +136,6 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

rhs_transpose = (
(jnp.dtype(b.dtype).itemsize == 2)
if rhs_transpose is None
else rhs_transpose
)

ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape
Expand Down
64 changes: 60 additions & 4 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def kernel(x_ref, o_ref, scratch_ref):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)

@parameterized.product(num_stages=[1, 2, 3])
def test_add_one_grid_pipelined(self, num_stages):
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4])
def test_add_one_grid_pipelined(self, max_concurrent_steps):

@functools.partial(
pl.pallas_call,
Expand All @@ -106,9 +106,9 @@ def test_add_one_grid_pipelined(self, num_stages):
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
num_stages=num_stages,
max_concurrent_steps=max_concurrent_steps,
),
grid=(2, 1),
grid=(2, 4),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
Expand Down Expand Up @@ -437,6 +437,62 @@ def kernel(a_ref, b_ref):
)(a)
np.testing.assert_array_equal(b, np.ones_like(a))

def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
# TODO(apaszke): Make the grid and tile sizes larger
# grid_m, grid_k, grid_n = 132, 10, 4
# TODO(apaszke): Increasing grid_k causes th test to fail.
# It seems like our pipelining implementation has a number of races.
grid_m, grid_k, grid_n = 2, 1, 2
# tile_m = tile_n = 128
tile_m = tile_n = 64
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races
# TODO(apaszke): Only store in the last step. It doesn't work because we
# don't have partial discharge for control flow.
# is_last_step = pl.program_id(2) == grid_k - 1
# @pl.when(is_last_step)
# def _epilogue():
# pl.debug_print("{}", acc_ref[...])
# TODO(apaszke): This is an untiled store! It's slow!!
o_ref[...] = acc_ref[...]

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)

res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda m, n, k: (m, k),
transforms=plgpu.TilingTransform((64, elems_128b)),
swizzle=128,
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda m, n, k: (k, n),
transforms=plgpu.TilingTransform((elems_128b, elems_128b)),
swizzle=128,
),
],
out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
grid=(grid_m, grid_n, grid_k),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "parallel", "sequential"],
max_concurrent_steps=2,
),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)


if __name__ == "__main__":
absltest.main()

0 comments on commit 86bb409

Please sign in to comment.