Skip to content

Commit

Permalink
Refactors triton_kernel_call_lowering to support both cuda and rocm.
Browse files Browse the repository at this point in the history
This is a rollforward of #287 with fixes.

PiperOrigin-RevId: 673490870
  • Loading branch information
pschuh authored and The jax_triton Authors committed Sep 11, 2024
1 parent e82c529 commit 285dba9
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 40 deletions.
206 changes: 169 additions & 37 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import types
from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
import zlib
from functools import partial

from absl import logging
import jax
Expand Down Expand Up @@ -56,6 +57,14 @@
CAN_USE_TRITON = True
except ModuleNotFoundError:
pass

try:
import triton.backends.amd.compiler as hb
except ImportError:
hb = None
pass


try:
from jax._src.lib import gpu_triton as triton_kernel_call_lib
except ImportError:
Expand Down Expand Up @@ -90,7 +99,6 @@
jnp.dtype("bool"): "B",
}


Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]]

Expand Down Expand Up @@ -157,22 +165,61 @@ def aval_size_bytes(aval):
return np.dtype(aval.dtype).itemsize * aval.size


def get_cuda_backend(device, compute_capability):
target = cb.GPUTarget('cuda', compute_capability, 32)
backend = cb.CUDABackend(target)
return backend

def get_hip_backend(device, compute_capability):
arch = triton_kernel_call_lib.get_arch_details(device)
arch = arch.split(":")[0]
target = hb.GPUTarget('hip', arch, 64)
backend = hb.HIPBackend(target)
return backend

@dataclasses.dataclass
class PtxCompilationResult:
ptx: str
class CompilationResult:
binary: str
name: str
shared_mem_bytes: int
cluster_dims: tuple
ttgir: Optional[str]
llir: Optional[str]

def compile_ttir_inplace(
ttir,
backend: [cb.CUDABackend | hb.HIPBackend],
options: [cb.CUDAOptions | hb.HIPOptions],
compute_capability,
platform
):
if platform == 'cuda':
return compile_ttir_to_ptx_inplace(
ttir,
backend,
options,
compute_capability,
)

elif platform == 'rocm':
return compile_ttir_to_hsaco_inplace(
ttir,
backend,
options,
compute_capability,
)
else:
raise ValueError(
"Unsupported device."
)


