From e82c5293aa98b204746ca8c8cfdb368062b69dd7 Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Wed, 11 Sep 2024 11:39:22 -0700 Subject: [PATCH] Copybara import of the project: -- 1c012dc651617426b15fa6ba9f5be609e7f2e0f2 by Rahul Batra : ROCM updates PiperOrigin-RevId: 673474325 --- jax_triton/triton_lib.py | 198 +++++++------------------------------- tests/triton_call_test.py | 6 +- 2 files changed, 36 insertions(+), 168 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 0c2f9e5..bcdafb8 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -27,7 +27,6 @@ import types from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union import zlib -from functools import partial from absl import logging import jax @@ -57,14 +56,6 @@ CAN_USE_TRITON = True except ModuleNotFoundError: pass - -try: - import triton.backends.amd.compiler as hb -except ImportError: - hb = None - pass - - try: from jax._src.lib import gpu_triton as triton_kernel_call_lib except ImportError: @@ -99,6 +90,7 @@ jnp.dtype("bool"): "B", } + Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]] @@ -165,61 +157,22 @@ def aval_size_bytes(aval): return np.dtype(aval.dtype).itemsize * aval.size -def get_cuda_backend(device, compute_capability): - target = cb.GPUTarget('cuda', compute_capability, 32) - backend = cb.CUDABackend(target) - return backend - -def get_hip_backend(device, compute_capability): - arch = triton_kernel_call_lib.get_arch_details(device) - arch = arch.split(":")[0] - target = hb.GPUTarget('hip', arch, 64) - backend = hb.HIPBackend(target) - return backend - @dataclasses.dataclass -class CompilationResult: - binary: str +class PtxCompilationResult: + ptx: str name: str shared_mem_bytes: int cluster_dims: tuple ttgir: Optional[str] llir: Optional[str] -def compile_ttir_inplace( - ttir, - backend: [cb.CUDABackend | hb.HIPBackend], - options: [cb.CUDAOptions | hb.HIPOptions], - compute_capability, - platform -): - if platform == 'cuda': - return compile_ttir_to_ptx_inplace( - ttir, - backend, - options, - compute_capability, - ) - - elif platform == 'rocm': - return compile_ttir_to_hsaco_inplace( - ttir, - backend, - options, - compute_capability, - ) - else: - raise ValueError( - "Unsupported device." - ) - def compile_ttir_to_ptx_inplace( ttir, cuda_backend: cb.CUDABackend, cuda_options: cb.CUDAOptions, compute_capability, -) -> CompilationResult: +) -> PtxCompilationResult: if cuda_options.debug: print(ttir) if isinstance(ttir, ir.Module): @@ -236,7 +189,7 @@ def compile_ttir_to_ptx_inplace( ttir = tl_ir.parse_mlir_module(f.name, context) ttir.context = context try: - metadata = {} + metadata = dict() opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options) ttgir = cuda_backend.make_ttgir( opt_ttir, @@ -274,8 +227,8 @@ def compile_ttir_to_ptx_inplace( cluster_dims = metadata["cluster_dims"] ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None llir = str(llir) if _JAX_TRITON_DUMP_DIR else None - return CompilationResult( - binary=ptx, + return PtxCompilationResult( + ptx=ptx, name=name, shared_mem_bytes=shared_mem_bytes, cluster_dims=cluster_dims, @@ -283,86 +236,11 @@ def compile_ttir_to_ptx_inplace( llir=llir, ) -def compile_ttir_to_hsaco_inplace( - ttir, - hip_backend: hb.HIPBackend, - hip_options: hb.HIPOptions, - compute_capability, -) -> CompilationResult: - if hip_options.debug: - print(ttir) - if isinstance(ttir, ir.Module): - context = _triton.ir.context() - _triton.ir.load_dialects(context) - hip_backend.load_dialects(context) - - # Triton compilation APIs only accept Triton-specific MLIR wrappers. - # So, here we serialize an ir.Module to a file and then deserialize - # it as a tl_ir.module. - with tempfile.NamedTemporaryFile(mode="wb") as f: - ttir.operation.write_bytecode(f) - f.flush() - ttir = tl_ir.parse_mlir_module(f.name, context) - ttir.context = context - try: - metadata = {} - opt_ttir = hip_backend.make_ttir(ttir, metadata, hip_options) - ttgir = hip_backend.make_ttgir( - opt_ttir, - metadata, - hip_options - ) - except RuntimeError as e: - ttir.dump() - raise ValueError("TTIR->TTGIR pass failed!") from e - if hip_options.debug: - print(ttgir) - try: - llir = hip_backend.make_llir( - ttgir, - metadata, - hip_options - ) - except RuntimeError as e: - ttgir.dump() - raise ValueError("TTGIR->LLIR pass failed!") from e - shared_mem_bytes = metadata["shared"] - if hip_options.debug: - print(llir) - - amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options) - hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options) - - if hip_options.debug: - print(x) - name = metadata["name"] - ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None - llir = str(llir) if _JAX_TRITON_DUMP_DIR else None - # cluster dims are NOT useful on hip backend. - # We just fill up with some value for API compatibility - cluster_dims = (0, 0, 0) - # Instead of passing hsaco which are "bytes", we first write - # to a file and then pass the "string" path. This is needed because - # nanobind doesn't automatically convert between bytes and string. - # https://github.com/wjakob/nanobind/discussions/137 - fd, hsaco_path = tempfile.mkstemp() - with os.fdopen(fd, "wb") as f: - f.write(hsaco) - return CompilationResult( - binary=hsaco_path, - name=name, - shared_mem_bytes=shared_mem_bytes, - cluster_dims=cluster_dims, - ttgir=ttgir, - llir=llir, - ) _COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache? def get_or_create_triton_kernel( - backend_init_func, - platform, fn, arg_dtypes, scalar_args, @@ -419,29 +297,29 @@ def get_or_create_triton_kernel( kernel = _COMPILED_KERNEL_CACHE.get(cache_key) if kernel is None: - opts = { - "num_warps": num_warps, - "num_stages": num_stages, - "num_ctas": num_ctas, - "optimize_epilogue": False, - "debug": dump, - "enable_fp_fusion": enable_fp_fusion, - } - - backend = backend_init_func(device, compute_capability) - options = backend.parse_options(opts) - + target = cb.GPUTarget('cuda', compute_capability, 32) + cuda_backend = cb.CUDABackend(target) + cuda_options = cuda_backend.parse_options( + dict( + num_warps=num_warps, + num_stages=num_stages, + num_ctas=num_ctas, + optimize_epilogue=False, + debug=dump, + enable_fp_fusion=enable_fp_fusion, + ) + ) kernel_hash = abs(hash(cache_key)) if _JAX_TRITON_DUMP_DIR: os.makedirs(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}") with open(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/config", "w") as f: pprint.pprint(cache_key, stream=f) - pprint.pprint(options, stream=f) + pprint.pprint(cuda_options, stream=f) context = _triton.ir.context() _triton.ir.load_dialects(context) - backend.load_dialects(context) - codegen_fns = backend.get_codegen_implementation() + cuda_backend.load_dialects(context) + codegen_fns = cuda_backend.get_codegen_implementation() module = ( code_gen.ast_to_ttir( @@ -452,10 +330,10 @@ def get_or_create_triton_kernel( signature=signature, attrs=specialization_attr, ), - options=options, + options=cuda_options, codegen_fns=codegen_fns, context=context, - module_map=backend.get_module_map(), + module_map=cuda_backend.get_module_map(), ) if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args # Triton changes ASTSource.ast_to_ttir to include module_map. Handle @@ -468,21 +346,19 @@ def get_or_create_triton_kernel( signature=signature, attrs=specialization_attr, ), - options=options, + options=cuda_options, codegen_fns=codegen_fns, context=context, ) ) ttir = str(module) - compilation_result = compile_ttir_inplace( - module, - backend, - options, - compute_capability, - platform + compilation_result = compile_ttir_to_ptx_inplace( + module, + cuda_backend, + cuda_options, + compute_capability, ) - kernel_name = compilation_result.name if _JAX_TRITON_DUMP_DIR: with open( @@ -515,7 +391,7 @@ def get_or_create_triton_kernel( kernel_name, num_warps, compilation_result.shared_mem_bytes, - compilation_result.binary, + compilation_result.ptx, ttir, compute_capability, *compilation_result.cluster_dims, @@ -527,7 +403,6 @@ def get_or_create_triton_kernel( def triton_kernel_call_lowering( - backend_init_func, ctx, *array_args, fn, @@ -552,7 +427,6 @@ def triton_kernel_call_lowering( "`input_output_aliases` only supported on `jaxlib>=0.3.22" ) - kernel_call_name = name args = list(ctx.avals_in) arg_dtypes = list(map(get_triton_type, ctx.avals_in)) @@ -647,8 +521,6 @@ def prune_configs(configs, named_args, **kwargs): kernel_calls = [] for params in config_params: kernel, specialization_attr = get_or_create_triton_kernel( - backend_init_func, - ctx.module_context.platforms[0], fn, arg_dtypes, scalar_args, @@ -718,13 +590,9 @@ def prune_configs(configs, named_args, **kwargs): operand_output_aliases=dict(input_output_aliases), ).results -mlir.register_lowering(triton_kernel_call_p, - partial(triton_kernel_call_lowering, get_cuda_backend), - platform='cuda') -mlir.register_lowering(triton_kernel_call_p, - partial(triton_kernel_call_lowering, get_hip_backend), - platform='rocm') +mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering) + class ShapeDtype(Protocol): diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index c8091a2..b921cf8 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -328,16 +328,16 @@ def test_kernel_cache_equivalent_kernels(self): x1, y1 = create_random_inputs([42]) x2, y2 = create_random_inputs([43]) - compile_ttir_inplace = jt.triton_lib.compile_ttir_inplace + compile_ttir_to_ptx_inplace = jt.triton_lib.compile_ttir_to_ptx_inplace call_count = [0] def my_compile(*args, **kwargs): call_count[0] += 1 - return compile_ttir_inplace(*args, **kwargs) + return compile_ttir_to_ptx_inplace(*args, **kwargs) with mock.patch.object( - jt.triton_lib, "compile_ttir_inplace", new=my_compile + jt.triton_lib, "compile_ttir_to_ptx_inplace", new=my_compile ): _ = fn1(x1, y1) self.assertEqual(call_count[0], 1)