Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU][OptEW] Define -intel-triton-optimize-elementwise-parallelism pass #2631

Merged
merged 6 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/TritonIntelGPU/optimize-elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-elementwise-parallelism | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// CHECK-LABEL: tt.func @test_two_convert_layout(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
tt.func @test_two_convert_layout(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%1 = triton_gpu.convert_layout %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_2]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
%2 = arith.addf %0, %1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: tt.return %[[VAL_3]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
tt.return %2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// CHECK: #[[$ATTR_2:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
victor-eds marked this conversation as resolved.
Show resolved Hide resolved
// CHECK: #[[$ATTR_3:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// CHECK-LABEL: tt.func @test_convert_layout_splat(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: f32
tt.func @test_convert_layout_splat(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: f32) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
%0 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: %[[VAL_2:.*]] = tt.splat %[[VAL_1]] : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>>
%1 = tt.splat %arg1 : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: %[[VAL_3:.*]] = arith.addf %[[VAL_0]], %[[VAL_2]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>>
// CHECK: %[[VAL_4:.*]] = triton_gpu.convert_layout %[[VAL_3]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_2]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
etiotto marked this conversation as resolved.
Show resolved Hide resolved
%2 = arith.addf %0, %1 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: tt.return %[[VAL_4]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
tt.return %2 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}

// -----

// CHECK: #[[$ATTR_4:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
// CHECK: #[[$ATTR_5:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// CHECK-LABEL: tt.func @test_chain(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>>, %[[VAL_1:.*]]: tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>>,
// CHECK-SAME: %[[VAL_2:.*]]: f32
tt.func @test_chain(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg2: f32) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
// CHECK: %[[VAL_3:.*]] = arith.addf %[[VAL_0]], %[[VAL_1]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>>
%0 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%1 = triton_gpu.convert_layout %arg1 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: %[[VAL_4:.*]] = tt.splat %[[VAL_2]] : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>>
%2 = tt.splat %arg2 : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%3 = arith.addf %0, %1 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_3]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layout conversion that would result from %[[VAL_3]] conversion is erased by the pass as we can see

// CHECK: %[[VAL_6:.*]] = triton_gpu.convert_layout %[[VAL_5]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_4]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>>
%4 = arith.addf %2, %3 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
// CHECK: tt.return %[[VAL_6]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_5]]}>>
tt.return %4 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,46 @@ tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slic
"mlir::triton::gpu::TritonGPUDialect"];
}

