Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 29, 2024
1 parent 5b7e443 commit b9ab98f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 18 deletions.
7 changes: 3 additions & 4 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,11 @@ void MEnzymeLogic::handlePredecessors(

Block *newPred = gutils->getNewFromOriginal(pred);

OpBuilder predecessorBuilder(gutils->oldFunc->getContext());
predecessorBuilder.setInsertionPointToStart(newPred);
OpBuilder predecessorBuilder(newPred->getTerminator());

Value pred_idx_c =
predecessorBuilder.create<arith::ConstantIntOp>(loc, idx - 1, 32);
newBuilder.create<enzyme::PushOp>(loc, cache, pred_idx_c);
predecessorBuilder.create<enzyme::PushOp>(loc, cache, pred_idx_c);

if (idx == 0) {
defaultBlock = reversePred;
Expand Down Expand Up @@ -222,7 +221,7 @@ void MEnzymeLogic::handlePredecessors(

revBuilder.create<cf::SwitchOp>(
loc, flag, defaultBlock, ArrayRef<Value>(), ArrayRef<APInt>(indices),
ArrayRef<Block *>(blocks), ArrayRef<ValueRange>());
ArrayRef<Block *>(blocks), SmallVector<ValueRange>(indices.size(), ValueRange()));

}
}
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ struct RemoveUnusedEnzymeOpsPass
}
}
});
llvm::errs() << " post: " << *getOperation() << "\n";
};
};
} // end anonymous namespace
Expand Down
66 changes: 53 additions & 13 deletions enzyme/test/MLIR/ReverseMode/bbarg-order.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt --enzyme %s | FileCheck %s
// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s

module {
func.func @bbargs(%x: f64) -> f64 {
Expand All @@ -19,15 +19,55 @@ module {
}
}

// CHECK: func.func @diff(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
// CHECK-NEXT: %[[i0:.+]] = call @diffebbargs(%[[arg0]], %[[arg1]]) : (f64, f64) -> f64
// CHECK-NEXT: return %[[i0:.+]]
// CHECK: func.func private @diffebbargs(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {

// There should be exactly one block with two f64 args, and their values should be accumulated
// in the shadow.
// CHECK: ^[[BBMULTI:.+]](%[[fst:.+]]: f64, %[[snd:.+]]: f64):
// CHECK-NEXT: "enzyme.set"(%[[shadow:.+]], %[[fst]])
// CHECK-NEXT: %[[before:.+]] = "enzyme.get"(%[[shadow]])
// CHECK-NEXT: %[[after:.+]] = arith.addf %[[snd]], %[[before]]
// CHECK-NEXT: "enzyme.set"(%[[shadow]], %[[after]])
// CHECK: func.func private @diffebbargs(%arg0: f64, %arg1: f64) -> f64 {
// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f64
// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache<i32>
// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %2 = "enzyme.init"() : () -> !enzyme.Cache<i32>
// CHECK-NEXT: %3 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %4 = arith.addf %arg0, %cst : f64
// CHECK-NEXT: "enzyme.push"(%2, %c0_i32) : (!enzyme.Cache<i32>, i32) -> ()
// CHECK-NEXT: cf.br ^bb1(%4 : f64)
// CHECK-NEXT: ^bb1(%5: f64): // 2 preds: ^bb0, ^bb1
// CHECK-NEXT: %6 = arith.cmpf ult, %5, %cst_0 : f64
// CHECK-NEXT: "enzyme.push"(%2, %c-1_i32) : (!enzyme.Cache<i32>, i32) -> ()
// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache<i32>, i32) -> ()
// CHECK-NEXT: cf.cond_br %6, ^bb1(%5 : f64), ^bb2
// CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %7 = arith.addf %arg1, %cst_0 : f64
// CHECK-NEXT: %8 = "enzyme.pop"(%0) : (!enzyme.Cache<i32>) -> i32
// CHECK-NEXT: %9 = arith.cmpi eq, %8, %c-1_i32 : i32
// CHECK-NEXT: %10 = arith.select %9, %7, %cst_0 : f64
// CHECK-NEXT: %11 = "enzyme.get"(%1) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: %12 = arith.addf %11, %10 : f64
// CHECK-NEXT: "enzyme.set"(%1, %12) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb3: // 2 preds: ^bb2, ^bb3
// CHECK-NEXT: %13 = "enzyme.pop"(%2) : (!enzyme.Cache<i32>) -> i32
// CHECK-NEXT: %14 = "enzyme.get"(%1) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %15 = arith.cmpi eq, %13, %c-1_i32 : i32
// CHECK-NEXT: %16 = arith.select %15, %14, %cst_0 : f64
// CHECK-NEXT: %17 = "enzyme.get"(%1) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: %18 = arith.addf %17, %16 : f64
// CHECK-NEXT: "enzyme.set"(%1, %18) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %19 = arith.cmpi eq, %13, %c-1_i32 : i32
// CHECK-NEXT: %20 = arith.select %19, %14, %cst_0 : f64
// CHECK-NEXT: %21 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: %22 = arith.addf %21, %20 : f64
// CHECK-NEXT: "enzyme.set"(%3, %22) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: cf.switch %13 : i32, [
// CHECK-NEXT: default: ^bb3,
// CHECK-NEXT: 0: ^bb4
// CHECK-NEXT: ]
// CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: %23 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT: %24 = arith.addf %23, %cst_0 : f64
// CHECK-NEXT: return %24 : f64
// CHECK-NEXT: }
24 changes: 24 additions & 0 deletions enzyme/test/MLIR/ReverseMode/square.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: %eopt --enzyme %s | FileCheck %s
// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM
// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN

module {
func.func @square(%x: f64) -> f64 {
Expand Down Expand Up @@ -54,3 +56,25 @@ module {
// CHECK-NEXT: %[[res:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT: return %[[res]] : f64
// CHECK-NEXT: }


// REM: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 {
// REM-NEXT: %[[cst:.+]] = arith.constant 0.000000e+00 : f64
// REM-NEXT: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64
// REM-NEXT: %[[pmu:.+]] = arith.mulf %arg0, %arg0 : f64
// REM-NEXT: cf.br ^bb1
// REM-NEXT: ^bb1: // pred: ^bb0
// REM-NEXT: %[[a1:.+]] = arith.addf %[[cst_0]], %arg1 : f64
// REM-NEXT: %[[cst_1:.+]] = arith.constant 0.000000e+00 : f64
// REM-NEXT: %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64
// REM-NEXT: %[[a3:.+]] = arith.addf %[[cst]], %[[a2]] : f64
// REM-NEXT: %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64
// REM-NEXT: %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64
// REM-NEXT: return %[[a5]] : f64
// REM-NEXT: }

// FIN: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 {
// FIN-NEXT: %0 = arith.mulf %arg1, %arg0 : f64
// FIN-NEXT: %1 = arith.addf %0, %0 : f64
// FIN-NEXT: return %1 : f64
// FIN-NEXT: }

0 comments on commit b9ab98f

Please sign in to comment.