diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 16a11db1529c..7d74fdafbdbf 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -111,13 +111,15 @@ def Gradient : Enzyme_Type<"Gradient"> { def SetOp : Enzyme_Op<"set"> { let summary = "Store the current value of the gradient"; - let arguments = (ins AnyType : $gradient, AnyType : $value); + let arguments = (ins Arg:$gradient, AnyType : $value); let results = (outs ); } def GetOp : Enzyme_Op<"get"> { let summary = "Load current value of gradient"; - let arguments = (ins AnyType : $gradient); + let arguments = (ins Arg:$gradient); let results = (outs AnyType); } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 8c39fc25bf16..adcabcc4b3db 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -46,7 +46,7 @@ struct ForOpInterfaceReverse for (OpResult v : forOp.getResults()) { if (!gutils->isConstantValue(v)) { auto autoDiffType = cast(v.getType()); - if (autoDiffType.isMutable()) { + if (!autoDiffType.isMutable()) { auto prev = gutils->diffe(v, builder); gutils->zeroDiffe(v, builder); resDiffes.push_back(prev); @@ -72,6 +72,7 @@ struct ForOpInterfaceReverse OperandRange operandRange = termIface.getSuccessorOperands(successor); assert(operandRange.size() == resDiffes.size()); + // There is an assumption here that there is only regions that branch to the successor. // Specifically, otherwise we would need to gutils->addToDiffe select (if came from that result) for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) { @@ -108,7 +109,7 @@ struct ForOpInterfaceReverse auto condition = builder.create(loc, arith::CmpIPredicate::sge, lhs, end); - for (auto [arg, init_arg] : llvm::zip(oBB->getArguments(), forOp.getInitArgs())) { + for (auto [arg, init_arg] : llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) { if (!gutils->isConstantValue(arg) && !cast(arg.getType()).isMutable()) { auto diffe = gutils->diffe(arg, builder); gutils->zeroDiffe(arg, builder); @@ -132,6 +133,8 @@ struct ForOpInterfaceReverse for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { gutils->mapReverseModeBlocks.map(&oBB, &revBB); + } + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { gutils->Logic.visitChildren(&oBB, &revBB, gutils); Block *newBB = gutils->getNewFromOriginal(&oBB); gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, buildFuncReturnOp); diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index dff94c63f16c..14c705fce44a 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -92,7 +92,7 @@ bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) { bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) { - for (auto op = start->getNextNode(); op != nullptr; op++) { + for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) { // This check op has been found after start in its block if (op->isAncestor(check)) { return true; @@ -136,27 +136,20 @@ T findNearestDominatingOpByUse(Operation *op, Value v) { if (options.size() == 1 && dInfo.dominates(options[0], op)) return options[0]; - llvm::errs() << " scope: " << *op->getParentOp() << "\n"; - llvm::errs() << " want to replace " << *op << "\n"; for (auto opt : options) { if (!dInfo.dominates(opt, op)) continue; bool conflict = false; - llvm::errs() << " trying: " << *opt << "\n"; for (auto opt2 : options) { if (opt == opt2) continue; - llvm::errs() << " conflict check: " << *opt2 << "\n"; - if (!mayExecuteBetween(opt, opt2, op)) { - llvm::errs() << " + known good since occurs before store\n"; continue; } conflict = true; } if (!conflict) { - llvm::errs() << " - replaced with " << *opt << "\n"; return opt; } } diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index e0a74d88145a..3fe3ffb0ee66 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s +// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -enzyme-simplify-math -canonicalize %s | FileCheck %s module { func.func @ppow(%x: f64) -> f64 { @@ -19,29 +19,46 @@ module { } } -// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 +// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 { +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %[[one:.+]] = arith.constant 1.0 +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// Make sure the right values are being cached in the primal -// CHECK: %[[one:.+]] = arith.constant 1.0 -// CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -// CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]]) -// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) - -// Ensure the right value is yielded in the adjoint -// CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]]) -// CHECK: %[[dr:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]]) -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]]) -// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) -// CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]]) -// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]]) -// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] -// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) : -// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) -// CHECK-NEXT: %[[dr_next:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: scf.yield %[[dr_next]] -// CHECK: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) -// CHECK-NEXT: return %[[final]] +// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) { +// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[fwd]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 { +// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]] +// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64 +// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) : +// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index +// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) +// CHECK-NEXT: return %[[final]] \ No newline at end of file