diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 5ee1c4af..22bf6d82 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -162,16 +162,19 @@ def compile_ttir_to_ptx_inplace( if dump: print(ttir) try: - ttir = tc.optimize_ttir(ttir, compute_capability) + target = tc.compiler.CudaTargetDescriptor(capability=compute_capability, + num_warps=num_warps, + enable_fp_fusion=True) + ttir = tc.optimize_ttir(ttir, target) ttgir = tc.ttir_to_ttgir( - ttir, num_warps, num_ctas=1, arch=compute_capability + ttir, num_warps, num_ctas=1, target=target ) ttgir = tc.optimize_ttgir( ttgir, num_stages, num_warps, num_ctas=1, - arch=compute_capability, + target=target, cluster_info=_triton.ClusterInfo(), enable_warp_specialization=False, enable_persistent=False, @@ -185,7 +188,7 @@ def compile_ttir_to_ptx_inplace( extern_libs = {} try: llir = tc.ttgir_to_llir( - ttgir, extern_libs, compute_capability, _triton.TMAInfos() + ttgir, extern_libs, target, _triton.TMAInfos() ) except RuntimeError as e: ttgir.dump() @@ -193,7 +196,7 @@ def compile_ttir_to_ptx_inplace( shared_mem_bytes = _triton.get_shared_memory_size(ttgir) if dump: print(llir) - ptx = tc.llir_to_ptx(llir, compute_capability) + ptx = tc.llir_to_ptx(llir, target) if dump: print(ptx) name = ptx_get_kernel_name(ptx) @@ -247,7 +250,7 @@ def get_or_create_triton_kernel( device = 0 arch = triton_kernel_call_lib.get_compute_capability(device) module = code_gen.ast_to_ttir( - fn, signature, specialization, constants, debug=dump, arch=arch + fn, signature, specialization, constants, debug=dump, target=arch ) ttir = str(module) # `module`` is compiled in-place, so copy TTIR here. ptx, kernel_name, shared_mem_bytes, compute_capability = (