diff --git a/test/TritonGEN/tritongen-addr-payload-opt.mlir b/test/TritonGEN/tritongen-addr-payload-opt.mlir index d5c00c0bd4..c7109e0ff8 100644 --- a/test/TritonGEN/tritongen-addr-payload-opt.mlir +++ b/test/TritonGEN/tritongen-addr-payload-opt.mlir @@ -47,8 +47,8 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 33792 %27 = arith.cmpi slt, %23, %arg5 : i64 cf.cond_br %27, ^bb2, ^bb3 ^bb2: - %28 = tt.load %25 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %26 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %25 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %29 = tt.load %26 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %30 = tt.dot %28, %29, %24, inputPrecision = tf32 : tensor<8x16xf16, #dot0> * tensor<16x8xf16, #dot1> -> tensor<8x8xf32, #mma> %31 = tt.advance %25, [%c0_i32, %c32_i32] : > %32 = tt.advance %26, [%c32_i32, %c0_i32] : > diff --git a/test/TritonIntelGPU/load-to-llvm-2dload.mlir b/test/TritonIntelGPU/load-to-llvm-2dload.mlir index 03f86376c6..524e93d869 100644 --- a/test/TritonIntelGPU/load-to-llvm-2dload.mlir +++ b/test/TritonIntelGPU/load-to-llvm-2dload.mlir @@ -17,8 +17,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () // CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () // CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> - %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> - %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas> %0 = triton_gpu.convert_layout %D : tensor<64x64xf32, #dpas> -> tensor<64x64xf32, #blocked> tt.return @@ -40,8 +40,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // CHECK-COUNT-2: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r8x2cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () // CHECK-COUNT-1: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_32b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () // CHECK-COUNT-4: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv4_fDv8_fS0_({{.*}}) {{.*}} : (vector<4xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> - %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> - %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf32, #dot0> * tensor<32x64xf32, #dot1> -> tensor<64x64xf32, #dpas> tt.return } @@ -109,7 +109,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // CHECK: %[[VAL_53:.*]] = llvm.insertelement %[[offsetY]], %[[VAL_52]]{{\[}}%[[VAL_49]] : i32] : vector<2xi32> %ptrA = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPt(%[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[VAL_53]], %[[VAL_48]]) - %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %B = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot1> %C = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> %D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<32x32xf16, #dot0> * tensor<32x32xf16, #dot1> -> tensor<32x32xf32, #dpas> @@ -179,10 +179,75 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // CHECK: %[[VAL_52:.*]] = llvm.insertelement %[[VAL_42]], %[[VAL_51]]{{\[}}%[[VAL_48]] : i32] : vector<2xi32> %ptrB = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x2cPU3AS1viiiDv2_iPj(%[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[VAL_52]], %[[VAL_47]]) - %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %A = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot0> %C = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> %D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<32x32xf16, #dot0> * tensor<32x32xf16, #dot1> -> tensor<32x32xf32, #dpas> tt.return } } + +// ----- + +// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32 +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 2], A = [8, 16], B = [16, 32], C = [8, 32]}> +#dot_b = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @column_major_dot_b + tt.func public @column_major_dot_b(%arg0: !tt.ptr, %col_stride: i64) { + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c32_i64 = arith.constant 32 : i64 + %21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: llvm.ptrtoint + // CHECK: %[[ELEM_BITS:.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: %[[TILE_WIDTH:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[TILE_HEIGHT:.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: %[[VBLOCKS:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[TRANSPOSE:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[VNNI:.*]] = llvm.mlir.constant(false) : i1 + // CHECK: %[[VAL_68:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ELEM_BITS]], %[[TILE_WIDTH]], %[[TILE_HEIGHT]], %[[VBLOCKS]], %[[TRANSPOSE]], %[[VNNI]], {{.*}}) + // CHECK: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // CHECK: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + // CHECK: %[[VAL_103:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32 + // CHECK: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // CHECK: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + // CHECK: %[[VAL_138:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32 + // CHECK: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // CHECK: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + // CHECK: %[[VAL_173:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32 + // CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + %45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr>> + tt.return + } +} + +// ----- + +// CHECK: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +#dot_b = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @column_major_dot_b + tt.func public @column_major_dot_b(%arg0: !tt.ptr, %col_stride: i64) { + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c32_i64 = arith.constant 32 : i64 + %21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj + // CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj + // CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj + // CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj + // CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32> + %45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr>> + tt.return + } +} diff --git a/test/TritonIntelGPU/prefetch-to-llvm.mlir b/test/TritonIntelGPU/prefetch-to-llvm.mlir index 45fe7a1bb7..d8d2edb01f 100644 --- a/test/TritonIntelGPU/prefetch-to-llvm.mlir +++ b/test/TritonIntelGPU/prefetch-to-llvm.mlir @@ -48,8 +48,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war %ptrB = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > triton_intel_gpu.prefetch %ptrA {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> triton_intel_gpu.prefetch %ptrB {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> - %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> - %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<32x16xf16, #dot0> * tensor<16x32xf16, #dot1> -> tensor<32x32xf32, #dpas> %0 = triton_gpu.convert_layout %D {allocation.offset = 0 : i32} : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked> tt.return diff --git a/test/TritonIntelGPU/store-to-llvm-2dstore.mlir b/test/TritonIntelGPU/store-to-llvm-2dstore.mlir index da5f1032a9..a801dfa4c6 100644 --- a/test/TritonIntelGPU/store-to-llvm-2dstore.mlir +++ b/test/TritonIntelGPU/store-to-llvm-2dstore.mlir @@ -17,8 +17,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war %6 = tt.make_tensor_ptr %arg1, [%arg3, %arg4], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > %7 = tt.advance %3, [%c64_i32, %c-32_i32] : > %8 = tt.advance %7, [%c-64_i32, %c32_i32] : > - %9 = tt.load %8 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> - %10 = tt.load %6 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %9 = tt.load %8 {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %10 = tt.load %6 {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %11 = tt.dot %9, %10, %cst, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas> %12 = arith.truncf %11#0 : tensor<64x64xf32, #dpas> to tensor<64x64xf16, #dpas> %13 = tt.make_tensor_ptr %arg2, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index a1ef0767c7..e71fb96ea2 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -214,6 +214,9 @@ static bool isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) { if (op.getVnniTransform()) return true; + if (op.getTranspose() && op.getTileHeight() != 16) + return false; + uint32_t tileWidth = op.getTileWidth(); switch (op.getElemSizeInBits()) { case 8: diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 5b897ee1cc..fe84f7214f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -366,7 +366,25 @@ struct LoadOpConversion if (!hasDotDpasEncoding(tensorType)) return failure(); + Attribute blockIOAttr = + op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (!blockIOAttr) + return failure(); + + // Only support rank 2 dot layout, either row major or column major. + StringRef memoryLayoutInfo = cast(blockIOAttr).getValue(); + assert((memoryLayoutInfo == "row_major" || + memoryLayoutInfo == "column_major") && + "Only row_major or column_major is supported"); + const bool memoryRowMajor = (memoryLayoutInfo == "row_major"); + DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value(); + auto dotOrder = dotLayout.getThreadOrder(); + const bool valueRowMajor = (dotOrder[0] == 1 && dotOrder[1] == 0); + assert((valueRowMajor || (dotOrder[0] == 0 && dotOrder[1] == 1)) && + "Only row_major or column_major is allowed"); + const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; + auto dpasLayout = cast(dotLayout.getParent()); const unsigned opIdx = dotLayout.getOpIdx(); @@ -376,14 +394,15 @@ struct LoadOpConversion SmallVector numReps = dpasLayout.getDPASRepetitions(tensorShape, opIdx); const SmallVector warpsPerCTA = dpasLayout.getWarpsPerCTA(); - SmallVector order = triton::gpu::getOrder(dpasLayout); + SmallVector dpasOrder = triton::gpu::getOrder(dpasLayout); int threadsPerWarp = triton::gpu::getWarpSize(dpasLayout); Value warpId = rewriter.create( loc, i32_ty, rewriter.create(loc, /*upperBound=*/nullptr)); + SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder); bool isOperandA = (opIdx == 0); SmallVector dpasInstShape = isOperandA @@ -445,48 +464,64 @@ struct LoadOpConversion offsetBaseY] = getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); - // Load the operand. - unsigned numRepOuter = numReps[opIdx]; - unsigned numRepK = numReps[!opIdx]; - - unsigned tileHeight; - unsigned vBlocks; + unsigned tileWidth = elemsPerDPASInst[dotOrder[0]]; + unsigned tileHeight = elemsPerDPASInst[dotOrder[1]]; + unsigned vBlocks = 1; unsigned numOperandsOuterDimPerLoad = 1; - unsigned numOperandsKDimPerLoad = 1; - // dot shape [M, N] = [M, K] * [K, N] - if (isOperandA) { - // Use the warp shape as the tileHeight to load multiple operands A in M - // dim. - tileHeight = warpShape[0]; - assert(tileHeight <= 32 && "invalid tile height."); - numOperandsOuterDimPerLoad = repCluster[0]; - - // PVC 2D load supports 64 bytes per row at most. - unsigned totalBytesPerRowPerDPASOp = - elemsPerDPASInst[1] * eltTy.getIntOrFloatBitWidth() / 8; - vBlocks = std::min(numRepK, 64 / totalBytesPerRowPerDPASOp); - numOperandsKDimPerLoad = vBlocks; + unsigned numOperandsInnerDimPerLoad = 1; + + unsigned numOperandsPer2DLoadM, numOperandsPer2DloadN; + if (!isTransposeRequired) { + numOperandsPer2DLoadM = isOperandA ? repCluster[opIdx] : numReps[!opIdx]; + numOperandsPer2DloadN = isOperandA ? numReps[!opIdx] : repCluster[opIdx]; } else { - // PVC 2D load supports 32 rows at most. Load multiple operands B in K - // dim. - numOperandsKDimPerLoad = std::min(numRepK, 32 / elemsPerDPASInst[0]); - tileHeight = elemsPerDPASInst[0] * numOperandsKDimPerLoad; - - // PVC 2D load supports 64 bytes per row at most. - unsigned totalBytesPerRowPerDPASOp = - elemsPerDPASInst[1] * eltTy.getIntOrFloatBitWidth() / 8; - // Use block array length to load multiple operands B in N dim if - // possible. - vBlocks = std::min(repCluster[1], 64 / totalBytesPerRowPerDPASOp); - numOperandsOuterDimPerLoad = vBlocks; + if (isOperandA) + return op.emitOpError("Transposing load doesn't support dot A layout."); + + if (!usePackedType) + return op.emitOpError( + "Transposing load doesn't support un-pack-able dot B layout."); + + std::swap(tileHeight, tileWidth); + + // We can decompose the matrix returned by transposed large 2d load + // when threads per warp < column size. Otherwise we have to load one + // operand per inst. + // Note: the tileHeight and numOperandsPer2DLoadM are the column size + // now. + numOperandsPer2DLoadM = + (threadsPerWarp <= tileHeight) ? repCluster[1] : 1; + // The transpose 2d load only support 1 operand per inst on column. + // (vBlocks = 1) + numOperandsPer2DloadN = 1; } + // PVC 2D load supports 32 rows at most. Load multiple dot operands in by + // enlarging the tileHeight. + numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight); + tileHeight = tileHeight * numOperandsPer2DLoadM; + + // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands + // by enlarging the vBlocks. + unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8; + numOperandsPer2DloadN = + std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp); + vBlocks = numOperandsPer2DloadN; + + numOperandsOuterDimPerLoad = + isOperandA ? numOperandsPer2DLoadM : numOperandsPer2DloadN; + numOperandsInnerDimPerLoad = + isOperandA ? numOperandsPer2DloadN : numOperandsPer2DLoadM; + + if (isTransposeRequired) + std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); + unsigned numLoadPerOutRepCluster = mlir::ceil(repCluster[opIdx], numOperandsOuterDimPerLoad); unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * numOperandsOuterDimPerLoad * - numOperandsKDimPerLoad; + numOperandsInnerDimPerLoad; Type load2DGenXType = LLVM::getFixedVectorType(loadResultElemType, numValuesPerLoad); @@ -496,10 +531,19 @@ struct LoadOpConversion unsigned warpOuterStride = warpShape[opIdx]; unsigned repKStride = elemsPerDPASInst[opIdx == 0 ? 1 : 0]; + unsigned numRepOuter = numReps[opIdx]; + unsigned numRepInner = numReps[!opIdx]; + unsigned originalElemBits = elemBits; + if (isTransposeRequired) { + // adjust the block io parameter to align HW's limitations on + // transposing load. + tileWidth = tileWidth / (32 / originalElemBits); + elemBits = 32; + } ValueTable loadVals; for (int outer = 0; outer < numRepOuter; ++outer) { for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) { - for (int k = 0; k < numRepK; k += numOperandsKDimPerLoad) { + for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) { Value offsetX, offsetY; if (opIdx == 0) { // A @@ -517,26 +561,38 @@ struct LoadOpConversion offsetY = add(offsetY, offsetBaseY); baseWidth = trunc(i32_ty, baseWidth); baseHeight = trunc(i32_ty, baseHeight); - rowStride = trunc(i32_ty, rowStride); + Value pitch = trunc(i32_ty, rowStride); + Value elemSizeInBytes = i32_val(originalElemBits / 8); + + if (!memoryRowMajor) { + // Column major memory. We need to swap the X and Y because HW only + // support row major memory layout. + pitch = trunc(i32_ty, colStride); + std::swap(offsetX, offsetY); + std::swap(baseWidth, baseHeight); + } - unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - Value elemSizeInBytes = i32_val(elemSizeInBits / 8); + if (isTransposeRequired) { + // adjust the block io parameter to align HW's limitations on + // transposing load. + offsetX = udiv(offsetX, i32_val(32 / originalElemBits)); + } auto load2dOp = rewriter.create( loc, load2DGenXType, /*ptr*/ base, /*base_width*/ mul(baseWidth, elemSizeInBytes), /*base_height*/ baseHeight, - /*base_pitch*/ mul(rowStride, elemSizeInBytes), + /*base_pitch*/ mul(pitch, elemSizeInBytes), /*x*/ trunc(i32_ty, offsetX), /*y*/ trunc(i32_ty, offsetY), - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ elemsPerDPASInst[1], + /*elem_size_in_bits*/ elemBits, + /*tile_width*/ tileWidth, /*tile_height*/ tileHeight, /*v_blocks*/ vBlocks, - /*transpose*/ false, + /*transpose*/ isTransposeRequired, /*vnni_transform*/ - (usePackedType && !isOperandA && + (usePackedType && !isOperandA && !isTransposeRequired && eltTy.getIntOrFloatBitWidth() != 32)); if (failed(load2dOp.verify())) { // Explicitly invoke verifier because `triton_gen` ops are @@ -544,34 +600,45 @@ struct LoadOpConversion return failure(); } - unsigned packedRowNum = - opIdx == 0 ? numOperandsOuterDimPerLoad : numOperandsKDimPerLoad; - unsigned packedColNum = - opIdx == 0 ? numOperandsKDimPerLoad : numOperandsOuterDimPerLoad; - unsigned offset = 0; - // The register value returned by 2D load is contiguous on the row. - for (int col = 0; col < packedColNum; ++col) { - for (int row = 0; row < packedRowNum; ++row) { - - Value loadVal = undef(packedDPASOperandType); - for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; - ++elemIdx) { - Value loaded = extract_element(load2dOp, i32_val(offset++)); - loadVal = insert_element(loadVal, loaded, i32_val(elemIdx)); - } - - // Save the unpacked vals to the map; - if (opIdx == 0) { - loadVals[{outer * packedRowNum * numLoadPerOutRepCluster + - rep * packedRowNum + row, - k + col}] = bitcast(loadVal, unpackedDPASOperandType); - } else { - loadVals[{outer * packedColNum * numLoadPerOutRepCluster + - rep * packedColNum + col, - k + row}] = bitcast(loadVal, unpackedDPASOperandType); + unsigned packedRowNum = opIdx == 0 ? numOperandsOuterDimPerLoad + : numOperandsInnerDimPerLoad; + unsigned packedColNum = opIdx == 0 ? numOperandsInnerDimPerLoad + : numOperandsOuterDimPerLoad; + + // Decompose the return value to multiple operands. + unsigned packedColNumPerVBlock = packedColNum / vBlocks; + for (int vblk = 0; vblk < vBlocks; ++vblk) + for (int row = 0; row < packedRowNum; ++row) + for (int col = 0; col < packedColNumPerVBlock; ++col) { + + unsigned operandStartOffset = (vblk * packedRowNum + row) * + packedColNumPerVBlock * + packedElemsPerLanePerDPASInst; + + SmallVector indices(packedElemsPerLanePerDPASInst); + for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; + ++elemIdx) { + indices[elemIdx] = operandStartOffset + + elemIdx * packedColNumPerVBlock + col; + } + DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); + Value loadVal = rewriter.create( + loc, packedDPASOperandType, load2dOp, load2dOp, attr); + + // Save the decomposed vals to the map; + if (opIdx == 0) { + loadVals[{outer * packedRowNum * numLoadPerOutRepCluster + + rep * packedRowNum + row, + k + vblk * packedColNumPerVBlock + col}] = + bitcast(loadVal, unpackedDPASOperandType); + } else { + loadVals[{outer * packedColNum * numLoadPerOutRepCluster + + rep * packedColNum + + vblk * packedColNumPerVBlock + col, + k + row}] = + bitcast(loadVal, unpackedDPASOperandType); + } } - } - } } } } @@ -580,7 +647,7 @@ struct LoadOpConversion // expected order for the layout. SmallVector unpackedLoadedVals; for (int outer = 0; outer < numRepOuter; ++outer) { - for (int k = 0; k < numRepK; ++k) { + for (int k = 0; k < numRepInner; ++k) { for (int rep = 0; rep < repCluster[opIdx]; ++rep) { Value loadVal = loadVals.at({outer * repCluster[opIdx] + rep, k}); VectorType loadTy = cast(loadVal.getType());