Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 29, 2024
1 parent 8a74023 commit 005f4fc
Showing 1 changed file with 149 additions and 101 deletions.
250 changes: 149 additions & 101 deletions enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,53 +30,138 @@ using namespace mlir;
using namespace enzyme;
namespace {

// TODO: Expand to region branches??
bool reachable(Operation *a, Operation *b) {
Block *aBlock = a->getBlock();
Block *bBlock = b->getBlock();
if (aBlock == bBlock) {
if (a->isBeforeInBlock(b)) {
return true;
}
}

// Starting at the beginning of blk, is there a path that can execute
// check before end.
bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) {
auto reg = blk->getParent();
assert(reg->isAncestor(end->getParentRegion()));

DenseSet<Block *> visitedBlocks;

SmallVector<Block *> blocksToVisit;
for (auto succ : blk->getSuccessors()) {
blocksToVisit.push_back(succ);
}

blocksToVisit.push_back(aBlock);
while (!blocksToVisit.empty()) {
Block *processedBlock = blocksToVisit[blocksToVisit.size() - 1];
blocksToVisit.pop_back();
Block *cur = blocksToVisit.pop_back_val();

if (visitedBlocks.contains(cur))
continue;

visitedBlocks.insert(cur);

for (Block *successor : processedBlock->getSuccessors()) {
if (!visitedBlocks.contains(successor)) {
visitedBlocks.insert(successor);
blocksToVisit.push_back(successor);
bool seenEnd = false;
for (auto &op : *cur) {

if (successor == bBlock)
return true;
// If we've seen the thing to check with, it may execute before
if (op.isAncestor(check)) {
// The sole exception to this is if they are in the same sub region, which is
// known to execute only once. TODO this later
/*
if (op.isAncestor(end)) {
for (auto reg2 : op.getRegions()) {
}
}
*/

return true;
}

// Otherwise if we've seen the end op, this path is over as the route we found here
// didn't first find a check.
if (op.isAncestor(end)) {
seenEnd = true;
break;
}
}

if (seenEnd) continue;

// If we didn't find the end, try all successors
for (auto succ : cur->getSuccessors()) {
blocksToVisit.push_back(succ);
}
}

return false;
}

bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) {

for (auto op = start->getNextNode(); op != nullptr; op++) {
// This check op has been found after start in its block
if (op->isAncestor(check)) {
return true;
}
if (op->isAncestor(end)) {
return false;
}
}

Block* blk = start->getBlock();

auto reg = blk->getParent();
if (reg->isAncestor(end->getParentRegion())) {
return mayExecuteBefore(blk, check, end);
}

// If the check is in the parent op, but the end is not, assume
// we may execute that parent op part before going to any later ops
if (reg->isAncestor(check->getParentRegion())) {
return true;
}

return mayExecuteBetween(start->getParentOp(), check, end);
}

// TODO this isn't necessarily correct. This is because there could be a
// non dominating use bewteen the dominating one and the op, causing
// correctness issues when not seen. In interim, be conservative and only
// succeed if these have the same parent block, and no other ops in path
template <class T>
Operation *findNearestDominatingOpByUse(Operation *op, Value v) {
T findNearestDominatingOpByUse(Operation *op, Value v) {
DominanceInfo dInfo;
PostDominanceInfo pdInfo;

Operation *closestSetOp = nullptr;
SmallVector<T, 1> options;
for (Operation *userSet : v.getUsers()) {
if (auto setOp = dyn_cast<T>(userSet)) {
if (dInfo.dominates(userSet, op)) {
if (closestSetOp == nullptr) {
closestSetOp = userSet;
} else if (dInfo.dominates(closestSetOp, userSet)) {
closestSetOp = userSet;
}
options.push_back(setOp);
}
}
if (options.size() == 1 && dInfo.dominates(options[0], op))
return options[0];

llvm::errs() << " scope: " << *op->getParentOp() << "\n";
llvm::errs() << " want to replace " << *op << "\n";
for (auto opt : options) {
if (!dInfo.dominates(opt, op))
continue;
bool conflict = false;
llvm::errs() << " trying: " << *opt << "\n";
for (auto opt2 : options) {
if (opt == opt2) continue;

llvm::errs() << " conflict check: " << *opt2 << "\n";

if (!mayExecuteBetween(opt, opt2, op)) {
llvm::errs() << " + known good since occurs before store\n";
continue;
}

conflict = true;
}
if (!conflict) {
llvm::errs() << " - replaced with " << *opt << "\n";
return opt;
}
}
return closestSetOp;

return nullptr;
}

struct RemoveUnusedEnzymeOpsPass
Expand All @@ -96,96 +181,59 @@ struct RemoveUnusedEnzymeOpsPass
if (auto type = dyn_cast<enzyme::GradientType>(initOp.getType())) {
bool replaceable = true;
for (Operation *userSet : v.getUsers()) {
if (auto setOp = dyn_cast<enzyme::SetOp>(userSet)) {
for (Operation *userGet : v.getUsers()) {
if (auto getOp = dyn_cast<enzyme::GetOp>(userGet)) {
// We can safely delete an enzyme.gradient op if each pair of
// enzyme.set and enzyme.get ops are either not reachable or
// are reachable and do not exist inside a loop
bool relatedButNotInLoop =
dInfo.dominates(userSet, userGet) &&
!reachable(getOp, setOp);
bool unrelated = !reachable(setOp, getOp);
if (!(relatedButNotInLoop || unrelated)) {
replaceable = false;
}
}
}
}
if (isa<enzyme::SetOp>(userSet)) continue;
if (isa<enzyme::GetOp>(userSet)) continue;
llvm::errs() << " unknown user of grad: " << *userSet << "\n";
replaceable = false;
}
if (replaceable) {
// Do replacing
for (Operation *userGet : v.getUsers()) {
bool allDelete = true;
for (Operation *userGet : make_early_inc_range(v.getUsers())) {
if (auto getOp = dyn_cast<enzyme::GetOp>(userGet)) {
Operation *closestSetOp =
findNearestDominatingOpByUse<enzyme::SetOp>(userGet, v);
auto setOp = cast<enzyme::SetOp>(closestSetOp);
getOp.replaceAllUsesWith(setOp.getValue());
if (auto setOp =
findNearestDominatingOpByUse<enzyme::SetOp>(userGet, v)) {
getOp.replaceAllUsesWith(setOp.getValue());
getOp->erase();
continue;
}
allDelete = false;
}
}
for (Operation *userGet : make_early_inc_range(v.getUsers())) {
userGet->erase();
if (allDelete) {
for (Operation *userGet : make_early_inc_range(v.getUsers())) {
userGet->erase();
}
initOp->erase();
}
initOp->erase();
continue;
}
} else if (auto type = dyn_cast<enzyme::CacheType>(initOp.getType())) {
bool replaceable = true;
for (Operation *userPush : v.getUsers()) {
if (auto pushOp = dyn_cast<enzyme::PushOp>(userPush)) {
// There should only be exactly one push per pop
if (reachable(userPush, userPush)) {
replaceable = false;
}
int numAssociatedPops = 0;
for (Operation *user : v.getUsers()) {
if (auto popOp = dyn_cast<enzyme::PopOp>(user)) {
if (reachable(userPush, user)) {
// Pops always need to be dominated by the push
if (dInfo.dominates(userPush, user)) {
numAssociatedPops++;
} else {
replaceable = false;
}
}
}
if (auto getOp = dyn_cast<enzyme::GetOp>(user)) {
if (reachable(userPush, user)) {
// Gets always need to be dominated by the push
if (!dInfo.dominates(userPush, user)) {
replaceable = false;
}
}
}
}
// There should only be one pop per push
if (numAssociatedPops > 1) {
replaceable = false;
}

SmallVector<enzyme::PopOp, 1> pops;
for (Operation *userSet : v.getUsers()) {
if (isa<enzyme::PushOp>(userSet)) continue;
if (auto pop = dyn_cast<enzyme::PopOp>(userSet)) {
pops.push_back(pop);
continue;
}
llvm::errs() << " unknown user of cache: " << *userSet << "\n";
replaceable = false;
}
if (replaceable) {
// Do replacing
for (Operation *user : v.getUsers()) {
if (auto popOp = dyn_cast<enzyme::PopOp>(user)) {
Operation *closestPushOp =
findNearestDominatingOpByUse<enzyme::PushOp>(user, v);
auto pushOp = dyn_cast<enzyme::PushOp>(closestPushOp);
popOp.replaceAllUsesWith(pushOp.getValue());
}
if (auto getOp = dyn_cast<enzyme::GetOp>(user)) {
Operation *closestPushOp =
findNearestDominatingOpByUse<enzyme::PushOp>(user, v);
auto pushOp = dyn_cast<enzyme::PushOp>(closestPushOp);
getOp.replaceAllUsesWith(pushOp.getValue());
}
}
for (Operation *user : make_early_inc_range(v.getUsers())) {
user->erase();

if (replaceable)
for (auto pop : pops) {
if (auto push = findNearestDominatingOpByUse<enzyme::PushOp>(pop, v)) {
pop.replaceAllUsesWith(push.getValue());
pop->erase();
push->erase();
}
}
if (v.use_empty()) {
initOp->erase();
continue;
}
continue;
}
}
}
Expand Down

0 comments on commit 005f4fc

Please sign in to comment.