def compile_ttir_to_ptx_inplace(
ttir,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
compute_capability,
) -> PtxCompilationResult:
) -> CompilationResult:
if cuda_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
Expand All @@ -189,7 +236,7 @@ def compile_ttir_to_ptx_inplace(
ttir = tl_ir.parse_mlir_module(f.name, context)
ttir.context = context
try:
metadata = dict()
metadata = {}
opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options)
ttgir = cuda_backend.make_ttgir(
opt_ttir,
Expand Down Expand Up @@ -227,20 +274,95 @@ def compile_ttir_to_ptx_inplace(
cluster_dims = metadata["cluster_dims"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
return PtxCompilationResult(
ptx=ptx,
return CompilationResult(
binary=ptx,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=ttgir,
llir=llir,
)

def compile_ttir_to_hsaco_inplace(
ttir,
hip_backend: hb.HIPBackend,
hip_options: hb.HIPOptions,
compute_capability,
) -> CompilationResult:
if hip_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
context = _triton.ir.context()
_triton.ir.load_dialects(context)
hip_backend.load_dialects(context)

# Triton compilation APIs only accept Triton-specific MLIR wrappers.
# So, here we serialize an ir.Module to a file and then deserialize
# it as a tl_ir.module.
with tempfile.NamedTemporaryFile(mode="wb") as f:
ttir.operation.write_bytecode(f)
f.flush()
ttir = tl_ir.parse_mlir_module(f.name, context)
ttir.context = context
try:
metadata = {}
opt_ttir = hip_backend.make_ttir(ttir, metadata, hip_options)
ttgir = hip_backend.make_ttgir(
opt_ttir,
metadata,
hip_options
)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if hip_options.debug:
print(ttgir)
try:
llir = hip_backend.make_llir(
ttgir,
metadata,
hip_options
)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = metadata["shared"]
if hip_options.debug:
print(llir)

amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options)
hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options)

if hip_options.debug:
print(x)
name = metadata["name"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
# cluster dims are NOT useful on hip backend.
# We just fill up with some value for API compatibility
cluster_dims = (0, 0, 0)
# Instead of passing hsaco which are "bytes", we first write
# to a file and then pass the "string" path. This is needed because
# nanobind doesn't automatically convert between bytes and string.
# https://github.com/wjakob/nanobind/discussions/137
fd, hsaco_path = tempfile.mkstemp()
with os.fdopen(fd, "wb") as f:
f.write(hsaco)
return CompilationResult(
binary=hsaco_path,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=ttgir,
llir=llir,
)

_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?


def get_or_create_triton_kernel(
backend_init_func,
platform,
fn,
arg_dtypes,
scalar_args,
Expand All @@ -257,11 +379,11 @@ def get_or_create_triton_kernel(
num_warps = 4
if num_stages is None:
num_stages = 3
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
if compute_capability is None:
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_ctas > 1 and compute_capability < 90:
raise ValueError("num_ctas > 1 unsupported before Hopper.")
Expand Down Expand Up @@ -297,29 +419,29 @@ def get_or_create_triton_kernel(
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)

if kernel is None:
target = cb.GPUTarget('cuda', compute_capability, 32)
cuda_backend = cb.CUDABackend(target)
cuda_options = cuda_backend.parse_options(
dict(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
optimize_epilogue=False,
debug=dump,
enable_fp_fusion=enable_fp_fusion,
)
)
opts = {
"num_warps": num_warps,
"num_stages": num_stages,
"num_ctas": num_ctas,
"optimize_epilogue": False,
"debug": dump,
"enable_fp_fusion": enable_fp_fusion,
}

backend = backend_init_func(device, compute_capability)
options = backend.parse_options(opts)

kernel_hash = abs(hash(cache_key))
if _JAX_TRITON_DUMP_DIR:
os.makedirs(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}")
with open(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/config", "w") as f:
pprint.pprint(cache_key, stream=f)
pprint.pprint(cuda_options, stream=f)
pprint.pprint(options, stream=f)

context = _triton.ir.context()
_triton.ir.load_dialects(context)
cuda_backend.load_dialects(context)
codegen_fns = cuda_backend.get_codegen_implementation()
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()

module = (
code_gen.ast_to_ttir(
Expand All @@ -330,10 +452,10 @@ def get_or_create_triton_kernel(
signature=signature,
attrs=specialization_attr,
),
options=cuda_options,
options=options,
codegen_fns=codegen_fns,
context=context,
module_map=cuda_backend.get_module_map(),
module_map=backend.get_module_map(),
)
if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args
# Triton changes ASTSource.ast_to_ttir to include module_map. Handle
Expand All @@ -346,19 +468,21 @@ def get_or_create_triton_kernel(
signature=signature,
attrs=specialization_attr,
),
options=cuda_options,
options=options,
codegen_fns=codegen_fns,
context=context,
)
)
ttir = str(module)

compilation_result = compile_ttir_to_ptx_inplace(
module,
cuda_backend,
cuda_options,
compute_capability,
compilation_result = compile_ttir_inplace(
module,
backend,
options,
compute_capability,
platform
)

kernel_name = compilation_result.name
if _JAX_TRITON_DUMP_DIR:
with open(
Expand Down Expand Up @@ -391,7 +515,7 @@ def get_or_create_triton_kernel(
kernel_name,
num_warps,
compilation_result.shared_mem_bytes,
compilation_result.ptx,
compilation_result.binary,
ttir,
compute_capability,
*compilation_result.cluster_dims,
Expand All @@ -403,6 +527,7 @@ def get_or_create_triton_kernel(


def triton_kernel_call_lowering(
backend_init_func,
ctx,
*array_args,
fn,
Expand All @@ -427,6 +552,7 @@ def triton_kernel_call_lowering(
"`input_output_aliases` only supported on `jaxlib>=0.3.22"
)


kernel_call_name = name
args = list(ctx.avals_in)
arg_dtypes = list(map(get_triton_type, ctx.avals_in))
Expand Down Expand Up @@ -521,6 +647,8 @@ def prune_configs(configs, named_args, **kwargs):
kernel_calls = []
for params in config_params:
kernel, specialization_attr = get_or_create_triton_kernel(
backend_init_func,
ctx.module_context.platforms[0],
fn,
arg_dtypes,
scalar_args,
Expand Down Expand Up @@ -590,9 +718,13 @@ def prune_configs(configs, named_args, **kwargs):
operand_output_aliases=dict(input_output_aliases),
).results

mlir.register_lowering(triton_kernel_call_p,
partial(triton_kernel_call_lowering, get_cuda_backend),
platform='cuda')

mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)

mlir.register_lowering(triton_kernel_call_p,
partial(triton_kernel_call_lowering, get_hip_backend),
platform='rocm')

class ShapeDtype(Protocol):

Expand Down
Loading

0 comments on commit 285dba9

Please sign in to comment.