Skip to content

Commit

Permalink
Integrate Triton up to [8e0c7b42](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
Moerafaat authored and The jax_triton Authors committed May 2, 2024
1 parent 7c08ae4 commit ebdbde9
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def get_or_create_triton_kernel(
context = _triton.ir.context()
_triton.ir.load_dialects(context)
cuda_backend.load_dialects(context)
codegen_fns = cuda_backend.get_codegen_implementation()

module = code_gen.ast_to_ttir(
fn,
Expand All @@ -328,6 +329,7 @@ def get_or_create_triton_kernel(
attrs=specialization_attr,
),
options=cuda_options,
codegen_fns=codegen_fns,
context=context,
)
ttir = str(module)
Expand Down Expand Up @@ -429,7 +431,7 @@ def triton_kernel_call_lowering(
# TODO(cjfj): Prune explicit `num_warps` / `num_stages`.
prev_early_config_prune_fn = fn.early_config_prune

def prune_configs(configs, named_args):
def prune_configs(configs, named_args, **kwargs):
pruned_configs = []
for config in configs:
if config.pre_hook is not None:
Expand Down

0 comments on commit ebdbde9

Please sign in to comment.