diff --git a/test/TritonIntelGPU/optimize-elementwise.mlir b/test/TritonIntelGPU/optimize-elementwise.mlir new file mode 100644 index 0000000000..d8b64bab89 --- /dev/null +++ b/test/TritonIntelGPU/optimize-elementwise.mlir @@ -0,0 +1,65 @@ +// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-elementwise-parallelism | FileCheck %s + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_dpas( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) + tt.func @test_dpas(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { +// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<16xf32, #[[$ATTR_0]]> + %0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> +// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: tt.return %[[VAL_5]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_blocked( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) + tt.func @test_blocked(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { +// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<16xf32, #[[$ATTR_1]]> + %0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> +// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<16xf32, #[[$ATTR_1]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> +// CHECK: tt.return %[[VAL_5]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_blocked_repeat( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) + tt.func @test_blocked_repeat(%arg0: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { +// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<64xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<64xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<64xf32, #[[$ATTR_1]]> + %0 = arith.addf %arg0, %arg1 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> +// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<64xf32, #[[$ATTR_1]]> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> +// CHECK: tt.return %[[VAL_5]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> + tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index c551a96856..1d81bc4741 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -365,4 +365,52 @@ tt.func @test(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #triton_gpu.slic "mlir::triton::gpu::TritonGPUDialect"]; } +def TritonIntelGPUOptimizeElementwiseParallelism + : Pass<"tritonintelgpu-optimize-elementwise-parallelism", "mlir::ModuleOp"> { + let summary = + "Improve parallelism of elementwise operations better utilizing hardware resources."; + + let description = [{ + Detect elementwise operations with an encoding causing sub-par parallelism, + i.e., with data duplication across threads, and convert the operands to a + more optimal encoding if the cost of doing so is heuristically estimated to + be sufficiently low. As of now, the cost should be 0, we only support + "unbroadcasting" tensors, i.e., dropping duplicated values held in other + threads by re-distributing them. + + As an example, this pass would modify the following code: +```mlir +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func @test_blocked(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { + %0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } +} +``` + Obtaining: +```mlir +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func @test_blocked(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { + %0 = triton_gpu.convert_layout %arg0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #blocked1> + %1 = triton_gpu.convert_layout %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #blocked1> + %2 = arith.addf %0, %1 : tensor<16xf32, #blocked1> + %3 = triton_gpu.convert_layout %2 : tensor<16xf32, #blocked1> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + tt.return %3 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } +} +``` + + Note how the converted tensors are not sliced and thus each element in the + tensor is held by a single thread. + }]; + + let dependentDialects = []; +} + + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index dbc641e2a3..46d121a070 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeElementwiseParallelism.cpp OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp Pipeliner/SoftwarePipeliner.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp new file mode 100644 index 0000000000..1bd154306d --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp @@ -0,0 +1,160 @@ +//===- OptimizeElementwiseParallelism.cpp -------------------------------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// This file implements the `tritonintelgpu-optimize-elementwise-parallelism` +/// pass. +//===----------------------------------------------------------------------===// + +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "tritonintelgpu-optimize-elementwise-parallelism" + +namespace mlir::triton::gpu::intel { +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEELEMENTWISEPARALLELISM +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +namespace { +/// Return whether the input linear layout can be unbroadcasted. +/// +/// A layout is valid for being "unbroadcasted" along its lanes if: +/// - The 'lane' input dimension is zero: this means the lane dimension has been +/// sliced. +/// - The size of the input 'block' dimension is 1. This is true for XPU +/// backend. +/// - The size of the input 'warp' dimension is 1. This is a limitation to keep +/// things simple for now. +/// +/// Broadcasted layouts are layouts with sliced lane, warp or block (not +/// possible for XPU backend) dimensions, i.e., the same data is owned by +/// different threads. +bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout, + PatternRewriter &rewriter) { + StringAttr kLane = rewriter.getStringAttr("lane"); + StringAttr kWarp = rewriter.getStringAttr("warp"); + StringAttr kBlock = rewriter.getStringAttr("block"); + StringAttr kDim0 = rewriter.getStringAttr("dim0"); + // 'lane' dimension must have been sliced away completely. + if (!linearLayout.sublayoutIsZero(kLane, kDim0)) + return false; + // Only single block for now. + if (linearLayout.getInDimSize(kBlock) != 1) + return false; + // Only single warp for now. + return linearLayout.getInDimSize(kWarp) == 1; +} + +/// Get optimized unbroadcasted tensor type. +/// +/// Get optimized ranked tensor type after unbroadcasting. As we only support 1D +/// tensors, this is as simple as getting an "unboradcasted" blocked-encoded 1D +/// tensor type. +RankedTensorType getOptimizedType(RankedTensorType type, + const LinearLayout &linearLayout, + PatternRewriter &rewriter) { + auto encoding = cast(type.getEncoding()); + unsigned threadsPerWarp = product(encoding.getThreadsPerWarp()); + [[maybe_unused]] unsigned warpsPerCTA = product(encoding.getWarpsPerCTA()); + assert(warpsPerCTA == 1 && "Expecting single warp"); + [[maybe_unused]] unsigned ctaSplitNum = product(encoding.getCTASplitNum()); + assert(ctaSplitNum == 1 && "Expecting single CTA"); + + RankedTensorType::Builder builder(type); + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(rewriter.getContext(), 1); + auto newEncoding = rewriter.getAttr( + /*sizePerThread=*/1, threadsPerWarp, /*warpsPerCTA=*/1, /*order=*/0, + ctaLayout); + builder.setEncoding(newEncoding); + return builder; +} + +struct ElementwiseOptPattern final + : OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + // Rely on this for a simpler pass. + if (!op->hasTrait() || + op->getNumResults() != 1) + return failure(); + + // Skip complex operations. + if (op->hasSuccessors() || op->getNumRegions() != 0) + return failure(); + + // Layout optimizations only apply to tensors. + auto type = dyn_cast(op->getResultTypes().front()); + if (!type) + return failure(); + + // Check if the layout is actually bad and can be optimized using our + // approach. We only support 1D tensors for now as these are easier to + // handle. + Attribute layout = type.getEncoding(); + if (!layout || type.getRank() != 1) + return failure(); + std::optional linearLayout = + toLinearLayout(type.getShape(), layout); + if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter)) + return failure(); + + // Check the operands are not used by other operations. This will prevent + // register pressure increase: + if (!llvm::all_of(op->getOperands(), + [](Value val) { return val.hasOneUse(); })) + return failure(); + + // As we are dealing with 1D tensors, we can do a simple transform to obtain + // a more optimized operation. + Location loc = op->getLoc(); + RankedTensorType newType = getOptimizedType(type, *linearLayout, rewriter); + SmallVector newOperands(op->getNumOperands()); + llvm::transform(op->getOperands(), std::begin(newOperands), + [&rewriter, loc, newType](Value operand) { + return rewriter.create(loc, newType, + operand); + }); + + // Now we create the optimized operation: + StringAttr opName = op->getName().getIdentifier(); + ArrayRef attributes = op->getAttrs(); + Operation *newElementwiseOp = + rewriter.create(loc, opName, newOperands, newType, attributes); + assert(newElementwiseOp->getNumResults() == 1 && + "Expecting single result operation"); + + // Convert to unoptimized encoding for further use. + Value newValue = newElementwiseOp->getResult(0); + rewriter.replaceOpWithNewOp(op, type, newValue); + + return success(); + } +}; + +struct TritonIntelGPUOptimizeElementwiseParallelism final + : impl::TritonIntelGPUOptimizeElementwiseParallelismBase< + TritonIntelGPUOptimizeElementwiseParallelism> { + using Base::Base; + + void runOnOperation() final { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace +} // namespace mlir::triton::gpu::intel