diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 7d4ccead7764..ead8baad9261 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -5,7 +5,6 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionInterfaces.h" // TODO: this shouldn't depend on specific dialects except Enzyme. diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index dc98a5c4c6a6..56d49bf79b09 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -119,8 +119,7 @@ class MEnzymeLogic { std::vector constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, - std::vector volatile_args, void *augmented, - SymbolTableCollection &symbolTable); + std::vector volatile_args, void *augmented); void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); @@ -132,8 +131,6 @@ class MEnzymeLogic { MGradientUtilsReverse *gutils); void visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); - bool visitChildCustom(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils); void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); SmallVector getDominatorToposort(MGradientUtilsReverse *gutils, @@ -145,4 +142,4 @@ class MEnzymeLogic { }; } // Namespace enzyme -} // Namespace mlir \ No newline at end of file +} // Namespace mlir diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index f39766a0a639..25e8f1818cd2 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -35,76 +35,6 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, } } -bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) { - std::string nameDiffe = "diffe_" + op->getName().getDialectNamespace().str() + - "_" + op->getName().stripDialect().str(); - std::string nameStore = "store_" + op->getName().getDialectNamespace().str() + - "_" + op->getName().stripDialect().str(); - - StringRef srDiffe(nameDiffe); - StringRef srStore(nameStore); - - OperationName opNameDiffe(srDiffe, op->getContext()); - OperationName opNameStore(srStore, op->getContext()); - - Operation *symbolDiffe = gutils->symbolTable.lookupNearestSymbolFrom( - op, opNameDiffe.getIdentifier()); - Operation *symbolStore = gutils->symbolTable.lookupNearestSymbolFrom( - op, opNameStore.getIdentifier()); - - if (symbolDiffe != nullptr) { - SmallVector caches; - if (symbolStore != nullptr) { - Operation *newOp = gutils->getNewFromOriginal(op); - - func::FuncOp funcStore = cast(symbolStore); - - SmallVector storeResultTypes; - for (auto x : funcStore.getFunctionType().getResults()) { - storeResultTypes.push_back(x); - } - - SmallVector storeArgs; - for (auto x : newOp->getOperands()) { - storeArgs.push_back(x); - } - - OpBuilder storeBuilder(newOp); - func::CallOp storeCI = storeBuilder.create( - op->getLoc(), srStore, storeResultTypes, storeArgs); - for (auto x : storeCI.getResults()) { - caches.push_back(gutils->initAndPushCache(x, storeBuilder)); - } - } - - SmallVector args; - for (Value opResult : op->getResults()) { - if (!gutils->isConstantValue(opResult)) { - Value invertValue = gutils->invertPointerM(opResult, builder); - args.push_back(invertValue); - } - } - for (Value cache : caches) { - args.push_back(gutils->popCache(cache, builder)); - } - - SmallVector resultTypes; - for (auto x : op->getOperands()) { - resultTypes.push_back(x.getType()); - } - - func::CallOp dCI = - builder.create(op->getLoc(), srDiffe, resultTypes, args); - for (int i = 0; i < (int)op->getNumOperands(); i++) { - gutils->setDiffe(op->getOperand(i), dCI.getResult(i), builder); - } - - return true; - } - return false; -} - /* Create reverse mode adjoint for an operation. */ @@ -139,10 +69,7 @@ void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, auto last = oBB->rend(); for (auto it = first; it != last; ++it) { Operation *op = &*it; - bool customFound = visitChildCustom(op, revBuilder, gutils); - if (!customFound) { - visitChild(op, revBuilder, gutils); - } + visitChild(op, revBuilder, gutils); } } } @@ -257,8 +184,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( FunctionOpInterface fn, DIFFE_TYPE retType, std::vector constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented, - SymbolTableCollection &symbolTable) { + MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; @@ -268,7 +194,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( ReturnType returnValue = ReturnType::Args; MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( *this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, - constants, returnValue, addedType, symbolTable); + constants, returnValue, addedType); Region &oldRegion = gutils->oldFunc.getFunctionBody(); Region &newRegion = gutils->newFunc.getFunctionBody(); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 20d34247fb3a..b57fbe68b594 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -35,12 +35,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( const SmallPtrSetImpl &activevals_, DIFFE_TYPE ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_) + DerivativeMode mode_, unsigned width) : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, invertedPointers_, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, - originalToNewFnOps_, mode_, width, /*omp*/ false), - symbolTable(symbolTable_) {} + originalToNewFnOps_, mode_, width, /*omp*/ false) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -135,8 +134,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg, - SymbolTableCollection &symbolTable_) { + ReturnType returnValue, mlir::Type additionalArg) { std::string prefix; switch (mode_) { @@ -168,8 +166,8 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( prefix + todiff.getName(), originalToNew, originalToNewOps, diffeReturnArg, additionalArg); - return new MGradientUtilsReverse( - Logic, newFunc, todiff, TA, invertedPointers, constant_values, - nonconstant_values, retType, constant_args, originalToNew, - originalToNewOps, mode_, width, symbolTable_); + return new MGradientUtilsReverse(Logic, newFunc, todiff, TA, invertedPointers, + constant_values, nonconstant_values, retType, + constant_args, originalToNew, + originalToNewOps, mode_, width); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index d3b2e818391f..96e899939538 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -31,13 +31,10 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width, - SymbolTableCollection &symbolTable_); + DerivativeMode mode_, unsigned width); IRMapping mapReverseModeBlocks; - SymbolTableCollection &symbolTable; - void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder); @@ -67,8 +64,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg, - SymbolTableCollection &symbolTable_); + ReturnType returnValue, mlir::Type additionalArg); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index de7328dcb186..b7d33b6faedc 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -153,7 +153,7 @@ struct DifferentiatePass : public DifferentiatePassBase { fn, retType, constants, TA, /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr, symbolTable); + /*augmented*/ nullptr); if (!newFunc) return failure(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index b8d91f2c82e8..b48705c220d1 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -91,12 +91,21 @@ struct DifferentiateWrapperPass volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); } - FunctionOpInterface newFunc = Logic.CreateForwardDiff( - fn, retType, constants, TA, - /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, - width, - /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + FunctionOpInterface newFunc; + if (mode == DerivativeMode::ForwardMode) { + newFunc = Logic.CreateForwardDiff( + fn, retType, constants, TA, + /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, + width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + } else { + newFunc = Logic.CreateReverseDiff( + fn, retType, constants, TA, + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + } if (!newFunc) { signalPassFailure(); return;