diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index d3d272f6..680735ca 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -318,6 +318,7 @@ def get_or_create_triton_kernel( context = _triton.ir.context() _triton.ir.load_dialects(context) cuda_backend.load_dialects(context) + codegen_fns = cuda_backend.get_codegen_implementation() module = code_gen.ast_to_ttir( fn, @@ -328,6 +329,7 @@ def get_or_create_triton_kernel( attrs=specialization_attr, ), options=cuda_options, + codegen_fns=codegen_fns, context=context, ) ttir = str(module) @@ -429,7 +431,7 @@ def triton_kernel_call_lowering( # TODO(cjfj): Prune explicit `num_warps` / `num_stages`. prev_early_config_prune_fn = fn.early_config_prune - def prune_configs(configs, named_args): + def prune_configs(configs, named_args, **kwargs): pruned_configs = [] for config in configs: if config.pre_hook is not None: