From 319d182a0b2cc98ecb8b3560f0383f1a8726c868 Mon Sep 17 00:00:00 2001 From: The jax_triton Authors Date: Mon, 26 Aug 2024 02:17:27 -0700 Subject: [PATCH] Integrate Triton up to [b2de88f8](https://github.com/openai/triton/commits/b2de88f89c1ff8082c92165535e48dece55da392) PiperOrigin-RevId: 667508787 --- jax_triton/triton_lib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index c6f5bdf0..f057da1c 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -319,6 +319,7 @@ def get_or_create_triton_kernel( _triton.ir.load_dialects(context) cuda_backend.load_dialects(context) codegen_fns = cuda_backend.get_codegen_implementation() + module_map = cuda_backend.get_module_map() module = code_gen.ast_to_ttir( fn, @@ -331,6 +332,7 @@ def get_or_create_triton_kernel( options=cuda_options, codegen_fns=codegen_fns, context=context, + module_map=module_map ) ttir = str(module)