diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index cbad734656b1..c556aa4b6e33 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -113,9 +113,11 @@ class NoopRevAutoDiffInterface : public ReverseAutoDiffOpInterface::ExternalModel< NoopRevAutoDiffInterface, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, - SmallVector caches) const {} + SmallVector caches) const { + return success(); + } SmallVector cacheValues(Operation *op, MGradientUtilsReverse *gutils) const { @@ -131,10 +133,11 @@ class ReturnRevAutoDiffInterface : public ReverseAutoDiffOpInterface::ExternalModel< ReturnRevAutoDiffInterface, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector caches) const { returnReverseHandler(op, builder, gutils); + return success(); } SmallVector cacheValues(Operation *op, @@ -181,9 +184,11 @@ class AutoDiffUsingAllocationRev : public ReverseAutoDiffOpInterface::ExternalModel< AutoDiffUsingAllocationRev, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, - SmallVector caches) const {} + SmallVector caches) const { + return success(); + } SmallVector cacheValues(Operation *op, MGradientUtilsReverse *gutils) const { diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index db9f4b08d6e1..7bd821e4bbfc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -68,7 +68,7 @@ template struct GenericOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel< GenericOpInterfaceReverse, T_> { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector caches) const { auto linalgOp = cast(op); @@ -255,6 +255,7 @@ struct GenericOpInterfaceReverse cacheBuilder.getArrayAttr(indexingMapsAttr)); adjoint->setAttr(adjoint.getIndexingMapsAttrName(), builder.getArrayAttr(indexingMapsAttrAdjoint)); + return success(); } SmallVector cacheValues(Operation *op, diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index cd21c5c548b9..229648e3e715 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -34,7 +34,7 @@ namespace { struct LoadOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector caches) const { auto loadOp = cast(op); @@ -62,6 +62,7 @@ struct LoadOpInterfaceReverse ArrayRef(retrievedArguments)); } } + return success(); } SmallVector cacheValues(Operation *op, @@ -96,7 +97,7 @@ struct LoadOpInterfaceReverse struct StoreOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector caches) const { auto storeOp = cast(op); @@ -133,6 +134,7 @@ struct StoreOpInterfaceReverse ArrayRef(retrievedArguments)); } } + return success(); } SmallVector cacheValues(Operation *op, @@ -167,9 +169,11 @@ struct StoreOpInterfaceReverse struct SubViewOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel< SubViewOpInterfaceReverse, memref::SubViewOp> { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, - SmallVector caches) const {} + SmallVector caches) const { + return success(); + } SmallVector cacheValues(Operation *op, MGradientUtilsReverse *gutils) const { diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 72fbe2106a53..6187eaec5b09 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -35,7 +35,7 @@ namespace { struct ForOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector caches) const { auto forOp = cast(op); @@ -149,6 +149,7 @@ struct ForOpInterfaceReverse buildFuncReturnOp); } } + return success(); } SmallVector cacheValues(Operation *op, diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 9123dfcaa22e..5507dc712eaf 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -87,7 +87,7 @@ def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { /*desc=*/[{ Emits a reverse-mode adjoint of the given function. }], - /*retTy=*/"void", + /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"createReverseModeAdjoint", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::enzyme::MGradientUtilsReverse *":$gutils, "SmallVector":$caches) >, diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index 56d49bf79b09..9463569b573f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -127,15 +127,13 @@ class MEnzymeLogic { handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, llvm::function_ref buildReturnOp); - void visitChildren(Block *oBB, Block *reverseBB, + LogicalResult visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); - void visitChild(Operation *op, OpBuilder &builder, + LogicalResult visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); - SmallVector getDominatorToposort(MGradientUtilsReverse *gutils, - Region ®ion); - void differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, + LogicalResult differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref buildFuncRetrunOp, std::function(Type)> cacheCreator); diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 25e8f1818cd2..30ea226730da 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -38,40 +38,35 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, /* Create reverse mode adjoint for an operation. */ -void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, +LogicalResult MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) { if ((op->getBlock()->getTerminator() != op) && llvm::all_of(op->getResults(), [gutils](Value v) { return gutils->isConstantValue(v); }) && gutils->isConstantInstruction(op)) { - return; + return success(); } if (auto ifaceOp = dyn_cast(op)) { SmallVector caches = ifaceOp.cacheValues(gutils); - ifaceOp.createReverseModeAdjoint(builder, gutils, caches); - return; - /* - for (int indexResult = 0; indexResult < (int)op->getNumResults(); - indexResult++) { - Value result = op->getResult(indexResult); - gutils->clearValue(result, builder); - } - */ + return ifaceOp.createReverseModeAdjoint(builder, gutils, caches); } op->emitError() << "could not compute the adjoint for this operation " << *op; + return failure(); } -void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, +LogicalResult MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils) { OpBuilder revBuilder(reverseBB, reverseBB->end()); + bool valid = true; if (!oBB->empty()) { auto first = oBB->rbegin(); auto last = oBB->rend(); for (auto it = first; it != last; ++it) { Operation *op = &*it; - visitChild(op, revBuilder, gutils); + valid &= visitChild(op, revBuilder, gutils).succeeded(); } } + return success(valid); } void MEnzymeLogic::handlePredecessors( @@ -161,7 +156,7 @@ void MEnzymeLogic::handlePredecessors( } } -void MEnzymeLogic::differentiate( +LogicalResult MEnzymeLogic::differentiate( MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref buildFuncReturnOp, std::function(Type)> cacheCreator) { @@ -171,13 +166,15 @@ void MEnzymeLogic::differentiate( gutils->createReverseModeBlocks(oldRegion, newRegion); + bool valid = true; for (auto &oBB : oldRegion) { Block *newBB = gutils->getNewFromOriginal(&oBB); Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB); handleReturns(&oBB, newBB, reverseBB, gutils); - visitChildren(&oBB, reverseBB, gutils); + valid &= visitChildren(&oBB, reverseBB, gutils).succeeded(); handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp); } + return success(valid); } FunctionOpInterface MEnzymeLogic::CreateReverseDiff( @@ -212,14 +209,19 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( gutils->forceAugmentedReturns(); - differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr); + auto res = differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr); auto nf = gutils->newFunc; + // llvm::errs() << "nf\n"; // nf.dump(); // llvm::errs() << "nf end\n"; delete gutils; + + if (!res.succeeded()) + return nullptr; + return nf; } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 6eff540b6622..7796486fe5de 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1341,7 +1341,7 @@ static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree, os << " MGradientUtilsReverse *gutils) const " "{}\n"; - os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " + os << " LogicalResult createReverseModeAdjoint(Operation *op0, OpBuilder " "&builder,\n"; os << " MGradientUtilsReverse *gutils,\n"; os << " SmallVector caches) const {\n"; @@ -1995,6 +1995,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, if (intrinsic == IntrDerivatives || intrinsic == CallDerivatives) os << " return true;\n }\n"; + else if (intrinsic == MLIRDerivatives) + os << " return success();\n }\n"; else os << " return;\n }\n"; if (intrinsic == MLIRDerivatives) @@ -2063,7 +2065,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, auto origName = "op"; emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); - os << " return;\n"; + os << " return success();\n"; os << " }\n"; os << " };\n"; }