Skip to content

Commit

Permalink
[VectorCombine] Add foldShuffleToIdentity (llvm#88693)
Browse files Browse the repository at this point in the history
This patch adds a basic version of a combine that attempts to remove
shuffles that when combined simplify away to an identity shuffle. For
example:
%ab = shufflevector <8 x half> %a, <8 x half> poison, <4 x i32> <i32 3,
i32 2, i32 1, i32 0>
%at = shufflevector <8 x half> %a, <8 x half> poison, <4 x i32> <i32 7,
i32 6, i32 5, i32 4>
  %abt = fneg <4 x half> %at
  %abb = fneg <4 x half> %ab
%r = shufflevector <4 x half> %abt, <4 x half> %abb, <8 x i32> <i32 7,
i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
By looking through the shuffles and fneg, it can be simplified to:
  %r = fneg <8 x half> %a

The code tracks each lane starting from the original shuffle, keeping a
track of a vector of {src, idx}. As we propagate up through the
instructions we will either look through intermediate instructions
(binops and unops) or see a collections of lanes that all have the same
src and incrementing idx (an identity). We can also see a single value
with identical lanes, which we can treat like a splat.

Only the basic version is added here, handling identities, splats,
binops and unops. In follow-up patches other instructions can be added
such as constants, intrinsics, cmp/sel and zext/sext/trunc.
  • Loading branch information
davemgreen authored May 3, 2024
1 parent 46c2d93 commit a4d1026
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 146 deletions.
147 changes: 147 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class VectorCombine {
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
bool foldTruncFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
Expand Down Expand Up @@ -1667,6 +1668,151 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
return true;
}

// Starting from a shuffle, look up through operands tracking the shuffled index
// of each lane. If we can simplify away the shuffles to identities then
// do so.
bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
auto *Ty = dyn_cast<FixedVectorType>(I.getType());
if (!Ty || !isa<Instruction>(I.getOperand(0)) ||
!isa<Instruction>(I.getOperand(1)))
return false;

using InstLane = std::pair<Value *, int>;

auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
unsigned NumElts =
cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
int M = SV->getMaskValue(Lane);
if (M < 0)
return {nullptr, PoisonMaskElem};
else if (M < (int)NumElts) {
V = SV->getOperand(0);
Lane = M;
} else {
V = SV->getOperand(1);
Lane = M - NumElts;
}
}
return InstLane{V, Lane};
};

auto GenerateInstLaneVectorFromOperand =
[&LookThroughShuffles](ArrayRef<InstLane> Item, int Op) {
SmallVector<InstLane> NItem;
for (InstLane V : Item) {
NItem.emplace_back(
!V.first
? InstLane{nullptr, PoisonMaskElem}
: LookThroughShuffles(
cast<Instruction>(V.first)->getOperand(Op), V.second));
}
return NItem;
};

SmallVector<InstLane> Start(Ty->getNumElements());
for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
Start[M] = LookThroughShuffles(&I, M);

SmallVector<SmallVector<InstLane>> Worklist;
Worklist.push_back(Start);
SmallPtrSet<Value *, 4> IdentityLeafs, SplatLeafs;
unsigned NumVisited = 0;

while (!Worklist.empty()) {
SmallVector<InstLane> Item = Worklist.pop_back_val();
if (++NumVisited > MaxInstrsToScan)
return false;

// If we found an undef first lane then bail out to keep things simple.
if (!Item[0].first)
return false;

// Look for an identity value.
if (Item[0].second == 0 && Item[0].first->getType() == Ty &&
all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
return !E.value().first || (E.value().first == Item[0].first &&
E.value().second == (int)E.index());
})) {
IdentityLeafs.insert(Item[0].first);
continue;
}
// Look for a splat value.
if (all_of(drop_begin(Item), [&](InstLane &IL) {
return !IL.first ||
(IL.first == Item[0].first && IL.second == Item[0].second);
})) {
SplatLeafs.insert(Item[0].first);
continue;
}

// We need each element to be the same type of value, and check that each
// element has a single use.
if (!all_of(drop_begin(Item), [&](InstLane IL) {
if (!IL.first)
return true;
if (auto *I = dyn_cast<Instruction>(IL.first); I && !I->hasOneUse())
return false;
if (IL.first->getValueID() != Item[0].first->getValueID())
return false;
auto *II = dyn_cast<IntrinsicInst>(IL.first);
return !II ||
II->getIntrinsicID() ==
cast<IntrinsicInst>(Item[0].first)->getIntrinsicID();
}))
return false;

// Check the operator is one that we support. We exclude div/rem in case
// they hit UB from poison lanes.
if (isa<BinaryOperator>(Item[0].first) &&
!cast<BinaryOperator>(Item[0].first)->isIntDivRem()) {
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 1));
} else if (isa<UnaryOperator>(Item[0].first)) {
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
} else {
return false;
}
}

// If we got this far, we know the shuffles are superfluous and can be
// removed. Scan through again and generate the new tree of instructions.
std::function<Value *(ArrayRef<InstLane>)> Generate =
[&](ArrayRef<InstLane> Item) -> Value * {
if (IdentityLeafs.contains(Item[0].first) &&
all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
return !E.value().first || (E.value().first == Item[0].first &&
E.value().second == (int)E.index());
})) {
return Item[0].first;
}
if (SplatLeafs.contains(Item[0].first)) {
if (auto ILI = dyn_cast<Instruction>(Item[0].first))
Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
else if (isa<Argument>(Item[0].first))
Builder.SetInsertPointPastAllocas(I.getParent()->getParent());
SmallVector<int, 16> Mask(Ty->getNumElements(), Item[0].second);
return Builder.CreateShuffleVector(Item[0].first, Mask);
}

