Skip to content

Commit

Permalink
[TTI] Support scalable offsets in getScalingFactorCost (llvm#88113)
Browse files Browse the repository at this point in the history
Part of the work to support vscale-relative immediates in LSR.
  • Loading branch information
huntergr-arm committed May 10, 2024
1 parent 64d4ade commit 2e8d815
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 20 deletions.
6 changes: 3 additions & 3 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ class TargetTransformInfo {
/// If the AM is not supported, it returns a negative value.
/// TODO: Handle pre/postinc as well.
InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale,
unsigned AddrSpace = 0) const;

Expand Down Expand Up @@ -1891,7 +1891,7 @@ class TargetTransformInfo::Concept {
virtual bool hasVolatileVariant(Instruction *I, unsigned AddrSpace) = 0;
virtual bool prefersVectorizedAddressing() = 0;
virtual InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset,
StackOffset BaseOffset,
bool HasBaseReg, int64_t Scale,
unsigned AddrSpace) = 0;
virtual bool LSRWithInstrQueries() = 0;
Expand Down Expand Up @@ -2403,7 +2403,7 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.prefersVectorizedAddressing();
}
InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale,
unsigned AddrSpace) override {
return Impl.getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
Expand Down
8 changes: 5 additions & 3 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Function;
/// Base class for use as a mix-in that aids implementing
/// a TargetTransformInfo-compatible class.
class TargetTransformInfoImplBase {

protected:
typedef TargetTransformInfo TTI;

Expand Down Expand Up @@ -326,12 +327,13 @@ class TargetTransformInfoImplBase {
bool prefersVectorizedAddressing() const { return true; }

InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale,
unsigned AddrSpace) const {
// Guess that all legal addressing mode are free.
if (isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
AddrSpace))
if (isLegalAddressingMode(Ty, BaseGV, BaseOffset.getFixed(), HasBaseReg,
Scale, AddrSpace, /*I=*/nullptr,
BaseOffset.getScalable()))
return 0;
return -1;
}
Expand Down
5 changes: 3 additions & 2 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
}

InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale, unsigned AddrSpace) {
TargetLoweringBase::AddrMode AM;
AM.BaseGV = BaseGV;
AM.BaseOffs = BaseOffset;
AM.BaseOffs = BaseOffset.getFixed();
AM.HasBaseReg = HasBaseReg;
AM.Scale = Scale;
AM.ScalableOffset = BaseOffset.getScalable();
if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
return 0;
return -1;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ bool TargetTransformInfo::prefersVectorizedAddressing() const {
}

InstructionCost TargetTransformInfo::getScalingFactorCost(
Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg,
Type *Ty, GlobalValue *BaseGV, StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale, unsigned AddrSpace) const {
InstructionCost Cost = TTIImpl->getScalingFactorCost(
Ty, BaseGV, BaseOffset, HasBaseReg, Scale, AddrSpace);
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4183,7 +4183,7 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {

InstructionCost
AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale, unsigned AddrSpace) const {
// Scaling factors are not free at all.
// Operands | Rt Latency
Expand All @@ -4194,9 +4194,10 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
// Rt, [Xn, Wm, <extend> #imm] |
TargetLoweringBase::AddrMode AM;
AM.BaseGV = BaseGV;
AM.BaseOffs = BaseOffset;
AM.BaseOffs = BaseOffset.getFixed();
AM.HasBaseReg = HasBaseReg;
AM.Scale = Scale;
AM.ScalableOffset = BaseOffset.getScalable();
if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
// Scale represents reg2 * scale, thus account for 1 if
// it is not equal to 0 or 1.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
/// If the AM is supported, the return value must be >= 0.
/// If the AM is not supported, it returns a negative value.
InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale, unsigned AddrSpace) const;
/// @}

Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2571,14 +2571,15 @@ bool ARMTTIImpl::preferPredicatedReductionSelect(
}

InstructionCost ARMTTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset,
StackOffset BaseOffset,
bool HasBaseReg, int64_t Scale,
unsigned AddrSpace) const {
TargetLoweringBase::AddrMode AM;
AM.BaseGV = BaseGV;
AM.BaseOffs = BaseOffset;
AM.BaseOffs = BaseOffset.getFixed();
AM.HasBaseReg = HasBaseReg;
AM.Scale = Scale;
AM.ScalableOffset = BaseOffset.getScalable();
if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace)) {
if (ST->hasFPAO())
return AM.Scale < 0 ? 1 : 0; // positive offsets execute faster
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/ARM/ARMTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
/// If the AM is supported, the return value must be >= 0.
/// If the AM is not supported, the return value must be negative.
InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale, unsigned AddrSpace) const;

bool maybeLoweredToCall(Instruction &I);
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6741,7 +6741,7 @@ InstructionCost X86TTIImpl::getInterleavedMemoryOpCost(
}

InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset,
StackOffset BaseOffset,
bool HasBaseReg, int64_t Scale,
unsigned AddrSpace) const {
// Scaling factors are not free at all.
Expand All @@ -6764,9 +6764,10 @@ InstructionCost X86TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
// vmovaps %ymm1, (%r8) can use port 2, 3, or 7.
TargetLoweringBase::AddrMode AM;
AM.BaseGV = BaseGV;
AM.BaseOffs = BaseOffset;
AM.BaseOffs = BaseOffset.getFixed();
AM.HasBaseReg = HasBaseReg;
AM.Scale = Scale;
AM.ScalableOffset = BaseOffset.getScalable();
if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
// Scale represents reg2 * scale, thus account for 1
// as soon as we use a second register.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/X86/X86TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
/// If the AM is supported, the return value must be >= 0.
/// If the AM is not supported, it returns a negative value.
InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
StackOffset BaseOffset, bool HasBaseReg,
int64_t Scale, unsigned AddrSpace) const;

bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1817,10 +1817,12 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
case LSRUse::Address: {
// Check the scaling factor cost with both the min and max offsets.
InstructionCost ScaleCostMinOffset = TTI.getScalingFactorCost(
LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MinOffset, F.HasBaseReg,
LU.AccessTy.MemTy, F.BaseGV,
StackOffset::getFixed(F.BaseOffset + LU.MinOffset), F.HasBaseReg,
F.Scale, LU.AccessTy.AddrSpace);
InstructionCost ScaleCostMaxOffset = TTI.getScalingFactorCost(
LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MaxOffset, F.HasBaseReg,
LU.AccessTy.MemTy, F.BaseGV,
StackOffset::getFixed(F.BaseOffset + LU.MaxOffset), F.HasBaseReg,
F.Scale, LU.AccessTy.AddrSpace);

assert(ScaleCostMinOffset.isValid() && ScaleCostMaxOffset.isValid() &&
Expand Down

0 comments on commit 2e8d815

Please sign in to comment.