diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index c6f5bdf0..bcdafb85 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -20,6 +20,7 @@ import copy import dataclasses import functools +import inspect import os import pprint import tempfile @@ -320,17 +321,35 @@ def get_or_create_triton_kernel( cuda_backend.load_dialects(context) codegen_fns = cuda_backend.get_codegen_implementation() - module = code_gen.ast_to_ttir( - fn, - specialization=tc.ASTSource( + module = ( + code_gen.ast_to_ttir( + fn, + specialization=tc.ASTSource( + fn, + constants=constants, + signature=signature, + attrs=specialization_attr, + ), + options=cuda_options, + codegen_fns=codegen_fns, + context=context, + module_map=cuda_backend.get_module_map(), + ) + if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args + # Triton changes ASTSource.ast_to_ttir to include module_map. Handle + # backward compatibility here. + else code_gen.ast_to_ttir( fn, - constants=constants, - signature=signature, - attrs=specialization_attr, - ), - options=cuda_options, - codegen_fns=codegen_fns, - context=context, + specialization=tc.ASTSource( + fn, + constants=constants, + signature=signature, + attrs=specialization_attr, + ), + options=cuda_options, + codegen_fns=codegen_fns, + context=context, + ) ) ttir = str(module)