Skip to content

Commit

Permalink
[NIT][OptRed] Cleanup -tritonintelgpu-optimize-reduction-locality c…
Browse files Browse the repository at this point in the history
…ode (#2632)

Use `Base` as base pass implementation alias and
`CTALayoutAttr::getDefault` to get default `CTALayoutAttr`.

Signed-off-by: victor-eds <victor.perez@codeplay.com>
  • Loading branch information
victor-eds authored Nov 5, 2024
1 parent 29c0ece commit 1442ff4
Showing 1 changed file with 5 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ namespace mlir::triton::gpu::intel {
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"

namespace {
static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
size_t rank) {
SmallVector<unsigned> ctasPerCGA(rank, 1);
SmallVector<unsigned> ctaSplitNum(rank, 1);
SmallVector<unsigned> ctaOrder(rank);
std::iota(std::rbegin(ctaOrder), std::rend(ctaOrder), 0);
return rewriter.getAttr<CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
}

// clang-format off
/// Optimize reduction with DPAS-encoded input.
///
Expand Down Expand Up @@ -282,7 +273,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
1, 1, oldEncoding.getWarpsPerCTA()[1],
1};
std::array<unsigned, rank> order{3, 4, 5, 6, 0, 1, 2};
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);

auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
Expand Down Expand Up @@ -341,7 +332,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
dpasEncoding.getWarpsPerCTA()[0], 1,
dpasEncoding.getWarpsPerCTA()[1]};
std::array<unsigned, rank> order{3, 4, 0, 1, 2};
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);

auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
Expand All @@ -368,7 +359,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
std::array<unsigned, rank> warpsPerCTA{
1, 1, oldEncoding.getWarpsPerCTA()[2], oldEncoding.getWarpsPerCTA()[4]};
std::array<unsigned, rank> order{3, 0, 1, 2};
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);

auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
Expand Down Expand Up @@ -407,7 +398,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
dpasEncoding.getWarpsPerCTA()[1]};
std::array<unsigned, rankBeforeLastReduction> order{3, 0, 1, 2};
CTALayoutAttr ctaLayout =
getIdentityCTALayoutAttr(rewriter, rankBeforeLastReduction);
CTALayoutAttr::getDefault(getContext(), rankBeforeLastReduction);

auto blockedEncoding = rewriter.getAttr<BlockedEncodingAttr>(
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
Expand All @@ -432,9 +423,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
struct TritonIntelGPUOptimizeReductionLocality final
: impl::TritonIntelGPUOptimizeReductionLocalityBase<
TritonIntelGPUOptimizeReductionLocality> {
using impl::TritonIntelGPUOptimizeReductionLocalityBase<
TritonIntelGPUOptimizeReductionLocality>::
TritonIntelGPUOptimizeReductionLocalityBase;
using Base::Base;

void runOnOperation() final {
Operation *op = getOperation();
Expand Down

0 comments on commit 1442ff4

Please sign in to comment.