def TritonIntelGPUOptimizeElementwiseParallelism
: Pass<"tritonintelgpu-optimize-elementwise-parallelism", "mlir::ModuleOp"> {
let summary =
"Improve parallelism of elementwise operations utilizing more hardware resources";

let description = [{
Detect elementwise operations with an encoding causing sub-par parallelism,
i.e., with data duplication across threads, and convert the operands to a
more optimal encoding if the cost of doing so is not too high according to
victor-eds marked this conversation as resolved.
Show resolved Hide resolved
some heuristics. As of now, the cost should be 0, meaning we only support
postponing layout conversions and modify scalar splat operations.

As an example, this pass would modify the following code:
```mlir
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

tt.func @test_two_convert_layout(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%1 = triton_gpu.convert_layout %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = arith.addf %0, %1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return %2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
```
Obtaining:
```mlir
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
module {
tt.func @test_two_convert_layout(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
%0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%1 = triton_gpu.convert_layout %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return %1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
}
```
}];

let dependentDialects = ["mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we find an elementwise operation from dialect X, that means the dialect has been loaded already, so no need to have it as a dependent dialect in order to create this. We want to include these two for future proofing, tho.

}

#endif // TRITON_INTEL_GPU_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_triton_library(TritonIntelGPUTransforms
DistributeToWarps.cpp
MatchTargetSize.cpp
MaterializeBlockPointer.cpp
OptimizeElementwiseParallelism.cpp
OptimizeReductionLocality.cpp
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/SoftwarePipeliner.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
//===- OptimizeElementwiseParallelism.cpp -------------------------------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/// This file implements the `tritonintelgpu-optimize-elementwise-parallelism`
/// pass.
//===----------------------------------------------------------------------===//

#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#define DEBUG_TYPE "tritonintelgpu-optimize-elementwise-parallelism"

namespace mlir::triton::gpu::intel {
#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEELEMENTWISEPARALLELISM
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"

namespace {
bool isBadLayoutForElementwise(Attribute layout) {
// We only support 'triton_gpu.slice' for now.
auto slicedLayout = dyn_cast<SliceEncodingAttr>(layout);
if (!slicedLayout)
return false;

// Check the parent layout is squeezed across a dimension with more than one
// warp per CTA or thread per warp, i.e., there is data duplication across
// threads along that dimension.
unsigned dim = slicedLayout.getDim();
auto parentLayout = cast<DistributedEncodingTrait>(slicedLayout.getParent());
return parentLayout.getWarpsPerCTA()[dim] != 1 ||
parentLayout.getThreadsPerWarp()[dim] != 1;
}

Value convertToCheaperLayout(Location loc, Value val, Attribute newLayout,
PatternRewriter &rewriter) {
assert(newLayout && "Expecting valid layout");
return TypeSwitch<Operation *, Value>(val.getDefiningOp())
.Case([loc, newLayout, val, &rewriter](SplatOp splat) {
// This is a cost <= 0 conversion as:
// - If the splat is used by other operation, we just don't use all the
// duplicated elements in our elementwise operation.
// - If the splat is not used by other operations, we reduce data
// duplication and possibly even calculation for this data.
RankedTensorType type =
RankedTensorType::Builder(splat.getResult().getType())
.setEncoding(newLayout);
return rewriter.create<SplatOp>(loc, type, splat.getSrc());
})
.Case([](ConvertLayoutOp convertLayout) {
// This is a cost = 0 conversion as we ensured no other op is using the
// layout conversion result.
return convertLayout.getSrc();
});
}

Value convertToOriginalLayout(Location loc, Value val, Attribute layout,
PatternRewriter &rewriter) {
RankedTensorType type =
RankedTensorType::Builder(cast<RankedTensorType>(val.getType()))
.setEncoding(layout);
return rewriter.create<ConvertLayoutOp>(loc, type, val);
}

class AttributeAcc {
public:
static AttributeAcc id() { return AttributeAcc(std::nullopt); }
static AttributeAcc error() { return AttributeAcc(Attribute()); }

AttributeAcc() = default;
AttributeAcc(Attribute value) : value(value) {}

friend bool operator==(AttributeAcc lhs, AttributeAcc rhs) {
return lhs.value == rhs.value;
}

friend bool operator!=(AttributeAcc lhs, AttributeAcc rhs) {
return !(lhs == rhs);
}

friend AttributeAcc operator+(AttributeAcc lhs, AttributeAcc rhs) {
if (lhs == error() || rhs == error())
return error();
if (lhs == id())
return rhs;
if (rhs == id())
return lhs;
if (lhs != rhs)
return error();
return lhs;
}

Attribute operator*() const {
assert(*this != id() && *this != error() && "Expecting valid layout");
return *value;
}

private:
AttributeAcc(std::optional<Attribute> value) : value(value) {}

std::optional<Attribute> value;
};

AttributeAcc getCheapLayoutToConvertTo(Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return AttributeAcc::error();
return TypeSwitch<Operation *, AttributeAcc>(op)
.Case([](SplatOp splat) {
// Do not support tensor splats, just scalar splats.
return isa<RankedTensorType>(splat.getSrc().getType())
? AttributeAcc::error()
: AttributeAcc::id();
})
.Case([](ConvertLayoutOp convertLayout) -> AttributeAcc {
// If the layout conversion has more than one user, this may worsen
// register pressure, as data would need to coexist in both layouts at
// the same time in registers.
// TODO: Extend with heuristics to check this is cheap to do.
if (!convertLayout->hasOneUse())
return AttributeAcc::error();
return convertLayout.getSrc().getType().getEncoding();
})
.Default(AttributeAcc::error());
}

AttributeAcc accumulateCheapLayoutToConvertTo(AttributeAcc acc, Value val) {
return acc + getCheapLayoutToConvertTo(val);
}

struct ElementwiseOptPattern final
: OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern<OpTrait::Elementwise>::OpTraitRewritePattern;

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
// Rely on this for a simpler pass.
if (!op->hasTrait<OpTrait::SameOperandsAndResultType>() ||
op->getNumResults() != 1)
return failure();

// Layout optimizations only apply to tensors.
auto type = dyn_cast<RankedTensorType>(op->getResultTypes().front());
if (!type)
return failure();

// Skip complex operations.
if (op->hasSuccessors() || op->getNumRegions() != 0)
return failure();

// Check if the layout is actually bad.
Attribute layout = type.getEncoding();
if (!layout || !isBadLayoutForElementwise(layout))
return failure();

// Check if we can convert the operands to a common optimal layout.
AttributeAcc layoutAcc =
std::accumulate(op->operand_begin(), op->operand_end(),
AttributeAcc::id(), accumulateCheapLayoutToConvertTo);
if (layoutAcc == AttributeAcc::error() || layoutAcc == AttributeAcc::id())
return failure();

// Check the new layout is good for elementwise operations.
// TODO: Provide heuristics to check it's *better* than the original one
// instead.
Attribute newLayout = *layoutAcc;
assert(newLayout && "Expecting valid layout");
if (isBadLayoutForElementwise(newLayout))
return failure();

// Replace operation with new operation taking operands with a more optimal
// layout.
Location loc = op->getLoc();
StringAttr opName = op->getName().getIdentifier();
SmallVector<Value> newOperands(op->getNumOperands());
llvm::transform(op->getOperands(), std::begin(newOperands),
[loc, newLayout, &rewriter](Value val) {
return convertToCheaperLayout(loc, val, newLayout,
rewriter);
});
Type newType = newOperands.front().getType();
ArrayRef<NamedAttribute> attributes = op->getAttrs();
Operation *newElementwiseOp =
rewriter.create(loc, opName, newOperands, newType, attributes);
assert(newElementwiseOp->getNumResults() == 1 &&
"Expecting single result operation");

// Convert the result back to the original layout for type consistency.
Value newOp = convertToOriginalLayout(loc, newElementwiseOp->getResult(0),
layout, rewriter);

rewriter.replaceOp(op, newOp);
return success();
}
};

struct TritonIntelGPUOptimizeElementwiseParallelism final
: impl::TritonIntelGPUOptimizeElementwiseParallelismBase<
TritonIntelGPUOptimizeElementwiseParallelism> {
using Base::Base;

void runOnOperation() final {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
RewritePatternSet patterns(ctx);
patterns.add<ElementwiseOptPattern>(ctx);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace
} // namespace mlir::triton::gpu::intel