Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 29, 2024
1 parent 01d8acf commit 8a74023
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
16 changes: 10 additions & 6 deletions enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ struct ForOpInterfaceReverse
auto end = gutils->popCache(caches[1], builder);
auto step = gutils->popCache(caches[2], builder);

SmallVector<Value> nArgs;
auto repFor = builder.create<scf::ForOp>(
forOp.getLoc(), start, end, step, ArrayRef<Value>()); // TODO
repFor.getRegion().begin()->erase();
forOp.getLoc(), start, end, step, ArrayRef<Value>());
// erase scf yield
repFor.getBody()->begin()->erase();

for (auto &&[oldReg, newReg] : llvm::zip(op->getRegions(), repFor->getRegions())) {

Expand All @@ -99,7 +99,7 @@ struct ForOpInterfaceReverse

auto loc = oBB->rbegin()->getLoc();

auto idx = repFor.getRegion().begin()->getArgument(0);
auto idx = repFor.getInductionVar();

auto lhs = builder.create<arith::AddIOp>(loc, idx, step);

Expand Down Expand Up @@ -130,8 +130,12 @@ struct ForOpInterfaceReverse
builder.create<scf::YieldOp>(loc);
};

gutils->Logic.differentiate(gutils, oldReg, newReg, buildFuncReturnOp,
nullptr);
for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->mapReverseModeBlocks.map(&oBB, &revBB);
gutils->Logic.visitChildren(&oBB, &revBB, gutils);
Block *newBB = gutils->getNewFromOriginal(&oBB);
gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, buildFuncReturnOp);
}
}
}

Expand Down
21 changes: 14 additions & 7 deletions enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,15 @@ struct RemoveUnusedEnzymeOpsPass
: public enzyme::RemoveUnusedEnzymeOpsPassBase<RemoveUnusedEnzymeOpsPass> {
void runOnOperation() override {

SmallVector<enzyme::InitOp, 1> inits;
getOperation()->walk([&](Operation *op) {
DominanceInfo dInfo;
if (auto initOp = dyn_cast<enzyme::InitOp>(op)) {
inits.push_back(initOp);
}
});

for (auto initOp : inits) {
DominanceInfo dInfo;
Value v = initOp;
if (auto type = dyn_cast<enzyme::GradientType>(initOp.getType())) {
bool replaceable = true;
Expand Down Expand Up @@ -120,8 +126,8 @@ struct RemoveUnusedEnzymeOpsPass
for (Operation *userGet : make_early_inc_range(v.getUsers())) {
userGet->erase();
}
op->erase();
return;
initOp->erase();
continue;
}
} else if (auto type = dyn_cast<enzyme::CacheType>(initOp.getType())) {
bool replaceable = true;
Expand Down Expand Up @@ -177,13 +183,14 @@ struct RemoveUnusedEnzymeOpsPass
for (Operation *user : make_early_inc_range(v.getUsers())) {
user->erase();
}
op->erase();
initOp->erase();
continue;
}
}
}
});
};
}
}
};

} // end anonymous namespace

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ReverseMode/pow.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt --enzyme %s | FileCheck %s
// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s

module {
func.func @ppow(%x: f64) -> f64 {
Expand Down

0 comments on commit 8a74023

Please sign in to comment.