Skip to content

Commit

Permalink
[Mosaic GPU] Add control over the output format in the matmul example
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662478648
  • Loading branch information
apaszke authored and jax authors committed Aug 13, 2024
1 parent 5cf89b3 commit f4c0b1f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
47 changes: 29 additions & 18 deletions jax/experimental/mosaic/gpu/examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def wrap(*args, **kw):
@mlir_context
def build_kernel(
m, n, k,
lhs_dtype, rhs_dtype,
lhs_dtype, rhs_dtype, out_dtype,
stages: int = 2,
tile_m: int = 128,
tile_n: int = 128,
Expand All @@ -134,15 +134,20 @@ def build_kernel(
if swizzle not in {32, 64, 128}:
raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}")

if tile_n % 32 == 0:
out_swizzle = 128
elif tile_n % 16 == 0:
out_swizzle = 64
else:
raise NotImplementedError(f"{tile_n=} must by divisible by 16")
out_swizzle_elems = out_swizzle // bytewidth(f32)
out_mlir_dtype = mlir.dtype_to_ir_type(out_dtype)
out_swizzle = swizzle
if bytewidth(out_mlir_dtype) == 4:
if tile_n % 32 == 0:
out_swizzle = 128
elif tile_n % 16 == 0:
out_swizzle = 64
else:
raise NotImplementedError(
f"{tile_n=} must by divisible by 16 for 32-bit output"
)
out_swizzle_elems = out_swizzle // bytewidth(out_mlir_dtype)
out_tiling = (64, out_swizzle_elems)
out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), jnp.float32)
out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), out_dtype)

lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype))
rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype))
Expand Down Expand Up @@ -271,7 +276,7 @@ def stage_loop_body(ki, accs):

with ctx.named_region("SMEM store"):
acc_val = wgmma_impl.get_result(stage_loop_body.result)
acc_val.store_tiled(epilogue_smem, swizzle=out_swizzle)
acc_val.astype(out_mlir_dtype).store_tiled(epilogue_smem, swizzle=out_swizzle)
commit_shared() # Make sure the stores are visible to TMA.

with ctx.named_region("GMEM store"):
Expand All @@ -292,7 +297,7 @@ def stage_loop_body(ki, accs):
jax.ShapeDtypeStruct((m, k), lhs_dtype),
jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype),
),
jax.ShapeDtypeStruct((m, n), jnp.float32),
jax.ShapeDtypeStruct((m, n), out_dtype),
(
smem_shape,
TMABarrier(num_barriers=stages),
Expand All @@ -318,6 +323,7 @@ def verify(
swizzle=128,
profile=False,
in_dtype=jnp.float16,
out_dtype=jnp.float32,
rhs_transpose=False,
):
lhs_dtype, rhs_dtype = in_dtype, in_dtype
Expand All @@ -329,7 +335,7 @@ def verify(
prof_spec = profiler.ProfilerSpec(4096) if profile else None
f = build_kernel(
m, n, k,
jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype),
jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype), jnp.dtype(out_dtype),
stages=stages,
tile_m=tile_m,
tile_n=tile_n,
Expand All @@ -352,14 +358,19 @@ def verify(
for v in (x, y)
)

ref_f = functools.partial(
jax.lax.dot_general,
dimension_numbers=dimension_numbers,
preferred_element_type=jnp.float32,
)
@jax.jit
def ref_f(x, y):
return jax.lax.dot_general(
x,
y,
dimension_numbers=dimension_numbers,
preferred_element_type=jnp.float32,
).astype(out_dtype)

ref, ref_runtime = profiler.measure(ref_f, x, y)
np.testing.assert_allclose(z, ref, atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(
z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
)
return runtime, ref_runtime


Expand Down
9 changes: 8 additions & 1 deletion tests/mosaic/matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ def setUp(self):
def test_matmul(self, data):
in_dtype = data.draw(
hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]),
label="dtype",
label="in_dtype",
)
out_dtype = jnp.float32
if in_dtype != jnp.float32:
out_dtype = data.draw(
hps.sampled_from([in_dtype, jnp.float32]),
label="out_dtype",
)
bytewidth = jnp.dtype(in_dtype).itemsize
m, n, k = (
data.draw(hps.sampled_from([128, 256, 512, 2048]), label=d)
Expand Down Expand Up @@ -102,6 +108,7 @@ def test_matmul(self, data):
tile_m=tile_m,
tile_n=tile_n,
in_dtype=in_dtype,
out_dtype=out_dtype,
cluster_m=cluster_m,
cluster_n=cluster_n,
swizzle=swizzle,
Expand Down

0 comments on commit f4c0b1f

Please sign in to comment.