Skip to content

Commit

Permalink
[Pallas/MGPU] Fix a race in the pipelining code
Browse files Browse the repository at this point in the history
We never checked if the output windows are done writing before we reused them.
Also, rename num_stages to max_concurrent_steps since we always only have 2 stages,
but might be running multiple iterations at a time.

Also fix the test for this that has been passing for reasons that I don't understand
(it didn't even write to all entries in the output??).

PiperOrigin-RevId: 679102765
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent 8599dbc commit 2d65abf
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ class GPUCompilerParams(pallas_core.CompilerParams):
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
executed sequentially.
num_stages: The number of pipline stages in the kernel. Defaults to 1,
meaning no pipelining is done.
max_concurrent_steps: The maximum number of sequential stages that are
active concurrently. Defaults to 1.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
approx_math: bool = False
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1
max_concurrent_steps: int = 1


class GPUMemorySpace(enum.Enum):
Expand Down
22 changes: 14 additions & 8 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def lower_jaxpr_to_module(
block = (128,) + (1,) * (len(grid) - 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
num_stages = params.get("num_stages", 1)
max_concurrent_steps = params.get("max_concurrent_steps", 1)
dimension_semantics = params.get("dimension_semantics")
if dimension_semantics is None:
dimension_semantics = ["parallel"] * len(grid_mapping.grid)
Expand Down Expand Up @@ -272,7 +272,9 @@ def lower_jaxpr_to_module(
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_structs_smem = [
jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype)
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
)
if in_smem
else None
for bm, in_smem in zip(
Expand All @@ -293,7 +295,7 @@ def lower_jaxpr_to_module(
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
out_structs_smem = [
jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype)
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
if in_smem
else None
for bm, in_smem, s in zip(
Expand Down Expand Up @@ -449,17 +451,20 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
num_steps = 1

with mgpu.single_thread():
for slot in range(min(num_stages, num_steps)):
for slot in range(min(max_concurrent_steps, num_steps)):
barriers[slot].arrive_expect_tx(in_transfer_bytes)
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))

@mgpu.fori(_as_index(num_steps), accs)
def _(step, accs):
slot = arith_dialect.remui(step, _as_index(num_stages))
slot = arith_dialect.remui(step, _as_index(max_concurrent_steps))
if grid_mapping.num_inputs:
# Only wait if async copies were issued.
barriers[slot].wait()
# We need to make sure the output copy is complete before the kernel starts
# writing to the output window.
launch_ctx.await_async_copy(max_concurrent_steps - 1)

args = [
mgpu.memref_slice(buffers_smem[idx], slot)
Expand All @@ -485,13 +490,14 @@ def _(step, accs):
for idx in range(grid_mapping.num_outputs):
store(idx, step, slot)

next_step = arith_dialect.addi(step, _as_index(num_stages))
next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps))
next_step_in_bounds = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps)
)
next_slot = slot # (x + y) % y == x % y
with mgpu.when(next_step_in_bounds), mgpu.single_thread():
for idx in range(grid_mapping.num_inputs):
fetch(idx, next_step, slot)
fetch(idx, next_step, next_slot)
barriers[slot].arrive_expect_tx(in_transfer_bytes)

return list(new_accs)
Expand Down Expand Up @@ -540,7 +546,7 @@ def _(step, accs):
(*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
(
mgpu.Barrier(arrival_count=1, num_barriers=num_stages),
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
*extra_barriers,
),
),
Expand Down
10 changes: 5 additions & 5 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def kernel(x_ref, o_ref, scratch_ref):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)

@parameterized.product(num_stages=[1, 2, 3])
def test_add_one_grid_pipelined(self, num_stages):
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4])
def test_add_one_grid_pipelined(self, max_concurrent_steps):

@functools.partial(
pl.pallas_call,
Expand All @@ -106,9 +106,9 @@ def test_add_one_grid_pipelined(self, num_stages):
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
num_stages=num_stages,
max_concurrent_steps=max_concurrent_steps,
),
grid=(2, 1),
grid=(2, 4),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
Expand Down Expand Up @@ -488,7 +488,7 @@ def kernel(a_ref, b_ref, o_ref, acc_ref):
grid=(grid_m, grid_n, grid_k),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "parallel", "sequential"],
num_stages=2,
max_concurrent_steps=2,
),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
Expand Down

0 comments on commit 2d65abf

Please sign in to comment.