Skip to content

Commit

Permalink
compile_ttir_to_ptx_inplace no longer requires an IR context
Browse files Browse the repository at this point in the history
It was only used to convert an ir.Module to tl_ir.module.

PiperOrigin-RevId: 604255805
  • Loading branch information
superbobry authored and The jax_triton Authors committed Feb 5, 2024
1 parent 28ad476 commit 6effbb5
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def aval_size_bytes(aval):

def compile_ttir_to_ptx_inplace(
ttir,
tl_context: tl_ir.Context,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
device: int = 0,
Expand All @@ -165,14 +164,18 @@ def compile_ttir_to_ptx_inplace(
if cuda_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
context = _triton.ir.context()
_triton.ir.load_dialects(context)
cuda_backend.load_dialects(context)

# Triton compilation APIs only accept Triton-specific MLIR wrappers.
# So, here we serialize an ir.Module to a file and then deserialize
# it as a tl_ir.module.
with tempfile.NamedTemporaryFile(mode="wb") as f:
ttir.operation.write_bytecode(f)
f.flush()
ttir = tl_ir.parse_mlir_module(f.name, tl_context)
ttir.context = tl_context
ttir = tl_ir.parse_mlir_module(f.name, context)
ttir.context = context
try:
metadata = dict()
opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options)
Expand Down Expand Up @@ -310,7 +313,6 @@ def get_or_create_triton_kernel(
ptx, kernel_name, shared_mem_bytes, compute_capability, cluster_dims = (
compile_ttir_to_ptx_inplace(
module,
context,
cuda_backend,
cuda_options,
device=device,
Expand Down

0 comments on commit 6effbb5

Please sign in to comment.