Skip to content

Commit

Permalink
[Mosaic GPU] Support tiled stores of arrays with fewer columns than s…
Browse files Browse the repository at this point in the history
…wizzling

PiperOrigin-RevId: 666798285
  • Loading branch information
apaszke authored and jax authors committed Aug 23, 2024
1 parent 71b7e78 commit be59f6e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
13 changes: 10 additions & 3 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ def store_tiled(self, ref, swizzle: int | None):
assert m % 64 == 0 # This is implied by the layout.
cols_per_tile = swizzle // bw
expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
if n < cols_per_tile: # We allow singular tiles shorter than swizzle.
expected_shape = [m // 64, 1, 64, cols_per_tile]
if ir.MemRefType(ref.type).shape != expected_shape:
raise ValueError(ref.type, (m, n))
for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle):
Expand Down Expand Up @@ -715,9 +717,12 @@ def transfer_tiled(shape, dtype, swizzle: int | None):
# TODO(apaszke): We could use ldmatrix/stmatrix for 16-bit types.
bw = mgpu.bytewidth(dtype)
m, n = shape
cols_per_tile = swizzle // bw
if n % cols_per_tile != 0:
raise NotImplementedError
assert m % 64 == 0 and n % 8 == 0 # Implied by the layout.
cols_per_tile = swizzle_elems = swizzle // bw
if n < swizzle_elems:
cols_per_tile = n
else:
assert n % swizzle_elems == 0, (n, swizzle_elems)
if swizzle not in {32, 64, 128}:
raise NotImplementedError("Only swizzled stores supported")

Expand Down Expand Up @@ -752,6 +757,8 @@ def transfer_tiled(shape, dtype, swizzle: int | None):
case _:
raise AssertionError(swizzle)
stagger_amount = swizzle // 64
if (cols_per_tile // 8) % (stagger_amount + 1):
raise NotImplementedError
else:
# We rely on canonicalization to clean up the selects.
i1 = ir.IntegerType.get_signless(1)
Expand Down
31 changes: 31 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,37 @@ def kernel(ctx, out, smem):
)()
np.testing.assert_array_equal(iota, expected)

@parameterized.product(
dtypes=(
(ir.F16Type.get, jnp.float16),
(partial(ir.IntegerType.get_signless, 8), jnp.int8),
),
swizzle=(32, 64, 128),
)
def test_store_tiled_short_n(self, dtypes, swizzle):
mlir_dtype_cls, jax_dtype = dtypes
mlir_dtype = mlir_dtype_cls()
col_tiling = swizzle // bytewidth(mlir_dtype)
m = 128
n = 16 // bytewidth(mlir_dtype)
tiling = (64, col_tiling)
def kernel(ctx, out, smem):
iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=swizzle)
ctx.async_copy(
src_ref=smem,
dst_ref=out,
swizzle=swizzle,
gmem_slice=(ds(0, m), ds(0, col_tiling)),
gmem_transform=mosaic_gpu.TileTransform(tiling),
)
ctx.await_async_copy(0)
smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), jax_dtype)
expected = np.arange(m * n, dtype=jax_dtype).reshape(m, n)
iota = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape
)()
np.testing.assert_array_equal(iota, expected)

@parameterized.named_parameters(
("bf16_i8",
ir.BF16Type.get, jnp.bfloat16,
Expand Down

0 comments on commit be59f6e

Please sign in to comment.