Skip to content

Commit

Permalink
Adjust the DPAS instruction order in TritonIntelGPUToLLVM (#2627)
Browse files Browse the repository at this point in the history
Adjust the order of the generated DPAS instruction to get better DPAS
operands locality in the generated LLVM IR.

The new logic generates a large set of the DPAS instructions which can
reuse the A or B operands across multiple DPAS instructions.

Add a new env-var `TRITON_INTEL_AGGRESSIVE_DPAS_REUSE` to generate most
aggressive DPAS instruction order for experimental. Will make the
aggressive order as default when the IGC scalar backend could perfectly
generate the best performance kernel in instruction scheduling.

---------

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Co-authored-by: Tiotto, Ettore <ettore.tiotto@intel.com>
  • Loading branch information
chengjunlu and etiotto authored Nov 5, 2024
1 parent 290bfa9 commit 29c0ece
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 23 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"USE_IR_LOC",
"NVPTX_ENABLE_DUMP",
"TRITON_INTEL_ADVANCED_PATH",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
Expand Down
52 changes: 35 additions & 17 deletions test/Conversion/intel/tritongpu_to_gen_dot.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,NO-AGGRESSIVE-REUSE
// RUN: TRITON_INTEL_AGGRESSIVE_DPAS_REUSE=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,AGGRESSIVE-REUSE

#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
Expand Down Expand Up @@ -543,22 +544,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: %[[C_3_1:.*]] = llvm.insertelement %[[VAL_353]], %[[VAL_416]]{{\[}}%[[CST_7]] : i32] : vector<8xf32>

// COM: Total 16 dpas ops unrolled.
// CHECK: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]])
// CHECK: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]])
// CHECK: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]])
// CHECK: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]])
// CHECK: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]])
// CHECK: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]])
// CHECK: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]])
// CHECK: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]])
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]])
// NO-AGGRESSIVE-REUSE: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]])
// NO-AGGRESSIVE-REUSE: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]])
// NO-AGGRESSIVE-REUSE: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]])
// NO-AGGRESSIVE-REUSE: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]])
// NO-AGGRESSIVE-REUSE: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]])
// NO-AGGRESSIVE-REUSE: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]])
// NO-AGGRESSIVE-REUSE: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]])
// NO-AGGRESSIVE-REUSE: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]])
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]])

// AGGRESSIVE-REUSE: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]])
// AGGRESSIVE-REUSE: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]])
// AGGRESSIVE-REUSE: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]])
// AGGRESSIVE-REUSE: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]])
// AGGRESSIVE-REUSE: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]])
// AGGRESSIVE-REUSE: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]])
// AGGRESSIVE-REUSE: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]])
// AGGRESSIVE-REUSE: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]])
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]])

%0 = tt.dot %a, %b, %c, inputPrecision = tf32 : tensor<32x32xf16, #dot_operand_a> * tensor<32x32xf16, #dot_operand_b> -> tensor<32x32xf32, #dpas>
tt.return
Expand Down
41 changes: 35 additions & 6 deletions third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,43 @@ class DotOpDPASConversionHelper {

ArrayRef<unsigned> repCluster = dpasEncoding.getRepCluster();
unsigned rank = repCluster.size();

auto innerLoop = [&](int b, int k, int outer, unsigned repNumM,
unsigned repNumN, unsigned repInner,
bool reverseLoop = false) {
auto body = [&](int b, int k, int outer, int inner) {
if (repNumM > repNumN)
generateDPASOp(b, inner, outer, k);
else
generateDPASOp(b, outer, inner, k);
};

if (reverseLoop) {
for (int inner = repInner - 1; inner >= 0; --inner)
body(b, k, outer, inner);
return;
}

for (int inner = 0; inner < repInner; ++inner)
body(b, k, outer, inner);
};

// Use the smaller of the two dimensions as the outer loop for better DPAS
// operands locality.
bool aggressiveReusing =
triton::tools::getBoolEnv("TRITON_INTEL_AGGRESSIVE_DPAS_REUSE");
unsigned repNumM = repM * repCluster[rank - 2];
unsigned repNumN = repN * repCluster[rank - 1];
unsigned repOuter = repNumM > repNumN ? repNumN : repNumM;
unsigned repInner = repNumM > repNumN ? repNumM : repNumN;
for (int b = 0; b < repBatch; ++b)
for (int k = 0; k < repK; ++k)
for (int m = 0; m < repM; ++m)
for (int n = 0; n < repN; ++n)
for (int repRow = 0; repRow < repCluster[rank - 2]; ++repRow)
for (int repCol = 0; repCol < repCluster[rank - 1]; ++repCol)
generateDPASOp(b, m * repCluster[rank - 2] + repRow,
n * repCluster[rank - 1] + repCol, k);
for (int outer = 0; outer < repOuter; ++outer) {
// Change the inner loop direction in odd outer loop iteration if
// aggressive reuse DPAS operands.
bool reverseLoop = aggressiveReusing && ((outer % 2) == 1);
innerLoop(b, k, outer, repNumM, repNumN, repInner, reverseLoop);
}

Value res = composeValuesToDotOperandLayoutStruct(fc, repBatch, repM, repN,
resElemTy);
Expand Down

0 comments on commit 29c0ece

Please sign in to comment.