From 97066352a40b3c66f9a1f41ec1802af255216c0c Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 5 Mar 2024 21:52:12 -0600 Subject: [PATCH] MLIR support LogicalResult return in reversemode (#1779) --- .../CoreDialectsAutoDiffImplementations.h | 23 ++++++----- .../LinalgAutoDiffOpInterfaceImpl.cpp | 7 ++-- .../MemRefAutoDiffOpInterfaceImpl.cpp | 22 ++++++----- .../SCFAutoDiffOpInterfaceImpl.cpp | 7 ++-- .../MLIR/Interfaces/AutoDiffOpInterface.td | 2 +- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 19 +++++----- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 38 ++++++++++--------- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 6 ++- 8 files changed, 69 insertions(+), 55 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index cbad734656b1..14ebb622ac36 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -113,9 +113,11 @@ class NoopRevAutoDiffInterface : public ReverseAutoDiffOpInterface::ExternalModel< NoopRevAutoDiffInterface, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const {} + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + return success(); + } SmallVector cacheValues(Operation *op, MGradientUtilsReverse *gutils) const { @@ -131,10 +133,11 @@ class ReturnRevAutoDiffInterface : public ReverseAutoDiffOpInterface::ExternalModel< ReturnRevAutoDiffInterface, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { returnReverseHandler(op, builder, gutils); + return success(); } SmallVector cacheValues(Operation *op, @@ -181,9 +184,11 @@ class AutoDiffUsingAllocationRev : public ReverseAutoDiffOpInterface::ExternalModel< AutoDiffUsingAllocationRev, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const {} + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + return success(); + } SmallVector cacheValues(Operation *op, MGradientUtilsReverse *gutils) const { diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index db9f4b08d6e1..4fffa1557433 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -68,9 +68,9 @@ template struct GenericOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel< GenericOpInterfaceReverse, T_> { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto linalgOp = cast(op); assert(linalgOp.hasPureBufferSemantics() && "Linalg op with tensor semantics not yet supported"); @@ -255,6 +255,7 @@ struct GenericOpInterfaceReverse cacheBuilder.getArrayAttr(indexingMapsAttr)); adjoint->setAttr(adjoint.getIndexingMapsAttrName(), builder.getArrayAttr(indexingMapsAttrAdjoint)); + return success(); } SmallVector cacheValues(Operation *op, diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index cd21c5c548b9..b25917c41cad 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -34,9 +34,9 @@ namespace { struct LoadOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto loadOp = cast(op); Value memref = loadOp.getMemref(); @@ -62,6 +62,7 @@ struct LoadOpInterfaceReverse ArrayRef(retrievedArguments)); } } + return success(); } SmallVector cacheValues(Operation *op, @@ -96,9 +97,9 @@ struct LoadOpInterfaceReverse struct StoreOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto storeOp = cast(op); Value val = storeOp.getValue(); Value memref = storeOp.getMemref(); @@ -133,6 +134,7 @@ struct StoreOpInterfaceReverse ArrayRef(retrievedArguments)); } } + return success(); } SmallVector cacheValues(Operation *op, @@ -167,9 +169,11 @@ struct StoreOpInterfaceReverse struct SubViewOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel< SubViewOpInterfaceReverse, memref::SubViewOp> { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const {} + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + return success(); + } SmallVector cacheValues(Operation *op, MGradientUtilsReverse *gutils) const { diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 72fbe2106a53..e22cb9daef35 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -35,9 +35,9 @@ namespace { struct ForOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto forOp = cast(op); // Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 @@ -149,6 +149,7 @@ struct ForOpInterfaceReverse buildFuncReturnOp); } } + return success(); } SmallVector cacheValues(Operation *op, diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 9123dfcaa22e..5507dc712eaf 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -87,7 +87,7 @@ def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { /*desc=*/[{ Emits a reverse-mode adjoint of the given function. }], - /*retTy=*/"void", + /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"createReverseModeAdjoint", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::enzyme::MGradientUtilsReverse *":$gutils, "SmallVector":$caches) >, diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index 56d49bf79b09..0005377cc5d7 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -127,18 +127,17 @@ class MEnzymeLogic { handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, llvm::function_ref buildReturnOp); - void visitChildren(Block *oBB, Block *reverseBB, - MGradientUtilsReverse *gutils); - void visitChild(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils); + LogicalResult visitChildren(Block *oBB, Block *reverseBB, + MGradientUtilsReverse *gutils); + LogicalResult visitChild(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils); void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); - SmallVector getDominatorToposort(MGradientUtilsReverse *gutils, - Region ®ion); - void differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, - Region &newRegion, - llvm::function_ref buildFuncRetrunOp, - std::function(Type)> cacheCreator); + LogicalResult + differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, + Region &newRegion, + llvm::function_ref buildFuncRetrunOp, + std::function(Type)> cacheCreator); }; } // Namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 25e8f1818cd2..96d5a65a884b 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -38,40 +38,35 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, /* Create reverse mode adjoint for an operation. */ -void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) { +LogicalResult MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) { if ((op->getBlock()->getTerminator() != op) && llvm::all_of(op->getResults(), [gutils](Value v) { return gutils->isConstantValue(v); }) && gutils->isConstantInstruction(op)) { - return; + return success(); } if (auto ifaceOp = dyn_cast(op)) { SmallVector caches = ifaceOp.cacheValues(gutils); - ifaceOp.createReverseModeAdjoint(builder, gutils, caches); - return; - /* - for (int indexResult = 0; indexResult < (int)op->getNumResults(); - indexResult++) { - Value result = op->getResult(indexResult); - gutils->clearValue(result, builder); - } - */ + return ifaceOp.createReverseModeAdjoint(builder, gutils, caches); } op->emitError() << "could not compute the adjoint for this operation " << *op; + return failure(); } -void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, - MGradientUtilsReverse *gutils) { +LogicalResult MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, + MGradientUtilsReverse *gutils) { OpBuilder revBuilder(reverseBB, reverseBB->end()); + bool valid = true; if (!oBB->empty()) { auto first = oBB->rbegin(); auto last = oBB->rend(); for (auto it = first; it != last; ++it) { Operation *op = &*it; - visitChild(op, revBuilder, gutils); + valid &= visitChild(op, revBuilder, gutils).succeeded(); } } + return success(valid); } void MEnzymeLogic::handlePredecessors( @@ -161,7 +156,7 @@ void MEnzymeLogic::handlePredecessors( } } -void MEnzymeLogic::differentiate( +LogicalResult MEnzymeLogic::differentiate( MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref buildFuncReturnOp, std::function(Type)> cacheCreator) { @@ -171,13 +166,15 @@ void MEnzymeLogic::differentiate( gutils->createReverseModeBlocks(oldRegion, newRegion); + bool valid = true; for (auto &oBB : oldRegion) { Block *newBB = gutils->getNewFromOriginal(&oBB); Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB); handleReturns(&oBB, newBB, reverseBB, gutils); - visitChildren(&oBB, reverseBB, gutils); + valid &= visitChildren(&oBB, reverseBB, gutils).succeeded(); handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp); } + return success(valid); } FunctionOpInterface MEnzymeLogic::CreateReverseDiff( @@ -212,7 +209,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( gutils->forceAugmentedReturns(); - differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr); + auto res = + differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr); auto nf = gutils->newFunc; @@ -221,5 +219,9 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( // llvm::errs() << "nf end\n"; delete gutils; + + if (!res.succeeded()) + return nullptr; + return nf; } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 6eff540b6622..7796486fe5de 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1341,7 +1341,7 @@ static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree, os << " MGradientUtilsReverse *gutils) const " "{}\n"; - os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " + os << " LogicalResult createReverseModeAdjoint(Operation *op0, OpBuilder " "&builder,\n"; os << " MGradientUtilsReverse *gutils,\n"; os << " SmallVector caches) const {\n"; @@ -1995,6 +1995,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, if (intrinsic == IntrDerivatives || intrinsic == CallDerivatives) os << " return true;\n }\n"; + else if (intrinsic == MLIRDerivatives) + os << " return success();\n }\n"; else os << " return;\n }\n"; if (intrinsic == MLIRDerivatives) @@ -2063,7 +2065,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, auto origName = "op"; emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); - os << " return;\n"; + os << " return success();\n"; os << " }\n"; os << " };\n"; }