diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 78ffaa7f..d3d272f6 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -18,8 +18,10 @@ from __future__ import annotations import copy +import dataclasses import functools import os +import pprint import tempfile import types from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union @@ -64,6 +66,7 @@ ) os.environ["TRITON_CACHE_DIR"] = "" +_JAX_TRITON_DUMP_DIR = os.environ.get("JAX_TRITON_DUMP_DIR") map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -153,12 +156,22 @@ def aval_size_bytes(aval): return np.dtype(aval.dtype).itemsize * aval.size +@dataclasses.dataclass +class PtxCompilationResult: + ptx: str + name: str + shared_mem_bytes: int + cluster_dims: tuple + ttgir: Optional[str] + llir: Optional[str] + + def compile_ttir_to_ptx_inplace( ttir, cuda_backend: cb.CUDABackend, cuda_options: cb.CUDAOptions, compute_capability, -) -> Tuple[str, str, int, Any]: +) -> PtxCompilationResult: if cuda_options.debug: print(ttir) if isinstance(ttir, ir.Module): @@ -211,7 +224,16 @@ def compile_ttir_to_ptx_inplace( print(ptx) name = metadata["name"] cluster_dims = metadata["cluster_dims"] - return ptx, name, shared_mem_bytes, cluster_dims + ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None + llir = str(llir) if _JAX_TRITON_DUMP_DIR else None + return PtxCompilationResult( + ptx=ptx, + name=name, + shared_mem_bytes=shared_mem_bytes, + cluster_dims=cluster_dims, + ttgir=ttgir, + llir=llir, + ) _COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache? @@ -286,6 +308,12 @@ def get_or_create_triton_kernel( enable_fp_fusion=enable_fp_fusion, ) ) + kernel_hash = abs(hash(cache_key)) + if _JAX_TRITON_DUMP_DIR: + os.makedirs(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}") + with open(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/config", "w") as f: + pprint.pprint(cache_key, stream=f) + pprint.pprint(cuda_options, stream=f) context = _triton.ir.context() _triton.ir.load_dialects(context) @@ -302,25 +330,50 @@ def get_or_create_triton_kernel( options=cuda_options, context=context, ) + ttir = str(module) - ttir = str(module) # `module`` is compiled in-place, so copy TTIR here. - ptx, kernel_name, shared_mem_bytes, cluster_dims = ( - compile_ttir_to_ptx_inplace( - module, - cuda_backend, - cuda_options, - compute_capability, - ) + compilation_result = compile_ttir_to_ptx_inplace( + module, + cuda_backend, + cuda_options, + compute_capability, ) + kernel_name = compilation_result.name + if _JAX_TRITON_DUMP_DIR: + with open( + f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ttir", "w" + ) as f: + f.write(ttir) + with open( + f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ptx", "w" + ) as f: + f.write(compilation_result.ptx) + with open( + f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ttgir", "w" + ) as f: + f.write(compilation_result.ttgir) + with open( + f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.llir", "w" + ) as f: + f.write(compilation_result.llir) + with open( + f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.compile_info", + "w", + ) as f: + f.write( + f"{kernel_name}: shared_mem_bytes:" + f" {compilation_result.shared_mem_bytes}, cluster_dims:" + f" {compilation_result.cluster_dims}\n" + ) kernel = triton_kernel_call_lib.TritonKernel( kernel_name, num_warps, - shared_mem_bytes, - ptx, + compilation_result.shared_mem_bytes, + compilation_result.ptx, ttir, compute_capability, - *cluster_dims, + *compilation_result.cluster_dims, ) _COMPILED_KERNEL_CACHE[cache_key] = kernel diff --git a/tests/cluster_test.py b/tests/cluster_test.py index 19a8f15a..79f0e9ce 100644 --- a/tests/cluster_test.py +++ b/tests/cluster_test.py @@ -51,7 +51,7 @@ def test_cluster(self, num_ctas): def my_compile_ttir_to_ptx(*args, **kwargs): nonlocal cluster_dims, original_compile_ttir_to_ptx_fn ret_args = original_compile_ttir_to_ptx_fn(*args, **kwargs) - cluster_dims = ret_args[-1] + cluster_dims = ret_args.cluster_dims return ret_args my_triton_call = functools.partial(jt.triton_call, num_ctas=num_ctas)