Skip to content

Commit

Permalink
Fix additional loop var issue and speed up enzyme-print-type when deb…
Browse files Browse the repository at this point in the history
…uginfo is present
  • Loading branch information
wsmoses committed Oct 22, 2023
1 parent a7e7868 commit bc80ae4
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 72 deletions.
6 changes: 3 additions & 3 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ class AdjointGenerator
auto alignment = LI.getAlign();
auto &DL = gutils->newFunc->getParent()->getDataLayout();

bool constantval = parseTBAA(LI, DL).Inner0().isIntegral();
bool constantval = parseTBAA(LI, DL, nullptr).Inner0().isIntegral();
visitLoadLike(LI, alignment, constantval);
eraseIfUnused(LI);
}
Expand Down Expand Up @@ -975,7 +975,7 @@ class AdjointGenerator
NewI->setMetadata(LLVMContext::MD_noalias, noscope);

bool constantval = gutils->isConstantValue(orig_val) ||
parseTBAA(I, DL).Inner0().isIntegral();
parseTBAA(I, DL, nullptr).Inner0().isIntegral();

IRBuilder<> BuilderZ(NewI);
BuilderZ.setFastMathFlags(getFast());
Expand Down Expand Up @@ -3484,7 +3484,7 @@ class AdjointGenerator
auto align0 = cast<ConstantInt>(I.getOperand(1))->getZExtValue();
auto align = MaybeAlign(align0);
auto &DL = gutils->newFunc->getParent()->getDataLayout();
bool constantval = parseTBAA(I, DL).Inner0().isIntegral();
bool constantval = parseTBAA(I, DL, nullptr).Inner0().isIntegral();
visitLoadLike(I, align, constantval,
/*mask*/ gutils->getNewFromOriginal(I.getOperand(2)),
/*orig_maskInit*/ I.getOperand(3));
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3284,7 +3284,7 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) {
auto &DL = newFunc->getParent()->getDataLayout();

bool constantval = isConstantValue(orig_val) ||
parseTBAA(I, DL).Inner0().isIntegral();
parseTBAA(I, DL, nullptr).Inner0().isIntegral();

// TODO allow recognition of other types that could contain
// pointers [e.g. {void*, void*} or <2 x i64> ]
Expand Down
60 changes: 43 additions & 17 deletions enzyme/Enzyme/TypeAnalysis/TBAA.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,28 +383,51 @@ extern llvm::cl::opt<bool> EnzymePrintType;

/// Derive the ConcreteType corresponding to the string TypeName
/// The llvm::Instruction I denotes the context in which this was found
static inline ConcreteType getTypeFromTBAAString(std::string TypeName,
llvm::Instruction &I) {
static inline ConcreteType
getTypeFromTBAAString(std::string TypeName, llvm::Instruction &I,
std::shared_ptr<llvm::ModuleSlotTracker> MST) {
if (TypeName == "long long" || TypeName == "long" || TypeName == "int" ||
TypeName == "bool" || TypeName == "jtbaa_arraysize" ||
TypeName == "jtbaa_arraylen") {
if (EnzymePrintType) {
llvm::errs() << "known tbaa " << I << " " << TypeName << "\n";
llvm::errs() << "known tbaa ";
if (MST)
I.print(llvm::errs(), *MST);
else
llvm::errs() << I;
llvm::errs() << " " << TypeName << "\n";
}
return ConcreteType(BaseType::Integer);
} else if (TypeName == "any pointer" || TypeName == "vtable pointer" ||
TypeName == "jtbaa_arrayptr" || TypeName == "jtbaa_tag") {
if (EnzymePrintType) {
llvm::errs() << "known tbaa " << I << " " << TypeName << "\n";
llvm::errs() << "known tbaa ";
if (MST)
I.print(llvm::errs(), *MST);
else
llvm::errs() << I;
llvm::errs() << " " << TypeName << "\n";
}
return ConcreteType(BaseType::Pointer);
} else if (TypeName == "float") {
if (EnzymePrintType)
llvm::errs() << "known tbaa " << I << " " << TypeName << "\n";
if (EnzymePrintType) {
llvm::errs() << "known tbaa ";
if (MST)
I.print(llvm::errs(), *MST);
else
llvm::errs() << I;
llvm::errs() << " " << TypeName << "\n";
}
return llvm::Type::getFloatTy(I.getContext());
} else if (TypeName == "double") {
if (EnzymePrintType)
llvm::errs() << "known tbaa " << I << " " << TypeName << "\n";
if (EnzymePrintType) {
llvm::errs() << "known tbaa ";
if (MST)
I.print(llvm::errs(), *MST);
else
llvm::errs() << I;
llvm::errs() << " " << TypeName << "\n";
}
return llvm::Type::getDoubleTy(I.getContext());
}
return ConcreteType(BaseType::Unknown);
Expand All @@ -415,10 +438,11 @@ static inline ConcreteType getTypeFromTBAAString(std::string TypeName,
/// corresponding offsets in the result
static inline TypeTree parseTBAA(TBAAStructTypeNode AccessType,
llvm::Instruction &I,
const llvm::DataLayout &DL) {
const llvm::DataLayout &DL,
std::shared_ptr<llvm::ModuleSlotTracker> MST) {

if (auto *Id = llvm::dyn_cast<llvm::MDString>(AccessType.getId())) {
auto CT = getTypeFromTBAAString(Id->getString().str(), I);
auto CT = getTypeFromTBAAString(Id->getString().str(), I, MST);
if (CT.isKnown()) {
return TypeTree(CT).Only(-1, &I);
}
Expand All @@ -428,7 +452,7 @@ static inline TypeTree parseTBAA(TBAAStructTypeNode AccessType,
for (unsigned i = 0, size = AccessType.getNumFields(); i < size; ++i) {
auto SubAccess = AccessType.getFieldType(i);
auto Offset = AccessType.getFieldOffset(i);
auto SubResult = parseTBAA(SubAccess, I, DL);
auto SubResult = parseTBAA(SubAccess, I, DL, MST);
Result |= SubResult.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1,
/*addOffset*/ Offset);
}
Expand All @@ -439,13 +463,14 @@ static inline TypeTree parseTBAA(TBAAStructTypeNode AccessType,
/// Given a TBAA metadata node return the corresponding TypeTree
/// Modified from llvm::MDNode::isTBAAVtableAccess()
static inline TypeTree parseTBAA(const llvm::MDNode *M, llvm::Instruction &I,
const llvm::DataLayout &DL) {
const llvm::DataLayout &DL,
std::shared_ptr<llvm::ModuleSlotTracker> MST) {
if (!isStructPathTBAA(M)) {
if (M->getNumOperands() < 1)
return TypeTree();
if (const llvm::MDString *Tag1 =
llvm::dyn_cast<llvm::MDString>(M->getOperand(0))) {
return TypeTree(getTypeFromTBAAString(Tag1->getString().str(), I))
return TypeTree(getTypeFromTBAAString(Tag1->getString().str(), I, MST))
.Only(0, &I);
}
return TypeTree();
Expand All @@ -454,20 +479,21 @@ static inline TypeTree parseTBAA(const llvm::MDNode *M, llvm::Instruction &I,
// For struct-path aware TBAA, we use the access type of the tag.
TBAAStructTagNode Tag(M);
TBAAStructTypeNode AccessType(Tag.getAccessType());
return parseTBAA(AccessType, I, DL);
return parseTBAA(AccessType, I, DL, MST);
}

/// Given an llvm::Instruction, return a TypeTree representing any
/// types that can be derived from TBAA metadata attached
static inline TypeTree parseTBAA(llvm::Instruction &I,
const llvm::DataLayout &DL) {
const llvm::DataLayout &DL,
std::shared_ptr<llvm::ModuleSlotTracker> MST) {
TypeTree Result;
if (const llvm::MDNode *M =
I.getMetadata(llvm::LLVMContext::MD_tbaa_struct)) {
for (unsigned i = 0, size = M->getNumOperands(); i < size; i += 3) {
if (const llvm::MDNode *M2 =
llvm::dyn_cast<llvm::MDNode>(M->getOperand(i + 2))) {
auto SubResult = parseTBAA(M2, I, DL);
auto SubResult = parseTBAA(M2, I, DL, MST);
auto Start = llvm::cast<llvm::ConstantInt>(
llvm::cast<llvm::ConstantAsMetadata>(M->getOperand(i))
->getValue())
Expand All @@ -484,7 +510,7 @@ static inline TypeTree parseTBAA(llvm::Instruction &I,
}
}
if (const llvm::MDNode *M = I.getMetadata(llvm::LLVMContext::MD_tbaa)) {
Result |= parseTBAA(M, I, DL);
Result |= parseTBAA(M, I, DL, MST);
}
Result |= TypeTree(BaseType::Pointer);
return Result;
Expand Down
Loading

0 comments on commit bc80ae4

Please sign in to comment.