diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 51b1e2041b4f..53cb270b3cdc 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -15,9 +15,9 @@ """Matmul kernels for H100.""" import dataclasses -import functools -from typing import Any +import itertools import math +from typing import Any import jax from jax import random @@ -115,7 +115,8 @@ def build_kernel( tile_m: int = 128, tile_n: int = 128, swizzle: int = 128, - cluster: tuple[int, int] = (1, 1), + cluster_m: int = 1, + cluster_n: int = 1, rhs_transpose: bool = False, wgmma_impl=WGMMADefaultImpl, profiler_spec: profiler.ProfilerSpec | None = None, @@ -304,10 +305,10 @@ def stage_loop_body(ki, accs): ClusterBarrier( collective_dims=(gpu.Dimension.x, gpu.Dimension.y), num_barriers=stages, - ) if math.prod(cluster) > 1 else None, + ) if cluster_m * cluster_n > 1 else None, ), profiler_spec, - cluster=(*cluster, 1), + cluster=(cluster_n, cluster_m, 1), ) @@ -339,7 +340,8 @@ def verify( stages=stages, tile_m=tile_m, tile_n=tile_n, - cluster=(cluster_m, cluster_n), + cluster_m=cluster_m, + cluster_n=cluster_n, rhs_transpose=rhs_transpose, swizzle=swizzle, wgmma_impl=WGMMADefaultImpl, @@ -375,9 +377,54 @@ def ref_f(x, y): if __name__ == "__main__": - m, k, n = 4 * 33 * 128, 2048, 4 * 128 - runtime, ref_runtime = verify(m=m, k=k, n=n, cluster_m=1, cluster_n=4) + dtype = jnp.dtype(jnp.float16) + m, k, n = 16384, 2048, 16384 + + kx, ky = random.split(random.key(1234)) + x = random.uniform(kx, (m, k), dtype=dtype) + y = random.uniform(ky, (k, n), dtype=dtype) + + tile_m = tile_n = (64, 128, 256) + cluster_m = cluster_n = (1, 2) + swizzle = (128,) + stages = (2, 4, 5, 6) + configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle) + names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle") + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if kwargs["cluster_m"] * kwargs["cluster_n"] > 8: + continue + if m < kwargs["tile_m"] or n < kwargs["tile_n"]: + continue + if (m // kwargs["tile_m"]) % kwargs["cluster_n"]: + continue + if (n // kwargs["tile_n"]) % kwargs["cluster_m"]: + continue + try: + f = build_kernel( + m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs + ) + _, runtime = profiler.measure(f, x, y) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + # else: + # print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) + if runtime < best_runtime: + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + runtime, ref_runtime = verify( + m=m, k=k, n=n, in_dtype=dtype, out_dtype=dtype, **best_kwargs + ) tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print("Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())) print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")