Skip to content

Commit

Permalink
Fix SCEV memory error
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 4, 2023
1 parent dfb8b55 commit 1d69c3e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 41 deletions.
111 changes: 71 additions & 40 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6094,6 +6094,11 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
mode == DerivativeMode::ReverseModeCombined);

assert(val->getName() != "<badref>");
{
auto found = incoming_available.find(val);
if (found != incoming_available.end())
return found->second;
}
if (isa<Constant>(val)) {
return val;
}
Expand Down Expand Up @@ -6121,7 +6126,6 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
}

auto inst = cast<Instruction>(val);
assert(inst->getName() != "<badref>");
if (inversionAllocs && inst->getParent() == inversionAllocs) {
return val;
}
Expand Down Expand Up @@ -6418,7 +6422,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
auto li2obj = getBaseObject(li2->getPointerOperand());

if (liobj == li2obj && DT.dominates(li2, li)) {
auto orig2 = isOriginal(li2);
auto orig2 = dyn_cast_or_null<LoadInst>(isOriginal(li2));
if (!orig2)
continue;

Expand All @@ -6427,8 +6431,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
// llvm::errs() << "found potential candidate loads: oli:"
// << *origInst << " oli2: " << *orig2 << "\n";

auto scev1 = SE.getSCEV(li->getPointerOperand());
auto scev2 = SE.getSCEV(li2->getPointerOperand());
auto scev1 = SE.getSCEV(origInst->getPointerOperand());
auto scev2 = SE.getSCEV(orig2->getPointerOperand());
// llvm::errs() << " scev1: " << *scev1 << " scev2: " << *scev2
// << "\n";

Expand All @@ -6449,11 +6453,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,

if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) {
if (auto ar2 = dyn_cast<SCEVAddRecExpr>(scev2)) {
if (ar1->getStart() != SE.getCouldNotCompute() &&
if (ar1->getStart() != OrigSE.getCouldNotCompute() &&
ar1->getStart() == ar2->getStart() &&
ar1->getStepRecurrence(SE) != SE.getCouldNotCompute() &&
ar1->getStepRecurrence(SE) ==
ar2->getStepRecurrence(SE)) {
ar1->getStepRecurrence(OrigSE) !=
OrigSE.getCouldNotCompute() &&
ar1->getStepRecurrence(OrigSE) ==
ar2->getStepRecurrence(OrigSE)) {

LoopContext l1;
getContext(ar1->getLoop()->getHeader(), l1);
Expand Down Expand Up @@ -6848,20 +6853,20 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
}
}

auto scev1 = SE.getSCEV(li->getPointerOperand());
auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand());
// Store in memcpy opt
Value *lim = nullptr;
BasicBlock *ctx = nullptr;
Value *start = nullptr;
Value *offset = nullptr;
if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) {
if (auto step =
dyn_cast<SCEVConstant>(ar1->getStepRecurrence(SE))) {
dyn_cast<SCEVConstant>(ar1->getStepRecurrence(OrigSE))) {
if (step->getAPInt() != loadSize)
goto noSpeedCache;

LoopContext l1;
getContext(ar1->getLoop()->getHeader(), l1);
getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), l1);

if (l1.dynamic)
goto noSpeedCache;
Expand All @@ -6886,40 +6891,66 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
lim = v.CreateAdd(lim, ConstantInt::get(lim->getType(), 1), "",
true, true);

SmallVector<Instruction *, 4> toErase;
{
#if LLVM_VERSION_MAJOR >= 12
SCEVExpander Exp(SE,
ctx->getParent()->getParent()->getDataLayout(),
"enzyme");
#else
fake::SCEVExpander Exp(
SE, ctx->getParent()->getParent()->getDataLayout(),
"enzyme");
#endif
Exp.setInsertPoint(l1.header->getTerminator());
Value *start0 = Exp.expandCodeFor(
ar1->getStart(), li->getPointerOperand()->getType());
start = unwrapM(start0, v,
/*available*/ ValueToValueMapTy(),
UnwrapMode::AttemptFullUnwrapWithLookup);
std::set<Value *> todo = {start0};
while (todo.size()) {
Value *now = *todo.begin();
todo.erase(now);
if (Instruction *inst = dyn_cast<Instruction>(now)) {
if (inst != start && inst->getNumUses() == 0 &&
Exp.isInsertedInstruction(inst)) {
for (auto &op : inst->operands()) {
todo.insert(op);
}
toErase.push_back(inst);
}
Value *start0;
SmallVector<Instruction *, 32> InsertedInstructions;
{
SCEVExpander OrigExp(
OrigSE, ctx->getParent()->getParent()->getDataLayout(),
"enzyme");

OrigExp.setInsertPoint(
isOriginal(l1.header)->getTerminator());

start0 = OrigExp.expandCodeFor(
ar1->getStart(), li->getPointerOperand()->getType());
InsertedInstructions = OrigExp.getAllInsertedInstructions();
}

ValueToValueMapTy available;
for (const auto &pair : originalToNewFn) {
available[pair.first] = pair.second;
}

// Sort so that later instructions do not dominate earlier
// instructions.
llvm::stable_sort(InsertedInstructions,
[this](Instruction *A, Instruction *B) {
return OrigDT.dominates(A, B);
});
for (auto a : InsertedInstructions) {
assert(!isa<PHINode>(a));
auto uw = cast<Instruction>(
unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap,
/*scope*/ nullptr, /*cache*/ false));
for (size_t i = 0; i < uw->getNumOperands(); i++) {
auto op = uw->getOperand(i);
if (auto arg = dyn_cast<Argument>(op))
assert(arg->getParent() == newFunc);
else if (auto inst = dyn_cast<Instruction>(op))
assert(inst->getParent()->getParent() == newFunc);
}
available[a] = uw;
unwrappedLoads.erase(cast<Instruction>(uw));
}

start =
isa<Constant>(start0) ? start0 : (Value *)available[start0];
if (!start) {
llvm::errs() << "old: " << *oldFunc << "\n";
llvm::errs() << "new: " << *newFunc << "\n";
llvm::errs() << "start0: " << *start0 << "\n";
}
assert(start);

available.clear();
for (auto I : llvm::reverse(InsertedInstructions)) {
assert(I->getNumUses() == 0);
I->eraseFromParent();
}
#endif
}
for (auto a : toErase)
erase(a);

if (!start)
goto noSpeedCache;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ReverseMode/rwrloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ attributes #9 = { noreturn nounwind }

; CHECK: for.cond1.preheader: ; preds = %for.cond.cleanup3, %entry
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup3 ], [ 0, %entry ]
; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
; CHECK-NEXT: br i1 %cmp233, label %for.body4.lr.ph, label %for.cond.cleanup3

; CHECK: for.body4.lr.ph: ; preds = %for.cond1.preheader
; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10
; CHECK-NEXT: %[[a3:.+]] = load i32, i32* %N, align 4, !tbaa !2, !alias.scope !8, !noalias !11, !invariant.group ![[INVG:[0-9]]]
; CHECK-NEXT: %[[a4:.+]] = getelementptr inbounds i32, i32* %[[malloccache12]], i64 %iv
; CHECK-NEXT: store i32 %[[a3]], i32* %[[a4]], align 4, !tbaa !2, !invariant.group ![[INVG]]
Expand Down

0 comments on commit 1d69c3e

Please sign in to comment.