Skip to content

Commit

Permalink
[SLP]Fix a crash for the strided nodes with reversed order and extern…
Browse files Browse the repository at this point in the history
…ally used pointer.

If the strided node is reversed, need to cehck for the last instruction,
not the first one in the list of scalars, when checking if the root
pointer must be extracted.
  • Loading branch information
alexey-bataev committed Aug 23, 2024
1 parent 67a9093 commit dab19da
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
27 changes: 21 additions & 6 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,8 @@ static bool allSameType(ArrayRef<Value *> VL) {
/// possible scalar operand in vectorized instruction.
static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
TargetLibraryInfo *TLI) {
if (!UserInst)
return false;
unsigned Opcode = UserInst->getOpcode();
switch (Opcode) {
case Instruction::Load: {
Expand Down Expand Up @@ -2809,6 +2811,11 @@ class BoUpSLP {
/// \ returns the graph entry for the \p Idx operand of the \p E entry.
const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const;

/// Gets the root instruction for the given node. If the node is a strided
/// load/store node with the reverse order, the root instruction is the last
/// one.
Instruction *getRootEntryInstruction(const TreeEntry &Entry) const;

/// \returns Cast context for the given graph node.
TargetTransformInfo::CastContextHint
getCastContextHint(const TreeEntry &TE) const;
Expand Down Expand Up @@ -5987,6 +5994,15 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
VectorizableTree.front()->ReorderIndices.clear();
}

Instruction *BoUpSLP::getRootEntryInstruction(const TreeEntry &Entry) const {
if ((Entry.getOpcode() == Instruction::Store ||
Entry.getOpcode() == Instruction::Load) &&
Entry.State == TreeEntry::StridedVectorize &&
!Entry.ReorderIndices.empty() && isReverseOrder(Entry.ReorderIndices))
return dyn_cast<Instruction>(Entry.Scalars[Entry.ReorderIndices.front()]);
return dyn_cast<Instruction>(Entry.Scalars.front());
}

void BoUpSLP::buildExternalUses(
const ExtraValueToDebugLocsMap &ExternallyUsedValues) {
DenseMap<Value *, unsigned> ScalarToExtUses;
Expand Down Expand Up @@ -6036,7 +6052,7 @@ void BoUpSLP::buildExternalUses(
// be used.
if (UseEntry->State == TreeEntry::ScatterVectorize ||
!doesInTreeUserNeedToExtract(
Scalar, cast<Instruction>(UseEntry->Scalars.front()), TLI)) {
Scalar, getRootEntryInstruction(*UseEntry), TLI)) {
LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
<< ".\n");
assert(!UseEntry->isGather() && "Bad state");
Expand Down Expand Up @@ -8450,8 +8466,8 @@ void BoUpSLP::transformNodes() {
Instruction::Store, VecTy, BaseSI->getPointerOperand(),
/*VariableMask=*/false, CommonAlignment, CostKind, BaseSI);
if (StridedCost < OriginalVecCost)
// Strided load is more profitable than consecutive load + reverse -
// transform the node to strided load.
// Strided store is more profitable than reverse + consecutive store -
// transform the node to strided store.
E.State = TreeEntry::StridedVectorize;
}
break;
Expand Down Expand Up @@ -13776,7 +13792,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
ST = Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
} else {
assert(E->State == TreeEntry::StridedVectorize &&
"Expected either strided or conseutive stores.");
"Expected either strided or consecutive stores.");
if (!E->ReorderIndices.empty()) {
SI = cast<StoreInst>(E->Scalars[E->ReorderIndices.front()]);
Ptr = SI->getPointerOperand();
Expand Down Expand Up @@ -14380,8 +14396,7 @@ Value *BoUpSLP::vectorizeTree(
(E->State == TreeEntry::Vectorize ||
E->State == TreeEntry::StridedVectorize) &&
doesInTreeUserNeedToExtract(
Scalar,
cast<Instruction>(UseEntry->Scalars.front()),
Scalar, getRootEntryInstruction(*UseEntry),
TLI);
})) &&
"Scalar with nullptr User must be registered in "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S --passes=slp-vectorizer -slp-threshold=-99999 -mtriple=riscv64 -mattr=+v < %s | FileCheck %s

define void @test(ptr %a, i64 %0) {
; CHECK-LABEL: define void @test(
; CHECK-SAME: ptr [[A:%.*]], i64 [[TMP0:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[A]], i32 0
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: br label %[[BB:.*]]
; CHECK: [[BB]]:
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i64 [[TMP0]], 1
; CHECK-NEXT: [[ARRAYIDX17_I28_1:%.*]] = getelementptr double, ptr [[A]], i64 [[TMP3]]
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x i64> poison, i64 [[TMP3]], i32 0
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i64> [[TMP4]], i64 0, i32 1
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr double, <2 x ptr> [[TMP2]], <2 x i64> [[TMP5]]
; CHECK-NEXT: [[TMP7:%.*]] = call <2 x double> @llvm.masked.gather.v2f64.v2p0(<2 x ptr> [[TMP6]], i32 8, <2 x i1> <i1 true, i1 true>, <2 x double> poison)
; CHECK-NEXT: [[TMP8:%.*]] = load <2 x double>, ptr [[A]], align 8
; CHECK-NEXT: [[TMP9:%.*]] = load <2 x double>, ptr [[A]], align 8
; CHECK-NEXT: [[TMP10:%.*]] = fsub <2 x double> [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[TMP11:%.*]] = fsub <2 x double> [[TMP7]], [[TMP10]]
; CHECK-NEXT: call void @llvm.experimental.vp.strided.store.v2f64.p0.i64(<2 x double> [[TMP11]], ptr align 8 [[ARRAYIDX17_I28_1]], i64 -8, <2 x i1> <i1 true, i1 true>, i32 2)
; CHECK-NEXT: br label %[[BB]]
;
entry:
br label %bb

bb:
%indvars.iv.next239.i = add i64 0, 0
%arrayidx.i.1 = getelementptr double, ptr %a, i64 %indvars.iv.next239.i
%1 = load double, ptr %arrayidx.i.1, align 8
%arrayidx10.i.1 = getelementptr double, ptr %a, i64 %0
%2 = or disjoint i64 %0, 1
%arrayidx17.i28.1 = getelementptr double, ptr %a, i64 %2
%3 = load double, ptr %arrayidx17.i28.1, align 8
%4 = load double, ptr %a, align 8
%5 = load double, ptr %a, align 8
%arrayidx38.i.1 = getelementptr double, ptr %a, i64 1
%6 = load double, ptr %arrayidx38.i.1, align 8
%arrayidx41.i.1 = getelementptr double, ptr %a, i64 1
%7 = load double, ptr %arrayidx41.i.1, align 8
%sub47.i.1 = fsub double %4, %5
%sub54.i.1 = fsub double %6, %7
%sub69.i.1 = fsub double %1, %sub54.i.1
store double %sub69.i.1, ptr %arrayidx10.i.1, align 8
%sub72.i.1 = fsub double %3, %sub47.i.1
store double %sub72.i.1, ptr %arrayidx17.i28.1, align 8
br label %bb
}

0 comments on commit dab19da

Please sign in to comment.