diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 687eb17f93fd..37f692f957fb 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -184,12 +184,11 @@ void MEnzymeLogic::handlePredecessors( Block *newPred = gutils->getNewFromOriginal(pred); - OpBuilder predecessorBuilder(gutils->oldFunc->getContext()); - predecessorBuilder.setInsertionPointToStart(newPred); + OpBuilder predecessorBuilder(newPred->getTerminator()); Value pred_idx_c = predecessorBuilder.create(loc, idx - 1, 32); - newBuilder.create(loc, cache, pred_idx_c); + predecessorBuilder.create(loc, cache, pred_idx_c); if (idx == 0) { defaultBlock = reversePred; @@ -222,7 +221,7 @@ void MEnzymeLogic::handlePredecessors( revBuilder.create( loc, flag, defaultBlock, ArrayRef(), ArrayRef(indices), - ArrayRef(blocks), ArrayRef()); + ArrayRef(blocks), SmallVector(indices.size(), ValueRange())); } } diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index f28711dd54d1..32ea4e14f8cc 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -182,7 +182,6 @@ struct RemoveUnusedEnzymeOpsPass } } }); - llvm::errs() << " post: " << *getOperation() << "\n"; }; }; } // end anonymous namespace diff --git a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir index e5bb39eea040..1dd48c3b89f7 100644 --- a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir +++ b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s module { func.func @bbargs(%x: f64) -> f64 { @@ -19,15 +19,55 @@ module { } } -// CHECK: func.func @diff(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { -// CHECK-NEXT: %[[i0:.+]] = call @diffebbargs(%[[arg0]], %[[arg1]]) : (f64, f64) -> f64 -// CHECK-NEXT: return %[[i0:.+]] -// CHECK: func.func private @diffebbargs(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { - -// There should be exactly one block with two f64 args, and their values should be accumulated -// in the shadow. -// CHECK: ^[[BBMULTI:.+]](%[[fst:.+]]: f64, %[[snd:.+]]: f64): -// CHECK-NEXT: "enzyme.set"(%[[shadow:.+]], %[[fst]]) -// CHECK-NEXT: %[[before:.+]] = "enzyme.get"(%[[shadow]]) -// CHECK-NEXT: %[[after:.+]] = arith.addf %[[snd]], %[[before]] -// CHECK-NEXT: "enzyme.set"(%[[shadow]], %[[after]]) +// CHECK: func.func private @diffebbargs(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f64 +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %2 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %3 = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %4 = arith.addf %arg0, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%2, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%4 : f64) +// CHECK-NEXT: ^bb1(%5: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %6 = arith.cmpf ult, %5, %cst_0 : f64 +// CHECK-NEXT: "enzyme.push"(%2, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %6, ^bb1(%5 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // pred: ^bb1 +// CHECK-NEXT: %7 = arith.addf %arg1, %cst_0 : f64 +// CHECK-NEXT: %8 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %9 = arith.cmpi eq, %8, %c-1_i32 : i32 +// CHECK-NEXT: %10 = arith.select %9, %7, %cst_0 : f64 +// CHECK-NEXT: %11 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %12 = arith.addf %11, %10 : f64 +// CHECK-NEXT: "enzyme.set"(%1, %12) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: cf.br ^bb3 +// CHECK-NEXT: ^bb3: // 2 preds: ^bb2, ^bb3 +// CHECK-NEXT: %13 = "enzyme.pop"(%2) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %14 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %15 = arith.cmpi eq, %13, %c-1_i32 : i32 +// CHECK-NEXT: %16 = arith.select %15, %14, %cst_0 : f64 +// CHECK-NEXT: %17 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %18 = arith.addf %17, %16 : f64 +// CHECK-NEXT: "enzyme.set"(%1, %18) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %19 = arith.cmpi eq, %13, %c-1_i32 : i32 +// CHECK-NEXT: %20 = arith.select %19, %14, %cst_0 : f64 +// CHECK-NEXT: %21 = "enzyme.get"(%3) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %22 = arith.addf %21, %20 : f64 +// CHECK-NEXT: "enzyme.set"(%3, %22) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: cf.switch %13 : i32, [ +// CHECK-NEXT: default: ^bb3, +// CHECK-NEXT: 0: ^bb4 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb4: // pred: ^bb3 +// CHECK-NEXT: %23 = "enzyme.get"(%3) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %24 = arith.addf %23, %cst_0 : f64 +// CHECK-NEXT: return %24 : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/square.mlir b/enzyme/test/MLIR/ReverseMode/square.mlir index 0c9fb0cc5264..37d57d426033 100644 --- a/enzyme/test/MLIR/ReverseMode/square.mlir +++ b/enzyme/test/MLIR/ReverseMode/square.mlir @@ -1,4 +1,6 @@ // RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM +// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN module { func.func @square(%x: f64) -> f64 { @@ -54,3 +56,25 @@ module { // CHECK-NEXT: %[[res:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 // CHECK-NEXT: return %[[res]] : f64 // CHECK-NEXT: } + + +// REM: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// REM-NEXT: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[pmu:.+]] = arith.mulf %arg0, %arg0 : f64 +// REM-NEXT: cf.br ^bb1 +// REM-NEXT: ^bb1: // pred: ^bb0 +// REM-NEXT: %[[a1:.+]] = arith.addf %[[cst_0]], %arg1 : f64 +// REM-NEXT: %[[cst_1:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a3:.+]] = arith.addf %[[cst]], %[[a2]] : f64 +// REM-NEXT: %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64 +// REM-NEXT: return %[[a5]] : f64 +// REM-NEXT: } + +// FIN: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// FIN-NEXT: %0 = arith.mulf %arg1, %arg0 : f64 +// FIN-NEXT: %1 = arith.addf %0, %0 : f64 +// FIN-NEXT: return %1 : f64 +// FIN-NEXT: } \ No newline at end of file