Skip to content

Commit

Permalink
[Pallas/Mosaic GPU] Disable inference of sequential axis shapes
Browse files Browse the repository at this point in the history
They should just be specified in the grid, so we don't need to do this. It's
also incorrect, because it's not guaranteed that each input is sliced in the
same dimension by the sequential axis.

PiperOrigin-RevId: 679114626
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent a6b4648 commit 5788773
Showing 1 changed file with 1 addition and 20 deletions.
21 changes: 1 addition & 20 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,26 +404,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
"Multiple sequential axes are not supported in Mosaic GPU lowering."
)
[sequential_axis] = sequential_axes
if any(
b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis]
for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem)
if b_smem
):
raise ValueError(
"Array dimensions along the sequential axis must be divisible by"
" the corresponding block dimensions."
)
num_steps, *rest = {
b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis]
for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem)
if b_smem
}
if rest:
raise ValueError(
"Array dimensions along the sequential axis must produce the same"
" number of steps when devided by the corresponding block"
" dimensions."
)
num_steps = grid_mapping.grid[sequential_axis]
else:
num_steps = 1

Expand Down

0 comments on commit 5788773

Please sign in to comment.