Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU][OptEW] Allow multiple warps in non-sliced dimension #2670

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions test/TritonIntelGPU/optimize-elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,91 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_blocked_multi_warp(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> {
tt.func @test_blocked_multi_warp(%arg0: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<32xf32, #[[$ATTR_1]]>
%0 = arith.addf %arg0, %arg1 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<32xf32, #[[$ATTR_1]]> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_blocked_multi_warp_double_stride(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> {
tt.func @test_blocked_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_1]]>
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_1]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> {
tt.func @test_mma_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]>
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride_repeat(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> {
tt.func @test_mma_multi_warp_double_stride_repeat(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]>
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,36 @@ namespace mlir::triton::gpu::intel {
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"

namespace {
bool isMultiWarpValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
int32_t numWorkGroupPos,
PatternRewriter &rewriter) {
StringAttr kLane = rewriter.getStringAttr("lane");
StringAttr kWarp = rewriter.getStringAttr("warp");
int32_t subGroupSize = linearLayout.getInDimSize(kLane);
ArrayRef<int32_t> numContiguousPerWarp = linearLayout.getBasis(kWarp, 0);
// Check the warp dimension hasn't been sliced away and we have n *
// sub_group_size contiguous elements per warp.
if (numContiguousPerWarp == ArrayRef<int32_t>{0} ||
numContiguousPerWarp[0] % subGroupSize != 0)
return false;
int32_t expectedValue = numContiguousPerWarp[0] * 2;
for (int32_t pos = 1; pos < numWorkGroupPos; ++pos) {
if (linearLayout.getBasis(kWarp, pos) != ArrayRef<int32_t>{expectedValue})
return false;
expectedValue *= 2;
}
return true;
}

/// Return whether the input linear layout can be unbroadcasted.
///
/// A layout is valid for being "unbroadcasted" along its lanes if:
/// - The 'lane' input dimension is zero: this means the lane dimension has been
/// sliced.
/// - The size of the input 'block' dimension is 1. This is true for XPU
/// backend.
/// - The size of the input 'warp' dimension is 1. This is a limitation to keep
/// things simple for now.
/// - The size of the input 'warp' dimension is 1 or there are n*sub_group_size
/// contiguous elements per warp.
///
/// Broadcasted layouts are layouts with sliced lane, warp or block (not
/// possible for XPU backend) dimensions, i.e., the same data is owned by
Expand All @@ -49,8 +70,11 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
// Only single block for now.
if (linearLayout.getInDimSize(kBlock) != 1)
return false;
// Only single warp for now.
return linearLayout.getInDimSize(kWarp) == 1;
// 'warp' dimension hasn't been sliced away and there are n*sub_group_size
// contiguous elements in each warp (or there is a single warp).
int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp);
return numWorkGroupPos == 0 || isMultiWarpValidLayoutForUnbroadcast(
linearLayout, numWorkGroupPos, rewriter);
}

/// Get optimized unbroadcasted tensor type.
Expand All @@ -61,18 +85,23 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
RankedTensorType getOptimizedType(RankedTensorType type,
const LinearLayout &linearLayout,
PatternRewriter &rewriter) {
StringAttr kWarp = rewriter.getStringAttr("warp");

auto encoding = cast<DistributedEncodingTrait>(type.getEncoding());
unsigned threadsPerWarp = product(encoding.getThreadsPerWarp());
[[maybe_unused]] unsigned warpsPerCTA = product(encoding.getWarpsPerCTA());
assert(warpsPerCTA == 1 && "Expecting single warp");
unsigned warpsPerCTA = product(encoding.getWarpsPerCTA());
[[maybe_unused]] unsigned ctaSplitNum = product(encoding.getCTASplitNum());
assert(ctaSplitNum == 1 && "Expecting single CTA");

RankedTensorType::Builder builder(type);
int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp);
unsigned sizePerThread =
numWorkGroupPos == 0
? 1
: linearLayout.getBasis(kWarp, 0)[0] / threadsPerWarp;
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(rewriter.getContext(), 1);
auto newEncoding = rewriter.getAttr<BlockedEncodingAttr>(
/*sizePerThread=*/1, threadsPerWarp, /*warpsPerCTA=*/1, /*order=*/0,
ctaLayout);
sizePerThread, threadsPerWarp, warpsPerCTA, /*order=*/0, ctaLayout);
builder.setEncoding(newEncoding);
return builder;
}
Expand Down