Skip to content

Commit

Permalink
[Mosaic GPU] Add an autotuning harness to the matmul example
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662521895
  • Loading branch information
apaszke authored and jax authors committed Aug 13, 2024
1 parent f4c0b1f commit bab096e
Showing 1 changed file with 55 additions and 8 deletions.
63 changes: 55 additions & 8 deletions jax/experimental/mosaic/gpu/examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"""Matmul kernels for H100."""

import dataclasses
import functools
from typing import Any
import itertools
import math
from typing import Any

import jax
from jax import random
Expand Down Expand Up @@ -115,7 +115,8 @@ def build_kernel(
tile_m: int = 128,
tile_n: int = 128,
swizzle: int = 128,
cluster: tuple[int, int] = (1, 1),
cluster_m: int = 1,
cluster_n: int = 1,
rhs_transpose: bool = False,
wgmma_impl=WGMMADefaultImpl,
profiler_spec: profiler.ProfilerSpec | None = None,
Expand Down Expand Up @@ -304,10 +305,10 @@ def stage_loop_body(ki, accs):
ClusterBarrier(
collective_dims=(gpu.Dimension.x, gpu.Dimension.y),
num_barriers=stages,
) if math.prod(cluster) > 1 else None,
) if cluster_m * cluster_n > 1 else None,
),
profiler_spec,
cluster=(*cluster, 1),
cluster=(cluster_n, cluster_m, 1),
)


Expand Down Expand Up @@ -339,7 +340,8 @@ def verify(
stages=stages,
tile_m=tile_m,
tile_n=tile_n,
cluster=(cluster_m, cluster_n),
cluster_m=cluster_m,
cluster_n=cluster_n,
rhs_transpose=rhs_transpose,
swizzle=swizzle,
wgmma_impl=WGMMADefaultImpl,
Expand Down Expand Up @@ -375,9 +377,54 @@ def ref_f(x, y):


if __name__ == "__main__":
m, k, n = 4 * 33 * 128, 2048, 4 * 128
runtime, ref_runtime = verify(m=m, k=k, n=n, cluster_m=1, cluster_n=4)
dtype = jnp.dtype(jnp.float16)
m, k, n = 16384, 2048, 16384

kx, ky = random.split(random.key(1234))
x = random.uniform(kx, (m, k), dtype=dtype)
y = random.uniform(ky, (k, n), dtype=dtype)

tile_m = tile_n = (64, 128, 256)
cluster_m = cluster_n = (1, 2)
swizzle = (128,)
stages = (2, 4, 5, 6)
configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle)
names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle")
best_runtime = float("inf")
best_kwargs = {}
for config in configs:
kwargs = dict(zip(names, config))
if kwargs["cluster_m"] * kwargs["cluster_n"] > 8:
continue
if m < kwargs["tile_m"] or n < kwargs["tile_n"]:
continue
if (m // kwargs["tile_m"]) % kwargs["cluster_n"]:
continue
if (n // kwargs["tile_n"]) % kwargs["cluster_m"]:
continue
try:
f = build_kernel(
m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs
)
_, runtime = profiler.measure(f, x, y)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
runtime = float("inf")
# Enable this to get more detailed information.
# else:
# print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
if runtime < best_runtime:
best_runtime = runtime
best_kwargs = kwargs
if not best_kwargs:
raise ValueError("No valid configuration found")

runtime, ref_runtime = verify(
m=m, k=k, n=n, in_dtype=dtype, out_dtype=dtype, **best_kwargs
)
tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
print("Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items()))
print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")

0 comments on commit bab096e

Please sign in to comment.