diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0c6d61fa9c96..0211b549a49e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -404,26 +404,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: "Multiple sequential axes are not supported in Mosaic GPU lowering." ) [sequential_axis] = sequential_axes - if any( - b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis] - for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) - if b_smem - ): - raise ValueError( - "Array dimensions along the sequential axis must be divisible by" - " the corresponding block dimensions." - ) - num_steps, *rest = { - b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis] - for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) - if b_smem - } - if rest: - raise ValueError( - "Array dimensions along the sequential axis must produce the same" - " number of steps when devided by the corresponding block" - " dimensions." - ) + num_steps = grid_mapping.grid[sequential_axis] else: num_steps = 1