From 429b3e05ede72fd49c0ffc170cff841027c0a78e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 9 Mar 2024 16:34:16 -0500 Subject: [PATCH] MLIR handle multiple distinct return activities --- .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 20 ++- .../Enzyme/MLIR/Analysis/ActivityAnalysis.h | 4 +- .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 130 +++++++----------- enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h | 19 ++- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 93 ++++--------- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 14 +- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 9 +- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 51 +------ enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 29 ++-- .../MLIR/Interfaces/GradientUtilsReverse.cpp | 12 +- .../MLIR/Interfaces/GradientUtilsReverse.h | 11 +- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 64 +++++---- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 71 ++++++---- enzyme/Enzyme/MLIR/Passes/Passes.td | 16 +-- .../MLIR/Passes/PrintActivityAnalysis.cpp | 9 +- enzyme/test/MLIR/ForwardMode/wrap.mlir | 2 +- enzyme/test/MLIR/ReverseMode/multiret.mlir | 27 ++++ 17 files changed, 276 insertions(+), 305 deletions(-) create mode 100644 enzyme/test/MLIR/ReverseMode/multiret.mlir diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 1415559b3f6b..cb84062d9066 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -3073,7 +3073,15 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( continue; } if (isFunctionReturn(a)) { - if (ActiveReturns == DIFFE_TYPE::CONSTANT) { + bool allConstant = true; + assert(a->getOperands().size() == ActiveReturns.size()); + for (auto &&[retval, act] : llvm::zip(a->getOperands(), ActiveReturns)) { + if (retval != parent) continue; + if (act == DIFFE_TYPE::CONSTANT) continue; + allConstant = false; + break; + } + if (allConstant) { continue; } else { return false; @@ -3424,7 +3432,15 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( return false; } if (isFunctionReturn(a)) { - if (ActiveReturns == DIFFE_TYPE::CONSTANT) + bool allConstant = true; + assert(a->getOperands().size() == ActiveReturns.size()); + for (auto &&[retval, act] : llvm::zip(a->getOperands(), ActiveReturns)) { + if (retval != val) continue; + if (act == DIFFE_TYPE::CONSTANT) continue; + allConstant = false; + break; + } + if (allConstant) continue; if (EnzymePrintActivity) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h index c73df2025883..bbc2a01e7084 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h @@ -30,7 +30,7 @@ class ActivityAnalyzer { public: /// Whether the returns of the function being analyzed are active - const DIFFE_TYPE ActiveReturns; + const llvm::ArrayRef ActiveReturns; private: /// Direction of current analysis @@ -71,7 +71,7 @@ class ActivityAnalyzer { // llvm::TargetLibraryInfo &TLI_, const llvm::SmallPtrSetImpl &ConstantValues, const llvm::SmallPtrSetImpl &ActiveValues, - DIFFE_TYPE ActiveReturns) + llvm::ArrayRef ActiveReturns) : notForAnalysis(notForAnalysis_), ActiveReturns(ActiveReturns), directions(UP | DOWN), ConstantValues(ConstantValues.begin(), ConstantValues.end()), diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 9b5c007c62ee..fc53347bb5f6 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -14,71 +14,48 @@ Type getShadowType(Type type, unsigned width) { mlir::FunctionType getFunctionTypeForClone( mlir::FunctionType FTy, DerivativeMode mode, unsigned width, - mlir::Type additionalArg, llvm::ArrayRef constant_args, - bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE ReturnType) { + mlir::Type additionalArg, + const std::vector &returnPrimals, + const std::vector &returnShadows, + llvm::ArrayRef ReturnActivity, + llvm::ArrayRef ArgActivity) { + SmallVector RetTypes; - if (returnValue == ReturnType::ArgsWithReturn || - returnValue == ReturnType::Return) { - assert(FTy.getNumResults() >= 1); - for (size_t i = 0; i < FTy.getNumResults(); i++) { - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(i), width)); - } else { - RetTypes.push_back(FTy.getResult(i)); - } - } - } else if (returnValue == ReturnType::ArgsWithTwoReturns || - returnValue == ReturnType::TwoReturns) { - assert(FTy.getNumResults() >= 1); - for (size_t i = 0; i < FTy.getNumResults(); i++) { - RetTypes.push_back(FTy.getResult(i)); - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(i), width)); - } else { - RetTypes.push_back(FTy.getResult(i)); + + for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip(FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) { + if (returnPrimal) + RetTypes.push_back(Ty); + if (returnShadow) { + assert(activity != DIFFE_TYPE::CONSTANT); + assert(activity != DIFFE_TYPE::OUT_DIFF); + RetTypes.push_back(getShadowType(Ty, width)); } } - } SmallVector ArgTypes; - // The user might be deleting arguments to the function by specifying them in - // the VMap. If so, we need to not add the arguments to the arg ty vector - unsigned argno = 0; - - for (auto I : FTy.getInputs()) { - ArgTypes.push_back(I); - if (constant_args[argno] == DIFFE_TYPE::DUP_ARG || - constant_args[argno] == DIFFE_TYPE::DUP_NONEED) { - ArgTypes.push_back(getShadowType(I, width)); - } else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(I, width)); + for (auto &&[ITy, act] : llvm::zip(FTy.getInputs(), ArgActivity)) { + ArgTypes.push_back(ITy); + if (act == DIFFE_TYPE::DUP_ARG || + act == DIFFE_TYPE::DUP_NONEED) { + ArgTypes.push_back(getShadowType(ITy, width)); + } else if (act == DIFFE_TYPE::OUT_DIFF) { + RetTypes.push_back(getShadowType(ITy, width)); } - ++argno; } - // TODO: Expand for multiple returns - if (diffeReturnArg) { - ArgTypes.push_back(getShadowType(FTy.getResult(0), width)); + for (auto &&[Ty, activity] : llvm::zip(FTy.getResults(), ReturnActivity)) { + if (activity == DIFFE_TYPE::OUT_DIFF) { + ArgTypes.push_back(getShadowType(Ty, width)); + } } + if (additionalArg) { ArgTypes.push_back(additionalArg); } - OpBuilder builder(FTy.getContext()); - if (returnValue == ReturnType::TapeAndTwoReturns || - returnValue == ReturnType::TapeAndReturn) { - RetTypes.insert(RetTypes.begin(), - LLVM::LLVMPointerType::get(FTy.getContext())); - } else if (returnValue == ReturnType::Tape) { - for (auto I : FTy.getInputs()) { - RetTypes.push_back(I); - } - } - // Create a new function type... + OpBuilder builder(FTy.getContext()); return builder.getFunctionType(ArgTypes, RetTypes); } @@ -205,19 +182,21 @@ void cloneInto(Region *src, Region *dest, Region::iterator destPos, FunctionOpInterface CloneFunctionWithReturns( DerivativeMode mode, unsigned width, FunctionOpInterface F, - IRMapping &ptrInputs, ArrayRef constant_args, + IRMapping &ptrInputs, ArrayRef ArgActivity, SmallPtrSetImpl &constants, SmallPtrSetImpl &nonconstants, - SmallPtrSetImpl &returnvals, ReturnType returnValue, - DIFFE_TYPE DReturnType, Twine name, IRMapping &VMap, - std::map &OpMap, bool diffeReturnArg, + SmallPtrSetImpl &returnvals, + const std::vector &returnPrimals, + const std::vector &returnShadows, + ArrayRef RetActivity, Twine name, IRMapping &VMap, + std::map &OpMap, mlir::Type additionalArg) { assert(!F.getFunctionBody().empty()); // F = preprocessForClone(F, mode); // llvm::ValueToValueMapTy VMap; auto FTy = getFunctionTypeForClone( F.getFunctionType().cast(), mode, width, - additionalArg, constant_args, diffeReturnArg, returnValue, DReturnType); + additionalArg, returnPrimals, returnShadows, RetActivity, ArgActivity); /* for (Block &BB : F.getFunctionBody().getBlocks()) { @@ -245,19 +224,19 @@ FunctionOpInterface CloneFunctionWithReturns( { auto &blk = NewF.getFunctionBody().front(); assert(F.getFunctionBody().front().getNumArguments() == - constant_args.size()); - for (ssize_t i = constant_args.size() - 1; i >= 0; i--) { + ArgActivity.size()); + for (ssize_t i = ArgActivity.size() - 1; i >= 0; i--) { mlir::Value oval = F.getFunctionBody().front().getArgument(i); - if (constant_args[i] == DIFFE_TYPE::CONSTANT) + if (ArgActivity[i] == DIFFE_TYPE::CONSTANT) constants.insert(oval); - else if (constant_args[i] == DIFFE_TYPE::OUT_DIFF) + else if (ArgActivity[i] == DIFFE_TYPE::OUT_DIFF) nonconstants.insert(oval); - else if (constant_args[i] == DIFFE_TYPE::DUP_ARG || - constant_args[i] == DIFFE_TYPE::DUP_NONEED) { + else if (ArgActivity[i] == DIFFE_TYPE::DUP_ARG || + ArgActivity[i] == DIFFE_TYPE::DUP_NONEED) { nonconstants.insert(oval); mlir::Value val = blk.getArgument(i); mlir::Value dval; - if (i == constant_args.size() - 1) + if (i == ArgActivity.size() - 1) dval = blk.addArgument(val.getType(), val.getLoc()); else dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(), @@ -265,11 +244,11 @@ FunctionOpInterface CloneFunctionWithReturns( ptrInputs.map(oval, dval); } } - // TODO: Add support for mulitple outputs? - if (diffeReturnArg) { - auto location = blk.getArgument(blk.getNumArguments() - 1).getLoc(); - auto val = F.getFunctionType().cast().getResult(0); - blk.addArgument(val, location); + for (auto &&[Ty, activity] : llvm::zip(F.getFunctionType().cast().getResults(), RetActivity)) { + if (activity == DIFFE_TYPE::OUT_DIFF) { + auto location = blk.getArgument(blk.getNumArguments() - 1).getLoc(); + blk.addArgument(getShadowType(Ty, width), location); + } } } @@ -285,15 +264,7 @@ FunctionOpInterface CloneFunctionWithReturns( size_t oldi = 0; size_t newi = 0; while (oldi < F.getNumResults()) { - bool primalReturn = returnValue == ReturnType::ArgsWithReturn || - returnValue == ReturnType::ArgsWithTwoReturns || - (returnValue == ReturnType::TapeAndReturn && - DReturnType == DIFFE_TYPE::CONSTANT) || - returnValue == ReturnType::TapeAndTwoReturns || - returnValue == ReturnType::TwoReturns || - (returnValue == ReturnType::Return && - DReturnType == DIFFE_TYPE::CONSTANT); - if (primalReturn) { + if (returnPrimals[oldi]) { for (auto attrName : ToClone) { auto attrNameS = StringAttr::get(F->getContext(), attrName); NewF.removeResultAttr(newi, attrNameS); @@ -310,8 +281,7 @@ FunctionOpInterface CloneFunctionWithReturns( } newi++; } - if (DReturnType == DIFFE_TYPE::DUP_ARG || - DReturnType == DIFFE_TYPE::DUP_NONEED) { + if (returnShadows[oldi]) { for (auto attrName : ToClone) { auto attrNameS = StringAttr::get(F->getContext(), attrName); NewF.removeResultAttr(newi, attrNameS); @@ -350,8 +320,8 @@ FunctionOpInterface CloneFunctionWithReturns( } newi++; - if (constant_args[oldi] == DIFFE_TYPE::DUP_ARG || - constant_args[oldi] == DIFFE_TYPE::DUP_NONEED) { + if (ArgActivity[oldi] == DIFFE_TYPE::DUP_ARG || + ArgActivity[oldi] == DIFFE_TYPE::DUP_NONEED) { for (auto attrName : ToClone) { NewF.removeArgAttr(newi, attrName); diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h index 2766f53357b8..0b58a1542eee 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h @@ -26,8 +26,11 @@ Type getShadowType(Type type, unsigned width = 1); mlir::FunctionType getFunctionTypeForClone( mlir::FunctionType FTy, DerivativeMode mode, unsigned width, - mlir::Type additionalArg, llvm::ArrayRef constant_args, - bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE ReturnType); + mlir::Type additionalArg, + llvm::ArrayRef returnPrimals, + llvm::ArrayRef returnShadows, + llvm::ArrayRef ReturnActivity, + llvm::ArrayRef ArgActivity); void cloneInto(Region *src, Region *dest, Region::iterator destPos, IRMapping &mapper, std::map &opMap); @@ -41,10 +44,12 @@ Operation *clone(Operation *src, IRMapping &mapper, FunctionOpInterface CloneFunctionWithReturns( DerivativeMode mode, unsigned width, FunctionOpInterface F, - IRMapping &ptrInputs, ArrayRef constant_args, + IRMapping &ptrInputs, ArrayRef ArgActivity, SmallPtrSetImpl &constants, SmallPtrSetImpl &nonconstants, - SmallPtrSetImpl &returnvals, ReturnType returnValue, - DIFFE_TYPE ReturnType, Twine name, IRMapping &VMap, - std::map &OpMap, bool diffeReturnArg, - mlir::Type additionalArg); \ No newline at end of file + SmallPtrSetImpl &returnvals, + const std::vector &returnPrimals, + const std::vector &returnShadows, + ArrayRef ReturnActivity, Twine name, IRMapping &VMap, + std::map &OpMap, + mlir::Type additionalArg); diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index ead8baad9261..34068ef61d61 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -21,8 +21,8 @@ using namespace mlir; using namespace mlir::enzyme; -void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, - DIFFE_TYPE retType, ReturnType retVal) { +void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, const std::vector &returnPrimals, + const std::vector &returnShadows) { auto inst = oBB->getTerminator(); mlir::Block *nBB = gutils->getNewFromOriginal(inst->getBlock()); @@ -42,67 +42,22 @@ void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, return; SmallVector retargs; - - switch (retVal) { - case ReturnType::Return: { - for (size_t i = 0; i < inst->getNumOperands(); i++) { - auto ret = inst->getOperand(i); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && - true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); - } else { - Type retTy = - ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue( - nBuilder, ret.getLoc()); - } - retargs.push_back(toret); - } - - break; - } - case ReturnType::TwoReturns: { - if (retType == DIFFE_TYPE::CONSTANT) - assert(false && "Invalid return type"); - for (size_t i = 0; i < inst->getNumOperands(); i++) { - auto ret = inst->getOperand(i); - + + for (auto &&[ret, returnPrimal, returnShadow] : llvm::zip(inst->getOperands(), returnPrimals, returnShadows)) { + if (returnPrimal) { retargs.push_back(gutils->getNewFromOriginal(ret)); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && - true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); + } + if (returnShadow) { + if (!gutils->isConstantValue(ret)) { + retargs.push_back(gutils->invertPointerM(ret, nBuilder)); } else { Type retTy = ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue( + auto toret = retTy.cast().createNullValue( nBuilder, ret.getLoc()); + retargs.push_back(gutils->invertPointerM(toret, nBuilder)); } - retargs.push_back(toret); } - break; - } - case ReturnType::Void: { - break; - } - default: { - llvm::errs() << "Invalid return type: " - << "for function: \n" - << gutils->newFunc << "\n"; - assert(false && "Invalid return type for function"); - return; - } } nBB->push_back( @@ -117,34 +72,36 @@ void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, //===----------------------------------------------------------------------===// FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( - FunctionOpInterface fn, DIFFE_TYPE retType, - std::vector constants, MTypeAnalysis &TA, bool returnUsed, + FunctionOpInterface fn, std::vector RetActivity, + std::vector ArgActivity, MTypeAnalysis &TA, std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); } - assert(fn.getFunctionBody().front().getNumArguments() == constants.size()); + assert(fn.getFunctionBody().front().getNumArguments() == ArgActivity.size()); assert(fn.getFunctionBody().front().getNumArguments() == volatile_args.size()); MForwardCacheKey tup = { - fn, retType, constants, + fn, RetActivity, ArgActivity, // std::map(_uncacheable_args.begin(), // _uncacheable_args.end()), - returnUsed, mode, static_cast(width), addedType, type_args}; + returnPrimals, mode, static_cast(width), addedType, type_args}; if (ForwardCachedFunctions.find(tup) != ForwardCachedFunctions.end()) { return ForwardCachedFunctions.find(tup)->second; } - bool retActive = retType != DIFFE_TYPE::CONSTANT; - ReturnType returnValue = - returnUsed ? (retActive ? ReturnType::TwoReturns : ReturnType::Return) - : (retActive ? ReturnType::Return : ReturnType::Void); + std::vector returnShadows; + for (auto act : RetActivity) { + returnShadows.push_back(act != DIFFE_TYPE::CONSTANT); + } auto gutils = MDiffeGradientUtils::CreateFromClone( - *this, mode, width, fn, TA, type_args, retType, - /*diffeReturnArg*/ false, constants, returnValue, addedType, + *this, mode, width, fn, TA, type_args, + returnPrimals, returnShadows, + RetActivity, + ArgActivity, addedType, /*omp*/ false); ForwardCachedFunctions[tup] = gutils->newFunc; @@ -208,7 +165,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( valid &= res.succeeded(); } - createTerminator(gutils, &oBB, retType, returnValue); + createTerminator(gutils, &oBB, returnPrimals, returnShadows); } // if (mode == DerivativeMode::ForwardModeSplit && augmenteddata) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index 0005377cc5d7..7ef985ddcc60 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -47,10 +47,10 @@ class MEnzymeLogic { public: struct MForwardCacheKey { FunctionOpInterface todiff; - DIFFE_TYPE retType; + const std::vector retType; const std::vector constant_args; // std::map uncacheable_args; - bool returnUsed; + std::vector returnUsed; DerivativeMode mode; unsigned width; mlir::Type additionalType; @@ -108,16 +108,18 @@ class MEnzymeLogic { std::map ForwardCachedFunctions; FunctionOpInterface - CreateForwardDiff(FunctionOpInterface fn, DIFFE_TYPE retType, + CreateForwardDiff(FunctionOpInterface fn, std::vector retType, std::vector constants, MTypeAnalysis &TA, - bool returnUsed, DerivativeMode mode, bool freeMemory, + std::vector returnPrimals, + DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, void *augmented); FunctionOpInterface - CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, + CreateReverseDiff(FunctionOpInterface fn, std::vector retType, std::vector constants, MTypeAnalysis &TA, - bool returnUsed, DerivativeMode mode, bool freeMemory, + std::vector returnPrimals, std::vector returnShadows, + DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, void *augmented); void diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 96d5a65a884b..bc96fbc18fdd 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -178,8 +178,9 @@ LogicalResult MEnzymeLogic::differentiate( } FunctionOpInterface MEnzymeLogic::CreateReverseDiff( - FunctionOpInterface fn, DIFFE_TYPE retType, - std::vector constants, MTypeAnalysis &TA, bool returnUsed, + FunctionOpInterface fn, std::vector retType, + std::vector constants, MTypeAnalysis &TA, std::vector returnPrimals, + std::vector returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { @@ -188,10 +189,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( llvm_unreachable("Differentiating empty function"); } - ReturnType returnValue = ReturnType::Args; MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( - *this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, - constants, returnValue, addedType); + *this, mode, width, fn, TA, type_args, returnPrimals, returnShadows, retType, constants, addedType); Region &oldRegion = gutils->oldFunc.getFunctionBody(); Region &newRegion = gutils->newFunc.getFunctionBody(); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index ebae44c9efa1..7ab1567b4413 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -32,7 +32,7 @@ mlir::enzyme::MGradientUtils::MGradientUtils( FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_, IRMapping &invertedPointers_, const SmallPtrSetImpl &constantvalues_, - const SmallPtrSetImpl &activevals_, DIFFE_TYPE ReturnActivity, + const SmallPtrSetImpl &activevals_, ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp) @@ -42,54 +42,7 @@ mlir::enzyme::MGradientUtils::MGradientUtils( activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), - RetDiffeTypes(1, ReturnActivity) { - - /* - for (BasicBlock &BB : *oldFunc) { - for (Instruction &I : BB) { - if (auto CI = dyn_cast(&I)) { - originalCalls.push_back(CI); - } - } - } - */ - - /* - for (BasicBlock &oBB : *oldFunc) { - for (Instruction &oI : oBB) { - newToOriginalFn[originalToNewFn[&oI]] = &oI; - } - newToOriginalFn[originalToNewFn[&oBB]] = &oBB; - } - for (Argument &oArg : oldFunc->args()) { - newToOriginalFn[originalToNewFn[&oArg]] = &oArg; - } - */ - /* - for (BasicBlock &BB : *newFunc) { - originalBlocks.emplace_back(&BB); - } - tape = nullptr; - tapeidx = 0; - assert(originalBlocks.size() > 0); - - SmallVector ReturningBlocks; - for (BasicBlock &BB : *oldFunc) { - if (isa(BB.getTerminator())) - ReturningBlocks.push_back(&BB); - } - for (BasicBlock &BB : *oldFunc) { - bool legal = true; - for (auto BRet : ReturningBlocks) { - if (!(BRet == &BB || OrigDT.dominates(&BB, BRet))) { - legal = false; - break; - } - } - if (legal) - BlocksDominatingAllReturns.insert(&BB); - } - */ + RetDiffeTypes(ReturnActivity) { } mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal( diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index f6f1861039eb..3c4269d77525 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -39,7 +39,7 @@ class MGradientUtils { unsigned width; ArrayRef ArgDiffeTypes; - SmallVector RetDiffeTypes; + ArrayRef RetDiffeTypes; mlir::Value getNewFromOriginal(const mlir::Value originst) const; mlir::Block *getNewFromOriginal(mlir::Block *originst) const; @@ -50,7 +50,7 @@ class MGradientUtils { MTypeResults TR_, IRMapping &invertedPointers_, const SmallPtrSetImpl &constantvalues_, const SmallPtrSetImpl &activevals_, - DIFFE_TYPE ReturnActivity, ArrayRef ArgDiffeTypes_, + ArrayRef ReturnActivities, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp); @@ -103,14 +103,14 @@ class MDiffeGradientUtils : public MGradientUtils { MTypeResults TR, IRMapping &invertedPointers_, const SmallPtrSetImpl &constantvalues_, const SmallPtrSetImpl &activevals_, - DIFFE_TYPE ActiveReturn, - ArrayRef constant_values, + ArrayRef RetActivity, + ArrayRef ArgActivity, IRMapping &origToNew_, std::map &origToNewOps_, DerivativeMode mode, unsigned width, bool omp) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, - constantvalues_, activevals_, ActiveReturn, - constant_values, origToNew_, origToNewOps_, mode, width, + constantvalues_, activevals_, RetActivity, + ArgActivity, origToNew_, origToNewOps_, mode, width, omp), initializationBlock(&*(newFunc.getFunctionBody().begin())) {} @@ -118,9 +118,12 @@ class MDiffeGradientUtils : public MGradientUtils { static MDiffeGradientUtils * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, - MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, - bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg, bool omp) { + MFnTypeInfo &oldTypeInfo, + const std::vector &returnPrimals, + const std::vector &returnShadows, + ArrayRef RetActivity, + ArrayRef ArgActivity, + mlir::Type additionalArg, bool omp) { std::string prefix; switch (mode) { @@ -147,14 +150,14 @@ class MDiffeGradientUtils : public MGradientUtils { SmallPtrSet nonconstant_values; IRMapping invertedPointers; FunctionOpInterface newFunc = CloneFunctionWithReturns( - mode, width, todiff, invertedPointers, constant_args, constant_values, - nonconstant_values, returnvals, returnValue, retType, + mode, width, todiff, invertedPointers, ArgActivity, constant_values, + nonconstant_values, returnvals, returnPrimals, returnShadows, RetActivity, prefix + todiff.getName(), originalToNew, originalToNewOps, - diffeReturnArg, additionalArg); + additionalArg); MTypeResults TR; // TODO return new MDiffeGradientUtils( Logic, newFunc, todiff, TA, TR, invertedPointers, constant_values, - nonconstant_values, retType, constant_args, originalToNew, + nonconstant_values, RetActivity, ArgActivity, originalToNew, originalToNewOps, mode, width, omp); } }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index b57fbe68b594..4aa8e82789d7 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -32,7 +32,7 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, IRMapping invertedPointers_, const SmallPtrSetImpl &constantvalues_, - const SmallPtrSetImpl &activevals_, DIFFE_TYPE ReturnActivity, + const SmallPtrSetImpl &activevals_, ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode_, unsigned width) @@ -133,8 +133,10 @@ void MGradientUtilsReverse::createReverseModeBlocks(Region &oldFunc, MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, - DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg) { + const std::vector &returnPrimals, + const std::vector &returnShadows, + ArrayRef retType, ArrayRef constant_args, + mlir::Type additionalArg) { std::string prefix; switch (mode_) { @@ -162,9 +164,9 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( IRMapping invertedPointers; FunctionOpInterface newFunc = CloneFunctionWithReturns( mode_, width, todiff, invertedPointers, constant_args, constant_values, - nonconstant_values, returnvals, returnValue, retType, + nonconstant_values, returnvals, returnPrimals, returnShadows, retType, prefix + todiff.getName(), originalToNew, originalToNewOps, - diffeReturnArg, additionalArg); + additionalArg); return new MGradientUtilsReverse(Logic, newFunc, todiff, TA, invertedPointers, constant_values, nonconstant_values, retType, diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index 96e899939538..8a17c881af98 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -27,7 +27,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { IRMapping invertedPointers_, const SmallPtrSetImpl &constantvalues_, const SmallPtrSetImpl &activevals_, - DIFFE_TYPE ReturnActivity, + ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, @@ -62,9 +62,12 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, - MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, - bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg); + MFnTypeInfo &oldTypeInfo, + const std::vector &returnPrimals, + const std::vector &returnShadows, + llvm::ArrayRef retType, + llvm::ArrayRef constant_args, + mlir::Type additionalArg); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index b7d33b6faedc..41773e909312 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -31,15 +31,20 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; - static DIFFE_TYPE mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { - DIFFE_TYPE retType = DIFFE_TYPE::CONSTANT; - if (fn.getNumResults() != 0) { + static std::vector mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { + std::vector retTypes; + for (auto ty : fn.getResultTypes()) { + if (isa(ty)) { + retTypes.push_back(DIFFE_TYPE::CONSTANT); + continue; + } + if (mode == DerivativeMode::ReverseModeCombined) - retType = DIFFE_TYPE::OUT_DIFF; + retTypes.push_back(DIFFE_TYPE::OUT_DIFF); else - retType = DIFFE_TYPE::DUP_ARG; + retTypes.push_back(DIFFE_TYPE::DUP_ARG); } - return retType; + return retTypes; } template @@ -72,7 +77,7 @@ struct DifferentiatePass : public DifferentiatePassBase { auto fn = cast(symbolOp); auto mode = DerivativeMode::ForwardMode; - DIFFE_TYPE retType = mode_from_fn(fn, mode); + std::vector retType = mode_from_fn(fn, mode); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); @@ -85,9 +90,15 @@ struct DifferentiatePass : public DifferentiatePassBase { volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); } + std::vector returnPrimals; + for (auto act : retType) { + (void)act; + returnPrimals.push_back(false); + } + FunctionOpInterface newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, - /*should return*/ false, mode, freeMemory, width, + returnPrimals, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); if (!newFunc) @@ -107,36 +118,39 @@ struct DifferentiatePass : public DifferentiatePassBase { std::vector constants; SmallVector args; - size_t truei = 0; - auto activityAttr = CI.getActivity(); - - for (unsigned i = 0; i < CI.getInputs().size() - 1; ++i) { - mlir::Value res = CI.getInputs()[i]; - - auto mop = activityAttr[truei]; - auto iattr = cast(mop); + size_t call_idx=0; + { + for (auto act : CI.getActivity()) { + mlir::Value res = CI.getInputs()[call_idx]; + ++call_idx; + + auto iattr = cast(act); DIFFE_TYPE ty = (DIFFE_TYPE)(iattr.getValue()); constants.push_back(ty); args.push_back(res); if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) { - ++i; - res = CI.getInputs()[i]; + res = CI.getInputs()[call_idx]; + ++call_idx; args.push_back(res); } - - truei++; + } } auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); auto mode = DerivativeMode::ReverseModeCombined; - DIFFE_TYPE retType = mode_from_fn(fn, mode); + std::vector retType = mode_from_fn(fn, mode); // Add the return gradient - mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; + for (auto act : retType) { + if (act == DIFFE_TYPE::OUT_DIFF) { + mlir::Value res = CI.getInputs()[call_idx]; + call_idx++; args.push_back(res); + } + } MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); @@ -144,14 +158,18 @@ struct DifferentiatePass : public DifferentiatePassBase { size_t width = 1; std::vector volatile_args; + std::vector returnPrimals; + std::vector returnShadows; for (auto &a : fn.getFunctionBody().getArguments()) { (void)a; volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); + returnPrimals.push_back(false); + returnShadows.push_back(false); } FunctionOpInterface newFunc = Logic.CreateReverseDiff( fn, retType, constants, TA, - /*should return*/ false, mode, freeMemory, width, + returnPrimals, returnShadows, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); if (!newFunc) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index b48705c220d1..782af01ebdc6 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -28,6 +28,29 @@ using namespace mlir; using namespace mlir::enzyme; using namespace enzyme; +std::vector parseActivityString(StringRef inp) { + std::vector ArgActivity; + SmallVector split; + StringRef(inp.data(), inp.size()) + .split(split, ','); + for (auto &str : split) { + if (str == "enzyme_dup") + ArgActivity.push_back(DIFFE_TYPE::DUP_ARG); + else if (str == "enzyme_const") + ArgActivity.push_back(DIFFE_TYPE::CONSTANT); + else if (str == "enzyme_dupnoneed") + ArgActivity.push_back(DIFFE_TYPE::DUP_NONEED); + else if (str == "enzyme_out") + ArgActivity.push_back(DIFFE_TYPE::OUT_DIFF); + else { + llvm::errs() << "unknown activity to parse, found: '" << str + << "'\n"; + assert(0 && " unknown constant"); + } + } + return ArgActivity; +} + namespace { struct DifferentiateWrapperPass : public DifferentiateWrapperPassBase { @@ -51,34 +74,30 @@ struct DifferentiateWrapperPass } } auto fn = cast(symbolOp); - SmallVector split; - StringRef(argTys.getValue().data(), argTys.getValue().size()) - .split(split, ','); - std::vector constants; - for (auto &str : split) { - if (str == "enzyme_dup") - constants.push_back(DIFFE_TYPE::DUP_ARG); - else if (str == "enzyme_const") - constants.push_back(DIFFE_TYPE::CONSTANT); - else if (str == "enzyme_dupnoneed") - constants.push_back(DIFFE_TYPE::DUP_NONEED); - else if (str == "enzyme_out") - constants.push_back(DIFFE_TYPE::OUT_DIFF); - else { - llvm::errs() << "unknown argument activity to parse, found: '" << str - << "'\n"; - assert(0 && " unknown constant"); - } - } + + std::vector ArgActivity = parseActivityString(argTys.getValue()); - if (constants.size() != fn.getFunctionBody().front().getNumArguments()) { + if (ArgActivity.size() != fn.getFunctionBody().front().getNumArguments()) { fn->emitError() << "Incorrect number of arg activity states for function, found " - << split; + << ArgActivity.size() << " expected " << fn.getFunctionBody().front().getNumArguments(); return; } - DIFFE_TYPE retType = retTy.getValue(); + std::vector RetActivity = parseActivityString(retTys.getValue()); + if (RetActivity.size() != fn.getFunctionType().cast().getNumResults()) { + fn->emitError() + << "Incorrect number of ret activity states for function, found " + << RetActivity.size() <<" expected " << fn.getFunctionType().cast().getNumResults(); + return; + } + std::vector returnPrimal; + std::vector returnShadow; + for (auto act : RetActivity) { + returnPrimal.push_back(act == DIFFE_TYPE::DUP_ARG); + returnShadow.push_back(false); + } + MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); @@ -94,15 +113,15 @@ struct DifferentiateWrapperPass FunctionOpInterface newFunc; if (mode == DerivativeMode::ForwardMode) { newFunc = Logic.CreateForwardDiff( - fn, retType, constants, TA, - /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, + fn, RetActivity, ArgActivity, TA, + returnPrimal, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); } else { newFunc = Logic.CreateReverseDiff( - fn, retType, constants, TA, - /*should return*/ false, mode, freeMemory, width, + fn, RetActivity, ArgActivity, TA, + returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); } diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index c1617eaf4af0..e9d31a8c7b3b 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -55,17 +55,11 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> { )}] >, Option< - /*C++ variable name=*/"retTy", - /*CLI argument=*/"retTy", - /*type=*/"DIFFE_TYPE", - /*default=*/"DIFFE_TYPE::DUP_ARG", - /*description=*/"activity of the return", -[{::llvm::cl::values( - clEnumValN(DIFFE_TYPE::DUP_ARG, "enzyme_dup", "Duplicated (default)"), - clEnumValN(DIFFE_TYPE::OUT_DIFF, "enzyme_out", "Active"), - clEnumValN(DIFFE_TYPE::CONSTANT, "enzyme_const", "Constant"), - clEnumValN(DIFFE_TYPE::DUP_NONEED, "enzyme_dupnoneed", "Duplicated noneed") - )}] + /*C++ variable name=*/"retTys", + /*CLI argument=*/"retTys", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"The activity of the returns" >, Option< /*C++ variable name=*/"argTys", diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index f86535c1d82d..6d08cd761ac6 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -148,10 +148,13 @@ struct PrintActivityAnalysisPass else activevals_.insert(arg); } - auto ReturnActivity = DIFFE_TYPE::CONSTANT; - for (auto act : resultActivities) + SmallVector ReturnActivity; + for (auto act : resultActivities) { if (act != enzyme::Activity::enzyme_const) - ReturnActivity = DIFFE_TYPE::DUP_ARG; + ReturnActivity.push_back(DIFFE_TYPE::DUP_ARG); + else + ReturnActivity.push_back(DIFFE_TYPE::CONSTANT); + } enzyme::ActivityAnalyzer activityAnalyzer( blocksNotForAnalysis, constant_values, activevals_, ReturnActivity); diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir index 5ff0e1540d13..13c58e7c3434 100644 --- a/enzyme/test/MLIR/ForwardMode/wrap.mlir +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup mode=ForwardMode" %s | FileCheck %s +// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTys=enzyme_dup argTys=enzyme_dup mode=ForwardMode" %s | FileCheck %s module { func.func @square(%x : f64) -> f64{ diff --git a/enzyme/test/MLIR/ReverseMode/multiret.mlir b/enzyme/test/MLIR/ReverseMode/multiret.mlir new file mode 100644 index 000000000000..82e678390076 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/multiret.mlir @@ -0,0 +1,27 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s + +module { + func.func @square(%x: f64, %y : i32, %z : f32) -> (f64, i32, f32) { + %x2 = arith.mulf %x, %x : f64 + %y2 = arith.muli %y, %y : i32 + %z2 = arith.mulf %z, %z : f32 + return %x2, %y2, %z2 : f64, i32, f32 + } + + func.func @dsquare(%x: f64, %y : i32, %z : f32, %dx: f64, %dz : f32) -> (f64, f32) { + %r:2 = enzyme.autodiff @square(%x, %y, %z, %dx, %dz) { activity=[#enzyme, #enzyme, #enzyme] } : (f64, i32, f32, f64, f32) -> (f64, f32) + return %r#0, %r#1 : f64, f32 + } +} + +// CHECK: func.func @dsquare(%arg0: f64, %arg1: i32, %arg2: f32, %arg3: f64, %arg4: f32) -> (f64, f32) { +// CHECK-NEXT: %0:2 = call @diffesquare(%arg0, %arg1, %arg2, %arg3, %arg4) : (f64, i32, f32, f64, f32) -> (f64, f32) +// CHECK-NEXT: return %0#0, %0#1 : f64, f32 +// CHECK-NEXT: } +// CHECK: func.func private @diffesquare(%arg0: f64, %arg1: i32, %arg2: f32, %arg3: f64, %arg4: f32) -> (f64, f32) { +// CHECK-NEXT: %0 = arith.mulf %arg4, %arg2 : f32 +// CHECK-NEXT: %1 = arith.addf %0, %0 : f32 +// CHECK-NEXT: %2 = arith.mulf %arg3, %arg0 : f64 +// CHECK-NEXT: %3 = arith.addf %2, %2 : f64 +// CHECK-NEXT: return %3, %1 : f64, f32 +// CHECK-NEXT: }