Skip to content

Commit

Permalink
[Performance] 2D load lowering for column major B (#1918)
Browse files Browse the repository at this point in the history
Enable the Intel 2D block IO to load the B matrix resident on column
major memory.

---------

Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
chengjunlu and whitneywhtsang authored Aug 28, 2024
1 parent 1ab4017 commit b66d768
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 84 deletions.
4 changes: 2 additions & 2 deletions test/TritonGEN/tritongen-addr-payload-opt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 33792
%27 = arith.cmpi slt, %23, %arg5 : i64
cf.cond_br %27, ^bb2, ^bb3
^bb2:
%28 = tt.load %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf16, #dot0>>
%29 = tt.load %26 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x8xf16, #dot1>>
%28 = tt.load %25 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<8x16xf16, #dot0>>
%29 = tt.load %26 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<16x8xf16, #dot1>>
%30 = tt.dot %28, %29, %24, inputPrecision = tf32 : tensor<8x16xf16, #dot0> * tensor<16x8xf16, #dot1> -> tensor<8x8xf32, #mma>
%31 = tt.advance %25, [%c0_i32, %c32_i32] : <tensor<8x16xf16, #dot0>>
%32 = tt.advance %26, [%c32_i32, %c0_i32] : <tensor<16x8xf16, #dot1>>
Expand Down
77 changes: 71 additions & 6 deletions test/TritonIntelGPU/load-to-llvm-2dload.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot1>>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
%0 = triton_gpu.convert_layout %D : tensor<64x64xf32, #dpas> -> tensor<64x64xf32, #blocked>
tt.return
Expand All @@ -40,8 +40,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK-COUNT-2: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r8x2cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-1: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_32b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-4: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv4_fDv8_fS0_({{.*}}) {{.*}} : (vector<4xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<64x32xf32, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x64xf32, #dot1>>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf32, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf32, #dot1>>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf32, #dot0> * tensor<32x64xf32, #dot1> -> tensor<64x64xf32, #dpas>
tt.return
}
Expand Down Expand Up @@ -109,7 +109,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_53:.*]] = llvm.insertelement %[[offsetY]], %[[VAL_52]]{{\[}}%[[VAL_49]] : i32] : vector<2xi32>
%ptrA = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf16, #dot0>>
// CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPt(%[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[VAL_53]], %[[VAL_48]])
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<32x32xf16, #dot0>>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x32xf16, #dot0>>
%B = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot1>
%C = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<32x32xf16, #dot0> * tensor<32x32xf16, #dot1> -> tensor<32x32xf32, #dpas>
Expand Down Expand Up @@ -179,10 +179,75 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_52:.*]] = llvm.insertelement %[[VAL_42]], %[[VAL_51]]{{\[}}%[[VAL_48]] : i32] : vector<2xi32>
%ptrB = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf16, #dot1>>
// CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x2cPU3AS1viiiDv2_iPj(%[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[VAL_52]], %[[VAL_47]])
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x32xf16, #dot1>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x32xf16, #dot1>>
%A = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot0>
%C = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<32x32xf16, #dot0> * tensor<32x32xf16, #dot1> -> tensor<32x32xf32, #dpas>
tt.return
}
}

// -----

// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 2], A = [8, 16], B = [16, 32], C = [8, 32]}>
#dot_b = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @column_major_dot_b
tt.func public @column_major_dot_b(%arg0: !tt.ptr<f16>, %col_stride: i64) {
%c64_i32 = arith.constant 64 : i32
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c32_i64 = arith.constant 32 : i64
%21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
// CHECK: llvm.ptrtoint
// CHECK: %[[ELEM_BITS:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[TILE_WIDTH:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[TILE_HEIGHT:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[VBLOCKS:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[TRANSPOSE:.*]] = llvm.mlir.constant(true) : i1
// CHECK: %[[VNNI:.*]] = llvm.mlir.constant(false) : i1
// CHECK: %[[VAL_68:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ELEM_BITS]], %[[TILE_WIDTH]], %[[TILE_HEIGHT]], %[[VBLOCKS]], %[[TRANSPOSE]], %[[VNNI]], {{.*}})
// CHECK: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
// CHECK: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
// CHECK: %[[VAL_103:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32
// CHECK: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
// CHECK: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
// CHECK: %[[VAL_138:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32
// CHECK: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
// CHECK: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
// CHECK: %[[VAL_173:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32
// CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
// CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
%45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
tt.return
}
}

// -----

// CHECK: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot_b = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @column_major_dot_b
tt.func public @column_major_dot_b(%arg0: !tt.ptr<f16>, %col_stride: i64) {
%c64_i32 = arith.constant 64 : i32
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c32_i64 = arith.constant 32 : i64
%21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj
// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj
// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj
// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj
// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32>
%45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
tt.return
}
}
4 changes: 2 additions & 2 deletions test/TritonIntelGPU/prefetch-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
%ptrB = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x32xf16, #dot1>>
triton_intel_gpu.prefetch %ptrA {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<32x16xf16, #dot0>>
triton_intel_gpu.prefetch %ptrB {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<16x32xf16, #dot1>>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<32x16xf16, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<16x32xf16, #dot1>>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x16xf16, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<16x32xf16, #dot1>>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<32x16xf16, #dot0> * tensor<16x32xf16, #dot1> -> tensor<32x32xf32, #dpas>
%0 = triton_gpu.convert_layout %D {allocation.offset = 0 : i32} : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
tt.return
Expand Down
4 changes: 2 additions & 2 deletions test/TritonIntelGPU/store-to-llvm-2dstore.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
%6 = tt.make_tensor_ptr %arg1, [%arg3, %arg4], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot1>>
%7 = tt.advance %3, [%c64_i32, %c-32_i32] : <tensor<64x32xf16, #dot0>>
%8 = tt.advance %7, [%c-64_i32, %c32_i32] : <tensor<64x32xf16, #dot0>>
%9 = tt.load %8 {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot0>>
%10 = tt.load %6 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot1>>
%9 = tt.load %8 {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
%10 = tt.load %6 {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
%11 = tt.dot %9, %10, %cst, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
%12 = arith.truncf %11#0 : tensor<64x64xf32, #dpas> to tensor<64x64xf16, #dpas>
%13 = tt.make_tensor_ptr %arg2, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>>
Expand Down
3 changes: 3 additions & 0 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ static bool isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
if (op.getVnniTransform())
return true;

if (op.getTranspose() && op.getTileHeight() != 16)
return false;

uint32_t tileWidth = op.getTileWidth();
switch (op.getElemSizeInBits()) {
case 8:
Expand Down
Loading

0 comments on commit b66d768

Please sign in to comment.