diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 406f5eaba3b2..892cd2d09332 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -686,6 +686,8 @@ def store_tiled(self, ref, swizzle: int | None): assert m % 64 == 0 # This is implied by the layout. cols_per_tile = swizzle // bw expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if n < cols_per_tile: # We allow singular tiles shorter than swizzle. + expected_shape = [m // 64, 1, 64, cols_per_tile] if ir.MemRefType(ref.type).shape != expected_shape: raise ValueError(ref.type, (m, n)) for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): @@ -715,9 +717,12 @@ def transfer_tiled(shape, dtype, swizzle: int | None): # TODO(apaszke): We could use ldmatrix/stmatrix for 16-bit types. bw = mgpu.bytewidth(dtype) m, n = shape - cols_per_tile = swizzle // bw - if n % cols_per_tile != 0: - raise NotImplementedError + assert m % 64 == 0 and n % 8 == 0 # Implied by the layout. + cols_per_tile = swizzle_elems = swizzle // bw + if n < swizzle_elems: + cols_per_tile = n + else: + assert n % swizzle_elems == 0, (n, swizzle_elems) if swizzle not in {32, 64, 128}: raise NotImplementedError("Only swizzled stores supported") @@ -752,6 +757,8 @@ def transfer_tiled(shape, dtype, swizzle: int | None): case _: raise AssertionError(swizzle) stagger_amount = swizzle // 64 + if (cols_per_tile // 8) % (stagger_amount + 1): + raise NotImplementedError else: # We rely on canonicalization to clean up the selects. i1 = ir.IntegerType.get_signless(1) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index b7f99ab7b290..ce1c02f5a01b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -423,6 +423,37 @@ def kernel(ctx, out, smem): )() np.testing.assert_array_equal(iota, expected) + @parameterized.product( + dtypes=( + (ir.F16Type.get, jnp.float16), + (partial(ir.IntegerType.get_signless, 8), jnp.int8), + ), + swizzle=(32, 64, 128), + ) + def test_store_tiled_short_n(self, dtypes, swizzle): + mlir_dtype_cls, jax_dtype = dtypes + mlir_dtype = mlir_dtype_cls() + col_tiling = swizzle // bytewidth(mlir_dtype) + m = 128 + n = 16 // bytewidth(mlir_dtype) + tiling = (64, col_tiling) + def kernel(ctx, out, smem): + iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=swizzle) + ctx.async_copy( + src_ref=smem, + dst_ref=out, + swizzle=swizzle, + gmem_slice=(ds(0, m), ds(0, col_tiling)), + gmem_transform=mosaic_gpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), jax_dtype) + expected = np.arange(m * n, dtype=jax_dtype).reshape(m, n) + iota = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape + )() + np.testing.assert_array_equal(iota, expected) + @parameterized.named_parameters( ("bf16_i8", ir.BF16Type.get, jnp.bfloat16,