diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 16a11db1529c..7d74fdafbdbf 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -111,13 +111,15 @@ def Gradient : Enzyme_Type<"Gradient"> { def SetOp : Enzyme_Op<"set"> { let summary = "Store the current value of the gradient"; - let arguments = (ins AnyType : $gradient, AnyType : $value); + let arguments = (ins Arg:$gradient, AnyType : $value); let results = (outs ); } def GetOp : Enzyme_Op<"get"> { let summary = "Load current value of gradient"; - let arguments = (ins AnyType : $gradient); + let arguments = (ins Arg:$gradient); let results = (outs AnyType); } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 8c39fc25bf16..adcabcc4b3db 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -46,7 +46,7 @@ struct ForOpInterfaceReverse for (OpResult v : forOp.getResults()) { if (!gutils->isConstantValue(v)) { auto autoDiffType = cast(v.getType()); - if (autoDiffType.isMutable()) { + if (!autoDiffType.isMutable()) { auto prev = gutils->diffe(v, builder); gutils->zeroDiffe(v, builder); resDiffes.push_back(prev); @@ -72,6 +72,7 @@ struct ForOpInterfaceReverse OperandRange operandRange = termIface.getSuccessorOperands(successor); assert(operandRange.size() == resDiffes.size()); + // There is an assumption here that there is only regions that branch to the successor. // Specifically, otherwise we would need to gutils->addToDiffe select (if came from that result) for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) { @@ -108,7 +109,7 @@ struct ForOpInterfaceReverse auto condition = builder.create(loc, arith::CmpIPredicate::sge, lhs, end); - for (auto [arg, init_arg] : llvm::zip(oBB->getArguments(), forOp.getInitArgs())) { + for (auto [arg, init_arg] : llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) { if (!gutils->isConstantValue(arg) && !cast(arg.getType()).isMutable()) { auto diffe = gutils->diffe(arg, builder); gutils->zeroDiffe(arg, builder); @@ -132,6 +133,8 @@ struct ForOpInterfaceReverse for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { gutils->mapReverseModeBlocks.map(&oBB, &revBB); + } + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { gutils->Logic.visitChildren(&oBB, &revBB, gutils); Block *newBB = gutils->getNewFromOriginal(&oBB); gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, buildFuncReturnOp); diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index dff94c63f16c..bd7ae8cca69f 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/IR/Dominance.h" #include "llvm/Support/raw_ostream.h" @@ -92,7 +93,7 @@ bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) { bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) { - for (auto op = start->getNextNode(); op != nullptr; op++) { + for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) { // This check op has been found after start in its block if (op->isAncestor(check)) { return true; @@ -122,41 +123,40 @@ bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) { // non dominating use bewteen the dominating one and the op, causing // correctness issues when not seen. In interim, be conservative and only // succeed if these have the same parent block, and no other ops in path -template +template T findNearestDominatingOpByUse(Operation *op, Value v) { DominanceInfo dInfo; PostDominanceInfo pdInfo; SmallVector options; + SmallVector conflicts; for (Operation *userSet : v.getUsers()) { if (auto setOp = dyn_cast(userSet)) { options.push_back(setOp); + conflicts.push_back(setOp); + continue; + } + if (auto setOp = dyn_cast(userSet)) { + conflicts.push_back(setOp); + continue; } } - if (options.size() == 1 && dInfo.dominates(options[0], op)) - return options[0]; - llvm::errs() << " scope: " << *op->getParentOp() << "\n"; - llvm::errs() << " want to replace " << *op << "\n"; for (auto opt : options) { if (!dInfo.dominates(opt, op)) continue; bool conflict = false; - llvm::errs() << " trying: " << *opt << "\n"; - for (auto opt2 : options) { + for (auto opt2 : conflicts) { if (opt == opt2) continue; - - llvm::errs() << " conflict check: " << *opt2 << "\n"; + if (opt2 == op) continue; if (!mayExecuteBetween(opt, opt2, op)) { - llvm::errs() << " + known good since occurs before store\n"; continue; } conflict = true; } if (!conflict) { - llvm::errs() << " - replaced with " << *opt << "\n"; return opt; } } @@ -164,78 +164,139 @@ T findNearestDominatingOpByUse(Operation *op, Value v) { return nullptr; } + + + + struct PopSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PopOp pop, + PatternRewriter &rewriter) const final { + + auto init = pop.getCache().getDefiningOp(); + if (!init) return failure(); + + SmallVector pops; + SmallVector pushes; + for (Operation *userSet : init.getResult().getUsers()) { + if (auto push = dyn_cast(userSet)) { + pushes.push_back(push); + continue; + } + if (auto pop = dyn_cast(userSet)) { + pops.push_back(pop); + continue; + } + return failure(); + } + + + if (auto push = findNearestDominatingOpByUse(pop, init)) { + // Do the block check to conservatively avoid multi execute push/pop + if (pop->getBlock() == push->getBlock() ) { + rewriter.replaceOp(pop, push.getValue()); + rewriter.eraseOp(push); + return success(); + } + } + + return failure(); + } +}; + + +struct GetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::GetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) continue; + if (isa(userSet)) continue; + return failure(); + } + + + if (auto set = findNearestDominatingOpByUse(get, init)) { + rewriter.replaceOp(get, set.getValue()); + return success(); + } + return failure(); + } +}; + + +struct SetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::SetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) continue; + return failure(); + } + + + rewriter.eraseOp(get); + return success(); + } +}; + + +struct PushSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PushOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getCache().getDefiningOp(); + if (!init) return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + + +struct InitSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::InitOp get, + PatternRewriter &rewriter) const final { + + if (get.use_empty()) { + rewriter.eraseOp(get); + return success(); + } + return failure(); + } +}; + struct RemoveUnusedEnzymeOpsPass : public enzyme::RemoveUnusedEnzymeOpsPassBase { void runOnOperation() override { - SmallVector inits; - getOperation()->walk([&](Operation *op) { - if (auto initOp = dyn_cast(op)) { - inits.push_back(initOp); - } - }); - - for (auto initOp : inits) { - DominanceInfo dInfo; - Value v = initOp; - if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userSet : v.getUsers()) { - if (isa(userSet)) continue; - if (isa(userSet)) continue; - llvm::errs() << " unknown user of grad: " << *userSet << "\n"; - replaceable = false; - } - if (replaceable) { - // Do replacing - bool allDelete = true; - for (Operation *userGet : make_early_inc_range(v.getUsers())) { - if (auto getOp = dyn_cast(userGet)) { - if (auto setOp = - findNearestDominatingOpByUse(userGet, v)) { - getOp.replaceAllUsesWith(setOp.getValue()); - getOp->erase(); - continue; - } - allDelete = false; - } - } - if (allDelete) { - for (Operation *userGet : make_early_inc_range(v.getUsers())) { - userGet->erase(); - } - initOp->erase(); - } - continue; - } - } else if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - - SmallVector pops; - for (Operation *userSet : v.getUsers()) { - if (isa(userSet)) continue; - if (auto pop = dyn_cast(userSet)) { - pops.push_back(pop); - continue; - } - llvm::errs() << " unknown user of cache: " << *userSet << "\n"; - replaceable = false; - } - if (replaceable) - for (auto pop : pops) { - if (auto push = findNearestDominatingOpByUse(pop, v)) { - pop.replaceAllUsesWith(push.getValue()); - pop->erase(); - push->erase(); - } - } - if (v.use_empty()) { - initOp->erase(); - } - continue; - } - } + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } }; diff --git a/enzyme/test/MLIR/Passes/dualpush.mlir b/enzyme/test/MLIR/Passes/dualpush.mlir new file mode 100644 index 000000000000..582feddeabff --- /dev/null +++ b/enzyme/test/MLIR/Passes/dualpush.mlir @@ -0,0 +1,48 @@ +// RUN: %eopt -remove-unnecessary-enzyme-ops %s | FileCheck %s + +// This pop cannot be removed even though we know the first popped value with be -1 +// the other pops will be conditional + +module { + func.func private @diffebbargs(%arg0: f64) { + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant 0.000000e+00 : f64 + %3 = "enzyme.init"() : () -> !enzyme.Cache + "enzyme.push"(%3, %c0_i32) : (!enzyme.Cache, i32) -> () + cf.br ^bb1(%arg0 : f64) + ^bb1(%7: f64): // 2 preds: ^bb0, ^bb1 + %8 = arith.cmpf ult, %7, %cst : f64 + "enzyme.push"(%3, %c-1_i32) : (!enzyme.Cache, i32) -> () + cf.cond_br %8, ^bb1(%7 : f64), ^bb4 + ^bb4: // 2 preds: ^bb3, ^bb4 + %18 = "enzyme.pop"(%3) : (!enzyme.Cache) -> i32 + cf.switch %18 : i32, [ + default: ^bb4, + 0: ^bb5 + ] + ^bb5: // pred: ^bb4 + return + } +} + +// CHECK: func.func private @diffebbargs(%arg0: f64) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: "enzyme.push"(%0, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%arg0 : f64) +// CHECK-NEXT: ^bb1(%1: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %2 = arith.cmpf ult, %1, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %2, ^bb1(%1 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: %3 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: cf.switch %3 : i32, [ +// CHECK-NEXT: default: ^bb2, +// CHECK-NEXT: 0: ^bb3 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb3: // pred: ^bb2 +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir index 1dd48c3b89f7..141ff46aaade 100644 --- a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir +++ b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s module { func.func @bbargs(%x: f64) -> f64 { @@ -25,49 +25,32 @@ module { // CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f64 // CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache -// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Gradient -// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %2 = "enzyme.init"() : () -> !enzyme.Cache -// CHECK-NEXT: %3 = "enzyme.init"() : () -> !enzyme.Gradient -// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %4 = arith.addf %arg0, %cst : f64 -// CHECK-NEXT: "enzyme.push"(%2, %c0_i32) : (!enzyme.Cache, i32) -> () -// CHECK-NEXT: cf.br ^bb1(%4 : f64) -// CHECK-NEXT: ^bb1(%5: f64): // 2 preds: ^bb0, ^bb1 -// CHECK-NEXT: %6 = arith.cmpf ult, %5, %cst_0 : f64 -// CHECK-NEXT: "enzyme.push"(%2, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %2 = arith.addf %arg0, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%2 : f64) +// CHECK-NEXT: ^bb1(%3: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %4 = arith.cmpf ult, %3, %cst_0 : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c-1_i32) : (!enzyme.Cache, i32) -> () // CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () -// CHECK-NEXT: cf.cond_br %6, ^bb1(%5 : f64), ^bb2 +// CHECK-NEXT: cf.cond_br %4, ^bb1(%3 : f64), ^bb2 // CHECK-NEXT: ^bb2: // pred: ^bb1 -// CHECK-NEXT: %7 = arith.addf %arg1, %cst_0 : f64 -// CHECK-NEXT: %8 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 -// CHECK-NEXT: %9 = arith.cmpi eq, %8, %c-1_i32 : i32 -// CHECK-NEXT: %10 = arith.select %9, %7, %cst_0 : f64 -// CHECK-NEXT: %11 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: %12 = arith.addf %11, %10 : f64 -// CHECK-NEXT: "enzyme.set"(%1, %12) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %5 = arith.addf %arg1, %cst_0 : f64 +// CHECK-NEXT: %6 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %7 = arith.cmpi eq, %6, %c-1_i32 : i32 +// CHECK-NEXT: %8 = arith.select %7, %5, %cst_0 : f64 +// CHECK-NEXT: %9 = arith.addf %8, %cst_0 : f64 // CHECK-NEXT: cf.br ^bb3 // CHECK-NEXT: ^bb3: // 2 preds: ^bb2, ^bb3 -// CHECK-NEXT: %13 = "enzyme.pop"(%2) : (!enzyme.Cache) -> i32 -// CHECK-NEXT: %14 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %15 = arith.cmpi eq, %13, %c-1_i32 : i32 -// CHECK-NEXT: %16 = arith.select %15, %14, %cst_0 : f64 -// CHECK-NEXT: %17 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: %18 = arith.addf %17, %16 : f64 -// CHECK-NEXT: "enzyme.set"(%1, %18) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %19 = arith.cmpi eq, %13, %c-1_i32 : i32 -// CHECK-NEXT: %20 = arith.select %19, %14, %cst_0 : f64 -// CHECK-NEXT: %21 = "enzyme.get"(%3) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: %22 = arith.addf %21, %20 : f64 -// CHECK-NEXT: "enzyme.set"(%3, %22) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: cf.switch %13 : i32, [ +// CHECK-NEXT: %10 = "enzyme.pop"(%1) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %11 = arith.cmpi eq, %10, %c-1_i32 : i32 +// CHECK-NEXT: %12 = arith.select %11, %9, %cst_0 : f64 +// CHECK-NEXT: %13 = arith.addf %12, %cst_0 : f64 +// CHECK-NEXT: cf.switch %10 : i32, [ // CHECK-NEXT: default: ^bb3, // CHECK-NEXT: 0: ^bb4 // CHECK-NEXT: ] // CHECK-NEXT: ^bb4: // pred: ^bb3 -// CHECK-NEXT: %23 = "enzyme.get"(%3) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %24 = arith.addf %23, %cst_0 : f64 -// CHECK-NEXT: return %24 : f64 +// CHECK-NEXT: %14 = arith.addf %13, %cst_0 : f64 +// CHECK-NEXT: return %14 : f64 // CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index e0a74d88145a..9934152def61 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -enzyme-simplify-math -canonicalize %s | FileCheck %s module { func.func @ppow(%x: f64) -> f64 { @@ -19,29 +19,46 @@ module { } } -// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 +// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 { +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %[[one:.+]] = arith.constant 1.0 +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// Make sure the right values are being cached in the primal -// CHECK: %[[one:.+]] = arith.constant 1.0 -// CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -// CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]]) -// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) - -// Ensure the right value is yielded in the adjoint -// CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]]) -// CHECK: %[[dr:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]]) -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]]) -// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) -// CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]]) -// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]]) -// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] -// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) : -// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) -// CHECK-NEXT: %[[dr_next:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: scf.yield %[[dr_next]] -// CHECK: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) -// CHECK-NEXT: return %[[final]] +// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) { +// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[fwd]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 { +// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]] +// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64 +// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) : +// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index +// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) +// CHECK-NEXT: return %[[final]] \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/square.mlir b/enzyme/test/MLIR/ReverseMode/square.mlir index 37d57d426033..4bcae3bb8000 100644 --- a/enzyme/test/MLIR/ReverseMode/square.mlir +++ b/enzyme/test/MLIR/ReverseMode/square.mlir @@ -1,6 +1,6 @@ // RUN: %eopt --enzyme %s | FileCheck %s -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN module { func.func @square(%x: f64) -> f64 { @@ -60,14 +60,9 @@ module { // REM: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { // REM-NEXT: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 -// REM-NEXT: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64 -// REM-NEXT: %[[pmu:.+]] = arith.mulf %arg0, %arg0 : f64 -// REM-NEXT: cf.br ^bb1 -// REM-NEXT: ^bb1: // pred: ^bb0 -// REM-NEXT: %[[a1:.+]] = arith.addf %[[cst_0]], %arg1 : f64 -// REM-NEXT: %[[cst_1:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[a1:.+]] = arith.addf %arg1, %[[cst]] : f64 // REM-NEXT: %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64 -// REM-NEXT: %[[a3:.+]] = arith.addf %[[cst]], %[[a2]] : f64 +// REM-NEXT: %[[a3:.+]] = arith.addf %[[a2]], %[[cst]] : f64 // REM-NEXT: %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64 // REM-NEXT: %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64 // REM-NEXT: return %[[a5]] : f64