Skip to content

Commit

Permalink
Add option to dump intermediate artifacts for jax-triton compilation
Browse files Browse the repository at this point in the history
Users can set the environment variable `JAX_TRITON_DUMP_DIR` to dump the intermediate ttir, ttgir, llir and ptx generated during jax-triton compilation. The folder looks something like this:

```

config                                                                                    multi_head_attention_0d1d2d34d5d6d7c8d9d10d11c12d13d14d15c16d17d18d19c2021d.llir  multi_head_attention_0d1d2d34d5d6d7c8d9d10d11c12d13d14d15c16d17d18d19c2021d.ttgir
multi_head_attention_0d1d2d34d5d6d7c8d9d10d11c12d13d14d15c16d17d18d19c2021d.compile_info  multi_head_attention_0d1d2d34d5d6d7c8d9d10d11c12d13d14d15c16d17d18d19c2021d.ptx   multi_head_attention_0d1d2d34d5d6d7c8d9d10d11c12d13d14d15c16d17d18d19c2021d.ttir
```

PiperOrigin-RevId: 626420876
  • Loading branch information
The jax_triton Authors committed Apr 19, 2024
1 parent 1999d9b commit 7c08ae4
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 14 deletions.
79 changes: 66 additions & 13 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from __future__ import annotations

import copy
import dataclasses
import functools
import os
import pprint
import tempfile
import types
from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
Expand Down Expand Up @@ -64,6 +66,7 @@
)

os.environ["TRITON_CACHE_DIR"] = ""
_JAX_TRITON_DUMP_DIR = os.environ.get("JAX_TRITON_DUMP_DIR")
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

Expand Down Expand Up @@ -153,12 +156,22 @@ def aval_size_bytes(aval):
return np.dtype(aval.dtype).itemsize * aval.size


@dataclasses.dataclass
class PtxCompilationResult:
ptx: str
name: str
shared_mem_bytes: int
cluster_dims: tuple
ttgir: Optional[str]
llir: Optional[str]


def compile_ttir_to_ptx_inplace(
ttir,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
compute_capability,
) -> Tuple[str, str, int, Any]:
) -> PtxCompilationResult:
if cuda_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
Expand Down Expand Up @@ -211,7 +224,16 @@ def compile_ttir_to_ptx_inplace(
print(ptx)
name = metadata["name"]
cluster_dims = metadata["cluster_dims"]
return ptx, name, shared_mem_bytes, cluster_dims
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
return PtxCompilationResult(
ptx=ptx,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=ttgir,
llir=llir,
)


_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?
Expand Down Expand Up @@ -286,6 +308,12 @@ def get_or_create_triton_kernel(
enable_fp_fusion=enable_fp_fusion,
)
)
kernel_hash = abs(hash(cache_key))
if _JAX_TRITON_DUMP_DIR:
os.makedirs(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}")
with open(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/config", "w") as f:
pprint.pprint(cache_key, stream=f)
pprint.pprint(cuda_options, stream=f)

context = _triton.ir.context()
_triton.ir.load_dialects(context)
Expand All @@ -302,25 +330,50 @@ def get_or_create_triton_kernel(
options=cuda_options,
context=context,
)
ttir = str(module)

ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, cluster_dims = (
compile_ttir_to_ptx_inplace(
module,
cuda_backend,
cuda_options,
compute_capability,
)
compilation_result = compile_ttir_to_ptx_inplace(
module,
cuda_backend,
cuda_options,
compute_capability,
)
kernel_name = compilation_result.name
if _JAX_TRITON_DUMP_DIR:
with open(
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ttir", "w"
) as f:
f.write(ttir)
with open(
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ptx", "w"
) as f:
f.write(compilation_result.ptx)
with open(
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ttgir", "w"
) as f:
f.write(compilation_result.ttgir)
with open(
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.llir", "w"
) as f:
f.write(compilation_result.llir)
with open(
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.compile_info",
"w",
) as f:
f.write(
f"{kernel_name}: shared_mem_bytes:"
f" {compilation_result.shared_mem_bytes}, cluster_dims:"
f" {compilation_result.cluster_dims}\n"
)

kernel = triton_kernel_call_lib.TritonKernel(
kernel_name,
num_warps,
shared_mem_bytes,
ptx,
compilation_result.shared_mem_bytes,
compilation_result.ptx,
ttir,
compute_capability,
*cluster_dims,
*compilation_result.cluster_dims,
)

_COMPILED_KERNEL_CACHE[cache_key] = kernel
Expand Down
2 changes: 1 addition & 1 deletion tests/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_cluster(self, num_ctas):
def my_compile_ttir_to_ptx(*args, **kwargs):
nonlocal cluster_dims, original_compile_ttir_to_ptx_fn
ret_args = original_compile_ttir_to_ptx_fn(*args, **kwargs)
cluster_dims = ret_args[-1]
cluster_dims = ret_args.cluster_dims
return ret_args

my_triton_call = functools.partial(jt.triton_call, num_ctas=num_ctas)
Expand Down

0 comments on commit 7c08ae4

Please sign in to comment.