Skip to content

Commit

Permalink
Integrate Triton up to [b2de88f8](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
The jax_triton Authors committed Sep 5, 2024
1 parent b3a2b6c commit 67fe01a
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import copy
import dataclasses
import functools
import inspect
import os
import pprint
import tempfile
Expand Down Expand Up @@ -320,17 +321,35 @@ def get_or_create_triton_kernel(
cuda_backend.load_dialects(context)
codegen_fns = cuda_backend.get_codegen_implementation()

module = code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
module = (
code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
fn,
constants=constants,
signature=signature,
attrs=specialization_attr,
),
options=cuda_options,
codegen_fns=codegen_fns,
context=context,
module_map=cuda_backend.get_module_map(),
)
if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args
# Triton changes ASTSource.ast_to_ttir to include module_map. Handle
# backward compatibility here.
else code_gen.ast_to_ttir(
fn,
constants=constants,
signature=signature,
attrs=specialization_attr,
),
options=cuda_options,
codegen_fns=codegen_fns,
context=context,
specialization=tc.ASTSource(
fn,
constants=constants,
signature=signature,
attrs=specialization_attr,
),
options=cuda_options,
codegen_fns=codegen_fns,
context=context,
)
)
ttir = str(module)

Expand Down

0 comments on commit 67fe01a

Please sign in to comment.