Skip to content

Commit

Permalink
Option to prevent caching certain values (#2078)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 15, 2024
1 parent 8aa216e commit ed3ae59
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
18 changes: 18 additions & 0 deletions enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,9 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
if (ASC->getSrcAddressSpace() == 10 && ASC->getDestAddressSpace() == 0)
continue;
}
if (hasNoCache((*mp.begin()).V)) {
continue;
}
// If an allocation call, we cannot cache any "capturing" users
if (isAllocationCall(V, TLI) || isa<AllocaInst>(V)) {
auto next = (*mp.begin()).V;
Expand Down Expand Up @@ -960,6 +963,20 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
}
}

// Fix up non-cacheable calls to use their operand(s) instead
for (auto V : Intermediates) {
if (!hasNoCache(V))
continue;
if (!MinReq.count(V))
continue;
MinReq.remove(V);
for (auto &pair : Orig) {
if (pair.second.count(Node(V, false))) {
MinReq.insert(pair.first.V);
}
}
}

// Fix up non-repeatable writing calls that chain within rematerialized
// allocations. We could iterate from the keys of the valuemap, but that would
// be a non-determinstic ordering.
Expand Down Expand Up @@ -995,6 +1012,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
// values that we are keeping for stores.
MinReq.insert(V);
}

return;
}

Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2554,6 +2554,7 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
assert(malloc);
assert(BuilderQ.GetInsertBlock()->getParent() == newFunc);
assert(isOriginalBlock(*BuilderQ.GetInsertBlock()));
assert(!hasNoCache(malloc));
if (mode == DerivativeMode::ReverseModeCombined) {
assert(!tape);
return malloc;
Expand Down
20 changes: 16 additions & 4 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ static inline llvm::StringRef getFuncName(llvm::Function *called) {
return called->getName();
}

template <typename T> static inline llvm::StringRef getFuncNameFromCall(T *op) {
static inline llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op) {
auto AttrList =
op->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
if (AttrList.hasAttribute("enzyme_math"))
Expand All @@ -1170,11 +1170,23 @@ template <typename T> static inline llvm::StringRef getFuncNameFromCall(T *op) {
return "";
}

template <typename T>
static inline bool hasNoCache(llvm::Value *op) {
using namespace llvm;
if (auto CB = dyn_cast<CallBase>(op)) {
if (auto called = getFunctionFromCall(CB)) {
if (called->hasFnAttribute("enzyme_nocache"))
return true;
}
}
return false;
}

#if LLVM_VERSION_MAJOR >= 16
static inline std::optional<size_t> getAllocationIndexFromCall(T *op)
static inline std::optional<size_t>
getAllocationIndexFromCall(const llvm::CallBase *op)
#else
static inline llvm::Optional<size_t> getAllocationIndexFromCall(T *op)
static inline llvm::Optional<size_t>
getAllocationIndexFromCall(const llvm::CallBase *op)
#endif
{
auto AttrList =
Expand Down

0 comments on commit ed3ae59

Please sign in to comment.