diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index bcdafb8..0784ca1 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -27,6 +27,7 @@ import types from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union import zlib +from functools import partial from absl import logging import jax @@ -56,6 +57,14 @@ CAN_USE_TRITON = True except ModuleNotFoundError: pass + +try: + import triton.backends.amd.compiler as hb +except ImportError: + hb = None + pass + + try: from jax._src.lib import gpu_triton as triton_kernel_call_lib except ImportError: @@ -90,7 +99,6 @@ jnp.dtype("bool"): "B", } - Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]] @@ -157,22 +165,61 @@ def aval_size_bytes(aval): return np.dtype(aval.dtype).itemsize * aval.size +def get_cuda_backend(device, compute_capability): + target = cb.GPUTarget('cuda', compute_capability, 32) + backend = cb.CUDABackend(target) + return backend + +def get_hip_backend(device, compute_capability): + arch = triton_kernel_call_lib.get_arch_details(device) + arch = arch.split(":")[0] + target = hb.GPUTarget('hip', arch, 64) + backend = hb.HIPBackend(target) + return backend + @dataclasses.dataclass -class PtxCompilationResult: - ptx: str +class CompilationResult: + binary: str name: str shared_mem_bytes: int cluster_dims: tuple ttgir: Optional[str] llir: Optional[str] +def compile_ttir_inplace( + ttir, + backend: [cb.CUDABackend | hb.HIPBackend], + options: [cb.CUDAOptions | hb.HIPOptions], + compute_capability, + platform +): + if platform == 'cuda': + return compile_ttir_to_ptx_inplace( + ttir, + backend, + options, + compute_capability, + ) + + elif platform == 'rocm': + return compile_ttir_to_hsaco_inplace( + ttir, + backend, + options, + compute_capability, + ) + else: + raise ValueError( + "Unsupported device." + ) + def compile_ttir_to_ptx_inplace( ttir, cuda_backend: cb.CUDABackend, cuda_options: cb.CUDAOptions, compute_capability, -) -> PtxCompilationResult: +) -> CompilationResult: if cuda_options.debug: print(ttir) if isinstance(ttir, ir.Module): @@ -189,7 +236,7 @@ def compile_ttir_to_ptx_inplace( ttir = tl_ir.parse_mlir_module(f.name, context) ttir.context = context try: - metadata = dict() + metadata = {} opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options) ttgir = cuda_backend.make_ttgir( opt_ttir, @@ -227,8 +274,8 @@ def compile_ttir_to_ptx_inplace( cluster_dims = metadata["cluster_dims"] ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None llir = str(llir) if _JAX_TRITON_DUMP_DIR else None - return PtxCompilationResult( - ptx=ptx, + return CompilationResult( + binary=ptx, name=name, shared_mem_bytes=shared_mem_bytes, cluster_dims=cluster_dims, @@ -236,11 +283,86 @@ def compile_ttir_to_ptx_inplace( llir=llir, ) +def compile_ttir_to_hsaco_inplace( + ttir, + hip_backend: hb.HIPBackend, + hip_options: hb.HIPOptions, + compute_capability, +) -> CompilationResult: + if hip_options.debug: + print(ttir) + if isinstance(ttir, ir.Module): + context = _triton.ir.context() + _triton.ir.load_dialects(context) + hip_backend.load_dialects(context) + + # Triton compilation APIs only accept Triton-specific MLIR wrappers. + # So, here we serialize an ir.Module to a file and then deserialize + # it as a tl_ir.module. + with tempfile.NamedTemporaryFile(mode="wb") as f: + ttir.operation.write_bytecode(f) + f.flush() + ttir = tl_ir.parse_mlir_module(f.name, context) + ttir.context = context + try: + metadata = {} + opt_ttir = hip_backend.make_ttir(ttir, metadata, hip_options) + ttgir = hip_backend.make_ttgir( + opt_ttir, + metadata, + hip_options + ) + except RuntimeError as e: + ttir.dump() + raise ValueError("TTIR->TTGIR pass failed!") from e + if hip_options.debug: + print(ttgir) + try: + llir = hip_backend.make_llir( + ttgir, + metadata, + hip_options + ) + except RuntimeError as e: + ttgir.dump() + raise ValueError("TTGIR->LLIR pass failed!") from e + shared_mem_bytes = metadata["shared"] + if hip_options.debug: + print(llir) + + amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options) + hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options) + + if hip_options.debug: + print(x) + name = metadata["name"] + ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None + llir = str(llir) if _JAX_TRITON_DUMP_DIR else None + # cluster dims are NOT useful on hip backend. + # We just fill up with some value for API compatibility + cluster_dims = (0, 0, 0) + # Instead of passing hsaco which are "bytes", we first write + # to a file and then pass the "string" path. This is needed because + # nanobind doesn't automatically convert between bytes and string. + # https://github.com/wjakob/nanobind/discussions/137 + fd, hsaco_path = tempfile.mkstemp() + with os.fdopen(fd, "wb") as f: + f.write(hsaco) + return CompilationResult( + binary=hsaco_path, + name=name, + shared_mem_bytes=shared_mem_bytes, + cluster_dims=cluster_dims, + ttgir=ttgir, + llir=llir, + ) _COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache? def get_or_create_triton_kernel( + backend_init_func, + platform, fn, arg_dtypes, scalar_args, @@ -257,11 +379,11 @@ def get_or_create_triton_kernel( num_warps = 4 if num_stages is None: num_stages = 3 + # TODO(sharadmv): handle multiple devices, right now we assume device 0 + # which is fine when we have multiple of the same GPU but this won't work in + # general. + device = 0 if compute_capability is None: - # TODO(sharadmv): handle multiple devices, right now we assume device 0 - # which is fine when we have multiple of the same GPU but this won't work in - # general. - device = 0 compute_capability = triton_kernel_call_lib.get_compute_capability(device) if num_ctas > 1 and compute_capability < 90: raise ValueError("num_ctas > 1 unsupported before Hopper.") @@ -297,29 +419,29 @@ def get_or_create_triton_kernel( kernel = _COMPILED_KERNEL_CACHE.get(cache_key) if kernel is None: - target = cb.GPUTarget('cuda', compute_capability, 32) - cuda_backend = cb.CUDABackend(target) - cuda_options = cuda_backend.parse_options( - dict( - num_warps=num_warps, - num_stages=num_stages, - num_ctas=num_ctas, - optimize_epilogue=False, - debug=dump, - enable_fp_fusion=enable_fp_fusion, - ) - ) + opts = { + "num_warps": num_warps, + "num_stages": num_stages, + "num_ctas": num_ctas, + "optimize_epilogue": False, + "debug": dump, + "enable_fp_fusion": enable_fp_fusion, + } + + backend = backend_init_func(device, compute_capability) + options = backend.parse_options(opts) + kernel_hash = abs(hash(cache_key)) if _JAX_TRITON_DUMP_DIR: os.makedirs(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}") with open(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/config", "w") as f: pprint.pprint(cache_key, stream=f) - pprint.pprint(cuda_options, stream=f) + pprint.pprint(options, stream=f) context = _triton.ir.context() _triton.ir.load_dialects(context) - cuda_backend.load_dialects(context) - codegen_fns = cuda_backend.get_codegen_implementation() + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() module = ( code_gen.ast_to_ttir( @@ -330,10 +452,10 @@ def get_or_create_triton_kernel( signature=signature, attrs=specialization_attr, ), - options=cuda_options, + options=options, codegen_fns=codegen_fns, context=context, - module_map=cuda_backend.get_module_map(), + module_map=backend.get_module_map(), ) if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args # Triton changes ASTSource.ast_to_ttir to include module_map. Handle @@ -346,19 +468,21 @@ def get_or_create_triton_kernel( signature=signature, attrs=specialization_attr, ), - options=cuda_options, + options=options, codegen_fns=codegen_fns, context=context, ) ) ttir = str(module) - compilation_result = compile_ttir_to_ptx_inplace( - module, - cuda_backend, - cuda_options, - compute_capability, + compilation_result = compile_ttir_inplace( + module, + backend, + options, + compute_capability, + platform ) + kernel_name = compilation_result.name if _JAX_TRITON_DUMP_DIR: with open( @@ -391,7 +515,7 @@ def get_or_create_triton_kernel( kernel_name, num_warps, compilation_result.shared_mem_bytes, - compilation_result.ptx, + compilation_result.binary, ttir, compute_capability, *compilation_result.cluster_dims, @@ -403,6 +527,7 @@ def get_or_create_triton_kernel( def triton_kernel_call_lowering( + backend_init_func, ctx, *array_args, fn, @@ -427,6 +552,7 @@ def triton_kernel_call_lowering( "`input_output_aliases` only supported on `jaxlib>=0.3.22" ) + kernel_call_name = name args = list(ctx.avals_in) arg_dtypes = list(map(get_triton_type, ctx.avals_in)) @@ -521,6 +647,8 @@ def prune_configs(configs, named_args, **kwargs): kernel_calls = [] for params in config_params: kernel, specialization_attr = get_or_create_triton_kernel( + backend_init_func, + ctx.module_context.platforms[0], fn, arg_dtypes, scalar_args, @@ -590,9 +718,13 @@ def prune_configs(configs, named_args, **kwargs): operand_output_aliases=dict(input_output_aliases), ).results +mlir.register_lowering(triton_kernel_call_p, + partial(triton_kernel_call_lowering, get_cuda_backend), + platform='cuda') -mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering) - +mlir.register_lowering(triton_kernel_call_p, + partial(triton_kernel_call_lowering, get_hip_backend), + platform='rocm') class ShapeDtype(Protocol): diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index b921cf8..098e6b5 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -263,6 +263,26 @@ def add_scalar(x, y): x = jnp.array([1.0]) np.testing.assert_allclose(add_scalar(x, scalar), x + scalar) + def test_explicit_compute_capability(self): + scalar = np.float32(8) + + @triton.jit + def add_scalar_kernel(x_ptr, y, output_ptr): + tl.store(output_ptr, tl.load(x_ptr) + y) + + def add_scalar(x, y): + return jt.triton_call( + x, + y, + kernel=add_scalar_kernel, + compute_capability=jt.get_compute_capability(0), + out_shape=jax.ShapeDtypeStruct((), x.dtype), + grid=1, + ) + + x = jnp.array([1.0]) + np.testing.assert_allclose(add_scalar(x, scalar), x + scalar) + def test_input_output_aliasing(self): @triton.jit def add_inplace_kernel(_, n_elements, output_ptr, BLOCK_SIZE: tl.constexpr): @@ -328,16 +348,16 @@ def test_kernel_cache_equivalent_kernels(self): x1, y1 = create_random_inputs([42]) x2, y2 = create_random_inputs([43]) - compile_ttir_to_ptx_inplace = jt.triton_lib.compile_ttir_to_ptx_inplace + compile_ttir_inplace = jt.triton_lib.compile_ttir_inplace call_count = [0] def my_compile(*args, **kwargs): call_count[0] += 1 - return compile_ttir_to_ptx_inplace(*args, **kwargs) + return compile_ttir_inplace(*args, **kwargs) with mock.patch.object( - jt.triton_lib, "compile_ttir_to_ptx_inplace", new=my_compile + jt.triton_lib, "compile_ttir_inplace", new=my_compile ): _ = fn1(x1, y1) self.assertEqual(call_count[0], 1)