From 6effbb5a30d2a0d350553ff9e29b6baefafb115d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 5 Feb 2024 02:35:56 -0800 Subject: [PATCH] compile_ttir_to_ptx_inplace no longer requires an IR context It was only used to convert an ir.Module to tl_ir.module. PiperOrigin-RevId: 604255805 --- jax_triton/triton_lib.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 842673e5..61150dd2 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -155,7 +155,6 @@ def aval_size_bytes(aval): def compile_ttir_to_ptx_inplace( ttir, - tl_context: tl_ir.Context, cuda_backend: cb.CUDABackend, cuda_options: cb.CUDAOptions, device: int = 0, @@ -165,14 +164,18 @@ def compile_ttir_to_ptx_inplace( if cuda_options.debug: print(ttir) if isinstance(ttir, ir.Module): + context = _triton.ir.context() + _triton.ir.load_dialects(context) + cuda_backend.load_dialects(context) + # Triton compilation APIs only accept Triton-specific MLIR wrappers. # So, here we serialize an ir.Module to a file and then deserialize # it as a tl_ir.module. with tempfile.NamedTemporaryFile(mode="wb") as f: ttir.operation.write_bytecode(f) f.flush() - ttir = tl_ir.parse_mlir_module(f.name, tl_context) - ttir.context = tl_context + ttir = tl_ir.parse_mlir_module(f.name, context) + ttir.context = context try: metadata = dict() opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options) @@ -310,7 +313,6 @@ def get_or_create_triton_kernel( ptx, kernel_name, shared_mem_bytes, compute_capability, cluster_dims = ( compile_ttir_to_ptx_inplace( module, - context, cuda_backend, cuda_options, device=device,