diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index d8b948a2b7..1cc88123d0 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -32,6 +32,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "USE_IR_LOC", "NVPTX_ENABLE_DUMP", "TRITON_INTEL_ADVANCED_PATH", + "TRITON_INTEL_AGGRESSIVE_DPAS_REUSE", "TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN", "TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT", "TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM", diff --git a/test/Conversion/intel/tritongpu_to_gen_dot.mlir b/test/Conversion/intel/tritongpu_to_gen_dot.mlir index a2f2de0d39..7c489ff5b5 100644 --- a/test/Conversion/intel/tritongpu_to_gen_dot.mlir +++ b/test/Conversion/intel/tritongpu_to_gen_dot.mlir @@ -1,4 +1,5 @@ -// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,NO-AGGRESSIVE-REUSE +// RUN: TRITON_INTEL_AGGRESSIVE_DPAS_REUSE=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,AGGRESSIVE-REUSE #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}> @@ -543,22 +544,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: %[[C_3_1:.*]] = llvm.insertelement %[[VAL_353]], %[[VAL_416]]{{\[}}%[[CST_7]] : i32] : vector<8xf32> // COM: Total 16 dpas ops unrolled. - // CHECK: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]]) - // CHECK: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]]) - // CHECK: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]]) - // CHECK: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]]) - // CHECK: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]]) - // CHECK: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]]) - // CHECK: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]]) - // CHECK: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]]) - // CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]]) + // NO-AGGRESSIVE-REUSE: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]]) + // NO-AGGRESSIVE-REUSE: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]]) + // NO-AGGRESSIVE-REUSE: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]]) + // NO-AGGRESSIVE-REUSE: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]]) + // NO-AGGRESSIVE-REUSE: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]]) + // NO-AGGRESSIVE-REUSE: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]]) + // NO-AGGRESSIVE-REUSE: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]]) + // NO-AGGRESSIVE-REUSE: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]]) + // NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]]) + + // AGGRESSIVE-REUSE: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]]) + // AGGRESSIVE-REUSE: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]]) + // AGGRESSIVE-REUSE: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]]) + // AGGRESSIVE-REUSE: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]]) + // AGGRESSIVE-REUSE: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]]) + // AGGRESSIVE-REUSE: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]]) + // AGGRESSIVE-REUSE: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]]) + // AGGRESSIVE-REUSE: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]]) + // AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]]) %0 = tt.dot %a, %b, %c, inputPrecision = tf32 : tensor<32x32xf16, #dot_operand_a> * tensor<32x32xf16, #dot_operand_b> -> tensor<32x32xf32, #dpas> tt.return diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index 529e7f07a6..ecd8eb1140 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -191,14 +191,43 @@ class DotOpDPASConversionHelper { ArrayRef repCluster = dpasEncoding.getRepCluster(); unsigned rank = repCluster.size(); + + auto innerLoop = [&](int b, int k, int outer, unsigned repNumM, + unsigned repNumN, unsigned repInner, + bool reverseLoop = false) { + auto body = [&](int b, int k, int outer, int inner) { + if (repNumM > repNumN) + generateDPASOp(b, inner, outer, k); + else + generateDPASOp(b, outer, inner, k); + }; + + if (reverseLoop) { + for (int inner = repInner - 1; inner >= 0; --inner) + body(b, k, outer, inner); + return; + } + + for (int inner = 0; inner < repInner; ++inner) + body(b, k, outer, inner); + }; + + // Use the smaller of the two dimensions as the outer loop for better DPAS + // operands locality. + bool aggressiveReusing = + triton::tools::getBoolEnv("TRITON_INTEL_AGGRESSIVE_DPAS_REUSE"); + unsigned repNumM = repM * repCluster[rank - 2]; + unsigned repNumN = repN * repCluster[rank - 1]; + unsigned repOuter = repNumM > repNumN ? repNumN : repNumM; + unsigned repInner = repNumM > repNumN ? repNumM : repNumN; for (int b = 0; b < repBatch; ++b) for (int k = 0; k < repK; ++k) - for (int m = 0; m < repM; ++m) - for (int n = 0; n < repN; ++n) - for (int repRow = 0; repRow < repCluster[rank - 2]; ++repRow) - for (int repCol = 0; repCol < repCluster[rank - 1]; ++repCol) - generateDPASOp(b, m * repCluster[rank - 2] + repRow, - n * repCluster[rank - 1] + repCol, k); + for (int outer = 0; outer < repOuter; ++outer) { + // Change the inner loop direction in odd outer loop iteration if + // aggressive reuse DPAS operands. + bool reverseLoop = aggressiveReusing && ((outer % 2) == 1); + innerLoop(b, k, outer, repNumM, repNumN, repInner, reverseLoop); + } Value res = composeValuesToDotOperandLayoutStruct(fc, repBatch, repM, repN, resElemTy);