Skip to content

Commit

Permalink
Fix reverse pass ordering (#2089)
Browse files Browse the repository at this point in the history
* Fix reverse pass ordering

* fix
  • Loading branch information
wsmoses authored Sep 29, 2024
1 parent 3f6ca08 commit cc98d92
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 50 deletions.
22 changes: 16 additions & 6 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2355,10 +2355,14 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
CA.compute_uncacheable_load_map();
gutils->can_modref_map = &can_modref_map;

gutils->forceAugmentedReturns();

// requires is_value_needed_in_reverse, that needs unnecessaryValues
// sets knownRecomputeHeuristic
gutils->computeMinCache();

// Requires knownRecomputeCache to be set as call to getContext
// itself calls createCacheForScope
gutils->forceAugmentedReturns();

SmallPtrSet<const Value *, 4> unnecessaryValues;
SmallPtrSet<const Instruction *, 4> unnecessaryInstructions;
calculateUnusedValuesInFunction(*gutils->oldFunc, unnecessaryValues,
Expand Down Expand Up @@ -4079,8 +4083,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
: CA.compute_uncacheable_load_map();
gutils->can_modref_map = &can_modref_map;

gutils->forceAugmentedReturns();

std::map<std::pair<Instruction *, CacheType>, int> mapping;
if (augmenteddata)
mapping = augmenteddata->tapeIndices;
Expand All @@ -4093,6 +4095,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
// sets knownRecomputeHeuristic
gutils->computeMinCache();

// Requires knownRecomputeCache to be set as call to getContext
// itself calls createCacheForScope
gutils->forceAugmentedReturns();

SmallPtrSet<const Value *, 4> unnecessaryValues;
SmallPtrSet<const Instruction *, 4> unnecessaryInstructions;
calculateUnusedValuesInFunction(*gutils->oldFunc, unnecessaryValues,
Expand Down Expand Up @@ -4731,10 +4737,14 @@ Function *EnzymeLogic::CreateForwardDiff(
CA.compute_uncacheable_load_map());
gutils->can_modref_map = can_modref_map.get();

gutils->forceAugmentedReturns();

// requires is_value_needed_in_reverse, that needs unnecessaryValues
// sets knownRecomputeHeuristic
gutils->computeMinCache();

// Requires knownRecomputeCache to be set as call to getContext
// itself calls createCacheForScope
gutils->forceAugmentedReturns();

auto getIndex = [&](Instruction *I, CacheType u,
IRBuilder<> &B) -> unsigned {
assert(augmenteddata);
Expand Down
91 changes: 47 additions & 44 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,50 +811,53 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
// returned when the original value would be recomputed (e.g. this
// function would not return null). See note below about the condition
// as applied to this case.
if (orig && knownRecomputeHeuristic.find(orig) !=
knownRecomputeHeuristic.end()) {
if (!knownRecomputeHeuristic[orig]) {
if (mode == DerivativeMode::ReverseModeCombined) {
// Don't unnecessarily cache a value if the caching
// heuristic says we should preserve this precise (and not
// an lcssa wrapped) value
if (!isOriginalBlock(*BuilderM.GetInsertBlock())) {
Value *nval = inst;
if (scope)
nval = fixLCSSA(inst, scope);
if (nval == inst)
goto endCheck;
}
} else {
// Note that this logic (original load must dominate or
// alternatively be in the reverse block) is only valid iff when
// applicable (here if in split mode), an overwritten load
// cannot be hoisted outside of a loop to be used as a loop
// limit. This optimization is currently done in the combined
// mode (e.g. if a load isn't modified between a prior insertion
// point and the actual load, it is legal to recompute).
if (!isOriginalBlock(*BuilderM.GetInsertBlock()) ||
DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
assert(inst->getParent()->getParent() == newFunc);
auto placeholder = BuilderM.CreatePHI(
val->getType(), 0,
val->getName() + "_krcAFUWLreplacement");
unwrappedLoads[placeholder] = inst;
SmallVector<Metadata *, 1> avail;
for (auto pair : available)
if (pair.second)
avail.push_back(
MDNode::get(placeholder->getContext(),
{ValueAsMetadata::get(
const_cast<Value *>(pair.first)),
ValueAsMetadata::get(pair.second)}));
placeholder->setMetadata(
"enzyme_available",
MDNode::get(placeholder->getContext(), avail));
if (!permitCache)
return placeholder;
return unwrap_cache[BuilderM.GetInsertBlock()][idx.first]
[idx.second] = placeholder;
if (orig) {
auto found = knownRecomputeHeuristic.find(orig);
if (found != knownRecomputeHeuristic.end()) {
if (!found->second) {
if (mode == DerivativeMode::ReverseModeCombined) {
// Don't unnecessarily cache a value if the caching
// heuristic says we should preserve this precise (and not
// an lcssa wrapped) value
if (!isOriginalBlock(*BuilderM.GetInsertBlock())) {
Value *nval = inst;
if (scope)
nval = fixLCSSA(inst, scope);
if (nval == inst)
goto endCheck;
}
} else {
// Note that this logic (original load must dominate or
// alternatively be in the reverse block) is only valid iff
// when applicable (here if in split mode), an overwritten
// load cannot be hoisted outside of a loop to be used as a
// loop limit. This optimization is currently done in the
// combined mode (e.g. if a load isn't modified between a
// prior insertion point and the actual load, it is legal to
// recompute).
if (!isOriginalBlock(*BuilderM.GetInsertBlock()) ||
DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
assert(inst->getParent()->getParent() == newFunc);
auto placeholder = BuilderM.CreatePHI(
val->getType(), 0,
val->getName() + "_krcAFUWLreplacement");
unwrappedLoads[placeholder] = inst;
SmallVector<Metadata *, 1> avail;
for (auto pair : available)
if (pair.second)
avail.push_back(
MDNode::get(placeholder->getContext(),
{ValueAsMetadata::get(
const_cast<Value *>(pair.first)),
ValueAsMetadata::get(pair.second)}));
placeholder->setMetadata(
"enzyme_available",
MDNode::get(placeholder->getContext(), avail));
if (!permitCache)
return placeholder;
return unwrap_cache[BuilderM.GetInsertBlock()][idx.first]
[idx.second] = placeholder;
}
}
}
}
Expand Down
65 changes: 65 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/krc_loop2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s

define void @f(i64 %i5, i64* noalias %i4, float* noalias %i13, float* noalias %i12) {
bb:
%i6 = icmp ult i64 %i5, 2
br i1 %i6, label %bb48, label %bb7

bb7: ; preds = %bb
%i8 = load i64, i64* %i4, align 8
%i9 = add i64 %i8, -1
br label %bb14

bb14: ; preds = %bb41, %bb7
%i15 = phi i64 [ %i42, %bb41 ], [ 0, %bb7 ]
%i45 = icmp eq i64 %i15, %i5
%i42 = add nuw i64 %i15, 1
br label %bb17

bb17: ; preds = %bb39, %bb14
%i18 = phi i64 [ 0, %bb14 ], [ %i19, %bb39 ]
%i19 = add nuw nsw i64 %i18, 1
%i40 = icmp eq i64 %i18, %i9
br label %bb21

bb21: ; preds = %bb32, %bb17
%i22 = phi i64 [ %i33, %bb32 ], [ 0, %bb17 ]
%i33 = add i64 %i22, 1
%i34 = icmp sle i64 %i15, %i22
br label %bb23

bb23: ; preds = %bb23, %bb21
%i24 = phi i64 [ %i25, %bb23 ], [ 0, %bb21 ]
%i25 = add i64 %i24, 1
%i26 = load float, float* %i12, align 4
%i27 = icmp sgt i64 %i18, %i24
br i1 %i27, label %bb32, label %bb23

bb32: ; preds = %bb23
br i1 %i34, label %bb21, label %bb39

bb39: ; preds = %bb32
store float %i26, float* %i13, align 4
br i1 %i40, label %bb41, label %bb17

bb41: ; preds = %bb39
br i1 %i45, label %bb48, label %bb14

bb48: ; preds = %bb41, %bb
ret void
}

declare i8* @__enzyme_reverse(...)

define void @main() {
bb:
%i = call i8* (...) @__enzyme_reverse(void (i64, i64*, float*, float*)* @f, i64 0, i64* null, metadata !"enzyme_dup", float* null, float* null, metadata !"enzyme_dup", float* null, float* null, i8* null)
ret void
}

; CHECK: define internal void @diffef(i64 %i5, i64* noalias %i4, float* noalias %i13, float* %"i13'", float* noalias %i12, float* %"i12'", i8* %tapeArg)
; CHECK-NEXT: bb:
; CHECK-NEXT: %0 = bitcast i8* %tapeArg to i64*
; CHECK-NEXT: %i9 = load i64, i64* %0, align 4, !enzyme_mustcache
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)

0 comments on commit cc98d92

Please sign in to comment.