From ebdbde90c8129eb3e0471ee2d49f8281a9ed24f5 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Thu, 2 May 2024 01:41:56 -0700 Subject: [PATCH] Integrate Triton up to [8e0c7b42](https://github.com/openai/triton/commits/8e0c7b425ac149c43183de966ffa423fd46e4762) PiperOrigin-RevId: 629987238 --- jax_triton/triton_lib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index d3d272f6..680735ca 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -318,6 +318,7 @@ def get_or_create_triton_kernel( context = _triton.ir.context() _triton.ir.load_dialects(context) cuda_backend.load_dialects(context) + codegen_fns = cuda_backend.get_codegen_implementation() module = code_gen.ast_to_ttir( fn, @@ -328,6 +329,7 @@ def get_or_create_triton_kernel( attrs=specialization_attr, ), options=cuda_options, + codegen_fns=codegen_fns, context=context, ) ttir = str(module) @@ -429,7 +431,7 @@ def triton_kernel_call_lowering( # TODO(cjfj): Prune explicit `num_warps` / `num_stages`. prev_early_config_prune_fn = fn.early_config_prune - def prune_configs(configs, named_args): + def prune_configs(configs, named_args, **kwargs): pruned_configs = [] for config in configs: if config.pre_hook is not None: