Skip to content

Commit

Permalink
Bumping Triton version
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573804972
  • Loading branch information
Moerafaat authored and The jax_triton Authors committed Oct 16, 2023
1 parent f41e408 commit f162738
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,19 @@ def compile_ttir_to_ptx_inplace(
if dump:
print(ttir)
try:
ttir = tc.optimize_ttir(ttir, compute_capability)
target = tc.compiler.CudaTargetDescriptor(capability=compute_capability,
num_warps=num_warps,
enable_fp_fusion=True)
ttir = tc.optimize_ttir(ttir, target)
ttgir = tc.ttir_to_ttgir(
ttir, num_warps, num_ctas=1, arch=compute_capability
ttir, num_warps, num_ctas=1, target=target
)
ttgir = tc.optimize_ttgir(
ttgir,
num_stages,
num_warps,
num_ctas=1,
arch=compute_capability,
target=target,
cluster_info=_triton.ClusterInfo(),
enable_warp_specialization=False,
enable_persistent=False,
Expand All @@ -185,15 +188,15 @@ def compile_ttir_to_ptx_inplace(
extern_libs = {}
try:
llir = tc.ttgir_to_llir(
ttgir, extern_libs, compute_capability, _triton.TMAInfos()
ttgir, extern_libs, target, _triton.TMAInfos()
)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
if dump:
print(llir)
ptx = tc.llir_to_ptx(llir, compute_capability)
ptx = tc.llir_to_ptx(llir, target)
if dump:
print(ptx)
name = ptx_get_kernel_name(ptx)
Expand Down Expand Up @@ -247,7 +250,7 @@ def get_or_create_triton_kernel(
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
module = code_gen.ast_to_ttir(
fn, signature, specialization, constants, debug=dump, arch=arch
fn, signature, specialization, constants, debug=dump, target=arch
)
ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
Expand Down

0 comments on commit f162738

Please sign in to comment.