Skip to content

Commit

Permalink
[Mosaic GPU] Add support for short n dimension in WGMMA
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666766079
  • Loading branch information
apaszke authored and jax authors committed Aug 23, 2024
1 parent c767875 commit f54e220
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
2 changes: 2 additions & 0 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ class DynamicSlice:
def memref_slice(ref: ir.Value, index) -> ir.Value:
ref_ty = ir.MemRefType(ref.type)
base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape)
# TODO(apaszke): Check that slice is within the memref (indices might be
# dynamic, but we can at least catch some OOB slices).

memref_strides, offset = ref_ty.get_strides_and_offset()
new_offset = offset
Expand Down
22 changes: 18 additions & 4 deletions jax/experimental/mosaic/gpu/wgmma.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ def wgmma_m64(
out_ty = ir.VectorType(acc.flat[0].type).element_type
if not _supported_wgmma_types(out_ty, element_type):
raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}")
if n % 8:
raise ValueError

i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
index = ir.IndexType.get()
if b_k_stride % 16:
raise ValueError
if n % (swizzle // bytewidth(element_type)):
raise ValueError
# Only 16-bit types support transposes
supports_transpose = bytewidth(element_type) == 2
if not supports_transpose and (a_transpose or b_transpose):
Expand Down Expand Up @@ -326,7 +326,15 @@ def wgmma(
kn_tile = swizzle // element_bytewidth

groups_k, groups_n = b_ty.shape[:2]
if b_ty.shape[2:] != [kn_tile, kn_tile]:
k_group_size, n_group_size = (
b_ty.shape[2:] if b_order == WGMMALayout.ROW_MAJOR else b_ty.shape[:1:-1]
)
# Note that while this technically allows n to be smaller than kn_tile,
# the stride checks below will still enforce that the memory region is padded.
# It might be possible to relax that requirement, but I haven't tested it.
if n_group_size > kn_tile and n_group_size % kn_tile:
raise ValueError(n_group_size, kn_tile)
if k_group_size != kn_tile:
raise ValueError(b_ty.shape)

if a_in_regs:
Expand All @@ -353,6 +361,12 @@ def wgmma(
if a_order == WGMMALayout.COL_MAJOR and swizzle != 128:
# Not sure what the layout is like, since the tiles aren't square.
raise NotImplementedError
expected_acc_shape = (groups_m * 64, groups_n * n_group_size)
if acc.value.shape != expected_acc_shape:
raise ValueError(
f"Accumulator shape mismatch: expected {expected_acc_shape}, got"
f" {acc.value.shape}"
)

row_major = WGMMALayout.ROW_MAJOR
col_major = WGMMALayout.COL_MAJOR
Expand All @@ -375,7 +389,7 @@ def wgmma(
b_transpose=b_order == row_major,
a_k_stride=(2 if a_order == row_major else 128) << 4,
b_k_stride=(swizzle if b_order == row_major else 2) << 4,
n=(groups_n * kn_tile),
n=(groups_n * n_group_size),
swizzle=swizzle,
element_type=ir.FloatTF32Type.get()
if ir.F32Type.isinstance(element_type)
Expand Down
56 changes: 56 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,62 @@ def kernel(ctx, rhs, out, rhs_smem):
rtol = 5e-4
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)

@parameterized.product(
rhs_transpose=(False, True),
swizzle=(32, 64, 128),
)
def test_narrow_n(self, rhs_transpose, swizzle):
m, n, k_steps = 64, 8, 2

row_major = mgpu.WGMMALayout.ROW_MAJOR
col_major = mgpu.WGMMALayout.COL_MAJOR
rhs_order = col_major if rhs_transpose else row_major
bytewidth = 2
nk_tile = swizzle // bytewidth
k = nk_tile * k_steps

def kernel(ctx, rhs, out, smem):
rhs_smem, barrier = smem
gmem_slice = (ds(0, k), ds(0, nk_tile))
smem_slice = (slice(None), slice(None), slice(None), ds(0, n))
transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),)
if rhs_transpose:
gmem_slice = gmem_slice[::-1]
smem_slice = (slice(None), slice(None), ds(0, n), slice(None))
transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=swizzle,
gmem_slice=gmem_slice,
gmem_transform=transform,
barrier=barrier,
)
barrier.wait()
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
lhs_regs = iota_tensor(m, k, ir.F16Type.get())
rhs_smem = memref_slice(rhs_smem, smem_slice)
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)

jax_dtype = jnp.float16
y_shape = (n, k) if rhs_transpose else (k, n)
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
rhs_scratch_shape = jax.ShapeDtypeStruct(
(k_steps, 1, nk_tile, nk_tile), jax_dtype
)
z = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()),
)(y)
x = np.arange(m * k, dtype=jax_dtype).reshape(m, k)
ref = jax.lax.dot(
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
)
np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0)


class BarrierTest(TestCase):

Expand Down

0 comments on commit f54e220

Please sign in to comment.