diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index 7d6e3de31400..dff94c63f16c 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -30,53 +30,138 @@ using namespace mlir; using namespace enzyme; namespace { -// TODO: Expand to region branches?? -bool reachable(Operation *a, Operation *b) { - Block *aBlock = a->getBlock(); - Block *bBlock = b->getBlock(); - if (aBlock == bBlock) { - if (a->isBeforeInBlock(b)) { - return true; - } - } + +// Starting at the beginning of blk, is there a path that can execute +// check before end. +bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) { + auto reg = blk->getParent(); + assert(reg->isAncestor(end->getParentRegion())); + DenseSet visitedBlocks; + SmallVector blocksToVisit; + for (auto succ : blk->getSuccessors()) { + blocksToVisit.push_back(succ); + } - blocksToVisit.push_back(aBlock); while (!blocksToVisit.empty()) { - Block *processedBlock = blocksToVisit[blocksToVisit.size() - 1]; - blocksToVisit.pop_back(); + Block *cur = blocksToVisit.pop_back_val(); + + if (visitedBlocks.contains(cur)) + continue; + + visitedBlocks.insert(cur); - for (Block *successor : processedBlock->getSuccessors()) { - if (!visitedBlocks.contains(successor)) { - visitedBlocks.insert(successor); - blocksToVisit.push_back(successor); + bool seenEnd = false; + for (auto &op : *cur) { - if (successor == bBlock) - return true; + // If we've seen the thing to check with, it may execute before + if (op.isAncestor(check)) { + // The sole exception to this is if they are in the same sub region, which is + // known to execute only once. TODO this later + /* + if (op.isAncestor(end)) { + + for (auto reg2 : op.getRegions()) { + + } + } + */ + + return true; + } + + // Otherwise if we've seen the end op, this path is over as the route we found here + // didn't first find a check. + if (op.isAncestor(end)) { + seenEnd = true; + break; } } + + if (seenEnd) continue; + + // If we didn't find the end, try all successors + for (auto succ : cur->getSuccessors()) { + blocksToVisit.push_back(succ); + } } + return false; } +bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) { + + for (auto op = start->getNextNode(); op != nullptr; op++) { + // This check op has been found after start in its block + if (op->isAncestor(check)) { + return true; + } + if (op->isAncestor(end)) { + return false; + } + } + + Block* blk = start->getBlock(); + + auto reg = blk->getParent(); + if (reg->isAncestor(end->getParentRegion())) { + return mayExecuteBefore(blk, check, end); + } + + // If the check is in the parent op, but the end is not, assume + // we may execute that parent op part before going to any later ops + if (reg->isAncestor(check->getParentRegion())) { + return true; + } + + return mayExecuteBetween(start->getParentOp(), check, end); +} + +// TODO this isn't necessarily correct. This is because there could be a +// non dominating use bewteen the dominating one and the op, causing +// correctness issues when not seen. In interim, be conservative and only +// succeed if these have the same parent block, and no other ops in path template -Operation *findNearestDominatingOpByUse(Operation *op, Value v) { +T findNearestDominatingOpByUse(Operation *op, Value v) { DominanceInfo dInfo; + PostDominanceInfo pdInfo; - Operation *closestSetOp = nullptr; + SmallVector options; for (Operation *userSet : v.getUsers()) { if (auto setOp = dyn_cast(userSet)) { - if (dInfo.dominates(userSet, op)) { - if (closestSetOp == nullptr) { - closestSetOp = userSet; - } else if (dInfo.dominates(closestSetOp, userSet)) { - closestSetOp = userSet; - } + options.push_back(setOp); + } + } + if (options.size() == 1 && dInfo.dominates(options[0], op)) + return options[0]; + + llvm::errs() << " scope: " << *op->getParentOp() << "\n"; + llvm::errs() << " want to replace " << *op << "\n"; + for (auto opt : options) { + if (!dInfo.dominates(opt, op)) + continue; + bool conflict = false; + llvm::errs() << " trying: " << *opt << "\n"; + for (auto opt2 : options) { + if (opt == opt2) continue; + + llvm::errs() << " conflict check: " << *opt2 << "\n"; + + if (!mayExecuteBetween(opt, opt2, op)) { + llvm::errs() << " + known good since occurs before store\n"; + continue; } + + conflict = true; + } + if (!conflict) { + llvm::errs() << " - replaced with " << *opt << "\n"; + return opt; } } - return closestSetOp; + + return nullptr; } struct RemoveUnusedEnzymeOpsPass @@ -96,96 +181,59 @@ struct RemoveUnusedEnzymeOpsPass if (auto type = dyn_cast(initOp.getType())) { bool replaceable = true; for (Operation *userSet : v.getUsers()) { - if (auto setOp = dyn_cast(userSet)) { - for (Operation *userGet : v.getUsers()) { - if (auto getOp = dyn_cast(userGet)) { - // We can safely delete an enzyme.gradient op if each pair of - // enzyme.set and enzyme.get ops are either not reachable or - // are reachable and do not exist inside a loop - bool relatedButNotInLoop = - dInfo.dominates(userSet, userGet) && - !reachable(getOp, setOp); - bool unrelated = !reachable(setOp, getOp); - if (!(relatedButNotInLoop || unrelated)) { - replaceable = false; - } - } - } - } + if (isa(userSet)) continue; + if (isa(userSet)) continue; + llvm::errs() << " unknown user of grad: " << *userSet << "\n"; + replaceable = false; } if (replaceable) { // Do replacing - for (Operation *userGet : v.getUsers()) { + bool allDelete = true; + for (Operation *userGet : make_early_inc_range(v.getUsers())) { if (auto getOp = dyn_cast(userGet)) { - Operation *closestSetOp = - findNearestDominatingOpByUse(userGet, v); - auto setOp = cast(closestSetOp); - getOp.replaceAllUsesWith(setOp.getValue()); + if (auto setOp = + findNearestDominatingOpByUse(userGet, v)) { + getOp.replaceAllUsesWith(setOp.getValue()); + getOp->erase(); + continue; + } + allDelete = false; } } - for (Operation *userGet : make_early_inc_range(v.getUsers())) { - userGet->erase(); + if (allDelete) { + for (Operation *userGet : make_early_inc_range(v.getUsers())) { + userGet->erase(); + } + initOp->erase(); } - initOp->erase(); continue; } } else if (auto type = dyn_cast(initOp.getType())) { bool replaceable = true; - for (Operation *userPush : v.getUsers()) { - if (auto pushOp = dyn_cast(userPush)) { - // There should only be exactly one push per pop - if (reachable(userPush, userPush)) { - replaceable = false; - } - int numAssociatedPops = 0; - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Pops always need to be dominated by the push - if (dInfo.dominates(userPush, user)) { - numAssociatedPops++; - } else { - replaceable = false; - } - } - } - if (auto getOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Gets always need to be dominated by the push - if (!dInfo.dominates(userPush, user)) { - replaceable = false; - } - } - } - } - // There should only be one pop per push - if (numAssociatedPops > 1) { - replaceable = false; - } + + SmallVector pops; + for (Operation *userSet : v.getUsers()) { + if (isa(userSet)) continue; + if (auto pop = dyn_cast(userSet)) { + pops.push_back(pop); + continue; } + llvm::errs() << " unknown user of cache: " << *userSet << "\n"; + replaceable = false; } - if (replaceable) { - // Do replacing - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - popOp.replaceAllUsesWith(pushOp.getValue()); - } - if (auto getOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - getOp.replaceAllUsesWith(pushOp.getValue()); - } - } - for (Operation *user : make_early_inc_range(v.getUsers())) { - user->erase(); + + if (replaceable) + for (auto pop : pops) { + if (auto push = findNearestDominatingOpByUse(pop, v)) { + pop.replaceAllUsesWith(push.getValue()); + pop->erase(); + push->erase(); } + } + if (v.use_empty()) { initOp->erase(); - continue; } + continue; } } }