Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 29, 2024
1 parent 005f4fc commit f5c4679
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 156 deletions.
6 changes: 4 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def Gradient : Enzyme_Type<"Gradient"> {

def SetOp : Enzyme_Op<"set"> {
let summary = "Store the current value of the gradient";
let arguments = (ins AnyType : $gradient, AnyType : $value);
let arguments = (ins Arg<AnyType, "the reference to store to",
[MemWrite]>:$gradient, AnyType : $value);
let results = (outs );
}

def GetOp : Enzyme_Op<"get"> {
let summary = "Load current value of gradient";
let arguments = (ins AnyType : $gradient);
let arguments = (ins Arg<AnyType, "the reference to load from",
[MemRead]>:$gradient);
let results = (outs AnyType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct ForOpInterfaceReverse
for (OpResult v : forOp.getResults()) {
if (!gutils->isConstantValue(v)) {
auto autoDiffType = cast<AutoDiffTypeInterface>(v.getType());
if (autoDiffType.isMutable()) {
if (!autoDiffType.isMutable()) {
auto prev = gutils->diffe(v, builder);
gutils->zeroDiffe(v, builder);
resDiffes.push_back(prev);
Expand All @@ -72,6 +72,7 @@ struct ForOpInterfaceReverse
OperandRange operandRange =
termIface.getSuccessorOperands(successor);
assert(operandRange.size() == resDiffes.size());

// There is an assumption here that there is only regions that branch to the successor.
// Specifically, otherwise we would need to gutils->addToDiffe select (if came from that result)
for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) {
Expand Down Expand Up @@ -108,7 +109,7 @@ struct ForOpInterfaceReverse

auto condition = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, lhs, end);

for (auto [arg, init_arg] : llvm::zip(oBB->getArguments(), forOp.getInitArgs())) {
for (auto [arg, init_arg] : llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) {
if (!gutils->isConstantValue(arg) && !cast<AutoDiffTypeInterface>(arg.getType()).isMutable()) {
auto diffe = gutils->diffe(arg, builder);
gutils->zeroDiffe(arg, builder);
Expand All @@ -132,6 +133,8 @@ struct ForOpInterfaceReverse

for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->mapReverseModeBlocks.map(&oBB, &revBB);
}
for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->Logic.visitChildren(&oBB, &revBB, gutils);
Block *newBB = gutils->getNewFromOriginal(&oBB);
gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, buildFuncReturnOp);
Expand Down
219 changes: 140 additions & 79 deletions enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Transforms/DialectConversion.h"

#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/IR/Dominance.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -92,7 +93,7 @@ bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) {

bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) {

for (auto op = start->getNextNode(); op != nullptr; op++) {
for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) {
// This check op has been found after start in its block
if (op->isAncestor(check)) {
return true;
Expand Down Expand Up @@ -122,120 +123,180 @@ bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) {
// non dominating use bewteen the dominating one and the op, causing
// correctness issues when not seen. In interim, be conservative and only
// succeed if these have the same parent block, and no other ops in path
template <class T>
template <class T, class T2=T>
T findNearestDominatingOpByUse(Operation *op, Value v) {
DominanceInfo dInfo;
PostDominanceInfo pdInfo;

SmallVector<T, 1> options;
SmallVector<Operation*, 1> conflicts;
for (Operation *userSet : v.getUsers()) {
if (auto setOp = dyn_cast<T>(userSet)) {
options.push_back(setOp);
conflicts.push_back(setOp);
continue;
}
if (auto setOp = dyn_cast<T2>(userSet)) {
conflicts.push_back(setOp);
continue;
}
}
if (options.size() == 1 && dInfo.dominates(options[0], op))
return options[0];

llvm::errs() << " scope: " << *op->getParentOp() << "\n";
llvm::errs() << " want to replace " << *op << "\n";
for (auto opt : options) {
if (!dInfo.dominates(opt, op))
continue;
bool conflict = false;
llvm::errs() << " trying: " << *opt << "\n";
for (auto opt2 : options) {
for (auto opt2 : conflicts) {
if (opt == opt2) continue;

llvm::errs() << " conflict check: " << *opt2 << "\n";
if (opt2 == op) continue;

if (!mayExecuteBetween(opt, opt2, op)) {
llvm::errs() << " + known good since occurs before store\n";
continue;
}

conflict = true;
}
if (!conflict) {
llvm::errs() << " - replaced with " << *opt << "\n";
return opt;
}
}

return nullptr;
}




struct PopSimplify : public OpRewritePattern<enzyme::PopOp> {
using OpRewritePattern<enzyme::PopOp>::OpRewritePattern;

LogicalResult matchAndRewrite(enzyme::PopOp pop,
PatternRewriter &rewriter) const final {

auto init = pop.getCache().getDefiningOp<enzyme::InitOp>();
if (!init) return failure();

SmallVector<enzyme::PopOp, 1> pops;
SmallVector<enzyme::PushOp, 1> pushes;
for (Operation *userSet : init.getResult().getUsers()) {
if (auto push = dyn_cast<enzyme::PushOp>(userSet)) {
pushes.push_back(push);
continue;
}
if (auto pop = dyn_cast<enzyme::PopOp>(userSet)) {
pops.push_back(pop);
continue;
}
return failure();
}


if (auto push = findNearestDominatingOpByUse<enzyme::PushOp, enzyme::PopOp>(pop, init)) {
// Do the block check to conservatively avoid multi execute push/pop
if (pop->getBlock() == push->getBlock() ) {
rewriter.replaceOp(pop, push.getValue());
rewriter.eraseOp(push);
return success();
}
}

return failure();
}
};


struct GetSimplify : public OpRewritePattern<enzyme::GetOp> {
using OpRewritePattern<enzyme::GetOp>::OpRewritePattern;

LogicalResult matchAndRewrite(enzyme::GetOp get,
PatternRewriter &rewriter) const final {

auto init = get.getGradient().getDefiningOp<enzyme::InitOp>();
if (!init) return failure();

for (Operation *userSet : init.getResult().getUsers()) {
if (isa<enzyme::GetOp>(userSet)) continue;
if (isa<enzyme::SetOp>(userSet)) continue;
return failure();
}


if (auto set = findNearestDominatingOpByUse<enzyme::SetOp>(get, init)) {
rewriter.replaceOp(get, set.getValue());
return success();
}
return failure();
}
};


struct SetSimplify : public OpRewritePattern<enzyme::SetOp> {
using OpRewritePattern<enzyme::SetOp>::OpRewritePattern;

LogicalResult matchAndRewrite(enzyme::SetOp get,
PatternRewriter &rewriter) const final {

auto init = get.getGradient().getDefiningOp<enzyme::InitOp>();
if (!init) return failure();

for (Operation *userSet : init.getResult().getUsers()) {
if (isa<enzyme::SetOp>(userSet)) continue;
return failure();
}


rewriter.eraseOp(get);
return success();
}
};


struct PushSimplify : public OpRewritePattern<enzyme::PushOp> {
using OpRewritePattern<enzyme::PushOp>::OpRewritePattern;

LogicalResult matchAndRewrite(enzyme::PushOp get,
PatternRewriter &rewriter) const final {

auto init = get.getCache().getDefiningOp<enzyme::InitOp>();
if (!init) return failure();

for (Operation *userSet : init.getResult().getUsers()) {
if (isa<enzyme::PushOp>(userSet)) continue;
return failure();
}

rewriter.eraseOp(get);
return success();
}
};


struct InitSimplify : public OpRewritePattern<enzyme::InitOp> {
using OpRewritePattern<enzyme::InitOp>::OpRewritePattern;

LogicalResult matchAndRewrite(enzyme::InitOp get,
PatternRewriter &rewriter) const final {

if (get.use_empty()) {
rewriter.eraseOp(get);
return success();
}
return failure();
}
};

struct RemoveUnusedEnzymeOpsPass
: public enzyme::RemoveUnusedEnzymeOpsPassBase<RemoveUnusedEnzymeOpsPass> {
void runOnOperation() override {

SmallVector<enzyme::InitOp, 1> inits;
getOperation()->walk([&](Operation *op) {
if (auto initOp = dyn_cast<enzyme::InitOp>(op)) {
inits.push_back(initOp);
}
});

for (auto initOp : inits) {
DominanceInfo dInfo;
Value v = initOp;
if (auto type = dyn_cast<enzyme::GradientType>(initOp.getType())) {
bool replaceable = true;
for (Operation *userSet : v.getUsers()) {
if (isa<enzyme::SetOp>(userSet)) continue;
if (isa<enzyme::GetOp>(userSet)) continue;
llvm::errs() << " unknown user of grad: " << *userSet << "\n";
replaceable = false;
}
if (replaceable) {
// Do replacing
bool allDelete = true;
for (Operation *userGet : make_early_inc_range(v.getUsers())) {
if (auto getOp = dyn_cast<enzyme::GetOp>(userGet)) {
if (auto setOp =
findNearestDominatingOpByUse<enzyme::SetOp>(userGet, v)) {
getOp.replaceAllUsesWith(setOp.getValue());
getOp->erase();
continue;
}
allDelete = false;
}
}
if (allDelete) {
for (Operation *userGet : make_early_inc_range(v.getUsers())) {
userGet->erase();
}
initOp->erase();
}
continue;
}
} else if (auto type = dyn_cast<enzyme::CacheType>(initOp.getType())) {
bool replaceable = true;

SmallVector<enzyme::PopOp, 1> pops;
for (Operation *userSet : v.getUsers()) {
if (isa<enzyme::PushOp>(userSet)) continue;
if (auto pop = dyn_cast<enzyme::PopOp>(userSet)) {
pops.push_back(pop);
continue;
}
llvm::errs() << " unknown user of cache: " << *userSet << "\n";
replaceable = false;
}

if (replaceable)
for (auto pop : pops) {
if (auto push = findNearestDominatingOpByUse<enzyme::PushOp>(pop, v)) {
pop.replaceAllUsesWith(push.getValue());
pop->erase();
push->erase();
}
}
if (v.use_empty()) {
initOp->erase();
}
continue;
}
}
RewritePatternSet patterns(&getContext());
patterns.insert<PopSimplify, GetSimplify, PushSimplify, SetSimplify, InitSimplify>(&getContext());

GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);

}
};

Expand Down
48 changes: 48 additions & 0 deletions enzyme/test/MLIR/Passes/dualpush.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: %eopt -remove-unnecessary-enzyme-ops %s | FileCheck %s

// This pop cannot be removed even though we know the first popped value with be -1
// the other pops will be conditional

module {
func.func private @diffebbargs(%arg0: f64) {
%c0_i32 = arith.constant 0 : i32
%c-1_i32 = arith.constant -1 : i32
%cst = arith.constant 0.000000e+00 : f64
%3 = "enzyme.init"() : () -> !enzyme.Cache<i32>
"enzyme.push"(%3, %c0_i32) : (!enzyme.Cache<i32>, i32) -> ()
cf.br ^bb1(%arg0 : f64)
^bb1(%7: f64): // 2 preds: ^bb0, ^bb1
%8 = arith.cmpf ult, %7, %cst : f64
"enzyme.push"(%3, %c-1_i32) : (!enzyme.Cache<i32>, i32) -> ()
cf.cond_br %8, ^bb1(%7 : f64), ^bb4
^bb4: // 2 preds: ^bb3, ^bb4
%18 = "enzyme.pop"(%3) : (!enzyme.Cache<i32>) -> i32
cf.switch %18 : i32, [
default: ^bb4,
0: ^bb5
]
^bb5: // pred: ^bb4
return
}
}

// CHECK: func.func private @diffebbargs(%arg0: f64) {
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache<i32>
// CHECK-NEXT: "enzyme.push"(%0, %c0_i32) : (!enzyme.Cache<i32>, i32) -> ()
// CHECK-NEXT: cf.br ^bb1(%arg0 : f64)
// CHECK-NEXT: ^bb1(%1: f64): // 2 preds: ^bb0, ^bb1
// CHECK-NEXT: %2 = arith.cmpf ult, %1, %cst : f64
// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache<i32>, i32) -> ()
// CHECK-NEXT: cf.cond_br %2, ^bb1(%1 : f64), ^bb2
// CHECK-NEXT: ^bb2: // 2 preds: ^bb1, ^bb2
// CHECK-NEXT: %3 = "enzyme.pop"(%0) : (!enzyme.Cache<i32>) -> i32
// CHECK-NEXT: cf.switch %3 : i32, [
// CHECK-NEXT: default: ^bb2,
// CHECK-NEXT: 0: ^bb3
// CHECK-NEXT: ]
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: return
// CHECK-NEXT: }
Loading

0 comments on commit f5c4679

Please sign in to comment.