Skip to content

Commit

Permalink
Testing triton integration 2023-09-14
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572808737
  • Loading branch information
The jax_triton Authors committed Oct 12, 2023
1 parent 06d3517 commit f41e408
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,30 @@ def compile_ttir_to_ptx_inplace(
print(ttir)
try:
ttir = tc.optimize_ttir(ttir, compute_capability)
ttgir = tc.ttir_to_ttgir(ttir, num_warps)
ttgir = tc.optimize_ttgir(ttgir, num_stages, compute_capability)
ttgir = tc.ttir_to_ttgir(
ttir, num_warps, num_ctas=1, arch=compute_capability
)
ttgir = tc.optimize_ttgir(
ttgir,
num_stages,
num_warps,
num_ctas=1,
arch=compute_capability,
cluster_info=_triton.ClusterInfo(),
enable_warp_specialization=False,
enable_persistent=False,
optimize_epilogue=False,
)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if dump:
print(ttgir)
extern_libs = {}
try:
llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability)
llir = tc.ttgir_to_llir(
ttgir, extern_libs, compute_capability, _triton.TMAInfos()
)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
Expand Down

0 comments on commit f41e408

Please sign in to comment.