auto *I = cast<Instruction>(Item[0].first);
SmallVector<Value *> Ops(I->getNumOperands());
for (unsigned Idx = 0, E = I->getNumOperands(); Idx < E; Idx++)
Ops[Idx] = Generate(GenerateInstLaneVectorFromOperand(Item, Idx));
Builder.SetInsertPoint(I);
if (auto BI = dyn_cast<BinaryOperator>(I))
return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
Ops[0], Ops[1]);
assert(isa<UnaryInstruction>(I) &&
"Unexpected instruction type in Generate");
return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
};

Value *V = Generate(Start);
replaceValue(I, *V);
return true;
}

/// Given a commutative reduction, the order of the input lanes does not alter
/// the results. We can use this to remove certain shuffles feeding the
/// reduction, removing the need to shuffle at all.
Expand Down Expand Up @@ -2224,6 +2370,7 @@ bool VectorCombine::run() {
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldShuffleOfShuffles(I);
MadeChange |= foldSelectShuffle(I);
MadeChange |= foldShuffleToIdentity(I);
break;
case Instruction::BitCast:
MadeChange |= foldBitcastShuffle(I);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@ define void @add4(ptr noalias noundef %x, ptr noalias noundef %y, i32 noundef %n
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <32 x i16>, ptr [[TMP0]], align 2
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[X]], i64 [[OFFSET_IDX]]
; CHECK-NEXT: [[WIDE_VEC24:%.*]] = load <32 x i16>, ptr [[TMP1]], align 2
; CHECK-NEXT: [[TMP2:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP3:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP4:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP5:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[TMP6:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP5]]
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <32 x i16> [[TMP2]], <32 x i16> [[TMP3]], <16 x i32> <i32 0, i32 4, i32 8, i32 12, i32 16, i32 20, i32 24, i32 28, i32 33, i32 37, i32 41, i32 45, i32 49, i32 53, i32 57, i32 61>
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <32 x i16> [[TMP4]], <32 x i16> [[TMP6]], <16 x i32> <i32 2, i32 6, i32 10, i32 14, i32 18, i32 22, i32 26, i32 30, i32 35, i32 39, i32 43, i32 47, i32 51, i32 55, i32 59, i32 63>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <16 x i16> [[TMP7]], <16 x i16> [[TMP8]], <32 x i32> <i32 0, i32 8, i32 16, i32 24, i32 1, i32 9, i32 17, i32 25, i32 2, i32 10, i32 18, i32 26, i32 3, i32 11, i32 19, i32 27, i32 4, i32 12, i32 20, i32 28, i32 5, i32 13, i32 21, i32 29, i32 6, i32 14, i32 22, i32 30, i32 7, i32 15, i32 23, i32 31>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP2:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP2]]
; CHECK-NEXT: store <32 x i16> [[INTERLEAVED_VEC]], ptr [[GEP]], align 2
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP9]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: for.end:
; CHECK-NEXT: ret void
;
Expand Down Expand Up @@ -412,22 +406,13 @@ define void @addmul(ptr noalias noundef %x, ptr noundef %y, ptr noundef %z, i32
; CHECK-NEXT: [[TMP2:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X]], i64 [[OFFSET_IDX]]
; CHECK-NEXT: [[WIDE_VEC36:%.*]] = load <32 x i16>, ptr [[TMP3]], align 2
; CHECK-NEXT: [[TMP4:%.*]] = add <32 x i16> [[TMP2]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP5:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP6:%.*]] = add <32 x i16> [[TMP5]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP7:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP8:%.*]] = add <32 x i16> [[TMP7]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP9:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[TMP10:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP11:%.*]] = add <32 x i16> [[TMP10]], [[WIDE_VEC36]]
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP9]]
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <32 x i16> [[TMP4]], <32 x i16> [[TMP6]], <16 x i32> <i32 0, i32 4, i32 8, i32 12, i32 16, i32 20, i32 24, i32 28, i32 33, i32 37, i32 41, i32 45, i32 49, i32 53, i32 57, i32 61>
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <32 x i16> [[TMP8]], <32 x i16> [[TMP11]], <16 x i32> <i32 2, i32 6, i32 10, i32 14, i32 18, i32 22, i32 26, i32 30, i32 35, i32 39, i32 43, i32 47, i32 51, i32 55, i32 59, i32 63>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <16 x i16> [[TMP12]], <16 x i16> [[TMP13]], <32 x i32> <i32 0, i32 8, i32 16, i32 24, i32 1, i32 9, i32 17, i32 25, i32 2, i32 10, i32 18, i32 26, i32 3, i32 11, i32 19, i32 27, i32 4, i32 12, i32 20, i32 28, i32 5, i32 13, i32 21, i32 29, i32 6, i32 14, i32 22, i32 30, i32 7, i32 15, i32 23, i32 31>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[TMP2]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP4]]
; CHECK-NEXT: store <32 x i16> [[INTERLEAVED_VEC]], ptr [[GEP]], align 2
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP14]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP5]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
; CHECK: for.end:
; CHECK-NEXT: ret void
;
Expand Down
Loading

0 comments on commit a4d1026

Please sign in to comment.