Skip to content

Commit

Permalink
[Mosaic GPU] Allow tile sizes to exceed dimension size
Browse files Browse the repository at this point in the history
Otherwise, the dimension size still needs to be a multiple of tiling.

PiperOrigin-RevId: 666298624
  • Loading branch information
apaszke authored and jax authors committed Aug 22, 2024
1 parent 4786930 commit 0b4f64e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
14 changes: 11 additions & 3 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,13 @@ def apply(self, ref: ir.Value) -> ir.Value:
tiling_rank = len(self.tiling)
tiled_rank = untiled_rank + tiling_rank
for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]):
ref = utils.memref_unfold(ref, d, (None, t))
s = ir.MemRefType(ref.type).shape[d]
if s % t and s > t:
raise ValueError(
f"Dimension {d} must have size smaller or a multiple of its tiling"
f" {t}, but got {s}"
)
ref = utils.memref_unfold(ref, d, (None, min(t, s)))
permutation = (
*range(untiled_rank - tiling_rank),
*range(untiled_rank - tiling_rank, tiled_rank, 2),
Expand Down Expand Up @@ -175,8 +181,10 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
for size, tile_size in zip(shape[-tiling_rank:], self.tiling):
if size % tile_size:
raise ValueError(
f"Expected GMEM slice shape {shape} suffix to be a multiple"
f" of tiling {self.tiling}"
f"Expected GMEM slice shape {shape} suffix to be a multiple of"
f" tiling {self.tiling}.\nIf you're using padded async copies, your"
" slice might need to extend out of bounds of the GMEM buffer (OOB"
" accesses will be skipped)."
)
return (
*shape[:-tiling_rank],
Expand Down
67 changes: 67 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,73 @@ def kernel(ctx, src, dst, tmp):
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)

@parameterized.parameters(0, 1)
def test_tma_small_tile_load(self, small_dim):
if small_dim == 0:
shape = (4, 128)
elif small_dim == 1:
shape = (128, 8)
else:
raise ValueError("small_dim must be 0 or 1")
tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64)
padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2]))
def kernel(ctx, src, dst, smem):
tmp, barrier = smem
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
swizzle=128,
gmem_transform=mosaic_gpu.TileTransform((64, 64)),
gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])),
barrier=barrier,
)
barrier.wait()
copy(tmp, dst, swizzle=128)
x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape)
tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16)
y_tiled = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, tiled, (tiled, mgpu.TMABarrier()),
)(x)
y = y_tiled.swapaxes(1, 2).reshape(padded_shape)
# y should contain x and zero everywhere else.
np.testing.assert_array_equal(y[:shape[0], :shape[1]], x)
y_mut = np.asarray(y).copy()
y_mut[:shape[0], :shape[1]] = 0
np.testing.assert_array_equal(y_mut, np.zeros_like(y_mut))

@parameterized.parameters(0, 1)
def test_tma_small_tile_store(self, small_dim):
if small_dim == 0:
shape = (4, 128)
elif small_dim == 1:
shape = (128, 8)
else:
raise ValueError("small_dim must be 0 or 1")
tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64)
padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2]))
def kernel(ctx, dst, tmp):
vals = iota_tensor(
m=padded_shape[0], n=padded_shape[1], mlir_dtype=ir.F16Type.get()
)
vals.store_tiled(tmp, swizzle=128)
ctx.async_copy(
src_ref=tmp,
dst_ref=dst,
swizzle=128,
gmem_transform=mosaic_gpu.TileTransform((64, 64)),
gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])),
)
ctx.await_async_copy(0)
tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16)
out = jax.ShapeDtypeStruct(shape, jnp.float16)
y = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out, tiled,
)()
iota = np.arange(np.prod(padded_shape), dtype=jnp.float16).reshape(
padded_shape
)
np.testing.assert_array_equal(y, iota[:shape[0], :shape[1]])

def test_tma_invalid(self):
def kernel(ctx, src, dst, tmp):
copy(src, tmp)
Expand Down

0 comments on commit 0b4f64e

Please sign in to comment.