From a971b85f4d6e1bcb2f3b6917d4ca341e79693eff Mon Sep 17 00:00:00 2001 From: victor-eds Date: Mon, 11 Nov 2024 13:56:07 +0000 Subject: [PATCH 1/2] [XPU][OptEW] Allow multiple warps in non-sliced dimension Allow multiple warps in non-sliced dimension as long as there are `n*sub_group_size` contiguous elements per warp in the non-sliced dimension. Signed-off-by: victor-eds --- test/TritonIntelGPU/optimize-elementwise.mlir | 88 +++++++++++++++++++ .../OptimizeElementwiseParallelism.cpp | 43 +++++++-- 2 files changed, 123 insertions(+), 8 deletions(-) diff --git a/test/TritonIntelGPU/optimize-elementwise.mlir b/test/TritonIntelGPU/optimize-elementwise.mlir index d8b64bab8..f01863e3c 100644 --- a/test/TritonIntelGPU/optimize-elementwise.mlir +++ b/test/TritonIntelGPU/optimize-elementwise.mlir @@ -63,3 +63,91 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> } } + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_blocked_multi_warp( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> { + tt.func @test_blocked_multi_warp(%arg0: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { +// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<32xf32, #[[$ATTR_1]]> + %0 = arith.addf %arg0, %arg1 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> +// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<32xf32, #[[$ATTR_1]]> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> +// CHECK: tt.return %[[VAL_5]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> + tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [32], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> + +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_blocked_multi_warp_double_stride( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> { + tt.func @test_blocked_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { +// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]> +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]> +// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_1]]> + %0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> +// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_1]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> +// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> + tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> { + tt.func @test_mma_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> { +// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]> +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]> +// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]> + %0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> +// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> + tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride_repeat( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> { + tt.func @test_mma_multi_warp_double_stride_repeat(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> { +// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]> +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]> +// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]> + %0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> +// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> + tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp index 1bd154306..5172a58f2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp @@ -24,6 +24,27 @@ namespace mlir::triton::gpu::intel { #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" namespace { +bool isMultiWarpValidLayoutForUnbroadcast(const LinearLayout &linearLayout, + int32_t numWorkGroupPos, + PatternRewriter &rewriter) { + StringAttr kLane = rewriter.getStringAttr("lane"); + StringAttr kWarp = rewriter.getStringAttr("warp"); + int32_t subGroupSize = linearLayout.getInDimSize(kLane); + ArrayRef numContiguousPerWarp = linearLayout.getBasis(kWarp, 0); + // Check the warp dimension hasn't been sliced away and we have n * + // sub_group_size contiguous elements per warp. + if (numContiguousPerWarp == ArrayRef{0} || + numContiguousPerWarp[0] % subGroupSize != 0) + return false; + int32_t expectedValue = numContiguousPerWarp[0] * 2; + for (int32_t pos = 1; pos < numWorkGroupPos; ++pos) { + if (linearLayout.getBasis(kWarp, pos) != ArrayRef{expectedValue}) + return false; + expectedValue *= 2; + } + return true; +} + /// Return whether the input linear layout can be unbroadcasted. /// /// A layout is valid for being "unbroadcasted" along its lanes if: @@ -31,8 +52,8 @@ namespace { /// sliced. /// - The size of the input 'block' dimension is 1. This is true for XPU /// backend. -/// - The size of the input 'warp' dimension is 1. This is a limitation to keep -/// things simple for now. +/// - The size of the input 'warp' dimension is 1 or there are n*sub_group_size +/// contiguous elements per warp. /// /// Broadcasted layouts are layouts with sliced lane, warp or block (not /// possible for XPU backend) dimensions, i.e., the same data is owned by @@ -49,8 +70,11 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout, // Only single block for now. if (linearLayout.getInDimSize(kBlock) != 1) return false; - // Only single warp for now. - return linearLayout.getInDimSize(kWarp) == 1; + // 'warp' dimension hasn't been sliced away and there are n*sub_group_size + // contiguous elements in each warp (or there is a single warp). + int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp); + return numWorkGroupPos == 0 || isMultiWarpValidLayoutForUnbroadcast( + linearLayout, numWorkGroupPos, rewriter); } /// Get optimized unbroadcasted tensor type. @@ -61,18 +85,21 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout, RankedTensorType getOptimizedType(RankedTensorType type, const LinearLayout &linearLayout, PatternRewriter &rewriter) { + StringAttr kWarp = rewriter.getStringAttr("warp"); + auto encoding = cast(type.getEncoding()); unsigned threadsPerWarp = product(encoding.getThreadsPerWarp()); - [[maybe_unused]] unsigned warpsPerCTA = product(encoding.getWarpsPerCTA()); - assert(warpsPerCTA == 1 && "Expecting single warp"); + unsigned warpsPerCTA = product(encoding.getWarpsPerCTA()); [[maybe_unused]] unsigned ctaSplitNum = product(encoding.getCTASplitNum()); assert(ctaSplitNum == 1 && "Expecting single CTA"); RankedTensorType::Builder builder(type); + int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp); + unsigned sizePerThread = + numWorkGroupPos == 0 ? 1 : linearLayout.getBasis(kWarp, 0)[0]; CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(rewriter.getContext(), 1); auto newEncoding = rewriter.getAttr( - /*sizePerThread=*/1, threadsPerWarp, /*warpsPerCTA=*/1, /*order=*/0, - ctaLayout); + sizePerThread, threadsPerWarp, warpsPerCTA, /*order=*/0, ctaLayout); builder.setEncoding(newEncoding); return builder; } From 40b6670b5b8e0c8a729568fd9084e7f4c85215fb Mon Sep 17 00:00:00 2001 From: victor-eds Date: Thu, 14 Nov 2024 10:44:26 +0000 Subject: [PATCH 2/2] Fix bug --- test/TritonIntelGPU/optimize-elementwise.mlir | 8 ++++---- .../OptimizeElementwiseParallelism.cpp | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/TritonIntelGPU/optimize-elementwise.mlir b/test/TritonIntelGPU/optimize-elementwise.mlir index f01863e3c..57c7cd415 100644 --- a/test/TritonIntelGPU/optimize-elementwise.mlir +++ b/test/TritonIntelGPU/optimize-elementwise.mlir @@ -67,7 +67,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- // CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> -// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> @@ -89,7 +89,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- // CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}> -// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [32], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> #blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}> @@ -110,7 +110,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}> +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}> // CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> @@ -132,7 +132,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- -// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> // CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp index 5172a58f2..dd92968e0 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp @@ -96,7 +96,9 @@ RankedTensorType getOptimizedType(RankedTensorType type, RankedTensorType::Builder builder(type); int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp); unsigned sizePerThread = - numWorkGroupPos == 0 ? 1 : linearLayout.getBasis(kWarp, 0)[0]; + numWorkGroupPos == 0 + ? 1 + : linearLayout.getBasis(kWarp, 0)[0] / threadsPerWarp; CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(rewriter.getContext(), 1); auto newEncoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, /*order=*/0, ctaLayout);