Skip to content

Commit

Permalink
[Mosaic GPU] Add support for WGMMA lhs in registers for swizzles othe…
Browse files Browse the repository at this point in the history
…r than 128

PiperOrigin-RevId: 653626991
  • Loading branch information
apaszke authored and jax authors committed Jul 18, 2024
1 parent 47e6da3 commit a07b9ad
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
3 changes: 2 additions & 1 deletion jax/experimental/mosaic/gpu/wgmma.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def wgmma_m64(
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64):
# Column count must be equal to swizzle // bytewidth.
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, swizzle // 2):
raise ValueError("Unsupported A register array layout")
if a_k_stride is not None or a_transpose is not None:
raise ValueError("Unsupported WGMMA features with A in registers")
Expand Down
24 changes: 16 additions & 8 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,31 +613,39 @@ def quantize(x):
n=(64, 128, 192),
k_steps=(1, 2),
rhs_transpose=(False, True),
swizzle=(32, 64, 128),
mlir_dtype_cls=(ir.F16Type, ir.BF16Type),
)
def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, mlir_dtype_cls):
k = 64 * k_steps
def test_wgmma_reg_lhs(
self, m, n, k_steps, rhs_transpose, swizzle, mlir_dtype_cls
):
index = ir.IndexType.get()

row_major = mgpu.WGMMALayout.ROW_MAJOR
col_major = mgpu.WGMMALayout.COL_MAJOR
rhs_order = col_major if rhs_transpose else row_major
bytewidth = 2
nk_tile = swizzle // bytewidth
k = nk_tile * k_steps

def kernel(ctx, rhs, out, rhs_smem):
del ctx
for ki in range(k_steps):
for ni in range(n // 64):
rhs_slice = (ds(c(ki * 64, index), 64), ds(c(ni * 64, index), 64))
for ni in range(n // nk_tile):
rhs_slice = (
ds(c(ki * nk_tile, index), nk_tile),
ds(c(ni * nk_tile, index), nk_tile),
)
if rhs_transpose:
rhs_slice = rhs_slice[::-1]
copy(
src=memref_slice(rhs, rhs_slice),
dst=memref_slice(rhs_smem, (ki, ni)),
swizzle=128,
swizzle=swizzle,
)
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get())
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order)
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)
Expand All @@ -647,7 +655,7 @@ def kernel(ctx, rhs, out, rhs_smem):
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
scratch_shape = jax.ShapeDtypeStruct(
(k_steps, n // 64, 64, 64), jax_dtype
(k_steps, n // nk_tile, nk_tile, nk_tile), jax_dtype
)
z = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape
Expand All @@ -656,7 +664,7 @@ def kernel(ctx, rhs, out, rhs_smem):
ref = jax.lax.dot(
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
)
rtol = 0 if k_steps == 1 else 2.2e-4
rtol = 5e-4
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)


Expand Down

0 comments on commit a07b9ad

Please sign in to comment.