diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index a5571409bba682..0c8aee8a494c03 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -240,7 +240,8 @@ class InferAddressSpacesImpl { SmallVectorImpl *PoisonUsesToFix) const; unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const; - unsigned getPredicatedAddrSpace(const Value &V, Value *Opnd) const; + unsigned getPredicatedAddrSpace(const Value &PtrV, + const Instruction *UserCtxI) const; public: InferAddressSpacesImpl(AssumptionCache &AC, const DominatorTree *DT, @@ -909,18 +910,14 @@ void InferAddressSpacesImpl::inferAddressSpaces( } } -unsigned InferAddressSpacesImpl::getPredicatedAddrSpace(const Value &V, - Value *Opnd) const { - const Instruction *I = dyn_cast(&V); - if (!I) - return UninitializedAddressSpace; - - Opnd = Opnd->stripInBoundsOffsets(); - for (auto &AssumeVH : AC.assumptionsFor(Opnd)) { +unsigned InferAddressSpacesImpl::getPredicatedAddrSpace( + const Value &Ptr, const Instruction *UserCtxI) const { + const Value *StrippedPtr = Ptr.stripInBoundsOffsets(); + for (auto &AssumeVH : AC.assumptionsFor(StrippedPtr)) { if (!AssumeVH) continue; CallInst *CI = cast(AssumeVH); - if (!isValidAssumeForContext(CI, I, DT)) + if (!isValidAssumeForContext(CI, UserCtxI, DT)) continue; const Value *Ptr; @@ -989,7 +986,8 @@ bool InferAddressSpacesImpl::updateAddressSpace( OperandAS = PtrOperand->getType()->getPointerAddressSpace(); if (OperandAS == FlatAddrSpace) { // Check AC for assumption dominating V. - unsigned AS = getPredicatedAddrSpace(V, PtrOperand); + unsigned AS = + getPredicatedAddrSpace(*PtrOperand, &cast(V)); if (AS != UninitializedAddressSpace) { LLVM_DEBUG(dbgs() << " deduce operand AS from the predicate addrspace "