Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 29, 2024
1 parent 005f4fc commit a90ea7e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 38 deletions.
6 changes: 4 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def Gradient : Enzyme_Type<"Gradient"> {

def SetOp : Enzyme_Op<"set"> {
let summary = "Store the current value of the gradient";
let arguments = (ins AnyType : $gradient, AnyType : $value);
let arguments = (ins Arg<AnyType, "the reference to store to",
[MemWrite]>:$gradient, AnyType : $value);
let results = (outs );
}

def GetOp : Enzyme_Op<"get"> {
let summary = "Load current value of gradient";
let arguments = (ins AnyType : $gradient);
let arguments = (ins Arg<AnyType, "the reference to load from",
[MemRead]>:$gradient);
let results = (outs AnyType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct ForOpInterfaceReverse
for (OpResult v : forOp.getResults()) {
if (!gutils->isConstantValue(v)) {
auto autoDiffType = cast<AutoDiffTypeInterface>(v.getType());
if (autoDiffType.isMutable()) {
if (!autoDiffType.isMutable()) {
auto prev = gutils->diffe(v, builder);
gutils->zeroDiffe(v, builder);
resDiffes.push_back(prev);
Expand All @@ -72,6 +72,7 @@ struct ForOpInterfaceReverse
OperandRange operandRange =
termIface.getSuccessorOperands(successor);
assert(operandRange.size() == resDiffes.size());

// There is an assumption here that there is only regions that branch to the successor.
// Specifically, otherwise we would need to gutils->addToDiffe select (if came from that result)
for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) {
Expand Down Expand Up @@ -108,7 +109,7 @@ struct ForOpInterfaceReverse

auto condition = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, lhs, end);

for (auto [arg, init_arg] : llvm::zip(oBB->getArguments(), forOp.getInitArgs())) {
for (auto [arg, init_arg] : llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) {
if (!gutils->isConstantValue(arg) && !cast<AutoDiffTypeInterface>(arg.getType()).isMutable()) {
auto diffe = gutils->diffe(arg, builder);
gutils->zeroDiffe(arg, builder);
Expand All @@ -132,6 +133,8 @@ struct ForOpInterfaceReverse

for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->mapReverseModeBlocks.map(&oBB, &revBB);
}
for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->Logic.visitChildren(&oBB, &revBB, gutils);
Block *newBB = gutils->getNewFromOriginal(&oBB);
gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, buildFuncReturnOp);
Expand Down
9 changes: 1 addition & 8 deletions enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) {

bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) {

for (auto op = start->getNextNode(); op != nullptr; op++) {
for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) {
// This check op has been found after start in its block
if (op->isAncestor(check)) {
return true;
Expand Down Expand Up @@ -136,27 +136,20 @@ T findNearestDominatingOpByUse(Operation *op, Value v) {
if (options.size() == 1 && dInfo.dominates(options[0], op))
return options[0];

llvm::errs() << " scope: " << *op->getParentOp() << "\n";
llvm::errs() << " want to replace " << *op << "\n";
for (auto opt : options) {
if (!dInfo.dominates(opt, op))
continue;
bool conflict = false;
llvm::errs() << " trying: " << *opt << "\n";
for (auto opt2 : options) {
if (opt == opt2) continue;

llvm::errs() << " conflict check: " << *opt2 << "\n";

if (!mayExecuteBetween(opt, opt2, op)) {
llvm::errs() << " + known good since occurs before store\n";
continue;
}

conflict = true;
}
if (!conflict) {
llvm::errs() << " - replaced with " << *opt << "\n";
return opt;
}
}
Expand Down
69 changes: 43 additions & 26 deletions enzyme/test/MLIR/ReverseMode/pow.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s
// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -enzyme-simplify-math -canonicalize %s | FileCheck %s

module {
func.func @ppow(%x: f64) -> f64 {
Expand All @@ -19,29 +19,46 @@ module {
}
}

// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64
// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 {
// CHECK-NEXT: %c10 = arith.constant 10 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %[[one:.+]] = arith.constant 1.0
// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache<f64>
// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache<f64>
// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient<f64>, f64) -> ()

// Make sure the right values are being cached in the primal
// CHECK: %[[one:.+]] = arith.constant 1.0
// CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]])
// CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]])
// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]])

// Ensure the right value is yielded in the adjoint
// CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]])
// CHECK: %[[dr:.+]] = "enzyme.get"(%[[rshadow]])
// CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]])
// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]])
// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]])
// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]])
// CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]])
// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]]
// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]])
// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]]
// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) :
// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]]
// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]])
// CHECK-NEXT: %[[dr_next:.+]] = "enzyme.get"(%[[rshadow]])
// CHECK-NEXT: scf.yield %[[dr_next]]
// CHECK: %[[final:.+]] = "enzyme.get"(%[[xshadow]])
// CHECK-NEXT: return %[[final]]
// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) {
// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache<f64>, f64) -> ()
// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache<f64>, f64) -> ()
// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64
// CHECK-NEXT: scf.yield %[[fwd]] : f64
// CHECK-NEXT: }
// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 {
// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache<f64>) -> f64
// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache<f64>) -> f64
// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]]
// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64
// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64
// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) :
// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]]
// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index
// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index
// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64
// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]])
// CHECK-NEXT: return %[[final]]

0 comments on commit a90ea7e

Please sign in to comment.