From b2ae82824def14fdcb780ece7a99551f711f2ed4 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Fri, 12 Jan 2024 08:34:51 -0800 Subject: [PATCH] Import openai/triton from GitHub. PiperOrigin-RevId: 597848585 --- jax_triton/triton_lib.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 8ef12e4a..4ef52c42 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -186,6 +186,7 @@ def get_arch_default_num_stages(device_type, capability): def compile_ttir_to_ptx_inplace( ttir, + tl_context: tl_ir.Context, cuda_backend: cb.CUDABackend, cuda_options: cb.CUDAOptions, device: int = 0, @@ -198,8 +199,6 @@ def compile_ttir_to_ptx_inplace( # Triton compilation APIs only accept Triton-specific MLIR wrappers. # So, here we serialize an ir.Module to a file and then deserialize # it as a tl_ir.module. - tl_context = tl_ir.context() - tl_context.load_triton() with tempfile.NamedTemporaryFile(mode="wb") as f: ttir.operation.write_bytecode(f) f.flush() @@ -229,7 +228,7 @@ def compile_ttir_to_ptx_inplace( except RuntimeError as e: ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e - shared_mem_bytes = _triton.translation.get_shared_memory_size(ttgir) + shared_mem_bytes = metadata["shared"] if cuda_options.debug: print(llir) ptx = cuda_backend.make_ptx( @@ -322,6 +321,10 @@ def get_or_create_triton_kernel( ) ) + context = _triton.ir.context() + _triton.ir.load_dialects(context) + cuda_backend.load_dialects(context) + module = code_gen.ast_to_ttir( fn, specialization=tc.ASTSource( @@ -331,12 +334,14 @@ def get_or_create_triton_kernel( attrs=specialization_attr, ), options=cuda_options, + context=context, ) ttir = str(module) # `module`` is compiled in-place, so copy TTIR here. ptx, kernel_name, shared_mem_bytes, compute_capability = ( compile_ttir_to_ptx_inplace( module, + context, cuda_backend, cuda_options, device=device,