From dab19dac94eee19483ba1a7c37bdec4b8501acc3 Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Fri, 23 Aug 2024 07:34:26 -0700 Subject: [PATCH] [SLP]Fix a crash for the strided nodes with reversed order and externally used pointer. If the strided node is reversed, need to cehck for the last instruction, not the first one in the list of scalars, when checking if the root pointer must be extracted. --- .../Transforms/Vectorize/SLPVectorizer.cpp | 27 +++++++--- ...reversed-strided-node-with-external-ptr.ll | 49 +++++++++++++++++++ 2 files changed, 70 insertions(+), 6 deletions(-) create mode 100644 llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index caee3bf9c958d5..949579772b94d5 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1020,6 +1020,8 @@ static bool allSameType(ArrayRef VL) { /// possible scalar operand in vectorized instruction. static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, TargetLibraryInfo *TLI) { + if (!UserInst) + return false; unsigned Opcode = UserInst->getOpcode(); switch (Opcode) { case Instruction::Load: { @@ -2809,6 +2811,11 @@ class BoUpSLP { /// \ returns the graph entry for the \p Idx operand of the \p E entry. const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const; + /// Gets the root instruction for the given node. If the node is a strided + /// load/store node with the reverse order, the root instruction is the last + /// one. + Instruction *getRootEntryInstruction(const TreeEntry &Entry) const; + /// \returns Cast context for the given graph node. TargetTransformInfo::CastContextHint getCastContextHint(const TreeEntry &TE) const; @@ -5987,6 +5994,15 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { VectorizableTree.front()->ReorderIndices.clear(); } +Instruction *BoUpSLP::getRootEntryInstruction(const TreeEntry &Entry) const { + if ((Entry.getOpcode() == Instruction::Store || + Entry.getOpcode() == Instruction::Load) && + Entry.State == TreeEntry::StridedVectorize && + !Entry.ReorderIndices.empty() && isReverseOrder(Entry.ReorderIndices)) + return dyn_cast(Entry.Scalars[Entry.ReorderIndices.front()]); + return dyn_cast(Entry.Scalars.front()); +} + void BoUpSLP::buildExternalUses( const ExtraValueToDebugLocsMap &ExternallyUsedValues) { DenseMap ScalarToExtUses; @@ -6036,7 +6052,7 @@ void BoUpSLP::buildExternalUses( // be used. if (UseEntry->State == TreeEntry::ScatterVectorize || !doesInTreeUserNeedToExtract( - Scalar, cast(UseEntry->Scalars.front()), TLI)) { + Scalar, getRootEntryInstruction(*UseEntry), TLI)) { LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U << ".\n"); assert(!UseEntry->isGather() && "Bad state"); @@ -8450,8 +8466,8 @@ void BoUpSLP::transformNodes() { Instruction::Store, VecTy, BaseSI->getPointerOperand(), /*VariableMask=*/false, CommonAlignment, CostKind, BaseSI); if (StridedCost < OriginalVecCost) - // Strided load is more profitable than consecutive load + reverse - - // transform the node to strided load. + // Strided store is more profitable than reverse + consecutive store - + // transform the node to strided store. E.State = TreeEntry::StridedVectorize; } break; @@ -13776,7 +13792,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { ST = Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign()); } else { assert(E->State == TreeEntry::StridedVectorize && - "Expected either strided or conseutive stores."); + "Expected either strided or consecutive stores."); if (!E->ReorderIndices.empty()) { SI = cast(E->Scalars[E->ReorderIndices.front()]); Ptr = SI->getPointerOperand(); @@ -14380,8 +14396,7 @@ Value *BoUpSLP::vectorizeTree( (E->State == TreeEntry::Vectorize || E->State == TreeEntry::StridedVectorize) && doesInTreeUserNeedToExtract( - Scalar, - cast(UseEntry->Scalars.front()), + Scalar, getRootEntryInstruction(*UseEntry), TLI); })) && "Scalar with nullptr User must be registered in " diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll new file mode 100644 index 00000000000000..3fa42047162e45 --- /dev/null +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll @@ -0,0 +1,49 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S --passes=slp-vectorizer -slp-threshold=-99999 -mtriple=riscv64 -mattr=+v < %s | FileCheck %s + +define void @test(ptr %a, i64 %0) { +; CHECK-LABEL: define void @test( +; CHECK-SAME: ptr [[A:%.*]], i64 [[TMP0:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[A]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: br label %[[BB:.*]] +; CHECK: [[BB]]: +; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i64 [[TMP0]], 1 +; CHECK-NEXT: [[ARRAYIDX17_I28_1:%.*]] = getelementptr double, ptr [[A]], i64 [[TMP3]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x i64> poison, i64 [[TMP3]], i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i64> [[TMP4]], i64 0, i32 1 +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr double, <2 x ptr> [[TMP2]], <2 x i64> [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = call <2 x double> @llvm.masked.gather.v2f64.v2p0(<2 x ptr> [[TMP6]], i32 8, <2 x i1> , <2 x double> poison) +; CHECK-NEXT: [[TMP8:%.*]] = load <2 x double>, ptr [[A]], align 8 +; CHECK-NEXT: [[TMP9:%.*]] = load <2 x double>, ptr [[A]], align 8 +; CHECK-NEXT: [[TMP10:%.*]] = fsub <2 x double> [[TMP8]], [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = fsub <2 x double> [[TMP7]], [[TMP10]] +; CHECK-NEXT: call void @llvm.experimental.vp.strided.store.v2f64.p0.i64(<2 x double> [[TMP11]], ptr align 8 [[ARRAYIDX17_I28_1]], i64 -8, <2 x i1> , i32 2) +; CHECK-NEXT: br label %[[BB]] +; +entry: + br label %bb + +bb: + %indvars.iv.next239.i = add i64 0, 0 + %arrayidx.i.1 = getelementptr double, ptr %a, i64 %indvars.iv.next239.i + %1 = load double, ptr %arrayidx.i.1, align 8 + %arrayidx10.i.1 = getelementptr double, ptr %a, i64 %0 + %2 = or disjoint i64 %0, 1 + %arrayidx17.i28.1 = getelementptr double, ptr %a, i64 %2 + %3 = load double, ptr %arrayidx17.i28.1, align 8 + %4 = load double, ptr %a, align 8 + %5 = load double, ptr %a, align 8 + %arrayidx38.i.1 = getelementptr double, ptr %a, i64 1 + %6 = load double, ptr %arrayidx38.i.1, align 8 + %arrayidx41.i.1 = getelementptr double, ptr %a, i64 1 + %7 = load double, ptr %arrayidx41.i.1, align 8 + %sub47.i.1 = fsub double %4, %5 + %sub54.i.1 = fsub double %6, %7 + %sub69.i.1 = fsub double %1, %sub54.i.1 + store double %sub69.i.1, ptr %arrayidx10.i.1, align 8 + %sub72.i.1 = fsub double %3, %sub47.i.1 + store double %sub72.i.1, ptr %arrayidx17.i28.1, align 8 + br label %bb +}