From 4ad395373e8800d2f98f69e47503454272fb89f4 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Fri, 5 Jan 2024 17:54:45 +0100 Subject: [PATCH 001/131] [mlir] correctly handle "unknown" state in activity analysis (#1571) This requires injecting ModRef information about library functions. --- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp | 54 ++++++++++++++-- .../Analysis/DataFlowActivityAnalysis.cpp | 62 ++++++++++++++----- 2 files changed, 93 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index ac302c6a9ab5..691be34eb25c 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -456,6 +456,51 @@ static bool mayWriteArg(FunctionOpInterface callee, unsigned argNo, return !hasReadOnlyAttr && !hasReadNoneAttr && funcMayWrite; } +/// Returns information indicating whether the function may read or write into +/// the memory pointed to by its arguments. When unknown, returns `nullopt`. +static std::optional +getFunctionArgModRef(FunctionOpInterface func) { + // First, handle some library functions with statically known behavior. + StringRef name = cast(func.getOperation()).getName(); + auto hardcoded = llvm::StringSwitch>(name) + // printf: only reads from arguments. + .Case("printf", LLVM::ModRefInfo::Ref) + // operator delete(void *) doesn't read from arguments. + .Case("_ZdlPv", LLVM::ModRefInfo::NoModRef) + .Default(std::nullopt); + if (hardcoded) + return hardcoded; + + if (auto memoryAttr = + func->getAttrOfType(kLLVMMemoryAttrName)) + return memoryAttr.getArgMem(); + return std::nullopt; +} + +/// Returns information indicating whether the function may read or write into +/// the memory other than that pointed to by its arguments, though still +/// accessible from (any) calling context. When unknown, returns `nullopt`. +static std::optional +getFunctionOtherModRef(FunctionOpInterface func) { + // First, handle some library functions with statically known behavior. + StringRef name = cast(func.getOperation()).getName(); + auto hardcoded = + llvm::StringSwitch>(name) + // printf: doesn't access other (technically, stdout is pointer-like, + // but we cannot flow information through it since it is write-only. + .Case("printf", LLVM::ModRefInfo::NoModRef) + // operator delete(void *) doesn't access other. + .Case("_ZdlPv", LLVM::ModRefInfo::NoModRef) + .Default(std::nullopt); + if (hardcoded) + return hardcoded; + + if (auto memoryAttr = + func->getAttrOfType(kLLVMMemoryAttrName)) + return memoryAttr.getOther(); + return std::nullopt; +} + void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( CallOpInterface call, CallControlFlowAction action, const PointsToSets &before, PointsToSets *after) { @@ -497,13 +542,10 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // into pointers that are non-arguments. if (auto callee = SymbolTable::lookupNearestSymbolFrom( call, symbol.getLeafReference())) { - auto memoryAttr = - callee->getAttrOfType(kLLVMMemoryAttrName); - std::optional argModRef = - memoryAttr ? std::make_optional(memoryAttr.getArgMem()) - : std::nullopt; + std::optional argModRef = getFunctionArgModRef(callee); std::optional otherModRef = - memoryAttr ? std::make_optional(memoryAttr.getOther()) : std::nullopt; + getFunctionOtherModRef(callee); + SmallVector pointerLikeOperands; for (auto &&[i, operand] : llvm::enumerate(call.getArgOperands())) { if (isPointerLike(operand.getType())) diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 01fb4a26046a..db04f280123e 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -794,6 +794,20 @@ void traverseCallGraph(FunctionOpInterface root, } } +static const enzyme::AliasClassSet & +getDefaultPointsTo(const enzyme::PointsToSets &pointsToSets) { + // Get the default points-to alias class set, which is where the + // "unknown" and any other unlisted class set points to. + const enzyme::AliasClassSet &defaultPointsTo = + pointsToSets.getPointsTo(nullptr); + // Unknown class can point to unknown or nothing, unless further + // refined. + assert((defaultPointsTo.isUnknown() || + defaultPointsTo.getAliasClasses().empty()) && + "new case introduced for AliasClassSet?"); + return defaultPointsTo; +} + void printActivityAnalysisResults(const DataFlowSolver &solver, FunctionOpInterface callee, const SmallPtrSet &returnOps, @@ -815,19 +829,33 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, auto *bma = solver.lookupState( &callee.getFunctionBody().front().front()); - auto *pointsToSets = + const enzyme::PointsToSets *pointsToSets = solver.lookupState(*returnOps.begin()); auto *aliasClassLattice = solver.lookupState(value); // Traverse the points-to sets in a simple BFS std::deque frontier; DenseSet visited; - // TODO(zinenko): FIXME, handle unknown... - if (!aliasClassLattice->isUnknown()) { + auto scheduleVisit = [&](auto range) { + for (DistinctAttr neighbor : range) { + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + frontier.push_back(neighbor); + } + } + }; + + if (aliasClassLattice->isUnknown()) { + // If this pointer is in unknown alias class, it may point to active + // data if the unknown alias class is known to point to something and + // may not point to active data if the unknown alias class is known not + // to point to anything. + auto &defaultPointsTo = getDefaultPointsTo(*pointsToSets); + return !defaultPointsTo.isUnknown() && + defaultPointsTo.getAliasClasses().empty(); + } else { const DenseSet &aliasClasses = aliasClassLattice->getAliasClasses(); - frontier.insert(frontier.end(), aliasClasses.begin(), - aliasClasses.end()); - visited.insert(aliasClasses.begin(), aliasClasses.end()); + scheduleVisit(aliasClasses); } while (!frontier.empty()) { DistinctAttr aliasClass = frontier.front(); @@ -841,19 +869,19 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, // Or if it points to a pointer that points to active data. if (pointsToSets->getPointsTo(aliasClass).isUnknown()) { - // TODO(zinenko): FIXME handle unknown. Conservative assumption here - // is to assume the value is active (or unknown if we can return - // that). Is there a less conservative option? + // If a pointer points to an unknown alias set, query the default + // points-to alias set (which also applies to the unknown alias set). + auto &defaultPointsTo = getDefaultPointsTo(*pointsToSets); + // If it is in turn unknown, conservatively assume the pointer may be + // pointing to some active data. + if (defaultPointsTo.isUnknown()) + return false; + // Otherwise look at classes pointed to by unknown (which can only be + // an empty set as of time of writing). + scheduleVisit(defaultPointsTo.getAliasClasses()); continue; } - const DenseSet &neighbors = - pointsToSets->getPointsTo(aliasClass).getAliasClasses(); - for (DistinctAttr neighbor : neighbors) { - if (!visited.contains(neighbor)) { - visited.insert(neighbor); - frontier.push_back(neighbor); - } - } + scheduleVisit(pointsToSets->getPointsTo(aliasClass).getAliasClasses()); } // Otherwise, it's constant return true; From b7971718f89e919b72d0c282581361efada9e7e8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 11 Jan 2024 21:04:23 -0500 Subject: [PATCH 002/131] [WIP] Auto truncation mode (#1581) * [WIP] Auto truncation * Add FCmp and cast * Fix cmp * Add cmp test * Add select * clang-format * Weird newlines * Make output reproducable * Handle math intrinsics * Add checks * Fix older LLVM versions * Limit test to llvm > 12 * Util function for intrinsic creation --------- Co-authored-by: Ivan Radanov Ivanov --- enzyme/Enzyme/Enzyme.cpp | 38 ++ enzyme/Enzyme/EnzymeLogic.cpp | 513 +++++++++++++++++++++ enzyme/Enzyme/EnzymeLogic.h | 6 + enzyme/Enzyme/Utils.cpp | 32 ++ enzyme/Enzyme/Utils.h | 7 + enzyme/test/Enzyme/CMakeLists.txt | 1 + enzyme/test/Enzyme/Truncate/CMakeLists.txt | 12 + enzyme/test/Enzyme/Truncate/cmp.ll | 34 ++ enzyme/test/Enzyme/Truncate/intrinsic.ll | 62 +++ enzyme/test/Enzyme/Truncate/select.ll | 39 ++ enzyme/test/Enzyme/Truncate/simple.ll | 42 ++ 11 files changed, 786 insertions(+) create mode 100644 enzyme/test/Enzyme/Truncate/CMakeLists.txt create mode 100644 enzyme/test/Enzyme/Truncate/cmp.ll create mode 100644 enzyme/test/Enzyme/Truncate/intrinsic.ll create mode 100644 enzyme/test/Enzyme/Truncate/select.ll create mode 100644 enzyme/test/Enzyme/Truncate/simple.ll diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index d09682c0aff7..a514a826351b 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1314,6 +1314,33 @@ class EnzymeBase { return type_args; } + bool HandleTruncate(CallInst *CI) { + IRBuilder<> Builder(CI); + Function *F = parseFunctionParameter(CI); + if (!F) + return false; + if (CI->arg_size() != 3) { + EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, + "Had incorrect number of args to __enzyme_truncate", *CI, + " - expected 3"); + return false; + } + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto = cast(CI->getArgOperand(2)); + assert(Cto); + RequestContext context(CI, &Builder); + llvm::Value *res = Logic.CreateTruncate( + context, F, (unsigned)Cfrom->getValue().getZExtValue(), + (unsigned)Cto->getValue().getZExtValue()); + if (!res) + return false; + res = Builder.CreatePointerCast(res, CI->getType()); + CI->replaceAllUsesWith(res); + CI->eraseFromParent(); + return true; + } + bool HandleBatch(CallInst *CI) { unsigned width = 1; unsigned truei = 0; @@ -2028,6 +2055,7 @@ class EnzymeBase { Fn->getName().contains("__enzyme_augmentfwd") || Fn->getName().contains("__enzyme_augmentsize") || Fn->getName().contains("__enzyme_reverse") || + Fn->getName().contains("__enzyme_truncate") || Fn->getName().contains("__enzyme_batch") || Fn->getName().contains("__enzyme_trace") || Fn->getName().contains("__enzyme_condition"))) @@ -2060,6 +2088,7 @@ class EnzymeBase { MapVector toVirtual; MapVector toSize; SmallVector toBatch; + SmallVector toTruncate; MapVector toProbProg; SetVector InactiveCalls; SetVector IterCalls; @@ -2369,6 +2398,7 @@ class EnzymeBase { bool virtualCall = false; bool sizeOnly = false; bool batch = false; + bool truncate = false; bool probProg = false; DerivativeMode derivativeMode; ProbProgMode probProgMode; @@ -2398,6 +2428,9 @@ class EnzymeBase { } else if (Fn->getName().contains("__enzyme_batch")) { enableEnzyme = true; batch = true; + } else if (Fn->getName().contains("__enzyme_truncate")) { + enableEnzyme = true; + truncate = true; } else if (Fn->getName().contains("__enzyme_likelihood")) { enableEnzyme = true; probProgMode = ProbProgMode::Likelihood; @@ -2455,6 +2488,8 @@ class EnzymeBase { toSize[CI] = derivativeMode; else if (batch) toBatch.push_back(CI); + else if (truncate) + toTruncate.push_back(CI); else if (probProg) { toProbProg[CI] = probProgMode; } else @@ -2548,6 +2583,9 @@ class EnzymeBase { for (auto call : toBatch) { HandleBatch(call); } + for (auto call : toTruncate) { + HandleTruncate(call); + } for (auto &&[call, mode] : toProbProg) { HandleProbProg(call, mode, calls); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 57aef11c087f..5000ed34bcaa 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -29,6 +29,7 @@ //===----------------------------------------------------------------------===// #include "ActivityAnalysis.h" #include "AdjointGenerator.h" +#include "llvm/IR/Intrinsics.h" #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -4815,6 +4816,518 @@ Function *EnzymeLogic::CreateForwardDiff( return nf; } +class TruncateGenerator : public llvm::InstVisitor { +private: + ValueToValueMapTy &originalToNewFn; + unsigned fromwidth; + unsigned towidth; + Function *oldFunc; + Function *newFunc; + AllocaInst *tmpBlock; + EnzymeLogic &Logic; + +public: + TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, + unsigned towidth, Function *oldFunc, Function *newFunc, + EnzymeLogic &Logic) + : originalToNewFn(originalToNewFn), fromwidth(fromwidth), + towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { + IRBuilder<> B(&newFunc->getEntryBlock().front()); + tmpBlock = B.CreateAlloca(getTypeForWidth(fromwidth)); + } + + void visitInstruction(llvm::Instruction &inst) { + using namespace llvm; + + // TODO explicitly handle all instructions rather than using the catch all + // below + + switch (inst.getOpcode()) { + // #include "InstructionDerivatives.inc" + default: + break; + } + + todo(inst); + } + + Type *getTypeForWidth(unsigned width) { + switch (width) { + default: + return llvm::Type::getIntNTy(oldFunc->getContext(), width); + case 64: + return llvm::Type::getDoubleTy(oldFunc->getContext()); + case 32: + return llvm::Type::getFloatTy(oldFunc->getContext()); + case 16: + return llvm::Type::getHalfTy(oldFunc->getContext()); + } + } + + Type *getFromType() { return getTypeForWidth(fromwidth); } + + Type *getToType() { return getTypeForWidth(towidth); } + + Value *truncate(IRBuilder<> &B, Value *v) { + Type *nextType = getTypeForWidth(towidth); + B.CreateStore( + v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad( + nextType, + B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType))); + } + + Value *expand(IRBuilder<> &B, Value *v) { + Type *origT = getFromType(); + auto c0 = Constant::getNullValue( + llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth)); + B.CreateStore(c0, B.CreatePointerCast( + tmpBlock, PointerType::getUnqual(c0->getType()))); + B.CreateStore( + v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad( + origT, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(origT))); + } + + void todo(llvm::Instruction &I) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "cannot handle unknown instruction\n" << I; + if (CustomErrorHandler) { + IRBuilder<> Builder2(getNewFromOriginal(&I)); + CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate, + this, nullptr, wrap(&Builder2)); + return; + } else { + EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str()); + return; + } + } + + void visitAllocaInst(llvm::AllocaInst &I) { return; } + void visitICmpInst(llvm::ICmpInst &I) { return; } + void visitFCmpInst(llvm::FCmpInst &CI) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + auto truncLHS = truncate(B, getNewFromOriginal(CI.getOperand(0))); + auto truncRHS = truncate(B, getNewFromOriginal(CI.getOperand(1))); + auto nres = + cast(B.CreateFCmp(CI.getPredicate(), truncLHS, truncRHS)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(nres); + newI->eraseFromParent(); + return; + } + void visitLoadInst(llvm::LoadInst &LI) { + auto alignment = LI.getAlign(); + visitLoadLike(LI, alignment); + } + void visitStoreInst(llvm::StoreInst &SI) { + auto align = SI.getAlign(); + visitCommonStore(SI, SI.getPointerOperand(), SI.getValueOperand(), align, + SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), + /*mask=*/nullptr); + } + void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } + void visitPHINode(llvm::PHINode &phi) { return; } + void visitCastInst(llvm::CastInst &CI) { + Value *newCI = nullptr; + auto newI = getNewFromOriginal(&CI); + std::string oldName = CI.getName().str(); + newI->setName(""); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (CI.getDestTy() == getToType()) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (newCI) { + newI->replaceAllUsesWith(newCI); + newI->eraseFromParent(); + } + return; + } + void visitSelectInst(llvm::SelectInst &SI) { + auto newI = getNewFromOriginal(&SI); + IRBuilder<> B(newI); + auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); + auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); + auto nres = cast( + B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + return; + } + void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } + void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } + void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { return; } + void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } + void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } + void visitBinaryOperator(llvm::BinaryOperator &BO) { + + switch (BO.getOpcode()) { + default: + break; + case BinaryOperator::Add: + case BinaryOperator::Sub: + case BinaryOperator::Mul: + case BinaryOperator::UDiv: + case BinaryOperator::SDiv: + case BinaryOperator::URem: + case BinaryOperator::SRem: + case BinaryOperator::AShr: + case BinaryOperator::LShr: + case BinaryOperator::Shl: + case BinaryOperator::And: + case BinaryOperator::Or: + case BinaryOperator::Xor: + return; + } + + if (towidth == 32 || towidth == 16 || towidth == 64) { + auto newI = getNewFromOriginal(&BO); + IRBuilder<> B(newI); + auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); + auto newRHS = truncate(B, getNewFromOriginal(BO.getOperand(1))); + switch (BO.getOpcode()) { + default: + break; + case BinaryOperator::FMul: { + auto nres = cast(B.CreateFMul(newLHS, newRHS)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FAdd: { + auto nres = cast(B.CreateFAdd(newLHS, newRHS)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FSub: { + auto nres = cast(B.CreateFSub(newLHS, newRHS)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FDiv: { + auto nres = cast(B.CreateFDiv(newLHS, newRHS)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FRem: { + auto nres = cast(B.CreateFRem(newLHS, newRHS)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + } + return; + } + } + todo(BO); + return; + } + void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } + void visitMemSetCommon(llvm::CallInst &MS) { return; } + void visitMemTransferInst(llvm::MemTransferInst &MTI) { + using namespace llvm; + Value *isVolatile = getNewFromOriginal(MTI.getOperand(3)); + auto srcAlign = MTI.getSourceAlign(); + auto dstAlign = MTI.getDestAlign(); + visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI, + MTI.getOperand(0), MTI.getOperand(1), + getNewFromOriginal(MTI.getOperand(2)), isVolatile); + } + void visitMemTransferCommon(llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign, + llvm::MaybeAlign dstAlign, llvm::CallInst &MTI, + llvm::Value *orig_dst, llvm::Value *orig_src, + llvm::Value *new_size, llvm::Value *isVolatile) { + return; + } + void visitFenceInst(llvm::FenceInst &FI) { return; } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { + SmallVector orig_ops(II.arg_size()); + for (unsigned i = 0; i < II.arg_size(); ++i) + orig_ops[i] = II.getOperand(i); + if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops)) + return; + + bool hasFromType = false; + auto newI = cast(getNewFromOriginal(&II)); + IRBuilder<> B(newI); + SmallVector new_ops(II.arg_size()); + for (unsigned i = 0; i < II.arg_size(); ++i) { + if (orig_ops[i]->getType() == getFromType()) { + new_ops[i] = truncate(B, getNewFromOriginal(orig_ops[i])); + hasFromType = true; + } else { + new_ops[i] = getNewFromOriginal(orig_ops[i]); + } + } + Type *retTy = II.getType(); + if (II.getType() == getFromType()) { + hasFromType = true; + retTy = getToType(); + } + + if (!hasFromType) + return; + + // TODO check that the intrinsic is overloaded + + CallInst *intr; + Value *nres = intr = createIntrinsicCall(B, II.getIntrinsicID(), retTy, + new_ops, &II, II.getName()); + if (II.getType() == getFromType()) + nres = expand(B, nres); + intr->copyIRFlags(newI); + newI->replaceAllUsesWith(nres); + newI->eraseFromParent(); + + return; + } + + void visitReturnInst(llvm::ReturnInst &I) { return; } + + void visitBranchInst(llvm::BranchInst &I) { return; } + void visitSwitchInst(llvm::SwitchInst &I) { return; } + void visitUnreachableInst(llvm::UnreachableInst &I) { return; } + void visitLoadLike(llvm::Instruction &I, llvm::MaybeAlign alignment, + llvm::Value *mask = nullptr, + llvm::Value *orig_maskInit = nullptr) { + return; + } + + void visitCommonStore(llvm::Instruction &I, llvm::Value *orig_ptr, + llvm::Value *orig_val, llvm::MaybeAlign prevalign, + bool isVolatile, llvm::AtomicOrdering ordering, + llvm::SyncScope::ID syncScope, llvm::Value *mask) { + return; + } + + bool + handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I, + llvm::SmallVectorImpl &orig_ops) { + using namespace llvm; + + switch (ID) { + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + case Intrinsic::nvvm_ldu_global_f: + case Intrinsic::nvvm_ldg_global_i: + case Intrinsic::nvvm_ldg_global_p: + case Intrinsic::nvvm_ldg_global_f: { + auto CI = cast(I.getOperand(1)); + visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue())); + return true; + } + default: + break; + } + + if (ID == Intrinsic::masked_store) { + auto align0 = cast(I.getOperand(2))->getZExtValue(); + auto align = MaybeAlign(align0); + visitCommonStore(I, /*orig_ptr*/ I.getOperand(1), + /*orig_val*/ I.getOperand(0), align, + /*isVolatile*/ false, llvm::AtomicOrdering::NotAtomic, + SyncScope::SingleThread, + /*mask*/ getNewFromOriginal(I.getOperand(3))); + return true; + } + if (ID == Intrinsic::masked_load) { + auto align0 = cast(I.getOperand(1))->getZExtValue(); + auto align = MaybeAlign(align0); + visitLoadLike(I, align, + /*mask*/ getNewFromOriginal(I.getOperand(2)), + /*orig_maskInit*/ I.getOperand(3)); + return true; + } + + return false; + } + + llvm::Value *getNewFromOriginal(llvm::Value *v) { + auto found = originalToNewFn.find(v); + assert(found != originalToNewFn.end()); + return found->second; + } + + llvm::Instruction *getNewFromOriginal(llvm::Instruction *v) { + return cast(getNewFromOriginal((llvm::Value *)v)); + } + + bool handleKnownCalls(llvm::CallInst &call, llvm::Function *called, + llvm::StringRef funcName, + llvm::CallInst *const newCall) { + return false; + } + + Value *GetShadow(RequestContext &ctx, Value *v) { + if (auto F = dyn_cast(v)) + return Logic.CreateTruncate(ctx, F, fromwidth, towidth); + llvm::errs() << " unknown get truncated func: " << *v << "\n"; + llvm_unreachable("unknown get truncated func"); + return v; + } + // Return + void visitCallInst(llvm::CallInst &call) { + using namespace llvm; + + CallInst *const newCall = cast(getNewFromOriginal(&call)); + IRBuilder<> BuilderZ(newCall); + + if (auto called = call.getCalledFunction()) + if (handleKnownCalls(call, called, getFuncNameFromCall(&call), newCall)) + return; + + RequestContext ctx(&call, &BuilderZ); + auto val = GetShadow(ctx, getNewFromOriginal(call.getCalledOperand())); + newCall->setCalledOperand(val); + return; + } +}; + +llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, + llvm::Function *totrunc, + unsigned fromwidth, + unsigned towidth) { + if (fromwidth == towidth) + return totrunc; + + TruncateCacheKey tup(totrunc, fromwidth, towidth); + if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { + return TruncateCachedFunctions.find(tup)->second; + } + + FunctionType *orig_FTy = totrunc->getFunctionType(); + SmallVector params; + + for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { + params.push_back(orig_FTy->getParamType(i)); + } + + Type *NewTy = totrunc->getReturnType(); + + FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); + Function *NewF = + Function::Create(FTy, totrunc->getLinkage(), + "trunc_" + std::to_string(fromwidth) + "_" + + std::to_string(towidth) + totrunc->getName(), + totrunc->getParent()); + + NewF->setLinkage(Function::LinkageTypes::InternalLinkage); + + TruncateCachedFunctions[tup] = NewF; + + if (totrunc->empty()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No truncate mode found for " + totrunc->getName() << "\n"; + llvm::Value *toshow = totrunc; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *totrunc << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(totrunc), + wrap(context.ip)); + return NewF; + } + if (context.req) { + EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, + ss.str()); + return NewF; + } + llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; + llvm::errs() << *totrunc << "\n"; + llvm_unreachable("attempting to truncate function without definition"); + } + + if (fromwidth < towidth) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Cannot truncate into a large width\n"; + llvm::Value *toshow = totrunc; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *totrunc << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(totrunc), + wrap(context.ip)); + return NewF; + } + if (context.req) { + EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, + ss.str()); + return NewF; + } + llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; + llvm::errs() << *totrunc << "\n"; + llvm_unreachable("attempting to truncate function without definition"); + } + + ValueToValueMapTy originalToNewFn; + + for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); + i != totrunc->arg_end();) { + originalToNewFn[i] = j; + j->setName(i->getName()); + ++j; + ++i; + } + + SmallVector Returns; +#if LLVM_VERSION_MAJOR >= 13 + CloneFunctionInto(NewF, totrunc, originalToNewFn, + CloneFunctionChangeType::LocalChangesOnly, Returns, "", + nullptr); +#else + CloneFunctionInto(NewF, totrunc, originalToNewFn, true, Returns, "", nullptr); +#endif + + NewF->setLinkage(Function::LinkageTypes::InternalLinkage); + + TruncateGenerator handle(originalToNewFn, fromwidth, towidth, totrunc, NewF, + *this); + for (auto &BB : *totrunc) + for (auto &I : BB) + handle.visit(&I); + + if (llvm::verifyFunction(*NewF, &llvm::errs())) { + llvm::errs() << *totrunc << "\n"; + llvm::errs() << *NewF << "\n"; + report_fatal_error("function failed verification (5)"); + } + + return NewF; +} + llvm::Function *EnzymeLogic::CreateBatch(RequestContext context, Function *tobatch, unsigned width, ArrayRef arg_types, diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index c7f7c4bae86e..4ce25e8ae465 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -510,6 +510,12 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); + using TruncateCacheKey = std::tuple; + std::map TruncateCachedFunctions; + llvm::Function *CreateTruncate(RequestContext context, + llvm::Function *tobatch, unsigned fromwidth, + unsigned towidth); + /// Create a traced version of a function /// \p context the instruction which requested this trace (or null). /// \p totrace is the function to trace diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 455346addf21..a177dfc7fad4 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2825,3 +2825,35 @@ bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, return true; #endif } + +llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B, + llvm::Intrinsic::ID ID, llvm::Type *RetTy, + llvm::ArrayRef Args, + llvm::Instruction *FMFSource, + const llvm::Twine &Name) { +#if LLVM_VERSION_MAJOR >= 16 + llvm::CallInst *nres = B.CreateIntrinsic(RetTy, ID, Args, FMFSource, Name); +#else + SmallVector Table; + Intrinsic::getIntrinsicInfoTableEntries(ID, Table); + ArrayRef TableRef(Table); + + SmallVector ArgTys; + ArgTys.reserve(Args.size()); + for (auto &I : Args) + ArgTys.push_back(I->getType()); + FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false); + SmallVector OverloadTys; + Intrinsic::MatchIntrinsicTypesResult Res = + matchIntrinsicSignature(FTy, TableRef, OverloadTys); + (void)Res; + assert(Res == Intrinsic::MatchIntrinsicTypes_Match && TableRef.empty() && + "Wrong types for intrinsic!"); + Function *Fn = Intrinsic::getDeclaration(B.GetInsertPoint()->getModule(), ID, + OverloadTys); + CallInst *nres = B.CreateCall(Fn, Args, {}, Name); + if (FMFSource) + nres->copyFastMathFlags(FMFSource); +#endif + return nres; +} diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index c2b41ed3bf68..90043bdd3db3 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -83,6 +83,7 @@ enum class ErrorType { MixedActivityError = 7, IllegalReplaceFicticiousPHIs = 8, GetIndexError = 9, + NoTruncate = 10, }; extern "C" { @@ -1808,4 +1809,10 @@ bool collectOffset(llvm::GEPOperator *gep, const llvm::DataLayout &DL, unsigned BitWidth, llvm::MapVector &VariableOffsets, llvm::APInt &ConstantOffset); + +llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B, + llvm::Intrinsic::ID ID, llvm::Type *RetTy, + llvm::ArrayRef Args, + llvm::Instruction *FMFSource = nullptr, + const llvm::Twine &Name = ""); #endif // ENZYME_UTILS_H diff --git a/enzyme/test/Enzyme/CMakeLists.txt b/enzyme/test/Enzyme/CMakeLists.txt index 0187644409f2..d88af6ddd95e 100644 --- a/enzyme/test/Enzyme/CMakeLists.txt +++ b/enzyme/test/Enzyme/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Sparse) +add_subdirectory(Truncate) add_subdirectory(ReverseMode) add_subdirectory(ReverseModeVector) add_subdirectory(ForwardMode) diff --git a/enzyme/test/Enzyme/Truncate/CMakeLists.txt b/enzyme/test/Enzyme/Truncate/CMakeLists.txt new file mode 100644 index 000000000000..79e649ab8e4b --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-trunc "Running enzyme truncation tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-trunc PROPERTIES FOLDER "Tests") + +# add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} +# DEPENDS ${ENZYME_TEST_DEPS} +# ) diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll new file mode 100644 index 000000000000..3c2cffec9979 --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -0,0 +1,34 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s + +define i1 @f(double %x, double %y) { + %res = fcmp olt double %x, %y + ret i1 %res +} + +declare i1 (double, double)* @__enzyme_truncate(...) + +define i1 @tester(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate(i1 (double, double)* @f, i64 64, i64 32) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} + +; CHECK: define i1 @tester(double %x, double %y) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %res = call i1 @trunc_64_32f(double %x, double %y) +; CHECK-NEXT: ret i1 %res +; CHECK-NEXT: } + +; CHECK: define internal i1 @trunc_64_32f(double %x, double %y) { +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res = fcmp olt float %3, %5 +; CHECK-DAG: ret i1 %res +; CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll new file mode 100644 index 000000000000..ea92f5d96bbc --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -0,0 +1,62 @@ +; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi +; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi + +declare double @llvm.pow.f64(double %Val, double %Power) +declare double @llvm.powi.f64.i16(double %Val, i16 %power) +declare void @llvm.nvvm.barrier0() + +define double @f(double %x, double %y) { + %res1 = call double @llvm.pow.f64(double %x, double %y) + %res2 = call double @llvm.powi.f64.i16(double %x, i16 2) + %res = fadd double %res1, %res2 + call void @llvm.nvvm.barrier0() + ret double %res +} + +declare double (double, double)* @__enzyme_truncate(...) + +define double @tester(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate(double (double, double)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y) + ret double %res +} + +; CHECK: define internal double @trunc_64_32f(double %x, double %y) { +; CHECK-NEXT: %1 = alloca double, align 8 +; CHECK-NEXT: store double %x, double* %1, align 8 +; CHECK-NEXT: %2 = bitcast double* %1 to float* +; CHECK-NEXT: %3 = load float, float* %2, align 4 +; CHECK-NEXT: store double %y, double* %1, align 8 +; CHECK-NEXT: %4 = bitcast double* %1 to float* +; CHECK-NEXT: %5 = load float, float* %4, align 4 +; CHECK-NEXT: %res11 = call float @llvm.pow.f32(float %3, float %5) +; CHECK-NEXT: %6 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %6, align 4 +; CHECK-NEXT: %7 = bitcast double* %1 to float* +; CHECK-NEXT: store float %res11, float* %7, align 4 +; CHECK-NEXT: %8 = load double, double* %1, align 8 +; CHECK-NEXT: store double %x, double* %1, align 8 +; CHECK-NEXT: %9 = bitcast double* %1 to float* +; CHECK-NEXT: %10 = load float, float* %9, align 4 +; CHECK-NEXT: %res22 = call float @llvm.powi.f32.i16(float %10, i16 2) +; CHECK-NEXT: %11 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %11, align 4 +; CHECK-NEXT: %12 = bitcast double* %1 to float* +; CHECK-NEXT: store float %res22, float* %12, align 4 +; CHECK-NEXT: %13 = load double, double* %1, align 8 +; CHECK-NEXT: store double %8, double* %1, align 8 +; CHECK-NEXT: %14 = bitcast double* %1 to float* +; CHECK-NEXT: %15 = load float, float* %14, align 4 +; CHECK-NEXT: store double %13, double* %1, align 8 +; CHECK-NEXT: %16 = bitcast double* %1 to float* +; CHECK-NEXT: %17 = load float, float* %16, align 4 +; CHECK-NEXT: %res = fadd float %15, %17 +; CHECK-NEXT: %18 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %18, align 4 +; CHECK-NEXT: %19 = bitcast double* %1 to float* +; CHECK-NEXT: store float %res, float* %19, align 4 +; CHECK-NEXT: %20 = load double, double* %1, align 8 +; CHECK-NEXT: call void @llvm.nvvm.barrier0() +; CHECK-NEXT: ret double %20 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll new file mode 100644 index 000000000000..ae539469b9a2 --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -0,0 +1,39 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s + +define double @f(double %x, double %y, i1 %cond) { + %res = select i1 %cond, double %x, double %y + ret double %res +} + +declare double (double, double, i1)* @__enzyme_truncate(...) + +define double @tester(double %x, double %y, i1 %cond) { +entry: + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate(double (double, double, i1)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y, i1 %cond) + ret double %res +} + +; CHECK: define double @tester(double %x, double %y, i1 %cond) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %res = call double @trunc_64_32f(double %x, double %y, i1 %cond) +; CHECK-NEXT: ret double %res +; CHECK-NEXT: } + +; CHECK: define internal double @trunc_64_32f(double %x, double %y, i1 %cond) { +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res = select i1 %cond, float %3, float %5 +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %res, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: ret double %8 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll new file mode 100644 index 000000000000..69990236a29e --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -0,0 +1,42 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s + +define void @f(double* %x) { + %y = load double, double* %x + %m = fmul double %y, %y + store double %m, double* %x + ret void +} + +declare void (double*)* @__enzyme_truncate(...) + +define void @tester(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate(void (double*)* @f, i64 64, i64 32) + call void %ptr(double* %data) + ret void +} + +; CHECK: define void @tester(double* %data) +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @trunc_64_32f(double* %data) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal void @trunc_64_32f(double* %x) +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %m = fmul float %3, %5 +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %m, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: store double %8, double* %x, align 8 +; CHECK-DAG: ret void From 9b979e37d7db91cc286ddb46dd46cc607babab1b Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 16 Jan 2024 03:12:44 -0500 Subject: [PATCH 003/131] Fix sparse (#1606) * tmp * progress * fixed memory issue * nicer errors * add a simplified test case * wip * get rid of args in tests * static and inline everywhere * logic fix * fix logic error * fix * add or replace rule * fix compile error * fdiv sparse prop * mul const const * fixup * reverse distribution * fix order bug * mulconstconst optimization * fix rotate * fix * extend constraint language * fixup * continued simplify * canonicalize combine * with builtin assumes * now compiling * format * cleanup * forced termination * improve solve * Don't do redundant soltns * continued fixes * fixing * fixing tests * fmt * fix build * fix 11 * fixtest * addl fix * smpl * fix --------- Co-authored-by: Jesse Michel --- enzyme/Enzyme/CacheUtility.cpp | 14 + enzyme/Enzyme/FunctionUtils.cpp | 2906 +++++++++++++---- enzyme/test/Enzyme/ReverseMode/incloop.ll | 4 +- enzyme/test/Enzyme/ReverseMode/insertsort.ll | 4 +- .../test/Enzyme/ReverseMode/mincachechain5.ll | 4 +- enzyme/test/Enzyme/ReverseMode/ompsqloop.ll | 4 +- .../Enzyme/ReverseMode/ompsqloopoutofplace.ll | 2 +- enzyme/test/Enzyme/ReverseMode/reorderrep.ll | 2 +- enzyme/test/Integration/Sparse/ringspring.cpp | 29 +- .../Sparse/ringspring2Dextenddata.cpp | 159 + .../Sparse/ringspring3Dextenddata.cpp | 57 +- .../ringspring3Dextenddatarestlengthone.cpp | 68 +- .../Sparse/ringspring3Drestlengthone.cpp | 64 +- enzyme/test/Integration/Sparse/sqrtspring.cpp | 35 +- 14 files changed, 2660 insertions(+), 692 deletions(-) create mode 100644 enzyme/test/Integration/Sparse/ringspring2Dextenddata.cpp diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index fd34be6c1146..68c6e46784cf 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -257,6 +257,20 @@ void RemoveRedundantIVs( // and must thus be expanded after all phi's Value *NewIV = Exp.expandCodeFor(S, Tmp->getType(), Header->getFirstNonPHI()); + + // Explicity preserve wrap behavior from original iv. This is necessary + // until this PR in llvm is merged: + // https://github.com/llvm/llvm-project/pull/78199 + if (auto addrec = dyn_cast(S)) { + if (addrec->getLoop()->getHeader() == Header) { + if (auto add_or_mul = dyn_cast(NewIV)) { + if (addrec->getNoWrapFlags(llvm::SCEV::FlagNUW)) + add_or_mul->setHasNoUnsignedWrap(true); + if (addrec->getNoWrapFlags(llvm::SCEV::FlagNSW)) + add_or_mul->setHasNoSignedWrap(true); + } + } + } replacer(Tmp, NewIV); eraser(Tmp); } diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 2346cf3cc82c..0e8a5019fdfb 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -29,6 +29,7 @@ #include "GradientUtils.h" #include "LibraryFuncs.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" @@ -52,6 +53,7 @@ #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include #if LLVM_VERSION_MAJOR < 16 #include "llvm/Analysis/CFLSteensAliasAnalysis.h" @@ -109,6 +111,13 @@ #include "CacheUtility.h" +#if LLVM_VERSION_MAJOR >= 14 +#define addAttribute addAttributeAtIndex +#define removeAttribute removeAttributeAtIndex +#define getAttribute getAttributeAtIndex +#define hasAttribute hasAttributeAtIndex +#endif + #define DEBUG_TYPE "enzyme" using namespace llvm; @@ -2604,9 +2613,192 @@ static bool isNot(Value *a, Value *b) { return false; } +struct compare_insts { +public: + DominatorTree &DT; + LoopInfo &LI; + compare_insts(DominatorTree &DT, LoopInfo &LI) : DT(DT), LI(LI) {} + + // return true if A appears later than B. + bool operator()(Instruction *A, Instruction *B) const { + if (A == B) { + return false; + } + if (A->getParent() == B->getParent()) { + return !A->comesBefore(B); + } + auto AB = A->getParent(); + auto BB = B->getParent(); + assert(AB->getParent() == BB->getParent()); + + for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) { + if (prev == AB) + return false; + } + return true; + } +}; + +class DominatorOrderSet : public std::set { +public: + DominatorOrderSet(DominatorTree &DT, LoopInfo &LI) + : std::set(compare_insts(DT, LI)) {} + bool contains(Instruction *I) const { + auto __i = find(I); + return __i != end(); + } + void remove(Instruction *I) { + auto __i = find(I); + assert(__i != end()); + erase(__i); + } + Instruction *pop_back_val() { + auto back = end(); + back--; + auto v = *back; + erase(back); + return v; + } +}; + +bool directlySparse(Value *z) { + if (isa(z)) + return true; + if (isa(z)) + return true; + if (isa(z)) + return true; + if (isa(z)) + return true; + if (auto SI = dyn_cast(z)) { + if (auto CI = dyn_cast(SI->getTrueValue())) + if (CI->isZero()) + return true; + if (auto CI = dyn_cast(SI->getFalseValue())) + if (CI->isZero()) + return true; + } + return false; +} + +typedef DominatorOrderSet QueueType; + +Function *getProductIntrinsic(llvm::Module &M, llvm::Type *T) { + std::string name = "__enzyme_product."; + if (T->isFloatTy()) + name += "f32"; + else if (T->isDoubleTy()) + name += "f64"; + else if (T->isIntegerTy()) + name += "i" + std::to_string(cast(T)->getBitWidth()); + else + assert(0); + auto FT = llvm::FunctionType::get(T, {}, true); + AttributeList AL; + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::ReadNone); + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::NoUnwind); +#if LLVM_VERSION_MAJOR >= 14 + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::NoFree); + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::NoSync); + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::WillReturn); +#endif + return cast(M.getOrInsertFunction(name, FT, AL).getCallee()); +} + +Function *getSumIntrinsic(llvm::Module &M, llvm::Type *T) { + std::string name = "__enzyme_sum."; + if (T->isFloatTy()) + name += "f32"; + else if (T->isDoubleTy()) + name += "f64"; + else if (T->isIntegerTy()) + name += "i" + std::to_string(cast(T)->getBitWidth()); + else + assert(0); + auto FT = llvm::FunctionType::get(T, {}, true); + AttributeList AL; + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::ReadNone); + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::NoUnwind); +#if LLVM_VERSION_MAJOR >= 14 + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::NoFree); + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::NoSync); + AL = AL.addAttribute(T->getContext(), AttributeList::FunctionIndex, + Attribute::WillReturn); +#endif + return cast(M.getOrInsertFunction(name, FT, AL).getCallee()); +} + +CallInst *isProduct(llvm::Value *v) { + if (auto prod = dyn_cast(v)) + if (auto F = getFunctionFromCall(prod)) + if (startsWith(F->getName(), "__enzyme_product")) + return prod; + return nullptr; +} + +CallInst *isSum(llvm::Value *v) { + if (auto prod = dyn_cast(v)) + if (auto F = getFunctionFromCall(prod)) + if (startsWith(F->getName(), "__enzyme_sum")) + return prod; + return nullptr; +} + +SmallVector callOperands(llvm::CallBase *CB) { +#if LLVM_VERSION_MAJOR >= 14 + return SmallVector(CB->args().begin(), CB->args().end()); +#else + return SmallVector(CB->arg_operands().begin(), + CB->arg_operands().end()); +#endif +} + +bool guaranteedDataDependent(Value *z) { + if (isa(z)) + return true; + if (isa(z)) + return true; + if (auto BO = dyn_cast(z)) + return guaranteedDataDependent(BO->getOperand(0)) && + guaranteedDataDependent(BO->getOperand(1)); + if (auto C = dyn_cast(z)) + return guaranteedDataDependent(C->getOperand(0)); + if (auto S = isSum(z)) { + for (auto op : callOperands(S)) + if (guaranteedDataDependent(op)) + return true; + return false; + } + if (auto S = isProduct(z)) { + for (auto op : callOperands(S)) + if (!guaranteedDataDependent(op)) + return false; + return true; + } + if (auto II = dyn_cast(z)) { + switch (II->getIntrinsicID()) { + case Intrinsic::sqrt: + case Intrinsic::sin: + case Intrinsic::cos: + return guaranteedDataDependent(II->getArgOperand(0)); + default: + break; + } + } + return false; +} + std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, - llvm::SetVector &Q, - DominatorTree &DT, + QueueType &Q, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI, const DataLayout &DL) { auto push = [&](llvm::Value *V) { @@ -2634,7 +2826,34 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, candidate = dyn_cast(U); if (!candidate) continue; - if (candidate == I || !candidate->isIdenticalTo(I)) { + if (candidate == I && candidate->getType() != I->getType()) { + candidate = nullptr; + continue; + } + bool isSame = candidate->isIdenticalTo(I); + if (!isSame) { + if (auto P1 = isProduct(I)) + if (auto P2 = isProduct(I2)) { + std::multiset s1; + std::multiset s2; + for (auto &v : callOperands(P1)) + s1.insert(v); + for (auto &v : callOperands(P2)) + s2.insert(v); + isSame = s1 == s2; + } + if (auto P1 = isSum(I)) + if (auto P2 = isSum(I2)) { + std::multiset s1; + std::multiset s2; + for (auto &v : callOperands(P1)) + s1.insert(v); + for (auto &v : callOperands(P2)) + s2.insert(v); + isSame = s1 == s2; + } + } + if (!isSame) { candidate = nullptr; continue; } @@ -2664,24 +2883,30 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (size_t i = 0; i < I->getNumOperands(); i++) { if (auto I2 = dyn_cast(I->getOperand(i))) { if ((!I2->mayWriteToMemory() || - (isa(I2) && cast(I2)->onlyReadsMemory()))) + (isa(I2) && isReadOnly(cast(I2))))) operands.insert(I2); } } + if (Q.contains(I)) { + Q.remove(I); + } + assert(!Q.contains(I)); I->eraseFromParent(); for (auto op : operands) if (op->getNumUses() == 0) { - Q.remove(op); + if (Q.contains(op)) + Q.remove(op); op->eraseFromParent(); } }; if (!cur->getType()->isVoidTy() && (!cur->mayWriteToMemory() || - (isa(cur) && cast(cur)->onlyReadsMemory()))) { + (isa(cur) && isReadOnly(cast(cur))))) { // DCE if (cur->getNumUses() == 0) { for (size_t i = 0; i < cur->getNumOperands(); i++) push(cur->getOperand(i)); + assert(!Q.contains(cur)); cur->eraseFromParent(); return "DCE"; } @@ -2695,7 +2920,35 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, candidate = dyn_cast(U); if (!candidate) continue; - if (candidate == cur || !candidate->isIdenticalTo(cur)) { + if (candidate == cur && candidate->getType() != cur->getType()) { + candidate = nullptr; + continue; + } + bool isSame = candidate->isIdenticalTo(cur); + if (!isSame) { + if (auto P1 = isProduct(candidate)) + if (auto P2 = isProduct(cur)) { + std::multiset s1; + std::multiset s2; + for (auto &v : callOperands(P1)) + s1.insert(v); + for (auto &v : callOperands(P2)) + s2.insert(v); + isSame = s1 == s2; + } + if (auto P1 = isSum(candidate)) + if (auto P2 = isSum(cur)) { + std::multiset s1; + std::multiset s2; + for (auto &v : callOperands(P1)) + s1.insert(v); + for (auto &v : callOperands(P2)) + s2.insert(v); + isSame = s1 == s2; + } + } + + if (!isSame) { candidate = nullptr; continue; } @@ -2710,7 +2963,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } if (candidate) { if (reverse) { - Q.remove(candidate); + if (Q.contains(candidate)) + Q.remove(candidate); auto tmp = candidate; candidate = cur; cur = tmp; @@ -2742,7 +2996,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return "OrZero"; } // or a, 1 -> 1 - if (C->isOne()) { + if (C->isOne() && cur->getType()->isIntegerTy(1)) { replaceAndErase(cur, C); return "OrOne"; } @@ -2753,7 +3007,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (int i = 0; i < 2; i++) { if (auto C = dyn_cast(cur->getOperand(i))) { // and a, 1 -> a - if (C->isOne()) { + if (C->isOne() && cur->getType()->isIntegerTy(1)) { replaceAndErase(cur, cur->getOperand(1 - i)); return "AndOne"; } @@ -2767,6 +3021,12 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } IRBuilder<> B(cur); + if (auto CI = dyn_cast(cur)) + if (auto C = dyn_cast(CI->getOperand(0))) { + replaceAndErase( + cur, cast(B.CreateCast(CI->getOpcode(), C, CI->getType()))); + return "CastConstProp"; + } std::function replace = [&](Value *val, Value *orig, Value *with) { @@ -2781,7 +3041,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (auto I = dyn_cast(val)) { if (I->mayWriteToMemory() && - !(isa(I) && cast(I)->onlyReadsMemory())) + !(isa(I) && isReadOnly(cast(I)))) return val; if (I->getOpcode() == Instruction::Add) { @@ -2921,156 +3181,854 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, fval == SI->getFalseValue()) return val; push(I); + if (auto CI = dyn_cast(cond)) { + if (CI->isOne()) + return tval; + else + return fval; + } return pushcse(B.CreateSelect(cond, tval, fval, "sel." + I->getName())); } + + if (isProduct(I) || isSum(I)) { + auto C = cast(I); + auto ops = callOperands(C); + bool changed = false; + for (auto &op : ops) { + auto next = replace(op, orig, with); + if (next != op) { + changed = true; + op = next; + } + } + if (!changed) + return (Value *)I; + push(I); + pushcse( + B.CreateCall(getFunctionFromCall(C), ops, "sel." + I->getName())); + } } return val; }; - // mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2) - if (cur->getOpcode() == Instruction::FMul) - if (cur->isFast()) - if (auto mul1 = dyn_cast(cur->getOperand(0))) - if (mul1->getOpcode() == Instruction::FMul && mul1->isFast()) - if (auto mul2 = dyn_cast(cur->getOperand(1))) - if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) - for (auto i1 = 0; i1 < 2; i1++) - for (auto i2 = 0; i2 < 2; i2++) - if (isa(mul1->getOperand(i1))) - if (isa(mul2->getOperand(i2))) { - - auto n0 = pushcse( - B.CreateFMulFMF(mul1->getOperand(1 - i1), - mul2->getOperand(1 - i2), cur)); - auto n1 = pushcse(B.CreateFMulFMF( - mul1->getOperand(i1), mul2->getOperand(i2), cur)); - auto n2 = pushcse(B.CreateFMulFMF(n0, n1, cur)); - push(mul1); - push(mul2); - replaceAndErase(cur, n2); - return "MulMulConstConst"; - } - - // mul (mul a, const), b -> mul (mul a, b), const - // note we avoid the case where b = (mul a, const) since otherwise - // we create an infinite recursion - if (cur->getOpcode() == Instruction::FMul) - if (cur->isFast() && cur->getOperand(0) != cur->getOperand(1)) - for (auto ic = 0; ic < 2; ic++) - if (auto mul = dyn_cast(cur->getOperand(ic))) - if (mul->getOpcode() == Instruction::FMul && mul->isFast()) - if (!isa(cur->getOperand(1 - ic))) { + if (auto II = dyn_cast(cur)) + if (II->getIntrinsicID() == Intrinsic::fmuladd || + II->getIntrinsicID() == Intrinsic::fma) { + B.setFastMathFlags(getFast()); + auto mul = pushcse(B.CreateFMul(II->getOperand(0), II->getOperand(1))); + auto add = pushcse(B.CreateFAdd(mul, II->getOperand(2))); + replaceAndErase(cur, add); + return "FMulAddExpand"; + } + + if (auto BO = dyn_cast(cur)) { + if (BO->getOpcode() == Instruction::FMul && BO->isFast()) { + Value *args[2] = {BO->getOperand(0), BO->getOperand(1)}; + auto mul = pushcse( + B.CreateCall(getProductIntrinsic(*F.getParent(), BO->getType()), args, + cur->getName())); + replaceAndErase(cur, mul); + return "FMulToProduct"; + } + if (BO->getOpcode() == Instruction::FDiv && BO->isFast()) { + auto c0 = dyn_cast(BO->getOperand(0)); + if (!c0 || !c0->isExactlyValue(1.0)) { + B.setFastMathFlags(getFast()); + auto div = pushcse(B.CreateFDivFMF(ConstantFP::get(BO->getType(), 1.0), + BO->getOperand(1), BO)); + auto mul = pushcse( + B.CreateFMulFMF(BO->getOperand(0), div, BO, cur->getName())); + replaceAndErase(cur, mul); + return "FDivToFMul"; + } + } + if (BO->getOpcode() == Instruction::FAdd && BO->isFast()) { + Value *args[2] = {BO->getOperand(0), BO->getOperand(1)}; + auto mul = pushcse( + B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()), args)); + replaceAndErase(cur, mul); + return "FAddToSum"; + } + if (BO->getOpcode() == Instruction::FSub && BO->isFast()) { + B.setFastMathFlags(getFast()); + Value *args[2] = {BO->getOperand(0), + pushcse(B.CreateFNeg(BO->getOperand(1)))}; + auto mul = + pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), BO->getType()), + args, cur->getName())); + replaceAndErase(cur, mul); + return "FAddToSum"; + } + } + if (cur->getOpcode() == Instruction::FNeg) { + B.setFastMathFlags(getFast()); + auto mul = + pushcse(B.CreateFMulFMF(ConstantFP::get(cur->getType(), -1.0), + cur->getOperand(0), cur, cur->getName())); + replaceAndErase(cur, mul); + return "FNegToMul"; + } - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(mul->getOperand(i))) { - auto n0 = pushcse(B.CreateFMulFMF( - mul->getOperand(1 - i), cur->getOperand(1 - ic), mul)); - auto n1 = pushcse(B.CreateFMulFMF(n0, C, cur)); - push(mul); + if (auto SI = dyn_cast(cur)) { + if (auto tc = dyn_cast(SI->getTrueValue())) + if (auto fc = dyn_cast(SI->getFalseValue())) + if (fc->isZero()) { + if (tc->isExactlyValue(1.0)) { + auto res = + pushcse(B.CreateUIToFP(SI->getCondition(), tc->getType())); + replaceAndErase(cur, res); + return "SelToUIFP"; + } + if (tc->isExactlyValue(-1.0)) { + auto res = + pushcse(B.CreateSIToFP(SI->getCondition(), tc->getType())); + replaceAndErase(cur, res); + return "SelToSIFP"; + } + } + } - replaceAndErase(cur, n1); - return "MulMulConst"; - } - } + if (auto P = isProduct(cur)) { + SmallVector operands; + std::optional constval; + bool changed = false; + for (auto &v : callOperands(P)) - if (auto fcmp = dyn_cast(cur)) { - if (fcmp->getPredicate() == FCmpInst::FCMP_OEQ) { - for (int i = 0; i < 2; i++) - if (auto C = dyn_cast(fcmp->getOperand(i))) { - if (C->isZero()) { - if (auto fmul = dyn_cast(fcmp->getOperand(1 - i))) { - // (a*b) == 0 -> (a == 0) || (b == 0) - if (fmul->getOpcode() == Instruction::FMul) { - auto ncmp1 = pushcse( - B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C)); - auto ncmp2 = pushcse( - B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(1), C)); - auto ori = pushcse(B.CreateOr(ncmp1, ncmp2)); - replaceAndErase(cur, ori); - return "CmpFMulSplit"; - } - // (a/b) == 0 -> (a == 0) - if (fmul->getOpcode() == Instruction::FDiv) { - auto ncmp1 = pushcse( - B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C)); - replaceAndErase(cur, ncmp1); - return "CmpFDivSplit"; - } - // (a - b) ?= 0 -> a ?= b - if (fmul->getOpcode() == Instruction::FSub) { - auto ncmp1 = pushcse(B.CreateFCmp(fcmp->getPredicate(), - fmul->getOperand(0), - fmul->getOperand(1))); - replaceAndErase(cur, ncmp1); - return "CmpFSubSplit"; - } - } - if (auto cast = dyn_cast(fcmp->getOperand(1 - i))) { - auto ncmp1 = pushcse(B.CreateICmp( - ICmpInst::ICMP_EQ, cast->getOperand(0), - ConstantInt::get(cast->getOperand(0)->getType(), 0))); - replaceAndErase(cur, ncmp1); - return "SFCmpToICmp"; - } - if (auto cast = dyn_cast(fcmp->getOperand(1 - i))) { - auto ncmp1 = pushcse(B.CreateICmp( - ICmpInst::ICMP_EQ, cast->getOperand(0), - ConstantInt::get(cast->getOperand(0)->getType(), 0))); - replaceAndErase(cur, ncmp1); - return "UFCmpToICmp"; - } - if (auto SI = dyn_cast(fcmp->getOperand(1 - i))) { - auto res = pushcse( - B.CreateSelect(SI->getCondition(), - pushcse(B.CreateCmp(fcmp->getPredicate(), C, - SI->getTrueValue())), - pushcse(B.CreateCmp(fcmp->getPredicate(), C, - SI->getFalseValue())))); - replaceAndErase(cur, res); - return "FCmpSelect"; - } - } + { + if (auto P2 = isProduct(v)) { + for (auto &v2 : callOperands(P2)) { + push(v2); + operands.push_back(v2); + } + push(P2); + changed = true; + continue; + } + if (auto C = dyn_cast(v)) { + if (C->isExactlyValue(1.0)) { + changed = true; + continue; + } + if (C->isZero()) { + replaceAndErase(cur, C); + return "ZeroProduct"; } + if (!constval) { + constval = C->getValue(); + continue; + } + constval = (*constval) * C->getValue(); + changed = true; + continue; + } + operands.push_back(v); } - } - if (auto fcmp = dyn_cast(cur)) { - if (fcmp->getPredicate() == CmpInst::ICMP_EQ || - fcmp->getPredicate() == CmpInst::ICMP_NE || - fcmp->getPredicate() == CmpInst::FCMP_OEQ || - fcmp->getPredicate() == CmpInst::FCMP_ONE) { + if (constval) + operands.push_back(ConstantFP::get(cur->getType(), *constval)); - // a + c ?= a -> c ?= 0 , if fast - for (int i = 0; i < 2; i++) - if (auto inst = dyn_cast(fcmp->getOperand(i))) - if (inst->getOpcode() == Instruction::FAdd && inst->isFast()) - for (int i2 = 0; i2 < 2; i2++) - if (inst->getOperand(i2) == fcmp->getOperand(1 - i)) { - auto res = pushcse( - B.CreateCmp(fcmp->getPredicate(), inst->getOperand(1 - i2), - ConstantFP::get(inst->getType(), 0))); - replaceAndErase(cur, res); - return "CmpFAddSame"; - } + if (operands.size() == 0) { + replaceAndErase(cur, ConstantFP::get(cur->getType(), 1.0)); + return "EmptyProduct"; + } + if (operands.size() == 1) { + replaceAndErase(cur, operands[0]); + return "SingleProduct"; + } + if (changed) { + auto mul = pushcse( + B.CreateCall(getProductIntrinsic(*F.getParent(), cur->getType()), + operands, cur->getName())); + replaceAndErase(cur, mul); + return "ProductSimplification"; + } + } - // a == b -> a & b | !a & !b - // a != b -> a & !b | !a & b - if (fcmp->getOperand(0)->getType()->isIntegerTy(1)) { - auto a = fcmp->getOperand(0); - auto b = fcmp->getOperand(1); - if (fcmp->getPredicate() == CmpInst::ICMP_EQ) { - auto res = pushcse( - B.CreateOr(pushcse(B.CreateAnd(a, b)), - pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), - pushcse(B.CreateNot(b)))))); - replaceAndErase(cur, res); - return "CmpI1EQ"; + if (auto P = isSum(cur)) { + // map from operand, to number of counts + std::map operands; + std::optional constval; + bool changed = false; + for (auto &v : callOperands(P)) { + if (auto P2 = isSum(v)) { + for (auto &v2 : callOperands(P2)) { + push(v2); + operands[v2]++; } - if (fcmp->getPredicate() == CmpInst::ICMP_NE) { - auto res = pushcse( - B.CreateOr(pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), b)), - pushcse(B.CreateAnd(a, pushcse(B.CreateNot(b)))))); - replaceAndErase(cur, res); - return "CmpI1NE"; + push(P2); + changed = true; + continue; + } + if (auto C = dyn_cast(v)) { + if (C->isExactlyValue(0.0)) { + changed = true; + continue; + } + if (!constval) { + constval = C->getValue(); + continue; + } + constval = (*constval) + C->getValue(); + changed = true; + continue; + } + operands[v]++; + } + if (constval) + operands[ConstantFP::get(cur->getType(), *constval)]++; + + if (operands.size() == 0) { + replaceAndErase(cur, ConstantFP::get(cur->getType(), 0.0)); + return "EmptySum"; + } + SmallVector args; + for (auto &pair : operands) { + if (pair.second == 1) { + args.push_back(pair.first); + continue; + } + changed = true; + Value *sargs[] = {pair.first, + ConstantFP::get(cur->getType(), (double)pair.second)}; + args.push_back(pushcse(B.CreateCall( + getProductIntrinsic(*F.getParent(), cur->getType()), sargs))); + } + if (args.size() == 1) { + replaceAndErase(cur, args[0]); + return "SingleSum"; + } + if (changed) { + auto sum = + pushcse(B.CreateCall(getSumIntrinsic(*F.getParent(), cur->getType()), + args, cur->getName())); + replaceAndErase(cur, sum); + return "SumSimplification"; + } + } + + if (auto P = isProduct(cur)) { + SmallVector operands; + SmallVector conditions; + for (auto &v : callOperands(P)) { + // z = uitofp i1 c to float -> select c, (prod withot z), 0 + if (auto op = dyn_cast(v)) { + if (op->getOperand(0)->getType()->isIntegerTy(1)) { + conditions.push_back(op->getOperand(0)); + continue; + } + } + // z = sitofp i1 c to float -> select c, (-prod withot z), 0 + if (auto op = dyn_cast(v)) { + if (op->getOperand(0)->getType()->isIntegerTy(1)) { + conditions.push_back(op->getOperand(0)); + operands.push_back(ConstantFP::get(cur->getType(), -1.0)); + continue; + } + } + if (auto op = dyn_cast(v)) { + if (auto tc = dyn_cast(op->getTrueValue())) + if (tc->isZero()) { + conditions.push_back(pushcse(B.CreateNot(op->getCondition()))); + operands.push_back(op->getFalseValue()); + continue; + } + if (auto tc = dyn_cast(op->getFalseValue())) + if (tc->isZero()) { + conditions.push_back(op->getCondition()); + operands.push_back(op->getTrueValue()); + continue; + } + } + operands.push_back(v); + } + + if (conditions.size()) { + auto mul = pushcse(B.CreateCall( + getProductIntrinsic(*F.getParent(), cur->getType()), operands)); + Value *condition = nullptr; + for (auto v : conditions) { + assert(v->getType()->isIntegerTy(1)); + if (condition == nullptr) { + condition = v; + continue; + } + condition = pushcse(B.CreateAnd(condition, v)); + } + auto zero = ConstantFP::get(cur->getType(), 0.0); + auto sel = pushcse(B.CreateSelect(condition, mul, zero, cur->getName())); + replaceAndErase(cur, sel); + return "ProductSelect"; + } + } + + // TODO + if (auto P = isSum(cur)) { + // whether negated + SmallVector, 1> conditions; + for (auto &v : callOperands(P)) { + // z = uitofp i1 c to float -> select c, (prod withot z), 0 + if (auto op = dyn_cast(v)) { + if (op->getOperand(0)->getType()->isIntegerTy(1)) { + conditions.emplace_back(op->getOperand(0), false); + continue; + } + } + // z = sitofp i1 c to float -> select c, (-prod withot z), 0 + if (auto op = dyn_cast(v)) { + if (op->getOperand(0)->getType()->isIntegerTy(1)) { + conditions.emplace_back(op->getOperand(0), false); + continue; + } + } + if (auto op = dyn_cast(v)) { + if (auto tc = dyn_cast(op->getTrueValue())) + if (tc->isZero()) { + conditions.emplace_back(op->getCondition(), true); + continue; + } + if (auto tc = dyn_cast(op->getFalseValue())) + if (tc->isZero()) { + conditions.emplace_back(op->getCondition(), false); + continue; + } + } + } + Value *condition = nullptr; + for (size_t i = 0; i < conditions.size(); i++) { + size_t count = 0; + for (size_t j = 0; j < conditions.size(); j++) { + if (((conditions[i].first == conditions[j].first) && + (conditions[i].second == conditions[i].second)) || + ((isNot(conditions[i].first, conditions[j].first) && + (conditions[i].second != conditions[i].second)))) + count++; + } + if (count == conditions.size() && count > 1) { + condition = conditions[i].first; + if (conditions[i].second) + condition = pushcse(B.CreateNot(condition, "sumpnot")); + break; + } + } + + if (condition) { + + SmallVector operands; + for (auto &v : callOperands(P)) { + // z = uitofp i1 c to float -> select c, (prod withot z), 0 + if (auto op = dyn_cast(v)) { + if (op->getOperand(0)->getType()->isIntegerTy(1)) { + operands.push_back(ConstantFP::get(cur->getType(), 1.0)); + continue; + } + } + // z = sitofp i1 c to float -> select c, (-prod withot z), 0 + if (auto op = dyn_cast(v)) { + if (op->getOperand(0)->getType()->isIntegerTy(1)) { + operands.push_back(ConstantFP::get(cur->getType(), -1.0)); + continue; + } + } + if (auto op = dyn_cast(v)) { + if (auto tc = dyn_cast(op->getTrueValue())) + if (tc->isZero()) { + operands.push_back(op->getFalseValue()); + continue; + } + if (auto tc = dyn_cast(op->getFalseValue())) + if (tc->isZero()) { + operands.push_back(op->getTrueValue()); + continue; + } + } + assert(0); + } + + if (conditions.size()) { + auto sum = pushcse(B.CreateCall( + getSumIntrinsic(*F.getParent(), cur->getType()), operands)); + auto zero = ConstantFP::get(cur->getType(), 0.0); + auto sel = + pushcse(B.CreateSelect(condition, sum, zero, cur->getName())); + replaceAndErase(cur, sel); + return "SumSelect"; + } + } + } + // (a1*b1) + (a1*c1) + (a1*d1 ) + ... -> a1 * (b1 + c1 + d1 + ...) + if (auto S = isSum(cur)) { + SmallVector allOps; + auto combine = [](const SmallVector &lhs, + SmallVector rhs) { + SmallVector out; + for (auto v : lhs) { + bool seen = false; + for (auto &v2 : rhs) { + if (v == v2) { + v2 = nullptr; + seen = true; + break; + } + } + if (seen) { + out.push_back(v); + } + } + return out; + }; + auto subtract = [](SmallVector lhs, + const SmallVector &rhs) { + for (auto v : rhs) { + auto found = find(lhs, v); + assert(found != lhs.end()); + lhs.erase(found); + } + return lhs; + }; + bool seen = false; + bool legal = true; + for (auto op : callOperands(S)) { + auto P = isProduct(op); + if (!P) { + legal = false; + break; + } + if (!seen) { + allOps = callOperands(P); + seen = true; + continue; + } + allOps = combine(allOps, callOperands(P)); + } + + if (legal && allOps.size() > 0) { + SmallVector operands; + for (auto op : callOperands(S)) { + auto P = isProduct(op); + push(op); + auto sub = subtract(callOperands(P), allOps); + auto newprod = pushcse(B.CreateCall( + getProductIntrinsic(*F.getParent(), S->getType()), sub)); + operands.push_back(newprod); + } + auto newsum = pushcse(B.CreateCall( + getSumIntrinsic(*F.getParent(), S->getType()), operands)); + allOps.push_back(newsum); + auto fprod = pushcse(B.CreateCall( + getProductIntrinsic(*F.getParent(), S->getType()), allOps)); + replaceAndErase(cur, fprod); + return "SumFactor"; + } + } + + /* + // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 != + c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; j++) + if (auto c0 = dyn_cast(cur->getOperand(j))) + if (auto cmp0 = dyn_cast(c0->getOperand(0))) + if (auto c1 = dyn_cast(cur->getOperand(1-j))) + if (auto cmp1 = dyn_cast(c0->getOperand(0))) + if (cmp0->getPredicate() == ICmpInst::ICMP_EQ && + cmp1->getPredicate() == ICmpInst::ICMP_EQ) + { + for (size_t i0 = 0; i0 < 2; i0++) + for (size_t i1 = 0; i1 < 2; i1++) + if (cmp0->getOperand(1 - i0) == cmp1->getOperand(1 - i1)) + auto e0 = SE.getSCEV(cmp0->getOperand(i0)); + auto e1 = SE.getSCEV(cmp1->getOperand(i1)); + auto m = SE.getMinusSCEV(e0, e1, SCEV::NoWrapMask); + if (auto C = dyn_cast(m)) { + // if c1 == c2 don't need the and they are equivalent + if (C->getValue()->isZero()) { + } else { + auto sel0 = pushcse(B.CreateSelect(cmp0, + ConstantInt::get(cur->getType(), isa(cmp0) ? 1 : -1), + ConstantInt::get(cur->getType(), 0)); + // if non one constant they must be distinct. + replaceAndErase(cur, + ConstantInt::getFalse(cur->getContext())); + return "AndNEExpr"; + } + } + } + } + */ + + if (auto fcmp = dyn_cast(cur)) { + auto predicate = fcmp->getPredicate(); + if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ || + predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) { + for (int i = 0; i < 2; i++) + if (auto C = dyn_cast(fcmp->getOperand(i))) { + if (C->isZero()) { + // (a1*a2*...an) == 0 -> (a1 == 0) || (a2 == 0) || ... (a2 == 0) + // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == 0) + // ] + if (auto P = isProduct(fcmp->getOperand(1 - i))) { + Value *res = nullptr; + + auto eq_predicate = predicate; + if (predicate == FCmpInst::FCMP_UNE || + predicate == FCmpInst::FCMP_ONE) + eq_predicate = fcmp->getInversePredicate(); + + for (auto &v : callOperands(P)) { + auto ncmp1 = pushcse(B.CreateFCmp(eq_predicate, v, C)); + if (!res) + res = ncmp1; + else + res = pushcse(B.CreateOr(res, ncmp1)); + } + + if (predicate == FCmpInst::FCMP_UNE || + predicate == FCmpInst::FCMP_ONE) { + res = pushcse(B.CreateNot(res)); + } + + replaceAndErase(cur, res); + return "CmpProductSplit"; + } + + // (a1*b1) + (a1*c1) + (a1*d1 ) + ... ?= 0 -> a1 * (b1 + c1 + d1 + + // ...) ?= 0 + if (auto S = isSum(fcmp->getOperand(1 - i))) { + SmallVector allOps; + auto combine = [](const SmallVector &lhs, + SmallVector rhs) { + SmallVector out; + for (auto v : lhs) { + bool seen = false; + for (auto &v2 : rhs) { + if (v == v2) { + v2 = nullptr; + seen = true; + break; + } + } + if (seen) { + out.push_back(v); + } + } + return out; + }; + auto subtract = [](SmallVector lhs, + const SmallVector &rhs) { + for (auto v : rhs) { + auto found = find(lhs, v); + assert(found != lhs.end()); + lhs.erase(found); + } + return lhs; + }; + bool seen = false; + bool legal = true; + for (auto op : callOperands(S)) { + auto P = isProduct(op); + if (!P) { + legal = false; + break; + } + if (!seen) { + allOps = callOperands(P); + seen = true; + continue; + } + allOps = combine(allOps, callOperands(P)); + } + + if (legal && allOps.size() > 0) { + SmallVector operands; + for (auto op : callOperands(S)) { + auto P = isProduct(op); + push(op); + auto sub = subtract(callOperands(P), allOps); + auto newprod = pushcse(B.CreateCall( + getProductIntrinsic(*F.getParent(), C->getType()), sub)); + operands.push_back(newprod); + } + auto newsum = pushcse(B.CreateCall( + getSumIntrinsic(*F.getParent(), C->getType()), operands)); + allOps.push_back(newsum); + auto fprod = pushcse(B.CreateCall( + getProductIntrinsic(*F.getParent(), C->getType()), allOps)); + auto fcmp = pushcse(B.CreateCmp(predicate, fprod, C)); + replaceAndErase(cur, fcmp); + return "CmpSumFactor"; + } + } + } + } + } + } + + if (auto fcmp = dyn_cast(cur)) { + auto predicate = fcmp->getPredicate(); + if (predicate == FCmpInst::FCMP_OEQ || predicate == FCmpInst::FCMP_UEQ || + predicate == FCmpInst::FCMP_UNE || predicate == FCmpInst::FCMP_ONE) { + for (int i = 0; i < 2; i++) + if (auto C = dyn_cast(fcmp->getOperand(i))) { + if (C->isZero()) { + // a + b == 0 -> ( (a == 0 & b == 0) || a == -b) + if (auto S = isSum(fcmp->getOperand(1 - i))) { + auto allOps = callOperands(S); + if (!llvm::any_of(allOps, guaranteedDataDependent)) { + auto eq_predicate = predicate; + if (predicate == FCmpInst::FCMP_UNE || + predicate == FCmpInst::FCMP_ONE) + eq_predicate = fcmp->getInversePredicate(); + + Value *op_checks = nullptr; + for (auto a : allOps) { + auto a_e0 = pushcse(B.CreateFCmp(eq_predicate, a, C)); + if (op_checks == nullptr) + op_checks = a_e0; + else + op_checks = pushcse(B.CreateAnd(op_checks, a_e0)); + } + SmallVector slice; + for (size_t i = 1; i < allOps.size(); i++) + slice.push_back(allOps[i]); + auto ane = pushcse(B.CreateFCmp( + eq_predicate, pushcse(B.CreateFNeg(allOps[0])), + pushcse(B.CreateCall(getFunctionFromCall(S), slice)))); + auto ori = pushcse(B.CreateOr(op_checks, ane)); + if (predicate == FCmpInst::FCMP_UNE || + predicate == FCmpInst::FCMP_ONE) { + ori = pushcse(B.CreateNot(ori)); + } + replaceAndErase(cur, ori); + return "Sum2ZeroSplit"; + } + } + } + } + } + } + + // (zext a) + (zext b) ?= 0 -> zext a ?= - zext b + if (auto icmp = dyn_cast(cur)) { + if (icmp->getPredicate() == CmpInst::ICMP_EQ || + icmp->getPredicate() == CmpInst::ICMP_NE) { + for (int i = 0; i < 2; i++) + if (auto C = dyn_cast(icmp->getOperand(i))) + if (C->isZero()) + if (auto add = dyn_cast(icmp->getOperand(1 - i))) + if (add->getOpcode() == Instruction::Add) + if (auto a0 = dyn_cast(add->getOperand(0))) + if (auto a1 = dyn_cast(add->getOperand(1))) + if (a0->getOperand(0)->getType() == + a1->getOperand(0)->getType() && + (isa(a0) || isa(a0))) { + auto cmp2 = pushcse(B.CreateCmp( + icmp->getPredicate(), a0, pushcse(B.CreateNeg(a1)))); + replaceAndErase(cur, cmp2); + return "CmpExt0Shuffle"; + } + } + } + + // sub 0, (zext i1 to N) -> sext i1 to N + // sub 0, (sext i1 to N) -> zext i1 to N + if (auto sub = dyn_cast(cur)) + if (sub->getOpcode() == Instruction::Sub) + if (auto C = dyn_cast(sub->getOperand(0))) + if (C->isZero()) + if (auto a0 = dyn_cast(sub->getOperand(1))) + if (a0->getOperand(0)->getType()->isIntegerTy(1)) { + + Value *tmp = nullptr; + if (isa(a0)) + tmp = B.CreateSExt(a0->getOperand(0), a0->getType()); + else if (isa(a0)) + tmp = B.CreateZExt(a0->getOperand(0), a0->getType()); + else + assert(0); + tmp = pushcse(tmp); + replaceAndErase(cur, tmp); + return "NegSZExtI1"; + } + + // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 + // divides C1 + if ((cur->getOpcode() == Instruction::LShr || + cur->getOpcode() == Instruction::SDiv || + cur->getOpcode() == Instruction::UDiv) && + cur->isExact()) + if (auto C2 = dyn_cast(cur->getOperand(1))) + if (auto mul = dyn_cast(cur->getOperand(0))) + if (mul->getOpcode() == Instruction::Mul) + for (int i0 = 0; i0 < 2; i0++) + if (auto C1 = dyn_cast(mul->getOperand(i0))) { + auto lhs = C1->getValue(); + APInt rhs = C2->getValue(); + if (cur->getOpcode() == Instruction::LShr) { + rhs = APInt(rhs.getBitWidth(), 1) << rhs; + } + + APInt div, rem; + if (cur->getOpcode() == Instruction::LShr || + cur->getOpcode() == Instruction::UDiv) + APInt::udivrem(lhs, rhs, div, rem); + else + APInt::sdivrem(lhs, rhs, div, rem); + if (rem == 0) { + auto res = B.CreateMul(mul->getOperand(1 - i0), + ConstantInt::get(cur->getType(), div), + "mdiv." + cur->getName(), + mul->hasNoUnsignedWrap(), + mul->hasNoSignedWrap()); + push(mul); + replaceAndErase(cur, res); + return "IMulDivConst"; + } + } + + // mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2) + if (cur->getOpcode() == Instruction::FMul) + if (cur->isFast()) + if (auto mul1 = dyn_cast(cur->getOperand(0))) + if (mul1->getOpcode() == Instruction::FMul && mul1->isFast()) + if (auto mul2 = dyn_cast(cur->getOperand(1))) + if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) { + for (auto i1 = 0; i1 < 2; i1++) + for (auto i2 = 0; i2 < 2; i2++) + if (isa(mul1->getOperand(i1))) + if (isa(mul2->getOperand(i2))) { + + auto n0 = pushcse( + B.CreateFMulFMF(mul1->getOperand(1 - i1), + mul2->getOperand(1 - i2), cur)); + auto n1 = pushcse(B.CreateFMulFMF( + mul1->getOperand(i1), mul2->getOperand(i2), cur)); + auto n2 = pushcse(B.CreateFMulFMF(n0, n1, cur)); + push(mul1); + push(mul2); + replaceAndErase(cur, n2); + return "MulMulConstConst"; + } + } + + // mul (mul a, const1), const2 -> mul a, (mul const1, const2) + if ((cur->getOpcode() == Instruction::FMul && cur->isFast()) || + cur->getOpcode() == Instruction::Mul) + for (auto i1 = 0; i1 < 2; i1++) + if (auto mul1 = dyn_cast(cur->getOperand(i1))) + if (((mul1->getOpcode() == Instruction::FMul && mul1->isFast())) || + mul1->getOpcode() == Instruction::FMul) + if (auto const2 = dyn_cast(cur->getOperand(1 - i1))) + for (auto i2 = 0; i2 < 2; i2++) + if (auto const1 = dyn_cast(mul1->getOperand(i2))) { + Value *res = nullptr; + if (cur->getOpcode() == Instruction::FMul) { + auto const3 = pushcse(B.CreateFMulFMF(const1, const2, mul1)); + res = pushcse( + B.CreateFMulFMF(mul1->getOperand(1 - i2), const3, cur)); + } else { + auto const3 = pushcse(B.CreateMul(const1, const2)); + res = pushcse(B.CreateMul(mul1->getOperand(1 - i2), const3)); + } + push(mul1); + replaceAndErase(cur, res); + return "MulConstConst"; + } + + if (auto fcmp = dyn_cast(cur)) { + if (fcmp->getPredicate() == FCmpInst::FCMP_OEQ) { + for (int i = 0; i < 2; i++) + if (auto C = dyn_cast(fcmp->getOperand(i))) { + if (C->isZero()) { + if (auto fmul = dyn_cast(fcmp->getOperand(1 - i))) { + // (a*b) == 0 -> (a == 0) || (b == 0) + if (fmul->getOpcode() == Instruction::FMul) { + auto ncmp1 = pushcse( + B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C)); + auto ncmp2 = pushcse( + B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(1), C)); + auto ori = pushcse(B.CreateOr(ncmp1, ncmp2)); + replaceAndErase(cur, ori); + return "CmpFMulSplit"; + } + // (a/b) == 0 -> (a == 0) + if (fmul->getOpcode() == Instruction::FDiv) { + auto ncmp1 = pushcse( + B.CreateFCmp(fcmp->getPredicate(), fmul->getOperand(0), C)); + replaceAndErase(cur, ncmp1); + return "CmpFDivSplit"; + } + // (a - b) ?= 0 -> a ?= b + if (fmul->getOpcode() == Instruction::FSub) { + auto ncmp1 = pushcse(B.CreateFCmp(fcmp->getPredicate(), + fmul->getOperand(0), + fmul->getOperand(1))); + replaceAndErase(cur, ncmp1); + return "CmpFSubSplit"; + } + } + if (auto cast = dyn_cast(fcmp->getOperand(1 - i))) { + auto ncmp1 = pushcse(B.CreateICmp( + ICmpInst::ICMP_EQ, cast->getOperand(0), + ConstantInt::get(cast->getOperand(0)->getType(), 0))); + replaceAndErase(cur, ncmp1); + return "SFCmpToICmp"; + } + if (auto cast = dyn_cast(fcmp->getOperand(1 - i))) { + auto ncmp1 = pushcse(B.CreateICmp( + ICmpInst::ICMP_EQ, cast->getOperand(0), + ConstantInt::get(cast->getOperand(0)->getType(), 0))); + replaceAndErase(cur, ncmp1); + return "UFCmpToICmp"; + } + if (auto SI = dyn_cast(fcmp->getOperand(1 - i))) { + auto res = pushcse( + B.CreateSelect(SI->getCondition(), + pushcse(B.CreateCmp(fcmp->getPredicate(), C, + SI->getTrueValue())), + pushcse(B.CreateCmp(fcmp->getPredicate(), C, + SI->getFalseValue())))); + replaceAndErase(cur, res); + return "FCmpSelect"; + } + } + } + } + } + if (auto fcmp = dyn_cast(cur)) { + if (fcmp->getPredicate() == CmpInst::ICMP_EQ || + fcmp->getPredicate() == CmpInst::ICMP_NE || + fcmp->getPredicate() == CmpInst::FCMP_OEQ || + fcmp->getPredicate() == CmpInst::FCMP_ONE) { + + // a + c ?= a -> c ?= 0 , if fast + for (int i = 0; i < 2; i++) + if (auto inst = dyn_cast(fcmp->getOperand(i))) + if (inst->getOpcode() == Instruction::FAdd && inst->isFast()) + for (int i2 = 0; i2 < 2; i2++) + if (inst->getOperand(i2) == fcmp->getOperand(1 - i)) { + auto res = pushcse( + B.CreateCmp(fcmp->getPredicate(), inst->getOperand(1 - i2), + ConstantFP::get(inst->getType(), 0))); + replaceAndErase(cur, res); + return "CmpFAddSame"; + } + + // a == b -> a & b | !a & !b + // a != b -> a & !b | !a & b + if (fcmp->getOperand(0)->getType()->isIntegerTy(1)) { + auto a = fcmp->getOperand(0); + auto b = fcmp->getOperand(1); + if (fcmp->getPredicate() == CmpInst::ICMP_EQ) { + auto res = pushcse( + B.CreateOr(pushcse(B.CreateAnd(a, b)), + pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), + pushcse(B.CreateNot(b)))))); + replaceAndErase(cur, res); + return "CmpI1EQ"; + } + if (fcmp->getPredicate() == CmpInst::ICMP_NE) { + auto res = pushcse( + B.CreateOr(pushcse(B.CreateAnd(pushcse(B.CreateNot(a)), b)), + pushcse(B.CreateAnd(a, pushcse(B.CreateNot(b)))))); + replaceAndErase(cur, res); + return "CmpI1NE"; } } @@ -3645,15 +4603,43 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (cur->getOpcode() == Instruction::And) { auto lhs = replace(cur->getOperand(0), cur->getOperand(1), ConstantInt::getTrue(cur->getContext())); + if (lhs != cur->getOperand(0)) { + auto res = pushcse( + B.CreateAnd(lhs, cur->getOperand(1), "postand." + cur->getName())); + replaceAndErase(cur, res); + return "AndReplaceLHS"; + } auto rhs = replace(cur->getOperand(1), cur->getOperand(0), ConstantInt::getTrue(cur->getContext())); - if (lhs != cur->getOperand(0) || rhs != cur->getOperand(1)) { - auto res = pushcse(B.CreateAnd(lhs, rhs, "postand." + cur->getName())); + if (rhs != cur->getOperand(1)) { + auto res = pushcse( + B.CreateAnd(cur->getOperand(0), rhs, "postand." + cur->getName())); replaceAndErase(cur, res); - return "AndReplace"; + return "AndReplaceRHS"; } } + // or a, b -> or a b[with a false] + if (cur->getOpcode() == Instruction::Or) { + auto lhs = replace(cur->getOperand(0), cur->getOperand(1), + ConstantInt::getFalse(cur->getContext())); + if (lhs != cur->getOperand(0)) { + auto res = pushcse( + B.CreateOr(lhs, cur->getOperand(1), "postor." + cur->getName())); + replaceAndErase(cur, res); + return "OrReplaceLHS"; + } + auto rhs = replace(cur->getOperand(1), cur->getOperand(0), + ConstantInt::getFalse(cur->getContext())); + if (rhs != cur->getOperand(1)) { + auto res = pushcse( + B.CreateOr(cur->getOperand(0), rhs, "postor." + cur->getName())); + replaceAndErase(cur, res); + return "OrReplaceRHS"; + } + } + + /* // and (i == c), (i != d) -> and (i == c) && (c != d) if (cur->getOpcode() == Instruction::And) { auto lhs = replace(cur->getOperand(0), cur->getOperand(1), @@ -3663,9 +4649,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (lhs != cur->getOperand(0) || rhs != cur->getOperand(1)) { auto res = pushcse(B.CreateAnd(lhs, rhs, "postand." + cur->getName())); replaceAndErase(cur, res); - return "AndReplace"; + return "AndReplace2"; } } + */ // and a, (or q, (not a)) -> and a q if (cur->getOpcode() == Instruction::And) { @@ -3799,6 +4786,24 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } + // (a | b) == 0 -> a == 0 & b == 0 + if (auto icmp = dyn_cast(cur)) + if (icmp->getPredicate() == ICmpInst::ICMP_EQ && + cur->getType()->isIntegerTy(1)) + for (int i = 0; i < 2; i++) + if (auto C = dyn_cast(icmp->getOperand(i))) + if (C->isZero()) + if (auto z = dyn_cast(icmp->getOperand(1 - i))) + if (z->getOpcode() == BinaryOperator::Or) { + auto a0 = pushcse(B.CreateICmpEQ(z->getOperand(0), C)); + auto b0 = pushcse(B.CreateICmpEQ(z->getOperand(1), C)); + auto res = pushcse(B.CreateAnd(a0, b0)); + push(z); + push(icmp); + replaceAndErase(cur, res); + return "OrEQZero"; + } + // add (mul a b), (mul c, b) -> mul (add a, c), b if (cur->getOpcode() == Instruction::Sub || cur->getOpcode() == Instruction::Add) { @@ -3858,13 +4863,19 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (cur->getOpcode() == Instruction::FSub || cur->getOpcode() == Instruction::FAdd || cur->getOpcode() == Instruction::FMul || - cur->getOpcode() == Instruction::FNeg) { + cur->getOpcode() == Instruction::FNeg || + (isSum(cur) && callOperands(cast(cur)).size() == 2)) { + auto opcode = cur->getOpcode(); + if (isSum(cur)) + opcode = Instruction::FAdd; auto Ty = B.getInt64Ty(); SmallVector temporaries; SmallVector precasts; Value *lhs = nullptr; - Value *prelhs = cur->getOperand(0); + Value *prelhs = (cur->getOpcode() == Instruction::FNeg) + ? ConstantFP::get(cur->getType(), 0.0) + : cur->getOperand(0); Value *prerhs = (cur->getOpcode() == Instruction::FNeg) ? cur->getOperand(0) : cur->getOperand(1); @@ -3914,7 +4925,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (ext->getOperand(0)->getType() == Ty) lhs = ext->getOperand(0); else if (ity->getBitWidth() < Ty->getBitWidth()) { - lhs = B.CreateZExt(ext->getOperand(0), Ty); + if (ext->getOpcode() == Instruction::UIToFP) + lhs = B.CreateZExt(ext->getOperand(0), Ty); + else + lhs = B.CreateSExt(ext->getOperand(0), Ty); if (auto I = dyn_cast(lhs)) temporaries.push_back(I); } @@ -3930,10 +4944,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, &isExact); if (isExact || C->isZero()) { rhs = ConstantInt::get(Ty, Tmp); - switch (cur->getOpcode()) { + switch (opcode) { case Instruction::FAdd: - minval *= Tmp; - maxval *= Tmp; + minval += Tmp; + maxval += Tmp; break; case Instruction::FSub: case Instruction::FNeg: @@ -3941,8 +4955,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, maxval -= Tmp; break; case Instruction::FMul: - minval -= Tmp; - maxval -= Tmp; + minval *= Tmp; + maxval *= Tmp; break; default: llvm_unreachable("Illegal opcode"); @@ -3982,7 +4996,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, APInt::getSignedMinValue(ity->getBitWidth()).zextOrTrunc(64); } } - switch (cur->getOpcode()) { + switch (opcode) { case Instruction::FAdd: minval += rhsMin; maxval += rhsMax; @@ -4009,7 +5023,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (ext->getOperand(0)->getType() == Ty) rhs = ext->getOperand(0); else if (ity->getBitWidth() < Ty->getBitWidth()) { - rhs = B.CreateZExt(ext->getOperand(0), Ty); + if (ext->getOpcode() == Instruction::UIToFP) + rhs = B.CreateZExt(ext->getOperand(0), Ty); + else + rhs = B.CreateSExt(ext->getOperand(0), Ty); if (auto I = dyn_cast(rhs)) temporaries.push_back(I); } @@ -4018,7 +5035,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (lhs && rhs) { Value *res = nullptr; - switch (cur->getOpcode()) { + switch (opcode) { case Instruction::FAdd: res = B.CreateAdd(lhs, rhs, "", false, true); break; @@ -4032,6 +5049,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, default: llvm_unreachable("Illegal opcode"); } + res = pushcse(res); for (auto I : temporaries) push(I); for (auto I : precasts) @@ -4057,6 +5075,109 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } + // select(cond, const1, b) ?= const2 -> select(cond, const1 ?= const2, b ?= + // const2) + if (auto fcmp = dyn_cast(cur)) + for (int i = 0; i < 2; i++) + if (auto const2 = dyn_cast(fcmp->getOperand(i))) + if (auto sel = dyn_cast(fcmp->getOperand(1 - i))) + if (isa(sel->getTrueValue()) || + isa(sel->getFalseValue())) { + auto tval = pushcse(B.CreateFCmp(fcmp->getPredicate(), + sel->getTrueValue(), const2)); + auto fval = pushcse(B.CreateFCmp(fcmp->getPredicate(), + sel->getFalseValue(), const2)); + auto res = pushcse(B.CreateSelect(sel->getCondition(), tval, fval)); + replaceAndErase(cur, res); + return "FCmpSelectConst"; + } + + // mul (mul a, const), b:not_sparse_or_const -> mul (mul a, b), const + // note we avoid the case where b = (mul a, const) since otherwise + // we create an infinite recursion + // and also we make sure b isn't sparse, since sparse is the first + // precedence for pushing, then constant, then others + if (cur->getOpcode() == Instruction::FMul) + if (cur->isFast() && cur->getOperand(0) != cur->getOperand(1)) + for (auto ic = 0; ic < 2; ic++) + if (auto mul = dyn_cast(cur->getOperand(ic))) + if (mul->getOpcode() == Instruction::FMul && mul->isFast()) { + auto b = cur->getOperand(1 - ic); + if (!isa(b) && !directlySparse(b)) { + + for (int i = 0; i < 2; i++) + if (auto C = dyn_cast(mul->getOperand(i))) { + auto n0 = + pushcse(B.CreateFMulFMF(mul->getOperand(1 - i), b, mul)); + auto n1 = pushcse(B.CreateFMulFMF(n0, C, cur)); + push(mul); + + replaceAndErase(cur, n1); + return "MulMulConst"; + } + } + } + + // (mul c, a) +/- (mul c, b) -> mul c, (a +/- b) + if (cur->getOpcode() == Instruction::FAdd || + cur->getOpcode() == Instruction::FSub) { + if (auto mul1 = dyn_cast(cur->getOperand(0))) { + if (mul1->getOpcode() == Instruction::FMul && mul1->isFast()) { + if (auto mul2 = dyn_cast(cur->getOperand(1))) { + if (mul2->getOpcode() == Instruction::FMul && mul2->isFast()) { + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + if (mul1->getOperand(i) == mul2->getOperand(j)) { + auto c = mul1->getOperand(i); + auto a = mul1->getOperand(1 - i); + auto b = mul2->getOperand(1 - j); + Value *intermediate = nullptr; + + if (cur->getOpcode() == Instruction::FAdd) + intermediate = pushcse(B.CreateFAddFMF(a, b, cur)); + else + intermediate = pushcse(B.CreateFSubFMF(a, b, cur)); + + auto res = pushcse(B.CreateFMulFMF(c, intermediate, cur)); + push(mul1); + push(mul2); + replaceAndErase(cur, res); + return "FAddMulConstMulConst"; + } + } + } + } + } + } + } + } + + // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), (sitofp + // b) + + if (cur->getOpcode() == Instruction::FMul && cur->isFast()) { + for (int i = 0; i < 2; i++) + if (auto z = dyn_cast(cur->getOperand(i))) + if (isa(z) || isa(z)) + if (auto imul = dyn_cast(z->getOperand(0))) + if (imul->getOpcode() == Instruction::Mul) + for (int j = 0; j < 2; j++) + if (auto c = dyn_cast(imul->getOperand(j))) { + auto b = imul->getOperand(1 - j); + auto a = cur->getOperand(1 - i); + + auto c_fp = pushcse(B.CreateSIToFP(c, cur->getType())); + auto b_fp = pushcse(B.CreateSIToFP(b, cur->getType())); + auto n_mul = pushcse(B.CreateFMulFMF(a, c_fp, cur)); + auto res = pushcse( + B.CreateFMulFMF(n_mul, b_fp, cur, cur->getName())); + push(imul); + push(z); + replaceAndErase(cur, res); + return "FMulIMulConstRotate"; + } + } + if (cur->getOpcode() == Instruction::FDiv) { Value *prelhs = cur->getOperand(0); Value *b = cur->getOperand(1); @@ -4085,20 +5206,31 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return "FDivSIToFPProp"; } } - // fdiv (select c, 0, a), b -> select c, 0 (fdiv a, b) if (auto SI = dyn_cast(prelhs)) { auto tvalC = dyn_cast(SI->getTrueValue()); auto fvalC = dyn_cast(SI->getFalseValue()); if ((tvalC && tvalC->isZero()) || (fvalC && fvalC->isZero())) { push(SI); - auto ntval = (tvalC && tvalC->isZero()) - ? tvalC - : pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur)); + auto ntval = + (tvalC && tvalC->isZero()) + ? tvalC + : pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur, + "sfdiv2_t." + cur->getName())); auto nfval = (fvalC && fvalC->isZero()) ? fvalC - : pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur)); + : pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur, + "sfdiv2_f." + cur->getName())); + + // Work around bad fdivfmf, fixed in LLVM 16+ + // https://github.com/llvm/llvm-project/commit/4f3b1c6dd6ef6c7b5bb79f058e3b7ba4bcdf4566 +#if LLVM_VERSION_MAJOR < 16 + for (auto v : {ntval, nfval}) + if (auto I = dyn_cast(v)) + I->setFastMathFlags(cur->getFastMathFlags()); +#endif + auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval, "sfdiv2." + cur->getName())); @@ -4108,12 +5240,107 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } + // div (mul a:not_sparse, b:is_sparse), c -> mul (div, a, c), b:is_sparse + if (cur->getOpcode() == Instruction::FDiv) { + auto c = cur->getOperand(1); + if (auto z = dyn_cast(cur->getOperand(0))) { + if (z->getOpcode() == Instruction::FMul) { + for (int i = 0; i < 2; i++) { + + Value *a = z->getOperand(i); + Value *b = z->getOperand(1 - i); + if (directlySparse(a)) + continue; + if (!directlySparse(b)) + continue; + + Value *inner_fdiv = pushcse(B.CreateFDivFMF(a, c, cur)); + Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fdiv, b, z)); + push(z); + replaceAndErase(cur, outer_fmul); + return "FDivFMulSparseProp"; + } + } + } + } + if (cur->getOpcode() == Instruction::FMul) for (int i = 0; i < 2; i++) { Value *prelhs = cur->getOperand(i); Value *b = cur->getOperand(1 - i); + // fmul (fmul x:constant, y):z, b:constant . + if (isa(b)) + if (auto z = dyn_cast(prelhs)) { + if (z->getOpcode() == Instruction::FMul) { + for (int j = 0; j < 2; j++) { + auto x = z->getOperand(i); + if (!isa(x)) + continue; + auto y = z->getOperand(1 - i); + Value *inner_fmul = pushcse(B.CreateFMulFMF(x, b, cur)); + Value *outer_fmul = pushcse(B.CreateFMulFMF(inner_fmul, y, z)); + push(z); + replaceAndErase(cur, outer_fmul); + return "FMulFMulConstantReorder"; + } + } + } + + auto integralFloat = [](Value *z) { + if (auto C = dyn_cast(z)) { + APSInt Tmp(64); + bool isExact = false; + C->getValue().convertToInteger(Tmp, llvm::RoundingMode::TowardZero, + &isExact); + if (isExact || C->isZero()) { + return true; + } + } + return false; + }; + + // fmul (fmul x:sparse, y):z, b + // 1) If x and y are both sparse, do nothing and let the inner fmul be + // simplified into a single sparse instruction. Thus, we may assume + // y is not sparse. + // 2) if b is sparse, swap it to be fmul (fmul x, b), y so the inner + // sparsity can be simplified. + // 3) otherwise b is not sparse and we should push the sparsity to + // be the outermost value + if (auto z = dyn_cast(prelhs)) { + if (z->getOpcode() == Instruction::FMul) { + for (int j = 0; j < 2; j++) { + auto x = z->getOperand(j); + if (!directlySparse(x)) + continue; + auto y = z->getOperand(1 - j); + if (directlySparse(y)) + continue; + + if (directlySparse(b) || integralFloat(b)) { + push(z); + Value *inner_fmul = pushcse( + B.CreateFMulFMF(x, b, cur, "mulisr." + cur->getName())); + Value *outer_fmul = pushcse( + B.CreateFMulFMF(inner_fmul, y, z, "mulisr." + z->getName())); + replaceAndErase(cur, outer_fmul); + return "FMulFMulSparseReorder"; + } else { + push(z); + Value *inner_fmul = pushcse( + B.CreateFMulFMF(y, b, cur, "mulisp." + cur->getName())); + Value *outer_fmul = pushcse( + B.CreateFMulFMF(inner_fmul, x, z, "mulisp." + z->getName())); + replaceAndErase(cur, outer_fmul); + return "FMulFMulSparsePush"; + } + } + } + } + + /* auto contains = [](MDNode *MD, Value *V) { if (!MD) return false; @@ -4126,7 +5353,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, }; // fmul (sitofp a), b -> select (a == 0), 0 [noprop fmul ( sitofp a), b] - if (!contains(hasMetadata(cur, "enzyme_fmulnoprop"), prelhs)) + if (true || !contains(hasMetadata(cur, "enzyme_fmulnoprop"), prelhs)) if (auto ext = dyn_cast(prelhs)) { if (ext->getOpcode() == Instruction::UIToFP || ext->getOpcode() == Instruction::SIToFP) { @@ -4158,6 +5385,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return "FMulSIToFPProp"; } } + */ // fmul (select c, 0, a), b -> select c, 0 (fmul a, b) if (auto SI = dyn_cast(prelhs)) { @@ -4826,6 +6054,107 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, class Constraints; raw_ostream &operator<<(raw_ostream &os, const Constraints &c); + +struct ConstraintComparator { + bool operator()(std::shared_ptr lhs, + std::shared_ptr rhs) const; +}; + +struct ConstraintContext { + ScalarEvolution &SE; + const Loop *loopToSolve; + const SmallVectorImpl &Assumptions; + DominatorTree &DT; + using InnerTy = std::shared_ptr; + using SetTy = std::set; + SetTy seen; + ConstraintContext(ScalarEvolution &SE, const Loop *loopToSolve, + const SmallVectorImpl &Assumptions, + DominatorTree &DT) + : SE(SE), loopToSolve(loopToSolve), Assumptions(Assumptions), DT(DT) { + assert(loopToSolve); + } + ConstraintContext(const ConstraintContext &) = delete; + ConstraintContext(const ConstraintContext &ctx, InnerTy lhs) + : SE(ctx.SE), loopToSolve(ctx.loopToSolve), Assumptions(ctx.Assumptions), + DT(ctx.DT), seen(ctx.seen) { + seen.insert(lhs); + } + ConstraintContext(const ConstraintContext &ctx, InnerTy lhs, InnerTy rhs) + : SE(ctx.SE), loopToSolve(ctx.loopToSolve), Assumptions(ctx.Assumptions), + DT(ctx.DT), seen(ctx.seen) { + seen.insert(lhs); + seen.insert(rhs); + } + bool contains(InnerTy x) const { return seen.count(x) != 0; } +}; + +bool cannotDependOnLoopIV(const SCEV *S, const Loop *L) { + assert(L); + if (isa(S)) + return true; + if (auto M = dyn_cast(S)) { + for (auto o : M->operands()) + if (!cannotDependOnLoopIV(o, L)) + return false; + return true; + } + if (auto M = dyn_cast(S)) { + for (auto o : M->operands()) + if (!cannotDependOnLoopIV(o, L)) + return false; + return true; + } + if (auto M = dyn_cast(S)) { + for (auto o : {M->getLHS(), M->getRHS()}) + if (!cannotDependOnLoopIV(o, L)) + return false; + return true; + } + if (auto UV = dyn_cast(S)) { + auto U = UV->getValue(); + if (isa(U)) + return true; + if (isa(U)) + return true; + auto I = cast(U); + return !L->contains(I->getParent()); + } + if (auto addrec = dyn_cast(S)) { + if (addrec->getLoop() == L) + return false; + for (auto o : addrec->operands()) + if (!cannotDependOnLoopIV(o, L)) + return false; + return true; + } + llvm::errs() << " cannot tell if depends on loop iv: " << *S << "\n"; + return false; +} + +const SCEV *evaluateAtLoopIter(const SCEV *V, ScalarEvolution &SE, + const Loop *find, const SCEV *replace) { + assert(find); + if (cannotDependOnLoopIV(V, find)) + return V; + if (auto addrec = dyn_cast(V)) { + if (addrec->getLoop() == find) { + auto V2 = addrec->evaluateAtIteration(replace, SE); + return evaluateAtLoopIter(V2, SE, find, replace); + } + } + if (auto div = dyn_cast(V)) { + auto lhs = evaluateAtLoopIter(div->getLHS(), SE, find, replace); + if (!lhs) + return nullptr; + auto rhs = evaluateAtLoopIter(div->getRHS(), SE, find, replace); + if (!rhs) + return nullptr; + return SE.getUDivExpr(lhs, rhs); + } + return nullptr; +} + class Constraints : public std::enable_shared_from_this { public: const enum class Type { @@ -4838,41 +6167,6 @@ class Constraints : public std::enable_shared_from_this { using InnerTy = std::shared_ptr; - struct ConstraintComparator { - bool operator()(InnerTy lhs, InnerTy rhs) const { - if (lhs->ty < rhs->ty) - return true; - else if (lhs->ty > rhs->ty) - return false; - - if (lhs->node < rhs->node) - return true; - else if (lhs->node > rhs->node) - return false; - - if (lhs->isEqual < rhs->isEqual) - return true; - else if (lhs->isEqual > rhs->isEqual) - return false; - - return lhs->values < rhs->values; - /* - auto lhss = lhs->values.size(); - auto rhss = rhs->values.size(); - if (lhss < rhss) - return true; - else if (lhss > rhss) - return false; - for (int i=0; ioperator()(lhs->values[i], rhs->values[i])) - return true; - if (this->operator()(rhs->values[i], lhs->values[i])) - return false; - } - return false; - */ - } - }; using SetTy = std::set; const SetTy values; @@ -4880,28 +6174,57 @@ class Constraints : public std::enable_shared_from_this { const SCEV *const node; // whether equal to the node, or not equal to the node bool isEqual; + // the loop of the iv comparing against. + const llvm::Loop *const Loop; // using SetTy = SmallVector; // using SetTy = SetVector, // std::set>; - Constraints() : ty(Type::Union), values(), node(nullptr), isEqual(false) {} + Constraints() + : ty(Type::Union), values(), node(nullptr), isEqual(false), + Loop(nullptr) {} + +private: + Constraints(const SCEV *v, bool isEqual, const llvm::Loop *Loop, bool) + : ty(Type::Compare), values(), node(v), isEqual(isEqual), Loop(Loop) {} + +public: + static InnerTy make_compare(const SCEV *v, bool isEqual, + const llvm::Loop *Loop, + const ConstraintContext &ctx); - Constraints(const SCEV *v, bool isEqual) - : ty(Type::Compare), values(), node(v), isEqual(isEqual) {} - Constraints(Type t) : ty(t), values(), node(nullptr), isEqual(false) { + Constraints(Type t) + : ty(t), values(), node(nullptr), isEqual(false), Loop(nullptr) { assert(t == Type::All || t == Type::None); } - Constraints(Type t, const SetTy &c) - : ty(t), values(c), node(nullptr), isEqual(false) { + Constraints(Type t, const SetTy &c, bool check = true) + : ty(t), values(c), node(nullptr), isEqual(false), Loop(nullptr) { assert(t != Type::All); assert(t != Type::None); assert(c.size() != 0); assert(c.size() != 1); - /* - for (int i=0; i tmp(c.begin(), c.end()); + for (unsigned i = 0; i < tmp.size(); i++) + for (unsigned j = 0; j < i; j++) + assert(*tmp[i] != *tmp[j]); + if (t == Type::Intersect) { + for (auto &v : c) { + assert(v->ty != Type::Intersect); + } + } + if (t == Type::Union) { + for (auto &v : c) { + assert(v->ty != Type::Union); + } + } + if (t == Type::Intersect && check) { + for (unsigned i = 0; i < tmp.size(); i++) + if (tmp[i]->ty == Type::Compare && tmp[i]->isEqual && tmp[i]->Loop) + for (unsigned j = 0; j < tmp.size(); j++) + if (tmp[j]->ty == Type::Compare) + if (auto s = dyn_cast(tmp[j]->node)) + assert(s->getLoop() != tmp[i]->Loop); + } } bool operator==(const Constraints &rhs) const { @@ -4914,6 +6237,9 @@ class Constraints : public std::enable_shared_from_this { if (isEqual != rhs.isEqual) { return false; } + if (Loop != rhs.Loop) { + return false; + } if (values.size() != rhs.values.size()) { return false; } @@ -4949,6 +6275,12 @@ return true; if (isEqual > rhs.isEqual) { return false; } + if (Loop < rhs.Loop) { + return true; + } + if (Loop > rhs.Loop) { + return false; + } if (values.size() < rhs.values.size()) { return true; } @@ -4961,17 +6293,12 @@ return true; if (*std::get<0>(pair) > *std::get<1>(pair)) return false; } - return true; - //) && !(rhs.values < values) - /* -for (size_t i=0; i::getHashValue(node) + isEqual; + res = llvm::detail::combineHashValue(res, (unsigned)(size_t)Loop); for (auto v : values) res = llvm::detail::combineHashValue(res, v->hash()); return res; @@ -4989,6 +6316,11 @@ return true; bool isAll() const { return ty == Type::All; } static void insert(SetTy &set, InnerTy ty) { set.insert(ty); + int mcount = 0; + for (auto &v : set) + if (*v == *ty) + mcount++; + assert(mcount == 1); /* for (auto &v : set) if (*v == *ty) @@ -4996,6 +6328,13 @@ return true; set.push_back(ty); */ } + static SetTy intersect(const SetTy &lhs, const SetTy &rhs) { + SetTy res; + for (auto &v : lhs) + if (rhs.count(v)) + res.insert(v); + return res; + } static void set_subtract(SetTy &set, const SetTy &rhs) { for (auto &v : rhs) if (set.count(v)) @@ -5010,125 +6349,51 @@ return true; } */ } - InnerTy notB() const { + __attribute__((noinline)) void dump() const { llvm::errs() << *this << "\n"; } + InnerTy notB(const ConstraintContext &ctx) const { switch (ty) { case Type::None: return Constraints::all(); case Type::All: return Constraints::none(); case Type::Compare: - return std::make_shared(node, !isEqual); + return make_compare(node, !isEqual, Loop, ctx); case Type::Union: { - // not of or's is and or not's + // not of or's is and of not's SetTy next; for (const auto &v : values) - insert(next, v->notB()); + insert(next, v->notB(ctx)); if (next.size() == 1) llvm::errs() << " uold : " << *this << "\n"; return std::make_shared(Type::Intersect, next); } case Type::Intersect: { - // not of and's is or or not's + // not of and's is or of not's SetTy next; for (const auto &v : values) - insert(next, v->notB()); + insert(next, v->notB(ctx)); if (next.size() == 1) llvm::errs() << " old : " << *this << "\n"; return std::make_shared(Type::Union, next); } - } - return Constraints::none(); - } - InnerTy orB(InnerTy rhs, ScalarEvolution &SE) const { - return notB()->andB(rhs->notB(), SE)->notB(); - /* - if (*rhs == *this) return shared_from_this(); - if (rhs->isNone()) return shared_from_this(); - if (rhs->isAll()) return rhs; - if (isNone()) return rhs; - if (isAll()) return shared_from_this(); - - if (ty == Type::Compare && rhs->ty == Type::Compare) { - auto sub = SE.getMinusSCEV(node, rhs->node); - if (auto cst = dyn_cast(sub)) { - // the two solves are equivalent to each other - if (cst->getAPInt().isZero()) { - // iv = a or iv = a - // also iv != a or iv != a - if (isEqual == rhs->isEqual) - return shared_from_this(); - else { - // iv = a or iv != a - return Constraints::all(); - } - } else { - // the two solves are guaranteed to be distinct - // iv == 0 or iv == 1 - if (isEqual && rhs->isEqual) { - SetTy vals; - insert(vals, shared_from_this()); - insert(vals, rhs); - return std::make_shared(Type::Union, - vals); } else if (!isEqual && !rhs->isEqual) { - // iv != 0 or iv != 1 - return Constraints::all(); - } else if (!isEqual) { - assert(rhs->isEqual); - // iv != 0 or iv == 1 - return shared_from_this(); - } else { - assert(isEqual); - assert(!rhs->isEqual); - return rhs; - } - } - } - SetTy vals; - insert(vals, shared_from_this()); - insert(vals, rhs); - return std::make_shared(Type::Union, vals); - } - if (ty == Type::Union && rhs->ty == Type::Union) { - SetTy vals = values; - for (const auto &v : rhs->values) - insert(vals, v); - return std::make_shared(Type::Union, vals); - } - if (rhs->ty == Type::Union) { - SetTy vals = rhs->values; - insert(vals, shared_from_this()); - return std::make_shared(Type::Union, vals); - } - if (ty == Type::Union) { - SetTy vals = values; - insert(vals, rhs); - return std::make_shared(Type::Union, vals); - } - // (m and a and b and d) or (m and a and c and e ...) -> m and a and - ( (b and d) or (c and e)) if (ty == Type::Intersect && rhs->ty == - Type::Intersect) { SetTy intersection = values; - set_subtract(intersection, rhs->values); - if (intersection.size() != 0) { - InnerTy other_lhs = remove(intersection); - InnerTy other_rhs = rhs->remove(intersection); - InnerTy remainder; - if (intersection.size() == 1) - remainder = intersection[0]; - else { - remainder = - std::make_shared(Type::Intersect, intersection); - } - return remainder->andB(other_lhs->orB(other_rhs, SE), SE); - } - SetTy vals; - insert(vals, shared_from_this()); - insert(vals, rhs); - return std::make_shared(Type::Union, vals); - } - llvm_unreachable("Illegal predicate state"); - */ + } + return Constraints::none(); + } + InnerTy orB(InnerTy rhs, const ConstraintContext &ctx) const { + auto notLHS = notB(ctx); + if (!notLHS) + return nullptr; + auto notRHS = rhs->notB(ctx); + if (!notRHS) + return nullptr; + auto andV = notLHS->andB(notRHS, ctx); + if (!andV) + return nullptr; + auto res = andV->notB(ctx); + return res; } - InnerTy andB(const InnerTy rhs, ScalarEvolution &SE) const { + InnerTy andB(const InnerTy rhs, const ConstraintContext &ctx) const { + assert(rhs); if (*rhs == *this) return shared_from_this(); if (rhs->isNone()) @@ -5140,69 +6405,206 @@ return true; if (isAll()) return rhs; + // llvm::errs() << " anding: " << *this << " with " << *rhs << "\n"; + if (ctx.contains(shared_from_this()) || ctx.contains(rhs)) { + // llvm::errs() << " %%% stopping recursion\n"; + return nullptr; + } if (ty == Type::Compare && rhs->ty == Type::Compare) { - auto sub = SE.getMinusSCEV(node, rhs->node); - if (auto cst = dyn_cast(sub)) { - // the two solves are equivalent to each other - if (cst->getValue()->isZero()) { - // iv = a and iv = a - // also iv != a and iv != a - if (isEqual == rhs->isEqual) - return shared_from_this(); - else { - // iv = a and iv != a - return Constraints::none(); - } - } else { - // the two solves are guaranteed to be distinct - // iv == 0 and iv == 1 - if (isEqual && rhs->isEqual) { - return Constraints::none(); - - } else if (!isEqual && !rhs->isEqual) { - // iv != 0 and iv != 1 - SetTy vals; - insert(vals, shared_from_this()); - insert(vals, rhs); - return std::make_shared(Type::Intersect, vals); - } else if (!isEqual) { - assert(rhs->isEqual); - // iv != 0 and iv == 1 - return rhs; - ; + auto sub = ctx.SE.getMinusSCEV(node, rhs->node); + if (Loop == rhs->Loop) { + // llvm::errs() << " + sameloop, sub=" << *sub << "\n"; + if (auto cst = dyn_cast(sub)) { + // the two solves are equivalent to each other + if (cst->getValue()->isZero()) { + // iv = a and iv = a + // also iv != a and iv != a + if (isEqual == rhs->isEqual) + return shared_from_this(); + else { + // iv = a and iv != a + return Constraints::none(); + } } else { - // iv == 0 and iv != 1 - assert(isEqual); - assert(!rhs->isEqual); - return shared_from_this(); + // the two solves are guaranteed to be distinct + // iv == 0 and iv == 1 + if (isEqual && rhs->isEqual) { + return Constraints::none(); + + } else if (!isEqual && !rhs->isEqual) { + // iv != 0 and iv != 1 + SetTy vals; + insert(vals, shared_from_this()); + insert(vals, rhs); + return std::make_shared(Type::Intersect, vals); + } else if (!isEqual) { + assert(rhs->isEqual); + // iv != 0 and iv == 1 + return rhs; + ; + } else { + // iv == 0 and iv != 1 + assert(isEqual); + assert(!rhs->isEqual); + return shared_from_this(); + } + } + } else if (isEqual || rhs->isEqual) { + // llvm::errs() << " + botheq\n"; + // eq(i, a) & i ?= b -> eq(i, a) & (a ?= b) + if (auto addrec = dyn_cast(sub)) { + // we want a ?= b, but we can only represent loopvar ?= something + // so suppose a-b is of the form X + Y * lv then a-b ?= 0 is + // X + Y * lv ?= 0 -> lv ?= - X / Y + if (addrec->isAffine()) { + auto X = addrec->getStart(); + auto Y = addrec->getStepRecurrence(ctx.SE); + auto MinusX = X; + + if (isa(Y) && + cast(Y)->getAPInt().isNegative()) + Y = ctx.SE.getNegativeSCEV(Y); + else + MinusX = ctx.SE.getNegativeSCEV(X); + + auto div = ctx.SE.getUDivExpr(MinusX, Y); + auto div_e = ctx.SE.getUDivExactExpr(MinusX, Y); + // in case of inexact division, check that these exactly equal for + // replacement + + if (div == div_e) { + if (isEqual) { + auto res = make_compare(div, /*isEqual*/ rhs->isEqual, + addrec->getLoop(), ctx); + // llvm::errs() << " simplified rhs to: " << *res << "\n"; + return andB(res, ctx); + } else { + assert(rhs->isEqual); + auto res = make_compare(div, /*isEqual*/ isEqual, + addrec->getLoop(), ctx); + // llvm::errs() << " simplified lhs to: " << *res << "\n"; + return rhs->andB(res, ctx); + } + } + } + } + if (isEqual && rhs->Loop && + cannotDependOnLoopIV(sub, ctx.loopToSolve)) { + auto res = make_compare(sub, /*isEqual*/ rhs->isEqual, + /*loop*/ nullptr, ctx); + // llvm::errs() << " simplified(noloop) rhs from " << *rhs + // << " to: " << *res << "\n"; + return andB(res, ctx); + } + if (rhs->isEqual && Loop && + cannotDependOnLoopIV(sub, ctx.loopToSolve)) { + auto res = + make_compare(sub, /*isEqual*/ isEqual, /*loop*/ nullptr, ctx); + // llvm::errs() << " simplified(noloop) lhs from " << *rhs + // << " to: " << *res << "\n"; + return rhs->andB(res, ctx); + } + + llvm::errs() << " warning: potential but unhandled simplification of " + "equalities: " + << *this << " and " << *rhs << " sub: " << *sub << "\n"; + } + } + + if (isEqual) { + if (Loop) + if (auto rep = evaluateAtLoopIter(rhs->node, ctx.SE, Loop, node)) + if (rep != rhs->node) { + auto newrhs = make_compare(rep, rhs->isEqual, rhs->Loop, ctx); + return andB(newrhs, ctx); + } + + // not loop -> node == 0 + if (!Loop) { + for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node), + ctx.SE.getMinusSCEV(rhs->node, node)}) { + // llvm::errs() << " maybe replace lhs: " << *this << " rhs: " << + // *rhs + // << " sub1: " << *sub1 << "\n"; + auto newrhs = make_compare(sub1, rhs->isEqual, rhs->Loop, ctx); + if (*newrhs == *this) + return shared_from_this(); + if (!isa(rhs->node) && isa(sub1)) { + return andB(newrhs, ctx); + } + } + } + } + + if (rhs->isEqual) { + if (rhs->Loop) + if (auto rep = evaluateAtLoopIter(node, ctx.SE, rhs->Loop, rhs->node)) + if (rep != node) { + auto newlhs = make_compare(rep, isEqual, Loop, ctx); + return newlhs->andB(rhs, ctx); + } + + // not loop -> node == 0 + if (!rhs->Loop) { + for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node), + ctx.SE.getMinusSCEV(rhs->node, node)}) { + // llvm::errs() << " maybe replace lhs2: " << *this << " rhs: " << + // *rhs + // << " sub1: " << *sub1 << "\n"; + auto newlhs = make_compare(sub1, isEqual, Loop, ctx); + if (*newlhs == *this) + return shared_from_this(); + if (!isa(node) && isa(sub1)) { + return newlhs->andB(rhs, ctx); + } } } } + + if (!Loop && !rhs->Loop && isEqual == rhs->isEqual) { + if (node == ctx.SE.getNegativeSCEV(rhs->node)) + return shared_from_this(); + } + SetTy vals; insert(vals, shared_from_this()); insert(vals, rhs); - return std::make_shared(Type::Intersect, vals); + if (vals.size() == 1) { + llvm::errs() << "this: " << *this << " rhs: " << *rhs << "\n"; + } + auto res = std::make_shared(Type::Intersect, vals); + // llvm::errs() << " naiive comp merge: " << *res << "\n"; + return res; } if (ty == Type::Intersect && rhs->ty == Type::Intersect) { - SetTy vals = values; - for (const auto &v : rhs->values) - insert(vals, v); - return std::make_shared(Type::Intersect, vals); + auto tmp = shared_from_this(); + for (const auto &v : rhs->values) { + auto tmp2 = tmp->andB(v, ctx); + if (!tmp2) + return nullptr; + tmp = std::move(tmp2); + } + return tmp; } if (ty == Type::Intersect && rhs->ty == Type::Compare) { SetTy vals; // Force internal merging to do individual compares bool foldedIn = false; - for (const auto &v : values) { + for (auto en : llvm::enumerate(values)) { + auto i = en.index(); + auto v = en.value(); assert(v->ty != Type::Intersect); assert(v->ty != Type::All); assert(v->ty != Type::None); + assert(v->ty == Type::Compare || v->ty == Type::Union); if (foldedIn) { insert(vals, v); continue; } // this is either a compare or a union - auto tmp = rhs->andB(v, SE); + auto tmp = rhs->andB(v, ctx); + if (!tmp) + return nullptr; switch (tmp->ty) { case Type::Union: case Type::All: @@ -5215,62 +6617,137 @@ return true; break; // if intersected, these two were not foldable, try folding into later case Type::Intersect: { + SetTy fuse; + insert(fuse, rhs); + insert(fuse, v); + + Constraints trivialFuse(Type::Intersect, fuse, false); + + // If this is not just making an intersect of the two operands, + // remerge. + if (trivialFuse != *tmp) { + InnerTy newlhs = Constraints::all(); + bool legal = true; + for (auto en2 : llvm::enumerate(values)) { + auto i2 = en2.index(); + auto v2 = en2.value(); + if (i2 == i) + continue; + auto newlhs2 = newlhs->andB(v2, ctx); + if (!newlhs2) { + legal = false; + break; + } + newlhs = std::move(newlhs2); + } + if (legal) { + return newlhs->andB(tmp, ctx); + } + } insert(vals, v); } } } - if (!foldedIn) + if (!foldedIn) { insert(vals, rhs); - return std::make_shared(Type::Intersect, vals); + return std::make_shared(Type::Intersect, vals); + } else { + auto cur = Constraints::all(); + for (auto &iv : vals) { + auto cur2 = cur->andB(iv, ctx); + if (!cur2) + return nullptr; + cur = std::move(cur2); + } + return cur; + } } - if (ty == Type::Intersect && rhs->ty == Type::Union) { + if ((ty == Type::Intersect || ty == Type::Compare) && + rhs->ty == Type::Union) { SetTy unionVals = rhs->values; bool changed = false; - for (const auto &iv : values) { + SetTy ivVals; + if (ty == Type::Intersect) + ivVals = values; + else + insert(ivVals, shared_from_this()); + + ConstraintContext ctxd(ctx, shared_from_this(), rhs); + + for (const auto &iv : ivVals) { SetTy nextunionVals; + bool midchanged = false; for (auto &uv : unionVals) { - - auto tmp = iv->andB(uv, SE); + auto tmp = iv->andB(uv, ctxd); + if (!tmp) { + midchanged = false; + nextunionVals = unionVals; + break; + } switch (tmp->ty) { + case Type::None: case Type::Compare: case Type::Union: - case Type::None: insert(nextunionVals, tmp); - changed = true; + changed |= tmp != uv; break; - case Type::Intersect: + case Type::Intersect: { + SetTy fuse; + if (uv->ty == Type::Intersect) + fuse = uv->values; + else { + assert(uv->ty == Type::Compare); + insert(fuse, uv); + } + insert(fuse, iv); + + Constraints trivialFuse(Type::Intersect, fuse, false); + if (trivialFuse != *tmp) { + insert(nextunionVals, tmp); + midchanged = true; + break; + } + insert(nextunionVals, uv); break; + } case Type::All: llvm_unreachable("Impossible"); } } - unionVals = nextunionVals; + if (midchanged) { + unionVals = nextunionVals; + changed = true; + } } - auto cur = rhs; if (changed) { - cur = Constraints::all(); - for (auto uv : unionVals) - cur = cur->orB(uv, SE); + auto cur = Constraints::none(); + for (auto uv : unionVals) { + cur = cur->orB(uv, ctxd); + if (!cur) + break; + } - if (cur->ty != Type::Union) - return andB(cur, SE); + if (*cur != *rhs) + return andB(cur, ctx); } - SetTy vals = values; - insert(vals, cur); + SetTy vals = ivVals; + insert(vals, rhs); return std::make_shared(Type::Intersect, vals); } // Handled above via symmetry - if (rhs->ty == Type::Intersect) { - return rhs->andB(shared_from_this(), SE); + if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) { + return rhs->andB(shared_from_this(), ctx); } // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and // (c or e)) if (ty == Type::Union && rhs->ty == Type::Union) { - SetTy intersection = values; - set_subtract(intersection, rhs->values); + if (*this == *rhs->notB(ctx)) { + return Constraints::none(); + } + SetTy intersection = intersect(values, rhs->values); if (intersection.size() != 0) { InnerTy other_lhs = remove(intersection); InnerTy other_rhs = rhs->remove(intersection); @@ -5280,22 +6757,84 @@ return true; else { remainder = std::make_shared(Type::Union, intersection); } - return remainder->orB(other_lhs->andB(other_rhs, SE), SE); + return remainder->orB(other_lhs->andB(other_rhs, ctx), ctx); + } + + bool changed = false; + SetTy lhsVals = values; + SetTy rhsVals = rhs->values; + + ConstraintContext ctxd(ctx, shared_from_this(), rhs); + + SetTy distributedVals; + for (const auto &l1 : lhsVals) { + bool subchanged = false; + SetTy subDistributedVals; + for (auto &r1 : rhsVals) { + auto tmp = l1->andB(r1, ctxd); + if (!tmp) { + subchanged = false; + break; + } + + if (l1->ty == Type::Intersect || r1->ty == Type::Intersect) { + subchanged = true; + insert(subDistributedVals, tmp); + } else { + + SetTy fuse; + insert(fuse, l1); + insert(fuse, r1); + assert(fuse.size() == 2); + Constraints trivialFuse(Type::Intersect, fuse); + if ((trivialFuse != *tmp) || distributedVals.count(tmp)) { + subchanged = true; + } + insert(subDistributedVals, tmp); + } + } + if (subchanged) { + for (auto sub : subDistributedVals) + insert(distributedVals, sub); + changed = true; + } else { + auto midand = l1->andB(rhs, ctxd); + if (!midand) { + changed = false; + break; + } + insert(distributedVals, midand); + } + } + + if (changed) { + auto cur = Constraints::none(); + bool legal = true; + for (auto &uv : distributedVals) { + auto cur2 = cur->orB(uv, ctxd); + if (!cur2) { + legal = false; + break; + } + cur = std::move(cur2); + } + if (legal) { + return cur; + } } + SetTy vals; insert(vals, shared_from_this()); insert(vals, rhs); auto res = std::make_shared(Type::Intersect, vals); - llvm::errs() << " res: " << *res << "lhs: " << *this << " rhs " << *rhs - << " eq " << (*this == *rhs) << "\n"; return res; } + llvm::errs() << " andB this: " << *this << " rhs: " << *rhs << "\n"; llvm_unreachable("Illegal predicate state"); } // what this would be like when removing the following list of constraints InnerTy remove(const SetTy &sub) const { - assert(ty == Type::Union); - assert(ty == Type::Intersect); + assert(ty == Type::Union || ty == Type::Intersect); SetTy res = values; set_subtract(res, sub); // res.set_subtract(sub); @@ -5310,32 +6849,20 @@ return true; return std::make_shared(ty, res); } } - SmallVector allSolutions(SCEVExpander &Exp, llvm::Type *T, - Instruction *IP) const; - bool canEvaluateSolutions() const { - switch (ty) { - case Type::None: - return true; - case Type::All: - return false; - case Type::Compare: - if (isEqual) { - return true; - } - return false; - case Type::Union: { - for (auto v : values) - if (!v->canEvaluateSolutions()) - return false; - return true; - } - case Type::Intersect: - return false; - } - return false; - } + SmallVector, 1> + allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, + const ConstraintContext &ctx, IRBuilder<> &B) const; }; +void dump(const Constraints &c) { c.dump(); } +void dump(std::shared_ptr c) { c->dump(); } + +bool ConstraintComparator::operator()( + std::shared_ptr lhs, + std::shared_ptr rhs) const { + return *lhs < *rhs; +} + raw_ostream &operator<<(raw_ostream &os, const Constraints &c) { switch (c.ty) { case Constraints::Type::All: @@ -5357,46 +6884,294 @@ raw_ostream &operator<<(raw_ostream &os, const Constraints &c) { return os; } case Constraints::Type::Compare: { - if (c.isEqual) { - os << "(eq " << *c.node << ")"; - } else { - os << "(ne " << *c.node << ")"; - } - return os; + if (c.isEqual) + os << "(eq "; + else + os << "(ne "; + os << *c.node << ", L="; + if (c.Loop) + os << c.Loop->getHeader()->getName(); + else + os << "nullptr"; + return os << ")"; } } return os; } -SmallVector Constraints::allSolutions(SCEVExpander &Exp, - llvm::Type *T, - Instruction *IP) const { +SmallVector, 1> +Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, + const ConstraintContext &ctx, IRBuilder<> &B) const { switch (ty) { case Type::None: return {}; case Type::All: llvm::errs() << *this << "\n"; llvm_unreachable("All not handled"); - case Type::Compare: + case Type::Compare: { + Value *cond = ConstantInt::getTrue(T->getContext()); + if (ctx.loopToSolve != Loop) { + assert(ctx.loopToSolve); + Value *ivVal = Exp.expandCodeFor(node, T, IP); + Value *iv = nullptr; + if (Loop) { + iv = Loop->getCanonicalInductionVariable(); + assert(iv); + } else { + iv = ConstantInt::getNullValue(ivVal->getType()); + } + if (isEqual) + cond = B.CreateICmpEQ(ivVal, iv); + else + cond = B.CreateICmpNE(ivVal, iv); + return {std::make_pair((Value *)nullptr, cond)}; + } if (isEqual) { - return {Exp.expandCodeFor(node, T, IP)}; + return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)}; } - llvm::errs() << *this << "\n"; - llvm_unreachable("Constraint ne not handled"); + EmitFailure("NoSparsification", IP->getDebugLoc(), IP, + "Negated solution not handled: ", *this); + assert(0); + return {}; + } case Type::Union: { - SmallVector vals; + SmallVector, 1> vals; for (auto v : values) - for (auto sol : v->allSolutions(Exp, T, IP)) + for (auto sol : v->allSolutions(Exp, T, IP, ctx, B)) vals.push_back(sol); return vals; } - case Type::Intersect: - llvm::errs() << *this << "\n"; - llvm_unreachable("Intersect not handled"); + case Type::Intersect: { + { + SmallVector vals(values.begin(), values.end()); + ssize_t unionidx = -1; + for (unsigned i = 0; i < vals.size(); i++) { + if (vals[i]->ty == Type::Union) { + unionidx = i; + bool allne = true; + for (auto &v : vals[i]->values) { + if (v->ty != Type::Compare || v->isEqual) { + allne = false; + break; + } + } + if (allne) + break; + } + } + if (unionidx != -1) { + auto others = Constraints::all(); + for (unsigned j = 0; j < vals.size(); j++) + if (unionidx != j) + others = others->andB(vals[j], ctx); + SmallVector, 1> resvals; + for (auto &v : vals[unionidx]->values) { + auto tmp = v->andB(others, ctx); + for (const auto &sol : tmp->allSolutions(Exp, T, IP, ctx, B)) + resvals.push_back(sol); + } + return resvals; + } + } + Value *solVal = nullptr; + Value *cond = ConstantInt::getTrue(T->getContext()); + for (auto v : values) { + auto sols = v->allSolutions(Exp, T, IP, ctx, B); + if (sols.size() != 1) { + llvm::errs() << *this << "\n"; + for (auto s : sols) + if (s.first) + llvm::errs() << " + sol: " << *s.first << " " << *s.second << "\n"; + else + llvm::errs() << " + sol: " << s.first << " " << *s.second << "\n"; + llvm::errs() << " v: " << *v << " this: " << *this << "\n"; + llvm_unreachable("Intersect not handled (solsize>1)"); + } + auto sol = sols[0]; + if (sol.first) { + if (solVal != nullptr) { + llvm::errs() << *this << "\n"; + llvm::errs() << " prevsolVal: " << *solVal << "\n"; + llvm_unreachable("Intersect not handled (prevsolval)"); + } + assert(solVal == nullptr); + solVal = sol.first; + } + cond = B.CreateAnd(cond, sol.second); + } + return {std::make_pair(solVal, cond)}; + } } return {}; } +std::shared_ptr +getSparseConditions(bool &legal, Value *val, + std::shared_ptr defaultFloat, + Instruction *scope, const ConstraintContext &ctx) { + if (auto I = dyn_cast(val)) { + // Binary `and` is a bit-wise `umin`. + if (I->getOpcode() == Instruction::And) { + auto lhs = getSparseConditions(legal, I->getOperand(0), + Constraints::all(), I, ctx); + auto rhs = getSparseConditions(legal, I->getOperand(1), + Constraints::all(), I, ctx); + auto res = lhs->andB(rhs, ctx); + assert(res); + assert(ctx.seen.size() == 0); + llvm::errs() << " getSparse(and, " << *I << "), lhs(" << *I->getOperand(0) + << ") = " << *lhs << "\n"; + llvm::errs() << " getSparse(and, " << *I << "), rhs(" << *I->getOperand(1) + << ") = " << *rhs << "\n"; + llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n"; + return res; + } + + // Binary `or` is a bit-wise `umax`. + if (I->getOpcode() == Instruction::Or) { + auto lhs = getSparseConditions(legal, I->getOperand(0), + Constraints::none(), I, ctx); + auto rhs = getSparseConditions(legal, I->getOperand(1), + Constraints::none(), I, ctx); + auto res = lhs->orB(rhs, ctx); + llvm::errs() << " getSparse(or, " << *I << "), lhs(" << *I->getOperand(0) + << ") = " << *lhs << "\n"; + llvm::errs() << " getSparse(or, " << *I << "), rhs(" << *I->getOperand(1) + << ") = " << *rhs << "\n"; + llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n"; + return res; + } + + if (I->getOpcode() == Instruction::Xor) { + for (int i = 0; i < 2; i++) { + if (auto C = dyn_cast(I->getOperand(i))) + if (C->isOne()) { + auto pres = + getSparseConditions(legal, I->getOperand(1 - i), + defaultFloat->notB(ctx), scope, ctx); + auto res = pres->notB(ctx); + llvm::errs() << " getSparse(not, " << *I << "), prev (" + << *I->getOperand(0) << ") = " << *pres << "\n"; + llvm::errs() << " getSparse(not, " << *I << ") = " << *res << "\n"; + return res; + } + } + } + + if (auto icmp = dyn_cast(I)) { + auto L = ctx.loopToSolve; + auto lhs = ctx.SE.getSCEVAtScope(icmp->getOperand(0), L); + auto rhs = ctx.SE.getSCEVAtScope(icmp->getOperand(1), L); + llvm::errs() << " lhs: " << *lhs << "\n"; + llvm::errs() << " rhs: " << *rhs << "\n"; + + auto sub1 = ctx.SE.getMinusSCEV(lhs, rhs); + + if (icmp->getPredicate() == ICmpInst::ICMP_EQ || + icmp->getPredicate() == ICmpInst::ICMP_NE) { + if (auto add = dyn_cast(sub1)) { + if (add->isAffine()) { + // 0 === A + B * inc -> -A / B = inc + auto A = add->getStart(); + if (auto B = + dyn_cast(add->getStepRecurrence(ctx.SE))) { + + auto MA = A; + if (B->getAPInt().isNegative()) + B = cast(ctx.SE.getNegativeSCEV(B)); + else + MA = ctx.SE.getNegativeSCEV(A); + auto div = ctx.SE.getUDivExpr(MA, B); + auto div_e = ctx.SE.getUDivExactExpr(MA, B); + if (div == div_e) { + auto res = Constraints::make_compare( + div, icmp->getPredicate() == ICmpInst::ICMP_EQ, + add->getLoop(), ctx); + llvm::errs() + << " getSparse(icmp, " << *I << ") = " << *res << "\n"; + return res; + } + } + } + } + if (cannotDependOnLoopIV(sub1, ctx.loopToSolve)) { + auto res = Constraints::make_compare( + sub1, icmp->getPredicate() == ICmpInst::ICMP_EQ, nullptr, ctx); + llvm::errs() << " getSparse(icmp_noloop, " << *I << ") = " << *res + << "\n"; + return res; + } + } + if (scope) + EmitFailure("NoSparsification", I->getDebugLoc(), I, + " No sparsification: not sparse solvable(icmp): ", *sub1); + legal = false; + return defaultFloat; + } + + // cmp x, 1.0 -> false/true + if (auto fcmp = dyn_cast(I)) { + auto res = defaultFloat; + llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n"; + return res; + + if (fcmp->getPredicate() == CmpInst::FCMP_OEQ || + fcmp->getPredicate() == CmpInst::FCMP_UEQ) { + return Constraints::all(); + } else if (fcmp->getPredicate() == CmpInst::FCMP_ONE || + fcmp->getPredicate() == CmpInst::FCMP_UNE) { + return Constraints::none(); + } + } + } + + if (scope) { + EmitFailure("NoSparsification", scope->getDebugLoc(), scope, + " No sparsification: not sparse solvable: ", *val); + } + legal = false; + return defaultFloat; +} + +Constraints::InnerTy Constraints::make_compare(const SCEV *v, bool isEqual, + const llvm::Loop *Loop, + const ConstraintContext &ctx) { + if (!Loop) { + assert(!isa(v)); + SmallVector noassumption; + ConstraintContext ctx2(ctx.SE, ctx.loopToSolve, noassumption, ctx.DT); + for (auto I : ctx.Assumptions) { + bool legal = true; + auto parsedCond = getSparseConditions(legal, I->getOperand(0), + Constraints::none(), nullptr, ctx2); + bool dominates = ctx.DT.dominates(I, ctx.loopToSolve->getHeader()); + if (legal && dominates) { + if (parsedCond->ty == Type::Compare && !parsedCond->Loop) { + if (parsedCond->node == v || + parsedCond->node == ctx.SE.getNegativeSCEV(v)) { + InnerTy res; + if (parsedCond->isEqual == isEqual) + res = Constraints::all(); + else + res = Constraints::none(); + return res; + } + } + } + } + } + // cannot have negative loop canonical induction var + if (Loop) + if (auto C = dyn_cast(v)) + if (C->getAPInt().isNegative()) { + if (isEqual) + return Constraints::none(); + else + return Constraints::all(); + } + return InnerTy(new Constraints(v, isEqual, Loop, false)); +} + void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, SetVector &toDenseBlocks) { @@ -5405,7 +7180,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, auto &LI = FAM.getResult(F); auto &DL = F.getParent()->getDataLayout(); - llvm::SetVector Q; + QueueType Q(DT, LI); { llvm::SetVector todoBlocks; for (auto b : toDenseBlocks) { @@ -5417,31 +7192,35 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, } for (auto BB : todoBlocks) for (auto &I : *BB) - if (!I.getType()->isVoidTy()) + if (!I.getType()->isVoidTy()) { Q.insert(&I); + assert(Q.contains(&I)); + } } - // llvm::errs() << " pre fix inner: " << F << "\n"; + llvm::errs() << " pre fix inner: " << F << "\n"; // Full simplification while (!Q.empty()) { auto cur = Q.pop_back_val(); - SetVector prev(Q.begin(), Q.end()); - // llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n"; + std::set prev; + for (auto v : Q) + prev.insert(v); + llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n"; auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL); (void)changed; - /* if (changed) { - llvm::errs() << "changed: " << *changed << "\n"; + llvm::errs() << "changed: " << *changed << "\n"; - for (auto I : Q) - if (!prev.contains(I)) - llvm::errs() << " + " << *I << "\n"; - llvm::errs() << F << "\n\n"; + for (auto I : Q) + if (!prev.count(I)) + llvm::errs() << " + " << *I << "\n"; + llvm::errs() << F << "\n\n"; } - */ } + llvm::errs() << " post fix inner " << F << "\n"; + SmallVector, 1> sparseBlocks; bool legalToSparse = true; for (auto &B : F) @@ -5472,37 +7251,31 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, auto L = LI.getLoopFor(blk); if (!L) { legalToSparse = false; - llvm::errs() << " F: " << F << "\n"; - llvm::errs() - << " Sparsification disabled, could not find loop for : " - << *blk << "\n"; + EmitFailure("NoSparsification", br->getDebugLoc(), br, + "F: ", F, "\nCould not find loop for: ", *blk); break; } auto idx = L->getCanonicalInductionVariable(); if (!idx) { legalToSparse = false; - llvm::errs() << " F: " << F << "\n"; - llvm::errs() << " L: " << *L << "\n"; - llvm::errs() - << " Sparsification disabled, could not find loop index " - << *L->getHeader() << "\n"; + EmitFailure("NoSparsification", br->getDebugLoc(), br, + "F: ", F, "\nL:", *L, + "\nCould not find loop index: ", *L->getHeader()); break; } assert(idx); auto preheader = L->getLoopPreheader(); if (!preheader) { legalToSparse = false; - llvm::errs() << " F: " << F << "\n"; - llvm::errs() << " L: " << *L << "\n"; - llvm::errs() << " Sparsification disabled, could not find " - "loop preheader\n"; + EmitFailure("NoSparsification", br->getDebugLoc(), br, + "F: ", F, "\nL:", *L, + "\nCould not find loop preheader"); break; } sparseBlocks.emplace_back(blk, br); } if (!legalToSparse) { - llvm::errs() << " was found not legal to sparsify\n"; return; } @@ -5514,6 +7287,15 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, 1>>> forSparsification; + SmallVector Assumptions; + for (auto &BB : F) + for (auto &I : BB) + if (auto II = dyn_cast(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + Assumptions.push_back(II); + + bool sawError = false; + for (auto [blk, br] : sparseBlocks) { auto L = LI.getLoopFor(blk); assert(L); @@ -5543,7 +7325,8 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, return true; if (isa(I)) return false; - llvm::errs() << " bad datadependent values check " << *val << "\n"; + EmitFailure("NoSparsification", I->getDebugLoc(), I, + " No sparsification: bad datadepedent values check: ", *I); legal = false; return true; }; @@ -5554,106 +7337,34 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, // result may become more true // - std::function( - Value *, std::shared_ptr)> - getSparseConditions = - [&](Value *val, std::shared_ptr defaultFloat) - -> std::shared_ptr { - if (auto I = dyn_cast(val)) { - // Binary `and` is a bit-wise `umin`. - if (I->getOpcode() == Instruction::And) { - auto res = getSparseConditions(I->getOperand(0), Constraints::all()) - ->andB(getSparseConditions(I->getOperand(1), - Constraints::all()), - SE); - return res; - } - - // Binary `or` is a bit-wise `umax`. - if (I->getOpcode() == Instruction::Or) { - auto res = getSparseConditions(I->getOperand(0), Constraints::none()) - ->orB(getSparseConditions(I->getOperand(1), - Constraints::none()), - SE); - return res; - } - - // cmp x, 1.0 -> false/true - if (auto icmp = dyn_cast(I)) { - auto lhs = SE.getSCEVAtScope(icmp->getOperand(0), L); - auto rhs = SE.getSCEVAtScope(icmp->getOperand(1), L); - - auto sub1 = SE.getMinusSCEV(lhs, rhs); - - if (auto add = dyn_cast(sub1)) { - if (add->getLoop() == L) { - if (add->isAffine()) { - // 0 === A + B * inc -> -A / B = inc - auto A = add->getStart(); - if (auto B = - dyn_cast(add->getStepRecurrence(SE))) { - - auto MA = A; - if (B->getAPInt().isNegative()) - B = cast(SE.getNegativeSCEV(B)); - else - SE.getNegativeSCEV(A); - auto div = SE.getUDivExpr(MA, B); - auto div_e = SE.getUDivExactExpr(MA, B); - if (div == div_e) { - auto res = std::make_shared( - div, icmp->getPredicate() == ICmpInst::ICMP_EQ); - return res; - } - } - } - } - llvm::errs() << " not sparse solvable " << *sub1 << "\n"; - legal = false; - } - } - - if (auto fcmp = dyn_cast(I)) { - auto res = defaultFloat; - return res; - - if (fcmp->getPredicate() == CmpInst::FCMP_OEQ || - fcmp->getPredicate() == CmpInst::FCMP_UEQ) { - return Constraints::all(); - } else if (fcmp->getPredicate() == CmpInst::FCMP_ONE || - fcmp->getPredicate() == CmpInst::FCMP_UNE) { - return Constraints::none(); - } - } - } - - llvm::errs() << " not sparse solvable " << *val << "\n"; - legal = false; - return Constraints::all(); - }; - auto solutions = getSparseConditions(cond, negated ? Constraints::all() - : Constraints::none()); - if (!negated) - solutions = solutions->notB(); - if (!legal) - continue; - - if (!solutions->canEvaluateSolutions()) { - llvm::errs() << "F: " << F << "\n"; - llvm::errs() << " L: " << *L << " blk: " << *blk << "\n"; - llvm::errs() << " cond: " << *cond << " negated: " << negated << "\n"; - - llvm::errs() << " not sparse solvable " << *solutions << "\n"; - legal = false; + // default is condition avoids sparse, negated is condition goes + // to sparse + Instruction *context = + isa(cond) ? cast(cond) : idx; + ConstraintContext cctx(SE, L, Assumptions, DT); + auto solutions = getSparseConditions( + legal, cond, negated ? Constraints::all() : Constraints::none(), + context, cctx); + // llvm::errs() << " solutions pre negate: " << *solutions << "\n"; + if (!negated) { + solutions = solutions->notB(cctx); + } + // llvm::errs() << " solutions post negate: " << *solutions << "\n"; + if (!legal) { + sawError = true; continue; } - if (solutions == Constraints::none()) { - llvm::errs() << "F: " << F << "\n"; - llvm::errs() << " L: " << *L << " blk: " << *blk << "\n"; - llvm::errs() << " cond: " << *cond << " negated: " << negated << "\n"; + + if (solutions == Constraints::none() || solutions == Constraints::all()) { + EmitFailure( + "NoSparsification", context->getDebugLoc(), context, "F: ", F, + "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, + "\n No sparsification: not sparse solvable(nosoltn): solutions:", + *solutions); + sawError = true; } - llvm::errs() << " found solvable solutions " << *solutions << "\n"; + // llvm::errs() << " found solvable solutions " << *solutions << "\n"; if (forSparsification.count(L) == 0) { { @@ -5696,8 +7407,23 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, forSparsification[L].second.emplace_back(blk, solutions); } + if (sawError) { + for (auto &pair : forSparsification) { + for (auto PN : {pair.second.first.first, pair.second.first.second}) { + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->eraseFromParent(); + } + } + if (llvm::verifyFunction(F, &llvm::errs())) { + llvm::errs() << F << "\n"; + report_fatal_error("function failed verification (6)"); + } + return; + } + if (forSparsification.size() == 0) { llvm::errs() << " found no stores for sparsification\n"; + assert(0); return; } @@ -5803,16 +7529,43 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, auto phterm = ph->getTerminator(); IRBuilder<> B(phterm); - SCEVExpander Exp(SE, DL, "sparseenzyme"); + + // We extracted code, reset analyses. + /* + DT.reset(); + SE.forgetAllLoops(); + */ for (auto en : llvm::enumerate(pair.second.second)) { auto off = en.index(); auto &solutions = en.value().second; - for (auto sol : solutions->allSolutions(Exp, idxty, phterm)) { + ConstraintContext ctx(SE, L, Assumptions, DT); +#if LLVM_VERSION_MAJOR >= 12 + SCEVExpander Exp(SE, DL, "sparseenzyme", /*preservelcssa*/ false); +#else + SCEVExpander Exp(SE, DL, "sparseenzyme"); +#endif + auto sols = solutions->allSolutions(Exp, idxty, phterm, ctx, B); + SmallVector prevSols; + for (auto [sol, condition] : sols) { SmallVector args(Inputs.begin(), Inputs.end()); args[off_idx] = ConstantInt::get(idxty, off); args[induct_idx] = sol; + for (auto sol2 : prevSols) + condition = B.CreateAnd(condition, B.CreateICmpNE(sol, sol2)); + prevSols.push_back(sol); + auto BB = B.GetInsertBlock(); + auto B2 = BB->splitBasicBlock(B.GetInsertPoint(), "poststore"); + B2->moveAfter(BB); + BB->getTerminator()->eraseFromParent(); + B.SetInsertPoint(BB); + auto callB = BasicBlock::Create(BB->getContext(), "tostore", + BB->getParent(), B2); + B.CreateCondBr(condition, callB, B2); + B.SetInsertPoint(callB); B.CreateCall(F2, args); + B.CreateBr(B2); + B.SetInsertPoint(B2->getTerminator()); } auto blk = en.value().first; auto term = blk->getTerminator(); @@ -5823,9 +7576,6 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, PN->eraseFromParent(); - // B.CreateCondBr(ConstantInt::getTrue(B.getContext()), L->getExitBlock(), - // L->getHeader()); phterm->eraseFromParent(); - for (auto &I : *L2Header) { auto boundsCheck = dyn_cast(&I); if (!boundsCheck) @@ -5858,6 +7608,48 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, break; } } + + for (auto &F2 : F.getParent()->functions()) { + if (startsWith(F2.getName(), "__enzyme_product")) { + SmallVector toErase; + for (llvm::User *I : F2.users()) { + auto CB = cast(I); + IRBuilder<> B(CB); + B.setFastMathFlags(getFast()); + Value *res = nullptr; + for (auto v : callOperands(CB)) { + if (res == nullptr) + res = v; + else { + res = B.CreateFMul(res, v); + } + } + CB->replaceAllUsesWith(res); + toErase.push_back(CB); + } + for (auto CB : toErase) + CB->eraseFromParent(); + } else if (startsWith(F2.getName(), "__enzyme_sum")) { + SmallVector toErase; + for (llvm::User *I : F2.users()) { + auto CB = cast(I); + IRBuilder<> B(CB); + B.setFastMathFlags(getFast()); + Value *res = nullptr; + for (auto v : callOperands(CB)) { + if (res == nullptr) + res = v; + else { + res = B.CreateFAdd(res, v); + } + } + CB->replaceAllUsesWith(res); + toErase.push_back(CB); + } + for (auto CB : toErase) + CB->eraseFromParent(); + } + } } bool LowerSparsification(llvm::Function *F, bool replaceAll) { @@ -6134,11 +7926,11 @@ bool LowerSparsification(llvm::Function *F, bool replaceAll) { PB.registerCGSCCAnalyses(CGAM); PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + SimplifyCFGPass(SimplifyCFGOptions()).run(*F, FAM); InstCombinePass().run(*F, FAM); // required to make preheaders LoopSimplifyPass().run(*F, FAM); fixSparseIndices(*F, FAM, toDenseBlocks); - llvm::errs() << " post ind: " << *F << "\n"; } return changed; } diff --git a/enzyme/test/Enzyme/ReverseMode/incloop.ll b/enzyme/test/Enzyme/ReverseMode/incloop.ll index 5e4354819015..8555e6364ed2 100644 --- a/enzyme/test/Enzyme/ReverseMode/incloop.ll +++ b/enzyme/test/Enzyme/ReverseMode/incloop.ll @@ -166,7 +166,7 @@ attributes #8 = { noreturn nounwind } ; CHECK: for.body: ; preds = %for.body, %for.body.preheader ; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.body ], [ 0, %for.body.preheader ] ; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 -; CHECK-NEXT: %6 = mul i64 %0, %iv +; CHECK-NEXT: %6 = mul {{(nuw )?}}{{(nsw )?}}i64 %0, %iv ; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %a, i64 %6 ; CHECK-NEXT: %7 = load double, double* %arrayidx, align 8, !tbaa !2 ; CHECK-NEXT: store double 0.000000e+00, double* %arrayidx, align 8, !tbaa !2 @@ -204,7 +204,7 @@ attributes #8 = { noreturn nounwind } ; CHECK: invertfor.body: ; preds = %incinvertfor.body, %invertfor.cond.cleanup.loopexit ; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %_unwrap4.zext, %invertfor.cond.cleanup.loopexit ], [ %[[i19:.+]], %incinvertfor.body ] ; CHECK-NEXT: %_unwrap5 = zext i32 %b to i64 -; CHECK-NEXT: %_unwrap6 = mul i64 %_unwrap5, %"iv'ac.0" +; CHECK-NEXT: %_unwrap6 = mul {{(nuw )?}}{{(nsw )?}}i64 %_unwrap5, %"iv'ac.0" ; CHECK-NEXT: %"arrayidx'ipg_unwrap" = getelementptr inbounds double, double* %"a'", i64 %_unwrap6 ; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg_unwrap", align 8 ; CHECK-NEXT: %12 = getelementptr inbounds double, double* %_cache.0, i64 %"iv'ac.0" diff --git a/enzyme/test/Enzyme/ReverseMode/insertsort.ll b/enzyme/test/Enzyme/ReverseMode/insertsort.ll index 9dbd3ad656a9..f6de9d7626c0 100644 --- a/enzyme/test/Enzyme/ReverseMode/insertsort.ll +++ b/enzyme/test/Enzyme/ReverseMode/insertsort.ll @@ -55,7 +55,7 @@ attributes #0 = { noinline norecurse nounwind uwtable } ; CHECK-NEXT: %iv = phi i64 [ %iv.next, %while.body ], [ 0, %land.rhs.preheader ] ; CHECK-DAG: %iv.next = add nuw nsw i64 %iv, 1 ; CHECK-DAG: %1 = mul {{(nuw )?}}{{(nsw )?}}i64 %iv, -1 -; CHECK-DAG: %[[a1:.+]] = add i64 %0, %1 +; CHECK-DAG: %[[a1:.+]] = add nsw i64 %0, %1 ; CHECK-DAG: %indvars.iv.next = add nsw i64 %[[a1]], -1 ; CHECK-NEXT: %arrayidx = getelementptr inbounds float, float* %array, i64 %indvars.iv.next ; CHECK-NEXT: %[[a2:.+]] = load float, float* %arrayidx, align 4 @@ -79,7 +79,7 @@ attributes #0 = { noinline norecurse nounwind uwtable } ; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %"iv'ac.1", %invertwhile.body ], [ %loopLimit_cache.0, %invertwhile.end.loopexit ] ; CHECK-NEXT: %_unwrap = sext i32 %i to i64 ; CHECK-NEXT: %_unwrap1 = mul {{(nuw )?}}{{(nsw )?}}i64 %"iv'ac.0", -1 -; CHECK-NEXT: %[[_unwrap2:.+]] = add i64 %_unwrap, %_unwrap1 +; CHECK-NEXT: %[[_unwrap2:.+]] = add nsw i64 %_unwrap, %_unwrap1 ; CHECK-NEXT: %[[arrayidx2ipg:.+]] = getelementptr inbounds float, float* %"array'", i64 %[[_unwrap2]] ; CHECK-NEXT: %[[a4:.+]] = load float, float* %[[arrayidx2ipg]], align 4 ; CHECK-NEXT: %[[a5:.+]] = fadd fast float %[[a4]], %"'de.0" diff --git a/enzyme/test/Enzyme/ReverseMode/mincachechain5.ll b/enzyme/test/Enzyme/ReverseMode/mincachechain5.ll index a4a916c17ac9..e1e4e28647d7 100644 --- a/enzyme/test/Enzyme/ReverseMode/mincachechain5.ll +++ b/enzyme/test/Enzyme/ReverseMode/mincachechain5.ll @@ -81,7 +81,7 @@ attributes #0 = { readnone speculatable } ; CHECK: for.body59: ; preds = %for.body59, %for.body ; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body59 ], [ 0, %for.body ] ; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1 -; CHECK-NEXT: %[[a3:.+]] = mul i64 {{(%iv1, %step|%step, %iv1)}} +; CHECK-NEXT: %[[a3:.+]] = mul {{(nuw ?)}}{{(nsw )?}}i64 {{(%iv1, %step|%step, %iv1)}} ; CHECK-NEXT: %add61 = add nuw nsw i64 %[[a3]], %step ; CHECK-NEXT: %_augmented = call fast double @augmented_inner(double* %x, double* %"x'") ; CHECK-NEXT: %[[a5:.+]] = mul nuw nsw i64 %iv, %[[a0]] @@ -117,7 +117,7 @@ attributes #0 = { readnone speculatable } ; CHECK: for.body59: ; preds = %for.body59, %for.body ; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body59 ], [ 0, %for.body ] ; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1 -; CHECK-NEXT: %[[a4:.+]] = mul i64 {{(%iv1, %step|%step, %iv1)}} +; CHECK-NEXT: %[[a4:.+]] = mul {{(nuw )?}}{{(nsw )?}}i64 {{(%iv1, %step|%step, %iv1)}} ; CHECK-NEXT: %add61 = add nuw nsw i64 %[[a4]], %step ; CHECK-NEXT: %cmp57 = icmp slt i64 %add61, 100 ; CHECK-NEXT: br i1 %cmp57, label %for.body59, label %for.cond.loopexit diff --git a/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll b/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll index b50739920d37..01a96bb9de09 100644 --- a/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll +++ b/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll @@ -145,7 +145,7 @@ attributes #1 = { argmemonly } ; CHECK: omp.inner.for.body: ; preds = %omp.precond.then, %omp.inner.for.body ; CHECK-NEXT: %iv = phi i64 [ %iv.next, %omp.inner.for.body ], [ 0, %omp.precond.then ] ; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 -; CHECK-NEXT: %[[true1iv:.+]] = add i64 +; CHECK-NEXT: %[[true1iv:.+]] = add nuw i64 ; %[[lb]], %iv ; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %tmp, i64 %[[true1iv]] ; CHECK-NEXT: %[[ld:.+]] = load double, double* %arrayidx, align 8, !tbaa !9 @@ -204,7 +204,7 @@ attributes #1 = { argmemonly } ; CHECK: invertomp.inner.for.body: ; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %[[_unwrap88:.+]], %invertomp.loop.exit.loopexit ], [ %[[i19:.+]], %incinvertomp.inner.for.body ] ; CHECK-NEXT: %_unwrap2 = load i64, i64* %.omp.lb_smpl -; CHECK-NEXT: %_unwrap3 = add i64 %_unwrap2, %"iv'ac.0" +; CHECK-NEXT: %_unwrap3 = add nuw i64 %_unwrap2, %"iv'ac.0" ; CHECK-NEXT: %"arrayidx'ipg_unwrap" = getelementptr inbounds double, double* %"tmp'", i64 %_unwrap3 ; CHECK-NEXT: %[[i8:.+]] = load double, double* %"arrayidx'ipg_unwrap", align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg_unwrap", align 8 diff --git a/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll b/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll index 7b28c6b18998..27d44e9907e8 100644 --- a/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll +++ b/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll @@ -150,7 +150,7 @@ attributes #1 = { argmemonly } ; CHECK: invertomp.inner.for.body: ; preds = %invertomp.loop.exit.loopexit, %incinvertomp.inner.for.body ; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %[[_unwrap7:.+]], %invertomp.loop.exit.loopexit ], [ %9, %incinvertomp.inner.for.body ] ; CHECK-NEXT: %_unwrap2 = load i64, i64* %.omp.lb_smpl -; CHECK-NEXT: %_unwrap3 = add i64 {{((%_unwrap2, %"iv'ac.0")|%"iv'ac.0", %_unwrap2)}} +; CHECK-NEXT: %_unwrap3 = add nuw i64 {{((%_unwrap2, %"iv'ac.0")|%"iv'ac.0", %_unwrap2)}} ; CHECK-NEXT: %"outidx'ipg_unwrap" = getelementptr inbounds double, double* %"out'", i64 %_unwrap3 ; CHECK-NEXT: %1 = load double, double* %"outidx'ipg_unwrap", align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"outidx'ipg_unwrap", align 8 diff --git a/enzyme/test/Enzyme/ReverseMode/reorderrep.ll b/enzyme/test/Enzyme/ReverseMode/reorderrep.ll index 4d70316034f9..3a9aa4ccffb8 100644 --- a/enzyme/test/Enzyme/ReverseMode/reorderrep.ll +++ b/enzyme/test/Enzyme/ReverseMode/reorderrep.ll @@ -96,7 +96,7 @@ attributes #3 = { nounwind } ; CHECK: bb377: ; preds = %bb381, %bexit ; CHECK-NEXT: %iv = phi i64 [ %iv.next, %bb381 ], [ 0, %bexit ] ; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 -; CHECK-NEXT: %[[a4:.+]] = add i64 {{(%iv, %.020|%.020, %iv)}} +; CHECK-NEXT: %[[a4:.+]] = add nsw i64 {{(%iv, %.020|%.020, %iv)}} ; CHECK-NEXT: %tmp378 = icmp slt i64 %[[a4]], 10 ; CHECK-NEXT: br i1 %tmp378, label %bb381, label %bb450 diff --git a/enzyme/test/Integration/Sparse/ringspring.cpp b/enzyme/test/Integration/Sparse/ringspring.cpp index 7c020a23e7b5..0ecae72bef5e 100644 --- a/enzyme/test/Integration/Sparse/ringspring.cpp +++ b/enzyme/test/Integration/Sparse/ringspring.cpp @@ -1,11 +1,11 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // everything should be always inline @@ -40,8 +40,8 @@ extern void __enzyme_fwddiff(void *, ...); extern double* __enzyme_todense(void *, ...) noexcept; -/// Compute energy -double f(size_t N, double* input) { +__attribute__((always_inline)) +static double f(size_t N, double* input) { double out = 0; // __builtin_assume(!((N-1) == 0)); for (size_t i=0; i hess_f(size_t N, double* input) { std::vector triplets; input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); + __builtin_assume(N != 1); for (size_t i=0; i +#include +#include + + +#include + +struct triple { + size_t row; + size_t col; + double val; + triple(triple&&) = default; + triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} +}; + + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +extern void __enzyme_autodiff(void *, ...); + +extern void __enzyme_fwddiff(void *, ...); + +extern double* __enzyme_todense(void *, ...) noexcept; + + +__attribute__((always_inline)) +static double f(size_t N, double* pos) { + double e = 0.; + for (size_t i = 0; i < N; i ++) { + __builtin_assume(i < 1000000000); + double vx = pos[2 * i]; + double vy = pos[2 * i + 1]; + + double wx = pos[2 * i + 2]; + double wy = pos[2 * i + 3]; + e += (wx - vx) * (wx - vx) + (wy - vy) * (wy - vy); + } + return e; +} + + +__attribute__((always_inline)) +static void grad_f(size_t N, double* input, double* dinput) { + __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); +} + +__attribute__((always_inline)) +void ident_store(double , int64_t idx, size_t i) { + assert(0 && "should never load"); +} + +__attribute__((always_inline)) +double ident_load(size_t idx, size_t i, size_t N) { + idx /= sizeof(double); + return (double)(idx == i);// ? 1.0 : 0.0; +} + +__attribute__((enzyme_sparse_accumulate)) +void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { + printf("row=%d col=%d val=%f\n", row, col % N, val); + // assert(abs(val) > 0.00001); + triplets.emplace_back(row % N, col % N, val); +} + +__attribute__((always_inline)) +void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { + if (val == 0.0) return; + idx /= sizeof(double); + inner_store(i, idx, N, val, triplets); +} + +__attribute__((always_inline)) +double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { + return 0.0; +} + +__attribute__((always_inline)) +void never_store(double val, int64_t idx, double* input, size_t N) { + assert(0 && "this is a read only input, why are you storing here..."); +} + +__attribute__((always_inline)) +double mod_load(int64_t idx, double* input, size_t N) { + idx /= sizeof(double); + return input[idx % N]; +} + +__attribute__((noinline)) +std::vector hess_f(size_t N, double* input) { + std::vector triplets; + // input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); + __builtin_assume(N > 0); + for (size_t i=0; i hess_f2(size_t N, double* input) { + std::vector triplets; + input = + ((void*)mod_load, (void*)never_store, input, N); + hess_f(N, input); +} +*/ +// int argc, char** argv +int __attribute__((always_inline)) main() { + + + // if (argc != 2) { + // printf("Usage: %s \n", argv[0]); + // return 1; + // } + + // size_t N = atoi(argv[1]); + size_t N = 16; + + double x[2 * N + 2]; + for (int i = 0; i < N; ++i) { + double angle = 2 * M_PI * i / N; + x[2 * i] = cos(angle) ;//+ normal(generator); + x[2 * i + 1] = sin(angle) ;//+ normal(generator); + } + x[2 * N] = x[0]; + x[2 * N + 1] = x[1]; + auto res = hess_f(N, &x[0]); + + printf("%ld\n", res.size()); + + for (auto & tup : res) + printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val); + + return 0; +} + diff --git a/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp b/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp index 248d2e7e7f89..72408a73df27 100644 --- a/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp +++ b/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp @@ -1,16 +1,14 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // everything should be always inline -// XFAIL: * - #include #include #include @@ -28,8 +26,6 @@ struct triple { }; -size_t N; - extern int enzyme_dup; extern int enzyme_dupnoneed; extern int enzyme_out; @@ -42,10 +38,12 @@ extern void __enzyme_fwddiff(void *, ...); extern double* __enzyme_todense(void *, ...) noexcept; -/// Compute energy -double f(size_t N, double* pos) { +__attribute__((always_inline)) +static double f(size_t N, double* pos) { double e = 0.; - for (size_t i = 0; i < N; i += 3) { + __builtin_assume(N != 0); + for (size_t i = 0; i < N; i+=3) { + __builtin_assume(i < 1000000000); double vx = pos[i]; double vy = pos[i + 1]; double vz = pos[i + 2]; @@ -59,34 +57,34 @@ double f(size_t N, double* pos) { } -/// Perform dinput += gradient(f) -void grad_f(size_t N, double* input, double* dinput) { +__attribute__((always_inline)) +static void grad_f(size_t N, double* input, double* dinput) { __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } - -void ident_store(double , int64_t idx, size_t i) { +__attribute__((always_inline)) +static void ident_store(double , int64_t idx, size_t i) { assert(0 && "should never load"); } __attribute__((always_inline)) -double ident_load(int64_t idx, size_t i, size_t N) { +static double ident_load(int64_t idx, size_t i, size_t N) { idx /= sizeof(double); return (double)(idx == i);// ? 1.0 : 0.0; } __attribute__((enzyme_sparse_accumulate)) -void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { +static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { printf("row=%d col=%d val=%f\n", row, col % N, val); // assert(abs(val) > 0.00001); triplets.emplace_back(row % N, col % N, val); } __attribute__((always_inline)) -void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { +static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { if (val == 0.0) return; idx /= sizeof(double); - inner_store(i, idx, val, triplets); + inner_store(i, idx, N, val, triplets); } __attribute__((always_inline)) @@ -110,6 +108,7 @@ std::vector hess_f(size_t N, double* input) { std::vector triplets; // input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); + __builtin_assume(N < 10000000000); for (size_t i=0; i hess_f2(size_t N, double* input) { hess_f(N, input); } */ - -int __attribute__((always_inline)) main(int argc, char** argv) { +// int argc, char** argv +int __attribute__((always_inline)) main() { std::mt19937 generator(0); // Seed the random number generator std::uniform_real_distribution normal(0, 0.05); - if (argc != 2) { - printf("Usage: %s \n", argv[0]); - return 1; - } + // if (argc != 2) { + // printf("Usage: %s \n", argv[0]); + // return 1; + // } - size_t N = atoi(argv[1]); - // size_t N = 16; + // size_t N = atoi(argv[1]); + size_t N = 30; double x[3 * N + 3]; for (int i = 0; i < N; ++i) { diff --git a/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp b/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp index 20a08a7ddcf6..b5bb2f259135 100644 --- a/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp +++ b/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp @@ -1,15 +1,15 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // everything should be always inline -// XFAIL: * - #include #include @@ -28,8 +28,6 @@ struct triple { }; -size_t N; - extern int enzyme_dup; extern int enzyme_dupnoneed; extern int enzyme_out; @@ -42,8 +40,8 @@ extern void __enzyme_fwddiff(void *, ...); extern double* __enzyme_todense(void *, ...) noexcept; -/// Compute energy -double f(size_t N, double* pos) { +__attribute__((always_inline)) +static double f(size_t N, double* pos) { double e = 0.; for (size_t i = 0; i < N; i += 3) { double vx = pos[i]; @@ -61,48 +59,49 @@ double f(size_t N, double* pos) { } -/// Perform dinput += gradient(f) -void grad_f(size_t N, double* input, double* dinput) { +__attribute__((always_inline)) +static void grad_f(size_t N, double* input, double* dinput) { __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } -void ident_store(double , int64_t idx, size_t i) { +__attribute__((always_inline)) +static void ident_store(double , int64_t idx, size_t i) { assert(0 && "should never load"); } __attribute__((always_inline)) -double ident_load(int64_t idx, size_t i, size_t N) { +static double ident_load(int64_t idx, size_t i, size_t N) { idx /= sizeof(double); return (double)(idx == i);// ? 1.0 : 0.0; } __attribute__((enzyme_sparse_accumulate)) -void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { +static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { printf("row=%d col=%d val=%f\n", row, col % N, val); // assert(abs(val) > 0.00001); triplets.emplace_back(row % N, col % N, val); } __attribute__((always_inline)) -void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { +static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { if (val == 0.0) return; idx /= sizeof(double); - inner_store(i, idx, val, triplets); + inner_store(i, idx, N, val, triplets); } __attribute__((always_inline)) -double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { +static double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { return 0.0; } __attribute__((always_inline)) -void never_store(double val, int64_t idx, double* input, size_t N) { +static void never_store(double val, int64_t idx, double* input, size_t N) { assert(0 && "this is a read only input, why are you storing here..."); } __attribute__((always_inline)) -double mod_load(int64_t idx, double* input, size_t N) { +static double mod_load(int64_t idx, double* input, size_t N) { idx /= sizeof(double); return input[idx % N]; } @@ -136,25 +135,26 @@ std::vector hess_f2(size_t N, double* input) { } */ -int __attribute__((always_inline)) main(int argc, char** argv) { - std::mt19937 generator(0); // Seed the random number generator - std::uniform_real_distribution normal(0, 0.05); +// int argc, char** argv +int __attribute__((always_inline)) main() { + //std::mt19937 generator(0); // Seed the random number generator + //std::uniform_real_distribution normal(0, 0.05); - if (argc != 2) { - printf("Usage: %s \n", argv[0]); - return 1; - } + // if (argc != 2) { + // printf("Usage: %s \n", argv[0]); + // return 1; + // } - size_t N = atoi(argv[1]); - // size_t N = 16; + // size_t N = atoi(argv[1]); + size_t N = 30; double x[3 * N + 3]; for (int i = 0; i < N; ++i) { double angle = 2 * M_PI * i / N; - x[3 * i] = cos(angle) + normal(generator); - x[3 * i + 1] = sin(angle) + normal(generator); - x[3 * i + 2] = normal(generator); + x[3 * i] = cos(angle) ;//+ normal(generator); + x[3 * i + 1] = sin(angle) ;//+ normal(generator); + x[3 * i + 2] = 0;//normal(generator); } x[3 * N] = x[0]; x[3 * N + 1] = x[1]; diff --git a/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp b/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp index 0c9e864c9256..49896b2cbc62 100644 --- a/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp +++ b/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp @@ -1,15 +1,13 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // everything should be always inline -// XFAIL: * - #include #include @@ -28,8 +26,6 @@ struct triple { }; -size_t N; - extern int enzyme_dup; extern int enzyme_dupnoneed; extern int enzyme_out; @@ -42,14 +38,14 @@ extern void __enzyme_fwddiff(void *, ...); extern double* __enzyme_todense(void *, ...) noexcept; -/// Compute energy -double f(size_t N, double* pos) { +__attribute__((always_inline)) +static double f(size_t N, double* pos) { double e = 0.; for (size_t i = 0; i < N; i += 3) { double vx = pos[i]; double vy = pos[i + 1]; double vz = pos[i + 2]; - + double wx = pos[i + 3]; double wy = pos[i + 4]; double wz = pos[i + 5]; @@ -61,48 +57,49 @@ double f(size_t N, double* pos) { } -/// Perform dinput += gradient(f) -void grad_f(size_t N, double* input, double* dinput) { +__attribute__((always_inline)) +static void grad_f(size_t N, double* input, double* dinput) { __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } -void ident_store(double , int64_t idx, size_t i) { +__attribute__((always_inline)) +static void ident_store(double , int64_t idx, size_t i) { assert(0 && "should never load"); } __attribute__((always_inline)) -double ident_load(int64_t idx, size_t i, size_t N) { +static double ident_load(int64_t idx, size_t i, size_t N) { idx /= sizeof(double); return (double)(idx == i);// ? 1.0 : 0.0; } __attribute__((enzyme_sparse_accumulate)) -void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { +static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { printf("row=%d col=%d val=%f\n", row, col % N, val); // assert(abs(val) > 0.00001); triplets.emplace_back(row % N, col % N, val); } __attribute__((always_inline)) -void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { +static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { if (val == 0.0) return; idx /= sizeof(double); - inner_store(i, idx, val, triplets); + inner_store(i, idx, N, val, triplets); } __attribute__((always_inline)) -double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { +static double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { return 0.0; } __attribute__((always_inline)) -void never_store(double val, int64_t idx, double* input, size_t N) { +static void never_store(double val, int64_t idx, double* input, size_t N) { assert(0 && "this is a read only input, why are you storing here..."); } __attribute__((always_inline)) -double mod_load(int64_t idx, double* input, size_t N) { +static double mod_load(int64_t idx, double* input, size_t N) { idx /= sizeof(double); return input[idx % N]; } @@ -136,25 +133,26 @@ std::vector hess_f2(size_t N, double* input) { } */ -int __attribute__((always_inline)) main(int argc, char** argv) { +// int argc, char** argv +int __attribute__((always_inline)) main() { std::mt19937 generator(0); // Seed the random number generator std::uniform_real_distribution normal(0, 0.05); - if (argc != 2) { - printf("Usage: %s \n", argv[0]); - return 1; - } + // if (argc != 2) { + // printf("Usage: %s \n", argv[0]); + // return 1; + // } - size_t N = atoi(argv[1]); - // size_t N = 16; + // size_t N = atoi(argv[1]); + size_t N = 30; double x[3 * N]; for (int i = 0; i < N; ++i) { double angle = 2 * M_PI * i / N; - x[3 * i] = cos(angle) + normal(generator); - x[3 * i + 1] = sin(angle) + normal(generator); - x[3 * i + 2] = normal(generator); + x[3 * i] = cos(angle) ;//+ normal(generator); + x[3 * i + 1] = sin(angle) ;//+ normal(generator); + x[3 * i + 2] = 0;//normal(generator); } auto res = hess_f(N, &x[0]); diff --git a/enzyme/test/Integration/Sparse/sqrtspring.cpp b/enzyme/test/Integration/Sparse/sqrtspring.cpp index fdec1f99cdde..a9750409b37d 100644 --- a/enzyme/test/Integration/Sparse/sqrtspring.cpp +++ b/enzyme/test/Integration/Sparse/sqrtspring.cpp @@ -1,12 +1,14 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi - +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi + +#include +#include #include #include #include @@ -33,8 +35,8 @@ extern void __enzyme_fwddiff(void *, ...); extern double* __enzyme_todense(void *, ...) noexcept; -/// Compute energy -double f(size_t N, double* input) { +__attribute__((always_inline)) +static double f(size_t N, double* input) { double out = 0; __builtin_assume(!((N-1) == 0)); for (size_t i=0; i &triplets) { +static void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { printf("row=%d col=%d val=%f\n", row, col, val); assert(abs(val) > 0.00001); triplets.emplace_back(row, col, val); } __attribute__((always_inline)) -void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { +static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { if (val == 0.0) return; idx /= sizeof(double); inner_store(i, idx, val, triplets); } __attribute__((always_inline)) -double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { +static double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { return 0.0; } From 9117cf3488cd1a36282348b3bea2a2a59ba09614 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 16 Jan 2024 23:34:06 -0500 Subject: [PATCH 004/131] Add eigen analysis test (#1612) --- enzyme/Enzyme/Clang/EnzymeClang.cpp | 1 - enzyme/Enzyme/FunctionUtils.cpp | 479 +++++++++--------- .../Integration/Sparse/eigen_analysis.cpp | 261 ++++++++++ enzyme/test/Integration/Sparse/matrix.h | 203 ++++++++ 4 files changed, 714 insertions(+), 230 deletions(-) create mode 100644 enzyme/test/Integration/Sparse/eigen_analysis.cpp create mode 100644 enzyme/test/Integration/Sparse/matrix.h diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index 8dbf0587e807..ee5794cb13e2 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -631,7 +631,6 @@ struct EnzymeSparseAccumulateAttrInfo : public ParsedAttrInfo { return AttributeNotApplied; } V->setInit(expr); - V->dump(); S.MarkVariableReferenced(loc, V); S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); return AttributeApplied; diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 0e8a5019fdfb..84d9c7ea2b72 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -7198,17 +7198,20 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, } } - llvm::errs() << " pre fix inner: " << F << "\n"; + // llvm::errs() << " pre fix inner: " << F << "\n"; // Full simplification while (!Q.empty()) { auto cur = Q.pop_back_val(); + /* std::set prev; for (auto v : Q) prev.insert(v); llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n"; + */ auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL); (void)changed; + /* if (changed) { llvm::errs() << "changed: " << *changed << "\n"; @@ -7217,9 +7220,10 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, llvm::errs() << " + " << *I << "\n"; llvm::errs() << F << "\n\n"; } + */ } - llvm::errs() << " post fix inner " << F << "\n"; + // llvm::errs() << " post fix inner " << F << "\n"; SmallVector, 1> sparseBlocks; bool legalToSparse = true; @@ -7422,8 +7426,9 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, } if (forSparsification.size() == 0) { - llvm::errs() << " found no stores for sparsification\n"; - assert(0); + auto context = &F.getEntryBlock().front(); + EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, + "\n Found no stores for sparsification"); return; } @@ -7652,267 +7657,269 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, } } -bool LowerSparsification(llvm::Function *F, bool replaceAll) { - auto &DL = F->getParent()->getDataLayout(); - bool changed = false; - SmallVector todo; - SetVector toDenseBlocks; - for (auto &BB : *F) { - for (auto &I : BB) { - if (auto CI = dyn_cast(&I)) { - if (getFuncNameFromCall(CI).contains("__enzyme_todense")) { - todo.push_back(CI); - toDenseBlocks.insert(&BB); - } - } - } - } - for (auto CI : todo) { - changed = true; - auto load_fn = cast(getBaseObject(CI->getArgOperand(0))); - auto store_fn = cast(getBaseObject(CI->getArgOperand(1))); - size_t argstart = 2; +void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F, + const llvm::DataLayout &DL) { + auto load_fn = cast(getBaseObject(CI->getArgOperand(0))); + auto store_fn = cast(getBaseObject(CI->getArgOperand(1))); + size_t argstart = 2; #if LLVM_VERSION_MAJOR >= 14 - size_t num_args = CI->arg_size(); + size_t num_args = CI->arg_size(); #else - size_t num_args = CI->getNumArgOperands(); + size_t num_args = CI->getNumArgOperands(); #endif - SmallVector, 1> users; + SmallVector, 1> users; - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - IntegerType *intTy = IntegerType::get(CI->getContext(), 64); - auto toInt = [&](IRBuilder<> &B, llvm::Value *V) { - if (auto PT = dyn_cast(V->getType())) { - if (PT->getAddressSpace() != 0) { + for (auto U : CI->users()) { + users.push_back(std::make_pair(cast(U), CI)); + } + IntegerType *intTy = IntegerType::get(CI->getContext(), 64); + auto toInt = [&](IRBuilder<> &B, llvm::Value *V) { + if (auto PT = dyn_cast(V->getType())) { + if (PT->getAddressSpace() != 0) { #if LLVM_VERSION_MAJOR < 17 #if LLVM_VERSION_MAJOR >= 15 - if (CI->getContext().supportsTypedPointers()) { + if (CI->getContext().supportsTypedPointers()) { #endif - V = B.CreateAddrSpaceCast( - V, PointerType::getUnqual(PT->getPointerElementType())); + V = B.CreateAddrSpaceCast( + V, PointerType::getUnqual(PT->getPointerElementType())); #if LLVM_VERSION_MAJOR >= 15 - } else { - V = B.CreateAddrSpaceCast(V, - PointerType::getUnqual(PT->getContext())); - } -#endif -#else + } else { V = B.CreateAddrSpaceCast(V, PointerType::getUnqual(PT->getContext())); -#endif } - return B.CreatePtrToInt(V, intTy); +#endif +#else + V = B.CreateAddrSpaceCast(V, PointerType::getUnqual(PT->getContext())); +#endif } - auto IT = cast(V->getType()); - if (IT == intTy) - return V; - return B.CreateZExtOrTrunc(V, intTy); - }; - SmallVector toErase; - - ValueToValueMapTy replacements; - replacements[CI] = Constant::getNullValue(CI->getType()); - Instruction *remaining = nullptr; - while (users.size()) { - auto pair = users.back(); - users.pop_back(); - auto U = pair.first; - auto val = pair.second; - if (replacements.count(U)) - continue; + return B.CreatePtrToInt(V, intTy); + } + auto IT = cast(V->getType()); + if (IT == intTy) + return V; + return B.CreateZExtOrTrunc(V, intTy); + }; + SmallVector toErase; - IRBuilder B(U); - if (auto CI = dyn_cast(U)) { - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - auto rep = - B.CreateCast(CI->getOpcode(), replacements[val], CI->getDestTy()); - if (auto I = dyn_cast(rep)) - I->setDebugLoc(CI->getDebugLoc()); - replacements[CI] = rep; - continue; + ValueToValueMapTy replacements; + replacements[CI] = Constant::getNullValue(CI->getType()); + Instruction *remaining = nullptr; + while (users.size()) { + auto pair = users.back(); + users.pop_back(); + auto U = pair.first; + auto val = pair.second; + if (replacements.count(U)) + continue; + + IRBuilder B(U); + if (auto CI = dyn_cast(U)) { + for (auto U : CI->users()) { + users.push_back(std::make_pair(cast(U), CI)); } - if (auto SI = dyn_cast(U)) { - for (auto U : SI->users()) { - users.push_back(std::make_pair(cast(U), SI)); - } - auto tval = SI->getTrueValue(); - auto fval = SI->getFalseValue(); - auto rep = B.CreateSelect( - SI->getCondition(), - replacements.count(tval) ? (Value *)replacements[tval] : tval, - replacements.count(fval) ? (Value *)replacements[fval] : fval); - if (auto I = dyn_cast(rep)) - I->setDebugLoc(SI->getDebugLoc()); - replacements[SI] = rep; - continue; + auto rep = + B.CreateCast(CI->getOpcode(), replacements[val], CI->getDestTy()); + if (auto I = dyn_cast(rep)) + I->setDebugLoc(CI->getDebugLoc()); + replacements[CI] = rep; + continue; + } + if (auto SI = dyn_cast(U)) { + for (auto U : SI->users()) { + users.push_back(std::make_pair(cast(U), SI)); + } + auto tval = SI->getTrueValue(); + auto fval = SI->getFalseValue(); + auto rep = B.CreateSelect( + SI->getCondition(), + replacements.count(tval) ? (Value *)replacements[tval] : tval, + replacements.count(fval) ? (Value *)replacements[fval] : fval); + if (auto I = dyn_cast(rep)) + I->setDebugLoc(SI->getDebugLoc()); + replacements[SI] = rep; + continue; + } + /* + if (auto CI = dyn_cast(U)) { + for (auto U : CI->users()) { + users.push_back(std::make_pair(cast(U), CI)); } - /* - if (auto CI = dyn_cast(U)) { + continue; + } + */ + if (auto CI = dyn_cast(U)) { + auto funcName = getFuncNameFromCall(CI); + if (funcName == "julia.pointer_from_objref") { for (auto U : CI->users()) { users.push_back(std::make_pair(cast(U), CI)); } - continue; - } - */ - if (auto CI = dyn_cast(U)) { - auto funcName = getFuncNameFromCall(CI); - if (funcName == "julia.pointer_from_objref") { - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - auto *F = CI->getCalledOperand(); + auto *F = CI->getCalledOperand(); - SmallVector args; + SmallVector args; #if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) + for (auto &arg : CI->args()) #else - for (auto &arg : CI->arg_operands()) + for (auto &arg : CI->arg_operands()) #endif - args.push_back(replacements[arg]); + args.push_back(replacements[arg]); - auto FT = CI->getFunctionType(); + auto FT = CI->getFunctionType(); - auto cal = cast(B.CreateCall(FT, F, args)); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(CI->getDebugLoc()); - replacements[CI] = cal; - continue; - } - } - if (auto CI = dyn_cast(U)) { - for (auto U : CI->users()) { - users.push_back(std::make_pair(cast(U), CI)); - } - SmallVector inds; - bool allconst = true; - for (auto &ind : CI->indices()) { - if (!isa(ind)) { - allconst = false; - } - inds.push_back(ind); - } - Value *gep; - - if (inds.size() == 1) { - gep = ConstantInt::get( - intTy, - (DL.getTypeSizeInBits(CI->getSourceElementType()) + 7) / 8); - gep = B.CreateMul(intTy == inds[0]->getType() - ? inds[0] - : B.CreateZExtOrTrunc(inds[0], intTy), - gep, "", true, true); - gep = B.CreateAdd(B.CreatePtrToInt(replacements[val], intTy), gep); - gep = B.CreateIntToPtr(gep, CI->getType()); - } else if (!allconst) { - gep = - B.CreateGEP(CI->getSourceElementType(), replacements[val], inds); - if (auto ge = cast(gep)) - ge->setIsInBounds(CI->isInBounds()); - } else { - APInt ai(64, 0); - CI->accumulateConstantOffset(DL, ai); - gep = B.CreateIntToPtr(ConstantInt::get(intTy, ai), CI->getType()); - } - if (auto I = dyn_cast(gep)) - I->setDebugLoc(CI->getDebugLoc()); - replacements[CI] = gep; + auto cal = cast(B.CreateCall(FT, F, args)); + cal->setCallingConv(CI->getCallingConv()); + cal->setDebugLoc(CI->getDebugLoc()); + replacements[CI] = cal; continue; } - if (auto LI = dyn_cast(U)) { - auto diff = toInt(B, replacements[LI->getPointerOperand()]); - SmallVector args; - args.push_back(diff); - for (size_t i = argstart; i < num_args; i++) - args.push_back(CI->getArgOperand(i)); - if (load_fn->getFunctionType()->getNumParams() != args.size()) { - auto fnName = load_fn->getName(); - auto found_numargs = load_fn->getFunctionType()->getNumParams(); - auto expected_numargs = args.size(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " incorrect number of arguments to loader function ", - fnName, " expected ", expected_numargs, " found ", - found_numargs, " - ", *load_fn->getFunctionType()); - continue; - } else { - bool tocontinue = false; - for (size_t i = 0; i < args.size(); i++) { - if (load_fn->getFunctionType()->getParamType(i) != - args[i]->getType()) { - auto fnName = load_fn->getName(); - EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " incorrect type of argument ", i, - " to loader function ", fnName, " expected ", - *args[i]->getType(), " found ", - load_fn->getFunctionType()->params()[i]); - tocontinue = true; - break; - } - } - if (tocontinue) - continue; - } - CallInst *call = B.CreateCall(load_fn, args); - call->setDebugLoc(LI->getDebugLoc()); - Value *tmp = call; - if (tmp->getType() != LI->getType()) - tmp = B.CreateBitCast(tmp, LI->getType()); - LI->replaceAllUsesWith(tmp); - - if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) { - InlineFunctionInfo IFI; - InlineFunction(*call, IFI); - } - toErase.push_back(LI); - continue; + } + if (auto CI = dyn_cast(U)) { + for (auto U : CI->users()) { + users.push_back(std::make_pair(cast(U), CI)); + } + SmallVector inds; + bool allconst = true; + for (auto &ind : CI->indices()) { + if (!isa(ind)) { + allconst = false; + } + inds.push_back(ind); + } + Value *gep; + + if (inds.size() == 1) { + gep = ConstantInt::get( + intTy, (DL.getTypeSizeInBits(CI->getSourceElementType()) + 7) / 8); + gep = B.CreateMul(intTy == inds[0]->getType() + ? inds[0] + : B.CreateZExtOrTrunc(inds[0], intTy), + gep, "", true, true); + gep = B.CreateAdd(B.CreatePtrToInt(replacements[val], intTy), gep); + gep = B.CreateIntToPtr(gep, CI->getType()); + } else if (!allconst) { + gep = B.CreateGEP(CI->getSourceElementType(), replacements[val], inds); + if (auto ge = cast(gep)) + ge->setIsInBounds(CI->isInBounds()); + } else { + APInt ai(64, 0); + CI->accumulateConstantOffset(DL, ai); + gep = B.CreateIntToPtr(ConstantInt::get(intTy, ai), CI->getType()); } - if (auto SI = dyn_cast(U)) { - assert(SI->getValueOperand() != val); - auto diff = toInt(B, replacements[SI->getPointerOperand()]); - SmallVector args; - args.push_back(SI->getValueOperand()); - auto sty = store_fn->getFunctionType()->getParamType(0); - if (args[0]->getType() != - store_fn->getFunctionType()->getParamType(0)) { - if (CastInst::castIsValid(Instruction::BitCast, args[0], sty)) - args[0] = B.CreateBitCast(args[0], sty); - else { - auto args0ty = args[0]->getType(); + if (auto I = dyn_cast(gep)) + I->setDebugLoc(CI->getDebugLoc()); + replacements[CI] = gep; + continue; + } + if (auto LI = dyn_cast(U)) { + auto diff = toInt(B, replacements[LI->getPointerOperand()]); + SmallVector args; + args.push_back(diff); + for (size_t i = argstart; i < num_args; i++) + args.push_back(CI->getArgOperand(i)); + if (load_fn->getFunctionType()->getNumParams() != args.size()) { + auto fnName = load_fn->getName(); + auto found_numargs = load_fn->getFunctionType()->getNumParams(); + auto expected_numargs = args.size(); + EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, + " incorrect number of arguments to loader function ", + fnName, " expected ", expected_numargs, " found ", + found_numargs, " - ", *load_fn->getFunctionType()); + continue; + } else { + bool tocontinue = false; + for (size_t i = 0; i < args.size(); i++) { + if (load_fn->getFunctionType()->getParamType(i) != + args[i]->getType()) { + auto fnName = load_fn->getName(); EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, - " first argument of store function must be the type of " - "the store found fn arg type ", - sty, " expected ", args0ty); + " incorrect type of argument ", i, + " to loader function ", fnName, " expected ", + *args[i]->getType(), " found ", + load_fn->getFunctionType()->params()[i]); + tocontinue = true; + break; } } - args.push_back(diff); - for (size_t i = argstart; i < num_args; i++) - args.push_back(CI->getArgOperand(i)); - auto call = B.CreateCall(store_fn, args); - call->setDebugLoc(SI->getDebugLoc()); - if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) { - InlineFunctionInfo IFI; - InlineFunction(*call, IFI); + if (tocontinue) + continue; + } + CallInst *call = B.CreateCall(load_fn, args); + call->setDebugLoc(LI->getDebugLoc()); + Value *tmp = call; + if (tmp->getType() != LI->getType()) + tmp = B.CreateBitCast(tmp, LI->getType()); + LI->replaceAllUsesWith(tmp); + + if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) { + InlineFunctionInfo IFI; + InlineFunction(*call, IFI); + } + toErase.push_back(LI); + continue; + } + if (auto SI = dyn_cast(U)) { + assert(SI->getValueOperand() != val); + auto diff = toInt(B, replacements[SI->getPointerOperand()]); + SmallVector args; + args.push_back(SI->getValueOperand()); + auto sty = store_fn->getFunctionType()->getParamType(0); + if (args[0]->getType() != store_fn->getFunctionType()->getParamType(0)) { + if (CastInst::castIsValid(Instruction::BitCast, args[0], sty)) + args[0] = B.CreateBitCast(args[0], sty); + else { + auto args0ty = args[0]->getType(); + EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, + " first argument of store function must be the type of " + "the store found fn arg type ", + sty, " expected ", args0ty); } - toErase.push_back(SI); - continue; } - remaining = U; + args.push_back(diff); + for (size_t i = argstart; i < num_args; i++) + args.push_back(CI->getArgOperand(i)); + auto call = B.CreateCall(store_fn, args); + call->setDebugLoc(SI->getDebugLoc()); + if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) { + InlineFunctionInfo IFI; + InlineFunction(*call, IFI); + } + toErase.push_back(SI); + continue; } - for (auto U : toErase) - U->eraseFromParent(); + remaining = U; + } + for (auto U : toErase) + U->eraseFromParent(); - if (!remaining) { - CI->replaceAllUsesWith(Constant::getNullValue(CI->getType())); - CI->eraseFromParent(); - } else if (replaceAll) { - EmitFailure("IllegalSparse", remaining->getDebugLoc(), remaining, - " Illegal remaining use (", *remaining, ") of todense (", *CI, - ") in function ", *F); + if (!remaining) { + CI->replaceAllUsesWith(Constant::getNullValue(CI->getType())); + CI->eraseFromParent(); + } else if (replaceAll) { + EmitFailure("IllegalSparse", remaining->getDebugLoc(), remaining, + " Illegal remaining use (", *remaining, ") of todense (", *CI, + ") in function ", *F); + } +} + +bool LowerSparsification(llvm::Function *F, bool replaceAll) { + auto &DL = F->getParent()->getDataLayout(); + bool changed = false; + SmallVector todo; + SetVector toDenseBlocks; + for (auto &BB : *F) { + for (auto &I : BB) { + if (auto CI = dyn_cast(&I)) { + if (getFuncNameFromCall(CI).contains("__enzyme_todense")) { + todo.push_back(CI); + toDenseBlocks.insert(&BB); + } + } } } + for (auto CI : todo) { + changed = true; + replaceToDense(CI, replaceAll, F, DL); + } + todo.clear(); if (changed && EnzymeAutoSparsity) { PassBuilder PB; @@ -7932,5 +7939,19 @@ bool LowerSparsification(llvm::Function *F, bool replaceAll) { LoopSimplifyPass().run(*F, FAM); fixSparseIndices(*F, FAM, toDenseBlocks); } + + for (auto &BB : *F) { + for (auto &I : BB) { + if (auto CI = dyn_cast(&I)) { + if (getFuncNameFromCall(CI).contains("__enzyme_post_sparse_todense")) { + todo.push_back(CI); + } + } + } + } + for (auto CI : todo) { + changed = true; + replaceToDense(CI, replaceAll, F, DL); + } return changed; } diff --git a/enzyme/test/Integration/Sparse/eigen_analysis.cpp b/enzyme/test/Integration/Sparse/eigen_analysis.cpp new file mode 100644 index 000000000000..d8dc311f957a --- /dev/null +++ b/enzyme/test/Integration/Sparse/eigen_analysis.cpp @@ -0,0 +1,261 @@ +// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi + +#include +#include +#include +#include +#include +#include +#include + +#include "matrix.h" + + +template +__attribute__((always_inline)) +static T face_load(unsigned long long offset, T* x, const int* faces) { + offset /= sizeof(T); + return x[faces[offset]]; +} + +template +__attribute__((always_inline)) +static void face_store(unsigned long long offset, T* x, const int* faces) { + assert(0 && "store is not legal"); +} + + +template +__attribute__((always_inline)) +static T area_load(unsigned long long offset, T* pos0, const int* faces) { + offset /= sizeof(T); + + int idx = offset / 9; + + int inc = offset % 9; + + int i = faces[3*idx]; + int j = faces[3*idx+1]; + int k = faces[3*idx+2]; + + /// pos_data[0:3] -> pos[3*faces[i]:3*faces[i]+3] + /// pos_data[3:6] -> pos[3*faces[j]:3*faces[j]+3] + /// pos_data[6:9] -> pos[3*faces[k]:3*faces[k]+3] + + if (inc < 3) { + return pos0[3*i+inc]; + } else if (inc < 6) { + return pos0[3*j+inc-3]; + } else { + return pos0[3*k+inc-6]; + } +} + +template +__attribute__((always_inline)) +static void area_store(unsigned long long offset, T* pos0, const int* faces) { + assert(0 && "store is not legal"); +} + +template +__attribute__((always_inline)) +static T eigenstuffM(const T *__restrict__ pos0, size_t n, const int *__restrict__ faces, const T *__restrict__ x) { + T sum = 0; + __builtin_assume(n != 0); + for (size_t idx=0; idx real_x[faces[i]] + T xj = x[3 * idx + 1]; + T xk = x[3 * idx + 2]; + + const T* pos_data = &pos0[9 * idx]; + /// + /// pos_data[0:3] -> pos[3*faces[i]:3*faces[i]+3] + /// pos_data[3:6] -> pos[3*faces[j]:3*faces[j]+3] + /// pos_data[6:9] -> pos[3*faces[k]:3*faces[k]+3] + + T tri_area = area(&pos_data[0], &pos_data[3], &pos_data[6]); + + sum += (xi * xi + xj * xj + xk * xk) * (1.0 / 3.0) * tri_area; // barycentric mass lumping + } + return sum; +} + + +// Calculate total energy for all faces in 3D +template +__attribute__((always_inline)) +static T eigenstuffL(const T *__restrict__ x, size_t num_faces, const int *__restrict__ faces, const T *__restrict__ pos0) { + T sum = 0; + __builtin_assume(num_faces != 0); + for (size_t idx=0; idx(g, g) * area(&pos0[3*i], &pos0[3*j], &pos0[3*k]); + } + + return sum; +} + + +template +__attribute__((always_inline)) +static void gradient_ip(const T *__restrict__ pos0, const size_t num_faces, const int* faces, const T *__restrict__ x, T *__restrict__ out) +{ + __enzyme_autodiff((void *)eigenstuffM, + enzyme_const, pos0, + enzyme_const, num_faces, + enzyme_const, faces, + enzyme_dup, x, out); +} + + +template +__attribute__((always_inline)) +static T ident_load(unsigned long long offset, size_t i) { + return (offset / sizeof(T) == i) ? T(1) : T(0); +} + + +template +__attribute__((always_inline)) +static void err_store(T val, unsigned long long offset, size_t i) { + assert(0 && "store is not legal"); +} + + +template +__attribute__((always_inline)) +static T zero_load(unsigned long long offset, size_t i, std::vector> &hess) { + return T(0); +} + + +__attribute__((enzyme_sparse_accumulate)) +void inner_store(size_t offset, size_t i, float val, std::vector> &hess) { + hess.push_back(Triple(offset, i, val)); +} + +template +__attribute__((always_inline)) +static void csr_store(T val, unsigned long long offset, size_t i, std::vector> &hess) { + if (val == 0.0) return; + offset /= sizeof(T); + inner_store(offset, i, val, hess); +} + +template +__attribute__((noinline)) +std::vector> hessian(const T*__restrict__ pos0, size_t num_faces, const int* faces, const T*__restrict__ x, size_t x_pts) +{ + float* x2 = __enzyme_post_sparse_todense(face_load, face_store, x, faces); + + /* + float* x3 = (float*)malloc(sizeof(float)*9*num_faces); + for (size_t idx=0; idx(area_load, area_store, pos0, faces); + std::vector> hess; + __builtin_assume(x_pts != 0); + for (size_t i=0; i<3*x_pts; i++) + __enzyme_fwddiff((void *)gradient_ip, + enzyme_const, pos02, + enzyme_const, num_faces, + enzyme_const, faces, + enzyme_dup, x2, __enzyme_todense(ident_load, err_store, i), + enzyme_dupnoneed, nullptr, __enzyme_todense(zero_load, csr_store, i, &hess)); + return hess; +} + +int main() { + const size_t x_pts = 1; + const float x[] = {0.0, 1.0, 0.0}; + + + const size_t num_faces = 1; + const int faces[] = {0, 1, 2}; + + const float pos0[] = {1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 3.0, 1.0, 3.0}; + + // Call eigenstuffM_simple + const float resultM = eigenstuffM(pos0, num_faces, faces, x); + printf("Result for eigenstuffM_simple: %f\n", resultM); + + // Call eigenstuffL_simple + const float resultL = eigenstuffL(pos0, num_faces, faces, x); + printf("Result for eigenstuffL_simple: %f\n", resultL); + + float dx[sizeof(x)/sizeof(x[0])]; + for (size_t i=0; i +#include + +template +struct Triple { + size_t row; + size_t col; + T val; + Triple(Triple&&) = default; + Triple(size_t row, size_t col, T val) : row(row), col(col), val(val) {} +}; + +extern int enzyme_width; +extern int enzyme_dup; +extern int enzyme_dupv; +extern int enzyme_const; +extern int enzyme_dupnoneed; + +template +extern T __enzyme_autodiff(void*, Tys...); + +template +extern T __enzyme_fwddiff(void *, Tys...); + +template +extern T __enzyme_todense(Tys...); + +template +extern T __enzyme_post_sparse_todense(Tys...); + +template +__attribute__((always_inline)) +static void elementwise_difference(T (&out)[n], const T x[n], const T y[n]) { + #pragma clang loop unroll(full) + for (int i=0; i +__attribute__((always_inline)) +static void elementwise_sum(T (&out)[n], const T x[n], const T y[n]) { + #pragma clang loop unroll(full) + for (int i=0; i +__attribute__((always_inline)) +static T dot_product(const T a[n], const T b[n]) { + T result = 0.0; + #pragma clang loop unroll(full) + for (size_t i = 0; i < n; ++i) { + result += a[i] * b[i]; + } + return result; +} + + +template +__attribute__((always_inline)) +static T norm(const T v[n]) { + T sum_squares = 0.0; + #pragma clang loop unroll(full) + for (size_t i=0; i +__attribute__((always_inline)) +static void transpose(T (&out)[n][m], const T in[m][n]) { + #pragma clang loop unroll(full) + for (int i=0; i +__attribute__((always_inline)) +static void matrix_multiply(T (&result)[m][k], const T matrix1[m][n], const T matrix2[n][k]) { + #pragma clang loop unroll(full) + for (int i = 0; i < m; ++i) { + #pragma clang loop unroll(full) + for (int j = 0; j < k; ++j) { + result[i][j] = 0.0; + #pragma clang loop unroll(full) + for (int z = 0; z < n; ++z) { + result[i][j] += matrix1[i][z] * matrix2[z][j]; + } + } + } +} + + +template +__attribute__((always_inline)) +static void inv(T (&out)[3][3], const T (&F)[3][3]) { + T det = F[0][0] * (F[1][1] * F[2][2] - F[1][2] * F[2][1]) + - F[0][1] * (F[1][0] * F[2][2] - F[1][2] * F[2][0]) + + F[0][2] * (F[1][0] * F[2][1] - F[1][1] * F[2][0]); + + T inv_det = 1 / det; + + out[0][0] = (F[1][1] * F[2][2] - F[1][2] * F[2][1]) * inv_det; + out[0][1] = (F[0][2] * F[2][1] - F[0][1] * F[2][2]) * inv_det; + out[0][2] = (F[0][1] * F[1][2] - F[0][2] * F[1][1]) * inv_det; + + out[1][0] = (F[1][2] * F[2][0] - F[1][0] * F[2][2]) * inv_det; + out[1][1] = (F[0][0] * F[2][2] - F[0][2] * F[2][0]) * inv_det; + out[1][2] = (F[0][2] * F[1][0] - F[0][0] * F[1][2]) * inv_det; + + out[2][0] = (F[1][0] * F[2][1] - F[1][1] * F[2][0]) * inv_det; + out[2][1] = (F[0][1] * F[2][0] - F[0][0] * F[2][1]) * inv_det; + out[2][2] = (F[0][0] * F[1][1] - F[0][1] * F[1][0]) * inv_det; +} + + +template +__attribute__((always_inline)) +static void inv(T (&out)[2][2], const T (&F)[2][2]) { + T det = F[0][0] * F[1][1] - F[0][1] * F[1][0]; + + T inv_det = 1 / det; + + out[0][0] = F[1][1] * inv_det; + out[0][1] = -F[0][1] * inv_det; + out[1][0] = -F[1][0] * inv_det; + out[1][1] = F[0][0] * inv_det; +} + +template +__attribute__((always_inline)) +static void pseudo_inverse(T (&matTsqrinv)[n][m], const T mat[m][n]) { + T matT[n][m]; + transpose(matT, mat); + T matmatT[m][m]; + matrix_multiply(matmatT, mat, matT); + T sqrinv[m][m]; + inv(sqrinv, matmatT); + matrix_multiply(matTsqrinv, matT, sqrinv); +} + +// m is 2 n is 3 +template +__attribute__((always_inline)) +static void get_pos( + T (&__restrict__ out)[n][m], + const float *__restrict__ pos, + const int idx[n]) { + + static_assert(m == 3, "Only Vector3 is supported"); + + // extract the 3d points at idx[0], idx[1], idx[2], idx[3] + #pragma clang loop unroll(full) + for (int i = 0; i < n; ++i) { + out[i][0] = pos[m * idx[i]]; + out[i][1] = pos[m * idx[i] + 1]; + out[i][2] = pos[m * idx[i] + 2]; + } +} + + +// m is 2 n is 3 +template +__attribute__((always_inline)) +static void get_pos_affine( + T (&__restrict__ out)[n][m], + const float *__restrict__ pos) { + + static_assert(m == 3, "Only Vector3 is supported"); + + // extract the 3d points at idx[0], idx[1], idx[2], idx[3] + #pragma clang loop unroll(full) + for (int i = 0; i < n; ++i) { + out[i][0] = pos[m * i]; + out[i][1] = pos[m * i + 1]; + out[i][2] = pos[m * i + 2]; + } +} + +template +__attribute__((always_inline)) +static void cross(T (&out)[3], const T v1[3], const T v2[3]) { + out[0] = v1[1] * v2[2] - v1[2] * v2[1]; + out[1] = v1[2] * v2[0] - v1[0] * v2[2]; + out[2] = v1[0] * v2[1] - v1[1] * v2[0]; +} + + +template +__attribute__((always_inline)) +static T area(const T *__restrict__ u, const T *__restrict__ v, const T *__restrict__ w) { + T diff1[3]; + elementwise_difference(diff1, v, u); + T diff2[3]; + elementwise_difference(diff2, w, u); + T cross_product[3]; + cross(cross_product, diff1, diff2); + return 0.5 * norm(cross_product); +} \ No newline at end of file From 04d60edca5277dc3b9be35ee7d3bbcfc0d26e35a Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 17 Jan 2024 17:28:25 -0500 Subject: [PATCH 005/131] make fpic work on windows (#1614) --- enzyme/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index da45a63b4319..13a9b4973c54 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -13,7 +13,8 @@ add_definitions(-DENZYME_VERSION_MAJOR=${ENZYME_MAJOR_VERSION}) add_definitions(-DENZYME_VERSION_MINOR=${ENZYME_MINOR_VERSION}) add_definitions(-DENZYME_VERSION_PATCH=${ENZYME_PATCH_VERSION}) -SET(CMAKE_CXX_FLAGS "-Wall -fPIC -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else") +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else") SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "-O2") From 116082715902fc06c7828da84e3a55f60b290e69 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 18 Jan 2024 00:44:30 -0500 Subject: [PATCH 006/131] Some function ordering (#1616) --- enzyme/Enzyme/FunctionUtils.cpp | 283 +++++++++++++++++++------------- 1 file changed, 170 insertions(+), 113 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 84d9c7ea2b72..0e14c20d9ef2 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2867,6 +2867,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, I->eraseFromParent(); return candidate; } + break; } } return push(I); @@ -2972,6 +2973,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, replaceAndErase(cur, candidate); return "CSE"; } + break; } } } @@ -3034,7 +3036,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return with; } if (isNot(val, orig)) { - return B.CreateNot(with); + return pushcse(B.CreateNot(with)); } if (isa(val)) return val; @@ -3092,7 +3094,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (lhs == I->getOperand(0) && rhs == I->getOperand(1)) return val; push(I); - return push(B.CreateOr(lhs, rhs, "sel." + I->getName())); + return pushcse(B.CreateOr(lhs, rhs, "sel." + I->getName())); } if (I->getOpcode() == Instruction::Xor) { @@ -3321,6 +3323,24 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, changed = true; continue; } + if (auto op = dyn_cast(v)) { + if (auto tc = dyn_cast(op->getTrueValue())) + if (tc->isZero()) { + operands.push_back(pushcse(B.CreateUIToFP( + pushcse(B.CreateNot(op->getCondition())), op->getType()))); + operands.push_back(op->getFalseValue()); + changed = true; + continue; + } + if (auto tc = dyn_cast(op->getFalseValue())) + if (tc->isZero()) { + operands.push_back( + pushcse(B.CreateUIToFP(op->getCondition(), op->getType()))); + operands.push_back(op->getTrueValue()); + changed = true; + continue; + } + } operands.push_back(v); } if (constval) @@ -3464,6 +3484,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (auto P = isSum(cur)) { // whether negated SmallVector, 1> conditions; + bool legal = true; for (auto &v : callOperands(P)) { // z = uitofp i1 c to float -> select c, (prod withot z), 0 if (auto op = dyn_cast(v)) { @@ -3491,24 +3512,27 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, continue; } } + legal = false; + break; } Value *condition = nullptr; - for (size_t i = 0; i < conditions.size(); i++) { - size_t count = 0; - for (size_t j = 0; j < conditions.size(); j++) { - if (((conditions[i].first == conditions[j].first) && - (conditions[i].second == conditions[i].second)) || - ((isNot(conditions[i].first, conditions[j].first) && - (conditions[i].second != conditions[i].second)))) - count++; - } - if (count == conditions.size() && count > 1) { - condition = conditions[i].first; - if (conditions[i].second) - condition = pushcse(B.CreateNot(condition, "sumpnot")); - break; + if (legal) + for (size_t i = 0; i < conditions.size(); i++) { + size_t count = 0; + for (size_t j = 0; j < conditions.size(); j++) { + if (((conditions[i].first == conditions[j].first) && + (conditions[i].second == conditions[i].second)) || + ((isNot(conditions[i].first, conditions[j].first) && + (conditions[i].second != conditions[i].second)))) + count++; + } + if (count == conditions.size() && count > 1) { + condition = conditions[i].first; + if (conditions[i].second) + condition = pushcse(B.CreateNot(condition, "sumpnot")); + break; + } } - } if (condition) { @@ -3540,6 +3564,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, continue; } } + llvm::errs() << " unhandled call op sumselect: " << *v << "\n"; assert(0); } @@ -3837,24 +3862,23 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, Value *tmp = nullptr; if (isa(a0)) - tmp = B.CreateSExt(a0->getOperand(0), a0->getType()); + tmp = pushcse(B.CreateSExt(a0->getOperand(0), a0->getType())); else if (isa(a0)) - tmp = B.CreateZExt(a0->getOperand(0), a0->getType()); + tmp = pushcse(B.CreateZExt(a0->getOperand(0), a0->getType())); else assert(0); - tmp = pushcse(tmp); replaceAndErase(cur, tmp); return "NegSZExtI1"; } - // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 - // divides C1 if ((cur->getOpcode() == Instruction::LShr || cur->getOpcode() == Instruction::SDiv || cur->getOpcode() == Instruction::UDiv) && cur->isExact()) if (auto C2 = dyn_cast(cur->getOperand(1))) - if (auto mul = dyn_cast(cur->getOperand(0))) + if (auto mul = dyn_cast(cur->getOperand(0))) { + // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 + // divides C1 if (mul->getOpcode() == Instruction::Mul) for (int i0 = 0; i0 < 2; i0++) if (auto C1 = dyn_cast(mul->getOperand(i0))) { @@ -3871,16 +3895,49 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, else APInt::sdivrem(lhs, rhs, div, rem); if (rem == 0) { - auto res = B.CreateMul(mul->getOperand(1 - i0), - ConstantInt::get(cur->getType(), div), - "mdiv." + cur->getName(), - mul->hasNoUnsignedWrap(), - mul->hasNoSignedWrap()); + auto res = pushcse(B.CreateMul( + mul->getOperand(1 - i0), + ConstantInt::get(cur->getType(), div), + "mdiv." + cur->getName(), mul->hasNoUnsignedWrap(), + mul->hasNoSignedWrap())); push(mul); replaceAndErase(cur, res); return "IMulDivConst"; } } + // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if C2 + if (mul->getOpcode() == Instruction::Add) + for (int i0 = 0; i0 < 2; i0++) + if (auto C1 = dyn_cast(mul->getOperand(i0))) { + auto lhs = C1->getValue(); + APInt rhs = C2->getValue(); + if (cur->getOpcode() == Instruction::LShr) { + rhs = APInt(rhs.getBitWidth(), 1) << rhs; + } + + APInt div, rem; + if (cur->getOpcode() == Instruction::LShr || + cur->getOpcode() == Instruction::UDiv) + APInt::udivrem(lhs, rhs, div, rem); + else + APInt::sdivrem(lhs, rhs, div, rem); + if (rem == 0 && ((mul->hasNoUnsignedWrap() && + (cur->getOpcode() == Instruction::LShr || + cur->getOpcode() == Instruction::UDiv)) || + (mul->hasNoSignedWrap() && + (cur->getOpcode() == Instruction::AShr || + cur->getOpcode() == Instruction::SDiv)))) { + auto res = pushcse(B.CreateAdd( + mul->getOperand(1 - i0), + ConstantInt::get(cur->getType(), div), + "madd." + cur->getName(), mul->hasNoUnsignedWrap(), + mul->hasNoSignedWrap())); + push(mul); + replaceAndErase(cur, res); + return "IAddDivConst"; + } + } + } // mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2) if (cur->getOpcode() == Instruction::FMul) @@ -4599,46 +4656,6 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } - // and a, b -> and a b[with a true] - if (cur->getOpcode() == Instruction::And) { - auto lhs = replace(cur->getOperand(0), cur->getOperand(1), - ConstantInt::getTrue(cur->getContext())); - if (lhs != cur->getOperand(0)) { - auto res = pushcse( - B.CreateAnd(lhs, cur->getOperand(1), "postand." + cur->getName())); - replaceAndErase(cur, res); - return "AndReplaceLHS"; - } - auto rhs = replace(cur->getOperand(1), cur->getOperand(0), - ConstantInt::getTrue(cur->getContext())); - if (rhs != cur->getOperand(1)) { - auto res = pushcse( - B.CreateAnd(cur->getOperand(0), rhs, "postand." + cur->getName())); - replaceAndErase(cur, res); - return "AndReplaceRHS"; - } - } - - // or a, b -> or a b[with a false] - if (cur->getOpcode() == Instruction::Or) { - auto lhs = replace(cur->getOperand(0), cur->getOperand(1), - ConstantInt::getFalse(cur->getContext())); - if (lhs != cur->getOperand(0)) { - auto res = pushcse( - B.CreateOr(lhs, cur->getOperand(1), "postor." + cur->getName())); - replaceAndErase(cur, res); - return "OrReplaceLHS"; - } - auto rhs = replace(cur->getOperand(1), cur->getOperand(0), - ConstantInt::getFalse(cur->getContext())); - if (rhs != cur->getOperand(1)) { - auto res = pushcse( - B.CreateOr(cur->getOperand(0), rhs, "postor." + cur->getName())); - replaceAndErase(cur, res); - return "OrReplaceRHS"; - } - } - /* // and (i == c), (i != d) -> and (i == c) && (c != d) if (cur->getOpcode() == Instruction::And) { @@ -4869,7 +4886,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (isSum(cur)) opcode = Instruction::FAdd; auto Ty = B.getInt64Ty(); - SmallVector temporaries; + SmallPtrSet temporaries; SmallVector precasts; Value *lhs = nullptr; @@ -4930,7 +4947,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, else lhs = B.CreateSExt(ext->getOperand(0), Ty); if (auto I = dyn_cast(lhs)) - temporaries.push_back(I); + if (I != ext->getOperand(0)) + temporaries.insert(I); } } } @@ -5028,13 +5046,18 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, else rhs = B.CreateSExt(ext->getOperand(0), Ty); if (auto I = dyn_cast(rhs)) - temporaries.push_back(I); + if (I != ext->getOperand(0)) + temporaries.insert(I); } } } if (lhs && rhs) { Value *res = nullptr; + if (temporaries.count(dyn_cast(lhs))) + lhs = pushcse(lhs); + if (temporaries.count(dyn_cast(rhs))) + rhs = pushcse(rhs); switch (opcode) { case Instruction::FAdd: res = B.CreateAdd(lhs, rhs, "", false, true); @@ -5050,8 +5073,6 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, llvm_unreachable("Illegal opcode"); } res = pushcse(res); - for (auto I : temporaries) - push(I); for (auto I : precasts) push(I); /* @@ -5065,7 +5086,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, MDNode::get(I->getContext(), vals)); } */ - auto ext = B.CreateSIToFP(res, cur->getType()); + auto ext = pushcse(B.CreateSIToFP(res, cur->getType())); replaceAndErase(cur, ext); return "BinopExtToExtBinop"; @@ -5429,19 +5450,6 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } - if (auto SI = dyn_cast(cur)) { - auto tval = replace(SI->getTrueValue(), SI->getCondition(), - ConstantInt::getTrue(SI->getContext())); - auto fval = replace(SI->getFalseValue(), SI->getCondition(), - ConstantInt::getFalse(SI->getContext())); - if (tval != SI->getTrueValue() || fval != SI->getFalseValue()) { - auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval, - "postsel." + SI->getName())); - replaceAndErase(cur, res); - return "SelectReplace"; - } - } - // select cmp, (ext tval), (ext fval) -> (cmp & tval) | (!cmp & fval) if (auto SI = dyn_cast(cur)) { @@ -5492,6 +5500,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, replaceAndErase(cur, ext); return "SelectI1"; } + if (auto PN = dyn_cast(cur)) { B.SetInsertPoint(PN->getParent()->getFirstNonPHI()); if (SE.isSCEVable(PN->getType())) { @@ -5519,7 +5528,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, push(U); } auto point = PN->getParent()->getFirstNonPHI(); - auto tmp = B.CreatePHI(cur->getType(), 1); + auto tmp = cast(pushcse(B.CreatePHI(cur->getType(), 1))); cur->replaceAllUsesWith(tmp); cur->eraseFromParent(); @@ -5640,15 +5649,15 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, legal = false; } if (legal) { - auto PN2 = B.CreatePHI(B.getInt64Ty(), PN->getNumIncomingValues()); + auto PN2 = cast( + pushcse(B.CreatePHI(B.getInt64Ty(), PN->getNumIncomingValues()))); PN2->takeName(PN); for (auto val : llvm::enumerate(negOps)) PN2->addIncoming(val.value(), PN->getIncomingBlock(val.index())); push(PN2); - auto fneg = B.CreateSIToFP(PN2, PN->getType()); - push(fneg); + auto fneg = pushcse(B.CreateSIToFP(PN2, PN->getType())); for (auto I : prevNegOps) push(I); @@ -5665,7 +5674,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { auto v = PN->getIncomingValue(i); if (auto C = dyn_cast(v)) { - negOps.push_back(C->isZero() ? C : B.CreateFNeg(C)); + negOps.push_back(C->isZero() ? C : pushcse(B.CreateFNeg(C))); continue; } if (auto fneg = dyn_cast(v)) { @@ -5683,8 +5692,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, push(PN); - auto fneg = B.CreateFNeg(PN); - push(fneg); + auto fneg = pushcse(B.CreateFNeg(PN)); for (auto &U : cur->uses()) { if (U.getUser() == fneg) @@ -5706,7 +5714,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { auto v = PN->getIncomingValue(i); if (auto C = dyn_cast(v)) { - negOps.push_back(B.CreateNeg(C)); + negOps.push_back(pushcse(B.CreateNeg(C))); continue; } if (auto fneg = dyn_cast(v)) { @@ -5726,8 +5734,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, push(PN); - auto fneg = B.CreateNeg(PN); - push(fneg); + auto fneg = pushcse(B.CreateNeg(PN)); for (auto &U : cur->uses()) { if (U.getUser() == fneg) @@ -5909,23 +5916,20 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } if (legal && changed) { - auto lhsPN = - B.CreatePHI(lhsOps[0]->getType(), PN->getNumIncomingValues()); + auto lhsPN = cast(pushcse( + B.CreatePHI(lhsOps[0]->getType(), PN->getNumIncomingValues()))); PHINode *rhsPN = nullptr; if (numOps == 2) - rhsPN = - B.CreatePHI(rhsOps[0]->getType(), PN->getNumIncomingValues()); + rhsPN = cast(pushcse( + B.CreatePHI(rhsOps[0]->getType(), PN->getNumIncomingValues()))); for (auto val : llvm::enumerate(lhsOps)) lhsPN->addIncoming(val.value(), PN->getIncomingBlock(val.index())); - push(lhsPN); - if (numOps == 2) { for (auto val : llvm::enumerate(rhsOps)) rhsPN->addIncoming(val.value(), PN->getIncomingBlock(val.index())); - push(rhsPN); } Value *fneg = nullptr; @@ -6039,16 +6043,69 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, iter++) { (*iter)->moveBefore(br); } - auto sel = B.CreateSelect( + auto sel = pushcse(B.CreateSelect( br->getCondition(), PN->getIncomingValueForBlock(prev), PN->getIncomingValueForBlock(br->getSuccessor(1)), - "tphisel." + cur->getName()); + "tphisel." + cur->getName())); replaceAndErase(cur, sel); return "TPhiSel"; } } } + + if (auto SI = dyn_cast(cur)) { + auto tval = replace(SI->getTrueValue(), SI->getCondition(), + ConstantInt::getTrue(SI->getContext())); + auto fval = replace(SI->getFalseValue(), SI->getCondition(), + ConstantInt::getFalse(SI->getContext())); + if (tval != SI->getTrueValue() || fval != SI->getFalseValue()) { + auto res = pushcse(B.CreateSelect(SI->getCondition(), tval, fval, + "postsel." + SI->getName())); + replaceAndErase(cur, res); + return "SelectReplace"; + } + } + + // and a, b -> and a b[with a true] + if (cur->getOpcode() == Instruction::And) { + auto lhs = replace(cur->getOperand(0), cur->getOperand(1), + ConstantInt::getTrue(cur->getContext())); + if (lhs != cur->getOperand(0)) { + auto res = pushcse( + B.CreateAnd(lhs, cur->getOperand(1), "postand." + cur->getName())); + replaceAndErase(cur, res); + return "AndReplaceLHS"; + } + auto rhs = replace(cur->getOperand(1), cur->getOperand(0), + ConstantInt::getTrue(cur->getContext())); + if (rhs != cur->getOperand(1)) { + auto res = pushcse( + B.CreateAnd(cur->getOperand(0), rhs, "postand." + cur->getName())); + replaceAndErase(cur, res); + return "AndReplaceRHS"; + } + } + + // or a, b -> or a b[with a false] + if (cur->getOpcode() == Instruction::Or) { + auto lhs = replace(cur->getOperand(0), cur->getOperand(1), + ConstantInt::getFalse(cur->getContext())); + if (lhs != cur->getOperand(0)) { + auto res = pushcse( + B.CreateOr(lhs, cur->getOperand(1), "postor." + cur->getName())); + replaceAndErase(cur, res); + return "OrReplaceLHS"; + } + auto rhs = replace(cur->getOperand(1), cur->getOperand(0), + ConstantInt::getFalse(cur->getContext())); + if (rhs != cur->getOperand(1)) { + auto res = pushcse( + B.CreateOr(cur->getOperand(0), rhs, "postor." + cur->getName())); + replaceAndErase(cur, res); + return "OrReplaceRHS"; + } + } return {}; } @@ -6121,6 +6178,7 @@ bool cannotDependOnLoopIV(const SCEV *S, const Loop *L) { return !L->contains(I->getParent()); } if (auto addrec = dyn_cast(S)) { + return false; if (addrec->getLoop() == L) return false; for (auto o : addrec->operands()) @@ -7104,7 +7162,9 @@ getSparseConditions(bool &legal, Value *val, } if (scope) EmitFailure("NoSparsification", I->getDebugLoc(), I, - " No sparsification: not sparse solvable(icmp): ", *sub1); + "F: ", *I->getParent()->getParent(), "\n", + " No sparsification: not sparse solvable(icmp): ", *I, + " via ", *sub1); legal = false; return defaultFloat; } @@ -7203,24 +7263,21 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, // Full simplification while (!Q.empty()) { auto cur = Q.pop_back_val(); - /* std::set prev; for (auto v : Q) prev.insert(v); - llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n"; - */ + // llvm::errs() << "\n\n\n\n" << F << "\n"; + llvm::errs() << "cur: " << *cur << "\n"; auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL); (void)changed; - /* if (changed) { llvm::errs() << "changed: " << *changed << "\n"; for (auto I : Q) if (!prev.count(I)) llvm::errs() << " + " << *I << "\n"; - llvm::errs() << F << "\n\n"; + // llvm::errs() << F << "\n\n"; } - */ } // llvm::errs() << " post fix inner " << F << "\n"; From 14f45263be7482c65fdfa254eb5034b216e4fd46 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Thu, 18 Jan 2024 17:27:56 +0100 Subject: [PATCH 007/131] [mlir] proper "undefined" state in pointer analysis and correct handling of fresh state (#1572) * [mlir] add a proper "undefined" state to pointer analyses Historically, `AliasClassSet` has been using empty set as the bottom state of the join-semilattice. This made it impossible to differentiate between "not yet analyzed" and "known not to alias with anything" states. While not necessarily problematic for the naive alias analysis, this is crucial for points-to-pointer analysis that was forced to become unnecessarily conservative by treating "not yet analyzed" points-to state as "points to unknown" (top sate of the join-semilattice). Worse, some transfer functions weren't trivially monotonous because of that. Introduce a new "undefined" set to `AliasClassSet` and lattice. Update PointsToLattice to explicitly store the cases when a known alias class points to unknown alias class and interpret unknown (not yet listed) alias class as pointing to undefined alias class. Make all transfer functions monotonous by only processing known points-to entries. Unknown points-to classes can be treated by the client as potentially pointing to unknown, but so far this wasn't needed as, at the fixpoint state, all available classes should have been analyzed. Assertions are put in place in activity analysis to catch undefined classes. Additionally, make it more difficult to move the join-semilattice to a previous (<= current) value by exposing mostly the APIs that can only advance the lattice. This is not yet complete due to partially incorect handling of distinct identifers that get recreated every time in transfer functions. * [mlir] correct handling of fresh alias sets Previously, dataflow initialization and transfer functions would allocate a fresh alias class for allocation-like cases that are known not to alias with anything, and do so _on every invocation_. This is unsound in the general case as the class may keep changing, which could affect convergence to fix point. Instead, add a cache that associates a class with a value on the first invocation of the dataflow-related functions and use the cached value on each subsequent invocation. Fix a couple of issues unconvered due to the change: - not recreate the common scope that function results may be pointing to for each capturable operand; - invert the condition of handling nocapture in function results; - add a TODO for better handling of alias classes of function results in presence of function-level attributes. * [mlir] better handling of function result attributes Specifically, handle the siutation where some alias classes are associateded with pointer operands that cannot be written into. In such a case, even if the results of the function may alias the operands, the alias classes of the operands are known not to point to anything that the function could have written. Additionally, don't include results marked as `noalias` into the common alias class of function results. * Update DataFlowActivityAnalysis.cpp --------- Co-authored-by: William Moses --- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp | 577 +++++++++++------- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h | 187 ++++-- .../Analysis/DataFlowActivityAnalysis.cpp | 158 +++-- .../Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp | 26 +- .../MLIR/AliasAnalysis/func_attributes.mlir | 224 ++++++- 5 files changed, 810 insertions(+), 362 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index 691be34eb25c..b584e8e64a1c 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -40,6 +40,7 @@ // TODO: remove this once aliasing interface is factored out. #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "llvm/ADT/SetOperations.h" using namespace mlir; using namespace mlir::dataflow; @@ -48,75 +49,64 @@ static bool isPointerLike(Type type) { return isa(type); } -const enzyme::AliasClassSet enzyme::AliasClassSet::emptySet = AliasClassSet(); +const enzyme::AliasClassSet enzyme::AliasClassSet::undefinedSet = + AliasClassSet(enzyme::AliasClassSet::State::Undefined); const enzyme::AliasClassSet enzyme::AliasClassSet::unknownSet = - AliasClassSet(true); + AliasClassSet(enzyme::AliasClassSet::State::Unknown); ChangeResult enzyme::AliasClassSet::join(const AliasClassSet &other) { - if (unknown) { + if (isUnknown()) return ChangeResult::NoChange; - } - if (other.unknown) { - unknown = true; + if (isUndefined() && other.isUndefined()) + return ChangeResult::NoChange; + if (other.isUnknown()) { + state = State::Unknown; return ChangeResult::Change; } - return insert(other.aliasClasses); + ChangeResult result = updateStateToDefined(); + return insert(other.aliasClasses) | result; } ChangeResult enzyme::AliasClassSet::insert(const DenseSet &classes) { - if (unknown) + if (isUnknown()) return ChangeResult::NoChange; size_t oldSize = aliasClasses.size(); aliasClasses.insert(classes.begin(), classes.end()); - return aliasClasses.size() == oldSize ? ChangeResult::NoChange - : ChangeResult::Change; -} - -ChangeResult enzyme::AliasClassSet::markFresh(Attribute debugLabel) { - reset(); - - auto freshClass = AliasClassLattice::getFresh(debugLabel); - aliasClasses.insert(freshClass); - return ChangeResult::Change; + ChangeResult result = aliasClasses.size() == oldSize ? ChangeResult::NoChange + : ChangeResult::Change; + return updateStateToDefined() | result; } ChangeResult enzyme::AliasClassSet::markUnknown() { - if (unknown) + if (isUnknown()) return ChangeResult::NoChange; - unknown = true; - aliasClasses.clear(); - return ChangeResult::Change; -} - -ChangeResult enzyme::AliasClassSet::reset() { - if (aliasClasses.empty() && !unknown) { - return ChangeResult::NoChange; - } - unknown = false; + state = State::Unknown; aliasClasses.clear(); return ChangeResult::Change; } bool enzyme::AliasClassSet::isCanonical() const { - return !unknown || aliasClasses.empty(); + return state == State::Defined || aliasClasses.empty(); } bool enzyme::AliasClassSet::operator==( const enzyme::AliasClassSet &other) const { assert(isCanonical() && other.isCanonical()); - return unknown == other.unknown && - llvm::equal(aliasClasses, other.aliasClasses); + return state == other.state && llvm::equal(aliasClasses, other.aliasClasses); } ChangeResult enzyme::AliasClassSet::foreachClass( - function_ref callback) const { + function_ref callback) const { + if (state != State::Defined) + return callback(nullptr, state); + ChangeResult result = ChangeResult::NoChange; for (DistinctAttr attr : aliasClasses) - result |= callback(attr); + result |= callback(attr, state); return result; } @@ -132,131 +122,116 @@ static ChangeResult mergeSets(DenseSet &dest, const DenseSet &src) { } void enzyme::PointsToSets::print(raw_ostream &os) const { + if (pointsTo.empty()) { + os << "\n"; + return; + } for (const auto &[srcClass, destClasses] : pointsTo) { os << " " << srcClass << " points to {"; if (destClasses.isUnknown()) { os << ""; + } else if (destClasses.isUndefined()) { + os << ""; } else { llvm::interleaveComma(destClasses.getAliasClasses(), os); } os << "}\n"; } - os << "other points to unknown: " << otherPointToUnknown << "\n"; + // os << "other points to unknown: " << otherPointToUnknown << "\n"; } /// Union for every variable. ChangeResult enzyme::PointsToSets::join(const AbstractDenseLattice &lattice) { const auto &rhs = static_cast(lattice); + llvm::SmallDenseSet keys; + auto lhsRange = llvm::make_first_range(pointsTo); + auto rhsRange = llvm::make_first_range(rhs.pointsTo); + keys.insert(lhsRange.begin(), lhsRange.end()); + keys.insert(rhsRange.begin(), rhsRange.end()); - // Both are exact, just join and carry over pointer classes from RHS. - if (!otherPointToUnknown && !rhs.otherPointToUnknown) { - ChangeResult result = ChangeResult::NoChange; - for (const auto &[otherPointer, otherPointee] : rhs.pointsTo) { - result |= pointsTo[otherPointer].join(otherPointee); + ChangeResult result = ChangeResult::NoChange; + for (DistinctAttr key : keys) { + auto lhsIt = pointsTo.find(key); + auto rhsIt = rhs.pointsTo.find(key); + assert(lhsIt != pointsTo.end() || rhsIt != rhs.pointsTo.end()); + + // If present in both, join. + if (lhsIt != pointsTo.end() && rhsIt != rhs.pointsTo.end()) { + result |= lhsIt->getSecond().join(rhsIt->getSecond()); + continue; } - return result; - } - // If this has other pointers pointing to unknown, only join in the RHS - // pointers that are known on the LHS. If some LHS pointers are not present in - // RHS, keep them as is because RHS is "exact". - if (otherPointToUnknown && !rhs.otherPointToUnknown) { - ChangeResult result = ChangeResult::NoChange; - for (DistinctAttr pointer : llvm::make_first_range(pointsTo)) { - auto it = rhs.pointsTo.find(pointer); - if (it != rhs.pointsTo.end()) - result |= pointsTo[pointer].join(it->getSecond()); + // Copy from RHS if available only there. + if (lhsIt == pointsTo.end()) { + pointsTo.try_emplace(rhsIt->getFirst(), rhsIt->getSecond()); + result = ChangeResult::Change; } - return result; - } - // If both have other pointers pointing to unknown, only join the pointers - // that are present simultaneously in LHS and RHS. Drop LHS pointers that - // are not present in RHS from the list (they would explicitly point to - // unknown on individual join, but this is implied by the otherPointsToUnknown - // flag). Create a temporary vector for iteration as we will be erasing from - // the map in the loop. - if (otherPointToUnknown && rhs.otherPointToUnknown) { - ChangeResult result = ChangeResult::NoChange; - for (DistinctAttr pointer : - llvm::to_vector(llvm::make_first_range(pointsTo))) { - auto it = rhs.pointsTo.find(pointer); - if (it != rhs.pointsTo.end()) { - result |= pointsTo[pointer].join(it->getSecond()); - } else { - pointsTo.erase(pointer); - result = ChangeResult::Change; - } - } - return result; + // Do nothing if available only in LHS. } + return result; +} - // If RHS has other pointers pointing to unknown, only join the pointers that - // are present in both simultaneously. Drop LHS pointers that are not present - // in RHS (they would explicitly point to unknown on individual join but this - // is implied by the otherPointsToUnknown flag). Set LHS to also indicate - // other pointers pointing to unknown. - assert(!otherPointToUnknown && rhs.otherPointToUnknown); - otherPointToUnknown = true; - for (DistinctAttr pointer : - llvm::to_vector(llvm::make_first_range(pointsTo))) { - auto it = rhs.pointsTo.find(pointer); - if (it != rhs.pointsTo.end()) - pointsTo[pointer].join(it->getSecond()); - else - pointsTo.erase(pointer); - } +ChangeResult +enzyme::PointsToSets::joinPotentiallyMissing(DistinctAttr key, + const AliasClassSet &value) { + // Don't store explicitly undefined values in the mapping, keys absent from + // the mapping are treated as implicitly undefined. + if (value.isUndefined()) + return ChangeResult::NoChange; + + bool inserted; + decltype(pointsTo.begin()) iterator; + std::tie(iterator, inserted) = pointsTo.try_emplace(key, value); + if (!inserted) + return iterator->second.join(value); return ChangeResult::Change; } ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, const AliasClassSet &values, bool replace) { - // If updating the unknown alias class to point to something, we have reached - // the pessimistic fixpoint. if (keysToUpdate.isUnknown()) return markAllPointToUnknown(); - // If updating to point to unknown, and we already know others are pointing to - // unknown, just erase the known information. - if (values.isUnknown() && otherPointToUnknown) { - return keysToUpdate.foreachClass([&](DistinctAttr dest) { - return pointsTo.erase(dest) ? ChangeResult::Change - : ChangeResult::NoChange; - }); - } - - // Otherwise just set the result. - if (replace) { - return keysToUpdate.foreachClass([&](DistinctAttr dest) { - auto it = pointsTo.find(dest); - if (it != pointsTo.end() && it->getSecond() == values) - return ChangeResult::NoChange; - if (it == pointsTo.end()) - pointsTo.try_emplace(dest, values); - else - it->second = values; - return ChangeResult::Change; - }); - } - - return keysToUpdate.foreachClass([&](DistinctAttr dest) { - // If pointers stored in "other" are pointing to unknown alias class, don't - // override that. - if (otherPointToUnknown && !pointsTo.count(dest)) - return ChangeResult::NoChange; + // Don't yet know what to update. + if (keysToUpdate.isUndefined()) + return ChangeResult::NoChange; - if (values.isUnknown()) - return pointsTo[dest].markUnknown(); - return pointsTo[dest].insert(values.getAliasClasses()); - }); + return keysToUpdate.foreachClass( + [&](DistinctAttr dest, AliasClassSet::State state) { + assert(state == AliasClassSet::State::Defined && + "unknown must have been handled above"); +#ifndef NDEBUG + if (replace) { + auto it = pointsTo.find(dest); + if (it != pointsTo.end()) { + // Check that we are updating to a state that's >= in the + // lattice. + // TODO: consider a stricter check that we only replace unknown + // values or a value with itself, currently blocked by memalign. + AliasClassSet valuesCopy(values); + valuesCopy.join(it->getSecond()); + values.print(llvm::errs()); + llvm::errs() << "\n"; + it->getSecond().print(llvm::errs()); + llvm::errs() << "\n"; + valuesCopy.print(llvm::errs()); + llvm::errs() << "\n"; + assert(valuesCopy == values && + "attempting to replace a pointsTo entry with an alias class " + "set that is ordered _before_ the existing one -> " + "non-monotonous update "); + } + } +#endif // NDEBUG + return joinPotentiallyMissing(dest, values); + }); } ChangeResult -enzyme::PointsToSets::setPointingToFresh(const AliasClassSet &destClasses, - StringAttr debugLabel) { - return update(destClasses, AliasClassLattice::getFresh(debugLabel), - /*replace=*/true); +enzyme::PointsToSets::setPointingToEmpty(const AliasClassSet &destClasses) { + return update(destClasses, AliasClassSet::getEmpty(), /*replace=*/true); } ChangeResult @@ -264,46 +239,69 @@ enzyme::PointsToSets::addSetsFrom(const AliasClassSet &destClasses, const AliasClassSet &srcClasses) { if (destClasses.isUnknown()) return markAllPointToUnknown(); + if (destClasses.isUndefined()) + return ChangeResult::NoChange; - return destClasses.foreachClass([&](DistinctAttr dest) { - return srcClasses.foreachClass( - [&](DistinctAttr src) { return pointsTo[dest].join(pointsTo[src]); }); - }); + return destClasses.foreachClass( + [&](DistinctAttr dest, AliasClassSet::State destState) { + assert(destState == AliasClassSet::State::Defined); + return srcClasses.foreachClass( + [&](DistinctAttr src, AliasClassSet::State srcState) { + const AliasClassSet *srcClasses = &AliasClassSet::getUndefined(); + if (srcState == AliasClassSet::State::Unknown) + srcClasses = &AliasClassSet::getUnknown(); + else if (srcState == AliasClassSet::State::Defined) { + auto it = pointsTo.find(src); + if (it != pointsTo.end()) + srcClasses = &it->getSecond(); + } + return joinPotentiallyMissing(dest, *srcClasses); + }); + }); } ChangeResult enzyme::PointsToSets::markPointToUnknown(const AliasClassSet &destClasses) { if (destClasses.isUnknown()) return markAllPointToUnknown(); + if (destClasses.isUndefined()) + return ChangeResult::NoChange; - return destClasses.foreachClass( - [&](DistinctAttr dest) { return pointsTo[dest].markUnknown(); }); + return destClasses.foreachClass([&](DistinctAttr dest, AliasClassSet::State) { + return joinPotentiallyMissing(dest, AliasClassSet::getUnknown()); + }); } ChangeResult enzyme::PointsToSets::markAllPointToUnknown() { - if (otherPointToUnknown && pointsTo.empty()) - return ChangeResult::NoChange; - - otherPointToUnknown = true; - pointsTo.clear(); - return ChangeResult::Change; + ChangeResult result = ChangeResult::NoChange; + for (auto &it : pointsTo) + result |= it.getSecond().join(AliasClassSet::getUnknown()); + return result; } ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown( const AliasClassSet &destClasses) { - bool wasOtherPointingToUnknown = otherPointToUnknown; - otherPointToUnknown = true; + if (destClasses.isUndefined()) + return ChangeResult::NoChange; - llvm::SmallDenseSet keysToDelete; - for (DistinctAttr key : llvm::make_first_range(pointsTo)) { - if (!destClasses.getAliasClasses().contains(key)) - keysToDelete.insert(key); + ChangeResult result = ChangeResult::NoChange; + for (auto &[key, value] : pointsTo) { + if (destClasses.isUnknown() || + !destClasses.getAliasClasses().contains(key)) { + result |= value.markUnknown(); + } } - for (DistinctAttr key : keysToDelete) - pointsTo.erase(key); - return (wasOtherPointingToUnknown && keysToDelete.empty()) - ? ChangeResult::NoChange - : ChangeResult::Change; + +#ifndef NDEBUG + (void)destClasses.foreachClass( + [&](DistinctAttr dest, AliasClassSet::State state) { + if (state == AliasClassSet::State::Defined) + assert(pointsTo.contains(dest) && "unknown dest cannot be preserved"); + return ChangeResult::NoChange; + }); +#endif // NDEBUG + + return result; } // TODO: Reduce code duplication with activity analysis @@ -501,6 +499,17 @@ getFunctionOtherModRef(FunctionOpInterface func) { return std::nullopt; } +/// Returns information indicating whether the function may read or write into +/// memory previously inaccessible in the calling context. When unknown, returns +/// `nullopt`. +static std::optional +getFunctionInaccessibleModRef(FunctionOpInterface func) { + if (auto memoryAttr = + func->getAttrOfType(kLLVMMemoryAttrName)) + return memoryAttr.getInaccessibleMem(); + return std::nullopt; +} + void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( CallOpInterface call, CallControlFlowAction action, const PointsToSets &before, PointsToSets *after) { @@ -520,14 +529,15 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // memalign points to a value OperandRange arguments = call.getArgOperands(); auto *memPtr = getOrCreateFor(call, arguments[0]); - for (DistinctAttr memPtrClass : memPtr->getAliasClasses()) { - // Note that this is a "must write" kind of situation, so we can - // directly set the classes pointed to, rather than inserting them. - auto debugLabel = StringAttr::get(call.getContext(), "memalign"); - propagateIfChanged(after, - after->setPointingToFresh(memPtrClass, debugLabel)); - } - return; + + // Note that this is a "must write" kind of situation, so we can + // directly set the classes pointed to, rather than inserting them. + auto single = AliasClassLattice::single( + arguments[0], + originalClasses.getOriginalClass(arguments[0], "memalign")); + return propagateIfChanged( + after, after->setPointingToClasses(memPtr->getAliasClassesObject(), + single.getAliasClassesObject())); } // Analyze the callee generically. @@ -556,7 +566,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // TODO: consider a more advanced lattice that can encode "may point to // any class _except_ the given classes"; this is mathematically possible // but needs careful programmatic encoding. - AliasClassSet functionMayCapture; + AliasClassSet functionMayCapture = AliasClassSet::getUndefined(); bool funcMayReadOther = modRefMayRef(otherModRef); unsigned numArguments = callee.getNumArguments(); if (funcMayReadOther) { @@ -577,30 +587,39 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( } } - AliasClassSet pointerOperandClasses; + // For each alias class the function may write to, indicate potentially + // stored classes. Keep the set of writable alias classes for future. + AliasClassSet writableClasses = AliasClassSet::getUndefined(); + AliasClassSet nonWritableOperandClasses = AliasClassSet::getUndefined(); ChangeResult changed = ChangeResult::NoChange; for (int pointerOperand : pointerLikeOperands) { auto *destClasses = getOrCreateFor( call, call.getArgOperands()[pointerOperand]); - pointerOperandClasses.join(destClasses->getAliasClassesObject()); // If the argument cannot be stored into, just preserve it as is. - if (!mayWriteArg(callee, pointerOperand, argModRef)) + if (!mayWriteArg(callee, pointerOperand, argModRef)) { + nonWritableOperandClasses.join(destClasses->getAliasClassesObject()); continue; + } + writableClasses.join(destClasses->getAliasClassesObject()); - // If the destination class is unknown, we reached the pessimistic - // fixpoint. + // If the destination class is unknown, mark all known classes + // pessimistic (alias classes that have not beed analyzed and thus are + // absent from pointsTo are treated as "undefined" at this point). if (destClasses->isUnknown()) { - pointerOperandClasses.reset(); + writableClasses.markUnknown(); changed |= after->markAllPointToUnknown(); break; } + if (destClasses->isUndefined()) + continue; + // Otherwise, indicate that a pointer that belongs to any of the // classes captured by this function may be stored into the // destination class. changed |= destClasses->getAliasClassesObject().foreachClass( - [&](DistinctAttr dest) { + [&](DistinctAttr dest, AliasClassSet::State) { return after->insert(dest, functionMayCapture); }); } @@ -608,19 +627,15 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // If the function may write to "other", that is any potential other // pointer, record that. if (modRefMayMod(otherModRef)) { - // All other alias classes that are not present as arguments should - // point to unknown. - // Since: - // - `after` was joined with `before` at the beginning; and - // - pre-existing keys in `after` (and in `before` since no new keys - // were added) have their values: preserved, joined with another - // alias set (->insert is a join), or removed here with default value - // being set to "any" (lattice top); - // this transfer function is monotonic with respect to its input, i.e, - // the `before` lattice. - // TODO(zinenko): consider monotonicity more carefully wrt to - // `destClasses` change. - changed |= after->markAllExceptPointToUnknown(pointerOperandClasses); + // Classes that have been analyzed, and therefore present in the `after` + // lattice after joining it with `before` are marked as pointing to + // "unknown", except the classes that are associated with operands for + // which we have more specific information. Classes that haven't been + // analyzed, and therefore absent from the `after` lattice, are left + // unmodified and thus assumed to be "undefined". This makes this + // transfer function monotonic as opposed to marking the latter classes + // as "unknown" eagerly, which would require rolling that marking back. + changed |= after->markAllExceptPointToUnknown(writableClasses); } // Pointer-typed results may be pointing to any other pointer. The @@ -641,24 +656,49 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( if (!isPointerLike(result.getType())) continue; + // Result alias classes may contain operand alias classes because + // results may alias with those operands. However, if the operands are + // not writable, they cannot be updated to point to other classes + // even though results can be. To handle this, only update the alias + // classes associated with the results that are not also associated + // with non-writable operands. + // + // This logic is a bit more conservative than the theoretical optimum to + // ensure monotonicity of the transfer function: if additional alias + // classes are discovered for non-writable operands at a later stage + // after these classes have already been associated with the result and + // marked as potentially pointing to some other classes, this marking + // is *not* rolled back. Since points-to-pointer analysis is a may- + // analysis, this is not problematic. const auto *destClasses = getOrCreateFor(call, result); + AliasClassSet resultWithoutNonWritableOperands = + AliasClassSet::getUndefined(); + if (destClasses->isUnknown() || nonWritableOperandClasses.isUnknown()) { + resultWithoutNonWritableOperands.markUnknown(); + } else if (!destClasses->isUndefined() && + !nonWritableOperandClasses.isUndefined()) { + DenseSet nonOperandClasses = + llvm::set_difference(destClasses->getAliasClasses(), + nonWritableOperandClasses.getAliasClasses()); + resultWithoutNonWritableOperands.insert(nonOperandClasses); + } else { + resultWithoutNonWritableOperands.join( + destClasses->getAliasClassesObject()); + } // If reading from other memory, the results may point to anything. if (funcMayReadOther) { propagateIfChanged(after, after->markPointToUnknown( - destClasses->getAliasClassesObject())); + resultWithoutNonWritableOperands)); continue; } - AliasClassSet commonReturnScope; - (void)commonReturnScope.markFresh( - StringAttr::get(call->getContext(), "function-return-common")); for (int operandNo : pointerLikeOperands) { const auto *srcClasses = getOrCreateFor( call, call.getArgOperands()[operandNo]); if (mayReadArg(callee, operandNo, argModRef)) { - changed |= after->addSetsFrom(destClasses->getAliasClassesObject(), + changed |= after->addSetsFrom(resultWithoutNonWritableOperands, srcClasses->getAliasClassesObject()); } @@ -666,12 +706,10 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( (operandNo < numArguments && !!callee.getArgAttr(operandNo, LLVM::LLVMDialect::getNoCaptureAttrName())); - if (isNoCapture) { - changed |= after->insert(destClasses->getAliasClassesObject(), - srcClasses->getAliasClassesObject()); - } - after->insert(destClasses->getAliasClassesObject(), - commonReturnScope); + if (isNoCapture) + continue; + changed |= after->insert(resultWithoutNonWritableOperands, + srcClasses->getAliasClassesObject()); } } return propagateIfChanged(after, changed); @@ -689,9 +727,22 @@ void enzyme::PointsToPointerAnalysis::setToEntryState(PointsToSets *lattice) {} // AliasClassLattice //===----------------------------------------------------------------------===// +void enzyme::AliasClassSet::print(raw_ostream &os) const { + if (isUnknown()) { + os << ""; + } else if (isUndefined()) { + os << ""; + } else { + llvm::interleaveComma(aliasClasses, os << "{"); + os << "}"; + } +} + void enzyme::AliasClassLattice::print(raw_ostream &os) const { if (aliasClasses.isUnknown()) { os << "Unknown AC"; + } else if (aliasClasses.isUndefined()) { + os << "Undefined AC"; } else { os << "size: " << aliasClasses.getAliasClasses().size() << ":\n"; for (auto aliasClass : aliasClasses.getAliasClasses()) { @@ -703,6 +754,9 @@ void enzyme::AliasClassLattice::print(raw_ostream &os) const { AliasResult enzyme::AliasClassLattice::alias(const AbstractSparseLattice &other) const { const auto *rhs = reinterpret_cast(&other); + + assert(!isUndefined() && !rhs->isUndefined() && "incomplete alias analysis"); + if (getPoint() == rhs->getPoint()) return AliasResult::MustAlias; @@ -731,20 +785,10 @@ enzyme::AliasClassLattice::alias(const AbstractSparseLattice &other) const { ChangeResult enzyme::AliasClassLattice::join(const AbstractSparseLattice &other) { // Set union of the alias classes - const auto *otherAliasClass = - reinterpret_cast(&other); + const auto *otherAliasClass = static_cast(&other); return aliasClasses.join(otherAliasClass->aliasClasses); } -ChangeResult enzyme::AliasClassLattice::markFresh(Attribute debugLabel) { - reset(); - - Value value = getPoint(); - if (!debugLabel) - debugLabel = UnitAttr::get(value.getContext()); - return aliasClasses.markFresh(debugLabel); -} - //===----------------------------------------------------------------------===// // AliasAnalysis //===----------------------------------------------------------------------===// @@ -757,17 +801,28 @@ void enzyme::AliasAnalysis::setToEntryState(AliasClassLattice *lattice) { LLVM::LLVMDialect::getNoAliasAttrName())) { Attribute debugLabel = funcOp.getArgAttr(arg.getArgNumber(), "enzyme.tag"); - propagateIfChanged(lattice, lattice->markFresh(debugLabel)); - } else { - // TODO: Not safe in general, integers can be a result of ptrtoint. We - // need a type analysis here I guess? - if (isPointerLike(arg.getType())) - propagateIfChanged(lattice, lattice->insert({entryClass})); + // TODO: this may currently be failing because `setToEntryState` + // is used by the framework to set the pessimistic fixpoint (top), which + // isn't correct for pessimistic analysis for which `setToEntryState` is + // the undefined state (bottom). + assert(lattice->isUndefined() && "resetting lattice point"); + + DistinctAttr noaliasClass = + originalClasses.getOriginalClass(lattice->getPoint(), debugLabel); + return propagateIfChanged(lattice, + lattice->join(AliasClassLattice::single( + lattice->getPoint(), noaliasClass))); } + // TODO: Not safe in general, integers can be a result of ptrtoint. We + // need a type analysis here I guess? + if (isPointerLike(arg.getType())) + return propagateIfChanged(lattice, lattice->insert({entryClass})); } - } else { - propagateIfChanged(lattice, lattice->reset()); } + if (!lattice->isUndefined()) + llvm::errs() << *lattice << "\n"; + assert(lattice->isUndefined()); + // The default state is "undefined", no need to explicitly (re)set it. } /// Returns `true` if the alias transfer function of the operation is fully @@ -796,11 +851,22 @@ void enzyme::AliasAnalysis::transfer( } if (isa(effect.getEffect())) { - // Mark the result of the allocation as a fresh memory location + // Mark the result of the allocation as a fresh memory location. for (AliasClassLattice *result : results) { if (result->getPoint() == value) { Attribute debugLabel = op->getAttr("tag"); - propagateIfChanged(result, result->markFresh(debugLabel)); + auto fresh = AliasClassLattice::single( + result->getPoint(), + originalClasses.getOriginalClass(result->getPoint(), debugLabel)); + propagateIfChanged(result, result->join(fresh)); + + // The pointer to freshly allocated memory is known not to point to + // anything. + // TODO(zinenko): this is a bit strange to update _another_ lattice + // here. + auto *pointsTo = getOrCreate(op); + propagateIfChanged(pointsTo, pointsTo->setPointingToEmpty( + fresh.getAliasClassesObject())); } } } else if (isa(effect.getEffect())) { @@ -810,6 +876,8 @@ void enzyme::AliasAnalysis::transfer( for (AliasClassLattice *result : results) { propagateIfChanged(result, result->markUnknown()); } + } else if (latticeElement->isUndefined()) { + // Do nothing unless we know something about the value. } else { for (auto srcClass : latticeElement->getAliasClasses()) { const auto &srcPointsTo = pointsToSets->getPointsTo(srcClass); @@ -818,11 +886,11 @@ void enzyme::AliasAnalysis::transfer( // doesn't require a conditional here. if (srcPointsTo.isUnknown()) { propagateIfChanged(result, result->markUnknown()); + } else if (srcPointsTo.isUndefined()) { + continue; } else { - // TODO: this looks potentially non-monotonous. - ChangeResult r = result->reset() | - result->insert(srcPointsTo.getAliasClasses()); - propagateIfChanged(result, r); + propagateIfChanged(result, + result->insert(srcPointsTo.getAliasClasses())); } } } @@ -938,6 +1006,8 @@ void enzyme::AliasAnalysis::visitExternalCall( // Even if a function is marked as not reading from memory or arguments, it // may still create pointers "out of the thin air", e.g., by "ptrtoint" from a // constant or an argument. + // TODO: consider "ptrtoint" here, for now assuming it is covered by + // inaccessible and other mem. auto symbol = dyn_cast(call.getCallableForCallee()); if (!symbol) return markResultsUnknown(); @@ -946,13 +1016,84 @@ void enzyme::AliasAnalysis::visitExternalCall( if (!callee) return markResultsUnknown(); + // Collect alias classes that can be read through the arguments. + std::optional argModRef = getFunctionArgModRef(callee); + std::optional otherModRef = getFunctionOtherModRef(callee); + std::optional inaccessibleModRef = + getFunctionInaccessibleModRef(callee); + auto operandAliasClasses = AliasClassSet::getEmpty(); + for (auto [operandNo, operand] : llvm::enumerate(call.getArgOperands())) { + if (!isPointerLike(operand.getType())) + continue; + + const AliasClassLattice *srcClasses = operands[operandNo]; + operandAliasClasses.join(srcClasses->getAliasClassesObject()); + + if (!mayReadArg(callee, operandNo, argModRef)) + continue; + + // If can read from argument, collect the alias classes that can this + // argument may be pointing to. + const auto *pointsToLattice = getOrCreateFor(call, call); + srcClasses->getAliasClassesObject().foreachClass( + [&](DistinctAttr srcClass, AliasClassSet::State state) { + // Nothing to do in top/bottom case. In the top case, we have already + // set `operandAliasClasses` to top above. + if (srcClass == nullptr) + return ChangeResult::NoChange; + operandAliasClasses.join(pointsToLattice->getPointsTo(srcClass)); + return ChangeResult::NoChange; + }); + } + + auto debugLabel = call->getAttrOfType("tag"); + DistinctAttr commonResultAttr = nullptr; + + // Collect all results that are not marked noalias so we can put them in a + // common alias group. + SmallVector aliasGroupResults; + for (OpResult result : call->getResults()) { + if (!callee.getResultAttr(result.getResultNumber(), + LLVM::LLVMDialect::getNoAliasAttrName())) + aliasGroupResults.push_back(result); + } + for (OpResult result : call->getResults()) { AliasClassLattice *resultLattice = results[result.getResultNumber()]; - if (callee.getResultAttr(result.getResultNumber(), - LLVM::LLVMDialect::getNoAliasAttrName())) { - propagateIfChanged( - resultLattice, - resultLattice->markFresh(call->getAttrOfType("tag"))); + if (!llvm::is_contained(aliasGroupResults, result)) { + Attribute individualDebugLabel = + debugLabel + ? StringAttr::get(debugLabel.getContext(), + debugLabel.getValue().str() + + std::to_string(result.getResultNumber())) + : nullptr; + auto individualAlloc = AliasClassLattice::single( + resultLattice->getPoint(), + originalClasses.getOriginalClass(resultLattice->getPoint(), + individualDebugLabel)); + propagateIfChanged(resultLattice, resultLattice->join(individualAlloc)); + } else if (!modRefMayRef(otherModRef) && + !modRefMayRef(inaccessibleModRef)) { + // Put results that are not marked as noalias into one common group. + if (!commonResultAttr) { + std::string label = !debugLabel + ? "func-result-common" + : debugLabel.getValue().str() + "-common"; + commonResultAttr = + originalClasses.getSameOriginalClass(aliasGroupResults, label); + } + AliasClassSet commonClass(commonResultAttr); + ChangeResult changed = resultLattice->join( + AliasClassLattice(resultLattice->getPoint(), std::move(commonClass))); + + // If the function is known not to read other (or inaccessible mem), its + // results may only alias what we know it can read, e.g. other arguments + // or anything stored in those arguments. + // FIXME: note the explicit copy, we need to simplify the relation between + // AliasClassSet and AliasClassLattice. + changed |= resultLattice->join(AliasClassLattice( + resultLattice->getPoint(), AliasClassSet(operandAliasClasses))); + propagateIfChanged(resultLattice, changed); } else { propagateIfChanged(resultLattice, resultLattice->markUnknown()); } diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h index be91b1f90705..57297f1242a7 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h @@ -29,6 +29,7 @@ #include "mlir/Analysis/AliasAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { @@ -38,34 +39,58 @@ class CallableOpInterface; namespace enzyme { /// A set of alias class identifiers to be treated as a single union. May be -/// marked as "unknown", which is a conservative pessimistic state. -struct AliasClassSet { - AliasClassSet() = default; - AliasClassSet(DistinctAttr single) { aliasClasses.insert(single); } +/// marked as "unknown", which is a conservative pessimistic state, or as +/// "undefined", which is a "not-yet-analyzed" initial state. Undefined state is +/// different from an empty alias set. +class AliasClassSet { +public: + enum class State { + Undefined, ///< Has not been analyzed yet (lattice bottom). + Defined, ///< Has specific alias classes. + Unknown ///< Analyzed and may point to any class (lattice top). + }; + + AliasClassSet() : state(State::Undefined) {} + + AliasClassSet(DistinctAttr single) : state(State::Defined) { + aliasClasses.insert(single); + } + // TODO(zinenko): deprecate this and use a visitor instead. DenseSet &getAliasClasses() { - assert(!unknown); + assert(state == State::Defined); return aliasClasses; } const DenseSet &getAliasClasses() const { return const_cast(this)->getAliasClasses(); } - bool isUnknown() const { return unknown; } + bool isUnknown() const { return state == State::Unknown; } + bool isUndefined() const { return state == State::Undefined; } ChangeResult join(const AliasClassSet &other); ChangeResult insert(const DenseSet &classes); ChangeResult markUnknown(); - ChangeResult markFresh(Attribute debugLabel); - ChangeResult reset(); - /// Returns true if this set is in the canonical form, i.e. has either the - /// unknown bit or the explicit list of classes, but not both. + /// Returns true if this set is in the canonical form, i.e. either the state + /// is `State::Defined` or the explicit list of classes is empty, but not + /// both. bool isCanonical() const; - /// Returns an empty instance of AliasClassSet. The instance is *not* a - /// classical singleton, there are other ways of obtaining it. - static const AliasClassSet &getEmpty() { return emptySet; } + /// Returns an instance of AliasClassSet known not to alias with anything. + /// This is different from "undefined" and "unknown". The instance is *not* a + /// classical singleton. + static const AliasClassSet &getEmpty() { + static const AliasClassSet empty(State::Defined); + return empty; + } + + /// Returns an instance of AliasClassSet in "undefined" state, i.e. without a + /// set of alias classes. This is different from empty alias set, which + /// indicates that the value is known not to alias with any alias class. The + /// instance is *not* a classical singleton, there are other ways of obtaining + /// it. + static const AliasClassSet &getUndefined() { return undefinedSet; } /// Returns an instance of AliasClassSet for the "unknown" class. The instance /// is *not* a classical singleton, there are other ways of obtaining an @@ -74,21 +99,81 @@ struct AliasClassSet { bool operator==(const AliasClassSet &other) const; + void print(llvm::raw_ostream &os) const; + ChangeResult - foreachClass(function_ref callback) const; + foreachClass(function_ref callback) const; private: - explicit AliasClassSet(bool unknown) : unknown(unknown) {} + explicit AliasClassSet(State state) : state(state) {} + + ChangeResult updateStateToDefined() { + assert(state != State::Unknown && "cannot go back from unknown state"); + ChangeResult result = state == State::Undefined ? ChangeResult::Change + : ChangeResult::NoChange; + state = State::Defined; + return result; + } const static AliasClassSet unknownSet; - const static AliasClassSet emptySet; + const static AliasClassSet undefinedSet; DenseSet aliasClasses; - bool unknown = false; + State state; }; //===----------------------------------------------------------------------===// -// PointsToAnalysis +// OriginalClasses +//===----------------------------------------------------------------------===// + +/// Alias classes for freshly created, e.g., allocated values. These must +/// be used instead of allocating a fresh distinct attribute every time. +/// Allocation may only happen when the mapping is not already present here. +class OriginalClasses { +public: + DistinctAttr getOriginalClass(Value value, StringRef debugLabel) { + return getOriginalClass(value, + StringAttr::get(value.getContext(), debugLabel)); + } + DistinctAttr getOriginalClass(Value value, Attribute referenced = nullptr) { + DistinctAttr &aliasClass = originalClasses[value]; + if (!aliasClass) { + if (!referenced) + referenced = UnitAttr::get(value.getContext()); + aliasClass = DistinctAttr::create(referenced); + } + return aliasClass; + } + + DistinctAttr getSameOriginalClass(ValueRange values, StringRef debugLabel) { + if (values.empty()) + return nullptr; + + auto label = StringAttr::get(values.front().getContext(), debugLabel); + + DistinctAttr common = nullptr; + for (Value v : values) { + DistinctAttr &aliasClass = originalClasses[v]; + if (!aliasClass) { + if (!common) + common = DistinctAttr::create(label); + aliasClass = common; + } else { + if (!common) + common = aliasClass; + else + assert(aliasClass == common && "original alias class mismatch"); + } + } + return common; + } + +private: + DenseMap originalClasses; +}; + +//===----------------------------------------------------------------------===// +// PointsToSets // // Specifically for pointers to pointers. This tracks alias information through // pointers stored/loaded through memory. @@ -121,10 +206,7 @@ class PointsToSets : public dataflow::AbstractDenseLattice { ChangeResult addSetsFrom(const AliasClassSet &destClasses, const AliasClassSet &srcClasses); - /// For every alias class in `dest`, record that it is pointing to the _same_ - /// new alias set. - ChangeResult setPointingToFresh(const AliasClassSet &destClasses, - StringAttr debugLabel); + ChangeResult setPointingToEmpty(const AliasClassSet &destClasses); /// Mark `dest` as pointing to "unknown" alias set, that is, any possible /// other pointer. This is partial pessimistic fixpoint. @@ -141,20 +223,37 @@ class PointsToSets : public dataflow::AbstractDenseLattice { const AliasClassSet &getPointsTo(DistinctAttr id) const { auto it = pointsTo.find(id); if (it == pointsTo.end()) - return otherPointToUnknown ? AliasClassSet::getUnknown() - : AliasClassSet::getEmpty(); + return AliasClassSet::getUndefined(); return it->getSecond(); } private: + /// Update all alias classes in `keysToUpdate` to additionally point to alias + /// classes in `values`. Handle undefined keys optimistically (ignore) and + /// unknown keys pessimistically (update all existing keys). `replace` is a + /// debugging aid that indicates whether the update is intended to replace the + /// pre-existing state, it has no effect in NoAsserts build. Since we don't + /// want to forcefully reset pointsTo value as that is not guaranteed to make + /// monotonous progress on the lattice and therefore convergence to fixpoint, + /// replacement is only expected for a previously "unknown" value (absent from + /// the mapping) or for a value with itself. Replacement is therefore handled + /// as a regular update, i.e. join, with additional assertions. Note that + /// currently an update is possible to _any_ value that is >= the current one + /// in the lattice, not only the replacements described above. ChangeResult update(const AliasClassSet &keysToUpdate, const AliasClassSet &values, bool replace); + ChangeResult joinPotentiallyMissing(DistinctAttr key, + const AliasClassSet &value); + /// Indicates that alias classes not listed as keys in `pointsTo` point to /// unknown alias set (when true) or an empty alias set (when false). // TODO: consider also differentiating between pointing to known-empty vs. // not-yet-computed. - bool otherPointToUnknown = false; + // bool otherPointToUnknown = false; + + // missing from map always beings "undefined", "unknown"s are stored + // explicitly. /// Maps an identifier of an alias set to the set of alias sets its value may /// belong to. When an identifier is not present in this map, it is considered @@ -163,10 +262,15 @@ class PointsToSets : public dataflow::AbstractDenseLattice { DenseMap pointsTo; }; +//===----------------------------------------------------------------------===// +// PointsToPointerAnalysis +//===----------------------------------------------------------------------===// + class PointsToPointerAnalysis : public dataflow::DenseForwardDataFlowAnalysis { public: - using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; + PointsToPointerAnalysis(DataFlowSolver &solver) + : DenseForwardDataFlowAnalysis(solver) {} void setToEntryState(PointsToSets *lattice) override; @@ -181,6 +285,13 @@ class PointsToPointerAnalysis void processCapturingStore(ProgramPoint dependent, PointsToSets *after, Value capturedValue, Value destinationAddress, bool isMustStore = false); + +private: + /// Alias classes originally assigned to known-distinct values, e.g., fresh + /// allocations, by this analysis. This does NOT necessarily need to be shared + /// with the other analysis as they may assign different classes, e.g., for + /// results of the same call. + OriginalClasses originalClasses; }; //===----------------------------------------------------------------------===// @@ -190,6 +301,9 @@ class PointsToPointerAnalysis class AliasClassLattice : public dataflow::AbstractSparseLattice { public: using AbstractSparseLattice::AbstractSparseLattice; + AliasClassLattice(Value value, AliasClassSet &&classes) + : dataflow::AbstractSparseLattice(value), + aliasClasses(std::move(classes)) {} void print(raw_ostream &os) const override; @@ -201,20 +315,19 @@ class AliasClassLattice : public dataflow::AbstractSparseLattice { return aliasClasses.insert(classes); } - ChangeResult markFresh(/*optional=*/Attribute debugLabel); + static AliasClassLattice single(Value point, DistinctAttr value) { + return AliasClassLattice(point, AliasClassSet(value)); + } ChangeResult markUnknown() { return aliasClasses.markUnknown(); } - ChangeResult reset() { return aliasClasses.reset(); } + // ChangeResult reset() { return aliasClasses.reset(); } - static DistinctAttr getFresh(Attribute debugLabel) { - return DistinctAttr::create(debugLabel); - } - - /// We don't know anything about the aliasing of this value. TODO: re-evaluate - /// if we need this. + /// We don't know anything about the aliasing of this value. bool isUnknown() const { return aliasClasses.isUnknown(); } + bool isUndefined() const { return aliasClasses.isUndefined(); } + const DenseSet &getAliasClasses() const { return aliasClasses.getAliasClasses(); } @@ -254,6 +367,12 @@ class AliasAnalysis /// A special alias class to denote unannotated pointer arguments. const DistinctAttr entryClass; + + /// Alias classes originally assigned to known-distinct values, e.g., fresh + /// allocations, by this analysis. This does NOT necessarily need to be shared + /// with the other analysis as they may assign different classes, e.g., for + /// results of the same call. + OriginalClasses originalClasses; }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index db04f280123e..608bf75f40a5 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -485,15 +485,15 @@ std::optional getCopySource(Operation *op) { /// Value corresponding to its allocation. /// The callback may receive null allocation when the class alias set is /// unknown. +/// If the classes are undefined, the callback will not be called at all. void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass, - function_ref forEachFn) { - if (ptrAliasClass->isUnknown()) { - // Unknown pointers alias with the unknown entry arguments and all - // known allocations - return forEachFn(nullptr); - } - for (DistinctAttr alloc : ptrAliasClass->getAliasClasses()) - forEachFn(alloc); + function_ref forEachFn) { + ptrAliasClass->getAliasClassesObject().foreachClass( + [&](DistinctAttr alloc, enzyme::AliasClassSet::State state) { + if (state != enzyme::AliasClassSet::State::Undefined) + forEachFn(alloc); + return ChangeResult::NoChange; + }); } class DenseForwardActivityAnalysis @@ -616,20 +616,23 @@ class DenseForwardActivityAnalysis /// Initialize the entry block with the supplied argument activities. void setToEntryState(ForwardMemoryActivity *lattice) override { - if (auto *block = dyn_cast_if_present(lattice->getPoint())) { - if (block == entryBlock) { - for (const auto &[arg, activity] : - llvm::zip(block->getArguments(), argumentActivity)) { - if (activity == enzyme::Activity::enzyme_dup || - activity == enzyme::Activity::enzyme_dupnoneed) { - auto *argAliasClasses = - getOrCreateFor(block, arg); - for (DistinctAttr argAliasClass : - argAliasClasses->getAliasClasses()) { - propagateIfChanged(lattice, lattice->setActiveIn(argAliasClass)); - } - } - } + if (auto *block = dyn_cast_if_present(lattice->getPoint()); + block && block == entryBlock) { + for (const auto &[arg, activity] : + llvm::zip(block->getArguments(), argumentActivity)) { + if (activity != enzyme::Activity::enzyme_dup && + activity != enzyme::Activity::enzyme_dupnoneed) + continue; + auto *argAliasClasses = getOrCreateFor(block, arg); + ChangeResult changed = + argAliasClasses->getAliasClassesObject().foreachClass( + [lattice](DistinctAttr argAliasClass, + enzyme::AliasClassSet::State state) { + if (state == enzyme::AliasClassSet::State::Undefined) + return ChangeResult::NoChange; + return lattice->setActiveIn(argAliasClass); + }); + propagateIfChanged(lattice, changed); } } } @@ -656,29 +659,39 @@ class DenseBackwardActivityAnalysis BackwardMemoryActivity *before) override { // Initialize the return activity of arguments. if (op->hasTrait() && op->getParentOp() == parentOp) { - for (const auto &[arg, argActivity] : llvm::zip( - parentOp->getRegions().front().getArguments(), argumentActivity)) - if (argActivity == enzyme::Activity::enzyme_dup || - argActivity == enzyme::Activity::enzyme_dupnoneed) { - auto *argAliasClasses = getOrCreateFor(op, arg); - for (DistinctAttr argAliasClass : - argAliasClasses->getAliasClasses()) { - propagateIfChanged(before, before->setActiveOut(argAliasClass)); - } + for (const auto &[arg, argActivity] : + llvm::zip(parentOp->getRegions().front().getArguments(), + argumentActivity)) { + if (argActivity != enzyme::Activity::enzyme_dup && + argActivity != enzyme::Activity::enzyme_dupnoneed) { + continue; } + auto *argAliasClasses = getOrCreateFor(op, arg); + ChangeResult changed = + argAliasClasses->getAliasClassesObject().foreachClass( + [before](DistinctAttr argAliasClass, + enzyme::AliasClassSet::State state) { + if (state == enzyme::AliasClassSet::State::Undefined) + return ChangeResult::NoChange; + return before->setActiveOut(argAliasClass); + }); + propagateIfChanged(before, changed); + } // Initialize the return activity of the operands for (Value operand : op->getOperands()) { if (isa(operand.getType())) { auto *retAliasClasses = getOrCreateFor(op, operand); - if (retAliasClasses->isUnknown()) { - propagateIfChanged(before, before->setActiveOut()); - } else { - for (DistinctAttr retAliasClass : - retAliasClasses->getAliasClasses()) - propagateIfChanged(before, before->setActiveOut(retAliasClass)); - } + ChangeResult changed = + retAliasClasses->getAliasClassesObject().foreachClass( + [before](DistinctAttr retAliasClass, + enzyme::AliasClassSet::State state) { + if (state == enzyme::AliasClassSet::State::Undefined) + return ChangeResult::NoChange; + return before->setActiveOut(retAliasClass); + }); + propagateIfChanged(before, changed); } } } @@ -794,20 +807,6 @@ void traverseCallGraph(FunctionOpInterface root, } } -static const enzyme::AliasClassSet & -getDefaultPointsTo(const enzyme::PointsToSets &pointsToSets) { - // Get the default points-to alias class set, which is where the - // "unknown" and any other unlisted class set points to. - const enzyme::AliasClassSet &defaultPointsTo = - pointsToSets.getPointsTo(nullptr); - // Unknown class can point to unknown or nothing, unless further - // refined. - assert((defaultPointsTo.isUnknown() || - defaultPointsTo.getAliasClasses().empty()) && - "new case introduced for AliasClassSet?"); - return defaultPointsTo; -} - void printActivityAnalysisResults(const DataFlowSolver &solver, FunctionOpInterface callee, const SmallPtrSet &returnOps, @@ -835,27 +834,31 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, // Traverse the points-to sets in a simple BFS std::deque frontier; DenseSet visited; - auto scheduleVisit = [&](auto range) { - for (DistinctAttr neighbor : range) { + auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) { + aliasClasses.foreachClass([&](DistinctAttr neighbor, + enzyme::AliasClassSet::State state) { + assert(neighbor && "unhandled undefined/unknown case before visit"); if (!visited.contains(neighbor)) { visited.insert(neighbor); frontier.push_back(neighbor); } - } + return ChangeResult::NoChange; + }); }; + // If this triggers, investigate why the alias classes weren't computed. + // If they weren't computed legitimately, treat the value as + // conservatively non-constant or change the return type to be tri-state. + assert(!aliasClassLattice->isUndefined() && + "didn't compute alias classes"); + if (aliasClassLattice->isUnknown()) { - // If this pointer is in unknown alias class, it may point to active - // data if the unknown alias class is known to point to something and - // may not point to active data if the unknown alias class is known not - // to point to anything. - auto &defaultPointsTo = getDefaultPointsTo(*pointsToSets); - return !defaultPointsTo.isUnknown() && - defaultPointsTo.getAliasClasses().empty(); + // Pointers of unknown class may point to active data. + // TODO: is this overly conservative? Should we rather check + // if listed classes may point to non-constants? + return false; } else { - const DenseSet &aliasClasses = - aliasClassLattice->getAliasClasses(); - scheduleVisit(aliasClasses); + scheduleVisit(aliasClassLattice->getAliasClassesObject()); } while (!frontier.empty()) { DistinctAttr aliasClass = frontier.front(); @@ -867,21 +870,16 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, bma->activeDataFlowsOut(aliasClass)) return false; - // Or if it points to a pointer that points to active data. - if (pointsToSets->getPointsTo(aliasClass).isUnknown()) { - // If a pointer points to an unknown alias set, query the default - // points-to alias set (which also applies to the unknown alias set). - auto &defaultPointsTo = getDefaultPointsTo(*pointsToSets); - // If it is in turn unknown, conservatively assume the pointer may be - // pointing to some active data. - if (defaultPointsTo.isUnknown()) - return false; - // Otherwise look at classes pointed to by unknown (which can only be - // an empty set as of time of writing). - scheduleVisit(defaultPointsTo.getAliasClasses()); - continue; - } - scheduleVisit(pointsToSets->getPointsTo(aliasClass).getAliasClasses()); + // If this triggers, investigate why points-to sets couldn't be + // computed. Treat conservatively as "unknown" if necessary. + assert(!pointsToSets->getPointsTo(aliasClass).isUndefined() && + "couldn't compute points-to sets"); + + // Pointers to unknown classes may (transitively) point to active data. + if (pointsToSets->getPointsTo(aliasClass).isUnknown()) + return false; + + scheduleVisit(pointsToSets->getPointsTo(aliasClass)); } // Otherwise, it's constant return true; diff --git a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp index 57694bc38f18..99a41f804593 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp @@ -87,10 +87,25 @@ struct PrintAliasAnalysisPass if (auto funcOp = dyn_cast(op)) { for (auto arg : funcOp.getArguments()) { auto *state = solver.lookupState(arg); - if (state) { - for (auto aliasClass : state->getAliasClasses()) - funcOp.setArgAttr(arg.getArgNumber(), "enzyme.ac", aliasClass); - } + if (!state) + continue; + // TODO(zinenko): this has been overriding the argument... + // Use an array attr instead (will break syntactic tests). + state->getAliasClassesObject().foreachClass( + [&](DistinctAttr aliasClass, enzyme::AliasClassSet::State state) { + if (state == enzyme::AliasClassSet::State::Undefined) + funcOp.setArgAttr( + arg.getArgNumber(), "enzyme.ac", + StringAttr::get(arg.getContext(), "undefined")); + else if (state == enzyme::AliasClassSet::State::Unknown) + funcOp.setArgAttr( + arg.getArgNumber(), "enzyme.ac", + StringAttr::get(arg.getContext(), "unknown")); + else + funcOp.setArgAttr(arg.getArgNumber(), "enzyme.ac", + aliasClass); + return ChangeResult::NoChange; + }); } } else if (op->hasTrait() && isa(op->getParentOp())) { @@ -107,6 +122,9 @@ struct PrintAliasAnalysisPass if (state->isUnknown()) { op->setAttr("ac", StringAttr::get(result.getContext(), "")); + } else if (state->isUndefined()) { + op->setAttr("ac", + StringAttr::get(result.getContext(), "")); } else { for (auto aliasClass : state->getAliasClasses()) { op->setAttr("ac", aliasClass); diff --git a/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir b/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir index d77a8c0988e0..87c7e4105a2b 100644 --- a/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir +++ b/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir @@ -3,7 +3,7 @@ func.func private @callee(%ptr : !llvm.ptr) // CHECK: points-to-pointer sets -// CHECK-NEXT: other points to unknown: 1 +// CHECK-NEXT: distinct[{{.*}}]<"entry"> points to {} // CHECK-LABEL @fully_opaque_call func.func @fully_opaque_call(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -20,7 +20,6 @@ func.func private @callee(%ptr : !llvm.ptr) attributes { // CHECK: points-to-pointer sets // CHECK-NEXT: distinct{{\[}}[[ID:.+]]]<"entry"> points to {distinct{{\[}}[[ID]]]<"entry">} -// CHECK-NEXT: other points to unknown: 0 // CHECK-LABEL @call_other_none_arg_rw func.func @call_other_none_arg_rw(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -36,7 +35,7 @@ func.func private @callee(%ptr : !llvm.ptr) attributes { } // CHECK: points-to-pointer sets -// CHECK-NEXT: other points to unknown: 0 +// CHECK-NEXT: // CHECK-LABEL @call_other_none_arg_ro func.func @call_other_none_arg_ro(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -53,9 +52,8 @@ func.func private @callee(%ptr : !llvm.ptr) attributes { // CHECK: points-to-pointer sets // CHECK-NEXT: distinct{{\[}}[[ID:.+]]]<"entry"> points to {distinct{{\[}}[[ID]]]<"entry">} -// CHECK-NEXT: other points to unknown: 0 // CHECK-LABEL @call_other_none_arg_wo -func.func @caller(%input: !llvm.ptr {enzyme.tag = "input"}) { +func.func @call_other_none_arg_wo(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () return } @@ -69,7 +67,7 @@ func.func private @callee(%ptr : !llvm.ptr {llvm.nocapture}) attributes { } // CHECK: points-to-pointer sets -// CHECK-NEXT: other points to unknown: 0 +// CHECK-NEXT: // CHECK-LABEL @call_other_none_arg_wo_nocapture func.func @call_other_none_arg_wo_nocapture(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -86,7 +84,6 @@ func.func private @callee(%ptr : !llvm.ptr {llvm.nocapture}) attributes { // CHECK: points-to-pointer sets // CHECK-NEXT: distinct{{\[}}[[ID:.+]]]<"entry"> points to {} -// CHECK-NEXT: other points to unknown: 0 // CHECK-LABEL @call_other_read_arg_wo_nocapture func.func @call_other_read_arg_wo_nocapture(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -103,7 +100,6 @@ func.func private @callee(%ptr : !llvm.ptr) attributes { // CHECK: points-to-pointer sets // CHECK-NEXT: distinct{{\[}}[[ID:.+]]]<"entry"> points to {} -// CHECK-NEXT: other points to unknown: 0 // CHECK-LABEL @call_other_read_arg_wo func.func @call_other_read_arg_wo(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -119,7 +115,7 @@ func.func private @callee(%ptr : !llvm.ptr {llvm.readonly}) attributes { } // CHECK: points-to-pointer sets -// CHECK-NEXT: other points to unknown: 0 +// CHECK-NEXT: // CHECK-LABEL @call_other_none_arg_rw_readonly func.func @call_other_none_arg_rw_readonly(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -137,7 +133,6 @@ func.func private @callee(%ptr : !llvm.ptr {llvm.writeonly}) attributes { // CHECK: points-to-pointer sets // CHECK-NEXT: distinct{{\[}}[[ID:.+]]]<"entry"> points to {distinct{{\[}}[[ID]]]<"entry">} -// CHECK-NEXT: other points to unknown: 0 // CHECK-LABEL @call_other_none_arg_rw_writeonly func.func @call_other_none_arg_rw_writeonly(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () @@ -157,7 +152,6 @@ func.func private @callee(%ptr1 : !llvm.ptr, %ptr2 : !llvm.ptr) attributes { // CHECK: points-to-pointer sets // CHECK-DAG: distinct{{\[}}[[ID:.+]]]<"alloca-2"> points to {distinct{{.*}}, distinct{{.*}}} // CHECK-DAG: distinct{{\[}}[[ID:.+]]]<"alloca-1"> points to {distinct{{.*}}, distinct{{.*}}} -// CHECK-NEXT: other points to unknown: 0 func.func @call_two_pointers_other_none_arg_rw_simple(%sz: i64) { %0 = llvm.alloca %sz x i8 { tag = "alloca-1" } : (i64) -> !llvm.ptr %1 = llvm.alloca %sz x i8 { tag = "alloca-2" } : (i64) -> !llvm.ptr @@ -178,7 +172,6 @@ func.func private @callee(%ptr1 : !llvm.ptr, %ptr2 : !llvm.ptr {llvm.nocapture}) // CHECK: points-to-pointer sets // CHECK-DAG: distinct{{\[}}[[ID:.+]]]<"alloca-2"> points to {distinct{{\[}}[[ID]]]<"alloca-1">} // CHECK-DAG: distinct{{\[}}[[ID]]]<"alloca-1"> points to {distinct{{\[}}[[ID]]]<"alloca-1">} -// CHECK-NEXT: other points to unknown: 0 // CHECK-LABEL: @call_two_pointers_other_none_arg_rw_nocapture func.func @call_two_pointers_other_none_arg_rw_nocapture(%sz: i64) { %0 = llvm.alloca %sz x i8 { tag = "alloca-1" } : (i64) -> !llvm.ptr @@ -195,9 +188,11 @@ func.func private @callee(%ptr1 : !llvm.ptr {llvm.readonly}, %ptr2 : !llvm.ptr) inaccessibleMem = none> } +// TODO: the DAG below is due to using DenseMap and printing in no particular +// order, this should be fixed to have a deterministic order in tests. // CHECK: points-to-pointer sets -// CHECK-NEXT: distinct[{{.+}}]<"alloca-2"> points to {} -// CHECK-NEXT: other points to unknown: 0 +// CHECK-DAG: distinct[{{.+}}]<"alloca-2"> points to {} +// CHECK-DAG: distinct[{{.+}}]<"alloca-1"> points to {} func.func @call_two_pointers_other_read_arg_rw(%sz: i64) { %0 = llvm.alloca %sz x i8 { tag = "alloca-1" } : (i64) -> !llvm.ptr %1 = llvm.alloca %sz x i8 { tag = "alloca-2" } : (i64) -> !llvm.ptr @@ -214,7 +209,7 @@ func.func private @callee() -> !llvm.ptr attributes { } // CHECK: points-to-pointer sets -// CHECK: other points to unknown: 1 +// CHECK: func.func @func_return_simple() -> !llvm.ptr { %0 = call @callee() {tag = "func-return"} : () -> !llvm.ptr return %0 : !llvm.ptr @@ -229,8 +224,7 @@ func.func private @callee() -> (!llvm.ptr {llvm.noalias}) attributes { } // CHECK: points-to-pointer sets -// CHECK-NEXT: distinct[{{.+}}]<"func-return"> points to {} -// CHECK-NEXT: other points to unknown: 0 +// CHECK-NEXT: distinct[{{.+}}]<"func-return0"> points to {} func.func @func_return_noalias() -> !llvm.ptr { %0 = call @callee() {tag = "func-return"} : () -> !llvm.ptr return %0 : !llvm.ptr @@ -241,7 +235,7 @@ func.func @func_return_noalias() -> !llvm.ptr { // CHECK: tag "func-return" Unknown AC // CHECK: "func-return" and "func-return": MayAlias // CHECK: points-to-pointer sets -// CHECK-NEXT: other points to unknown: 1 +// CHECK-NEXT: distinct[{{.*}}]<"func-return0"> points to {} func.func private @callee() -> (!llvm.ptr {llvm.noalias}, !llvm.ptr) attributes { memory = #llvm.memory_effects (!llvm.ptr {llvm.noalias}) attributes { inaccessibleMem = none> } -// CHECK: "func-return-1" and "func-return-2": NoAlias +// CHECK: "func-1-return" and "func-2-return": NoAlias // CHECK: points-to-pointer sets -// CHECK-DAG: distinct[{{.+}}]<"func-return-1"> points to {} -// CHECK-DAG: distinct[{{.+}}]<"func-return-2"> points to {} -// CHECK-NEXT: other points to unknown: 0 +// CHECK-DAG: distinct[{{.+}}]<"func-1-return0"> points to {} +// CHECK-DAG: distinct[{{.+}}]<"func-2-return0"> points to {} func.func @caller() -> !llvm.ptr { - %0 = call @callee() {tag = "func-return-1"} : () -> !llvm.ptr - %1 = call @callee() {tag = "func-return-2"} : () -> !llvm.ptr + %0 = call @callee() {tag = "func-1-return"} : () -> !llvm.ptr + %1 = call @callee() {tag = "func-2-return"} : () -> !llvm.ptr return %0 : !llvm.ptr } @@ -276,8 +269,7 @@ func.func @caller() -> !llvm.ptr { // ----- // CHECK: points-to-pointer sets -// CHECK-NOT: points to {} -// CHECK: other points to unknown: 0 +// CHECK: func.func private @callee(!llvm.ptr {llvm.readnone}) attributes { memory = #llvm.memory_effects () return } + +// ----- + +func.func private @callee() -> (!llvm.ptr {llvm.noalias}) attributes { + memory = #llvm.memory_effects +} + +// CHECK: points-to-pointer sets +// CHECK-DAG: distinct[{{.+}}]<"alloca"> points to {distinct[{{.+}}]<"func-return0">} +// CHECK-DAG: distinct[{{.+}}]<"func-return0"> points to {} +func.func @func_return_noalias_stored() -> !llvm.ptr { + %0 = call @callee() {tag = "func-return"} : () -> !llvm.ptr + %c1 = arith.constant 1 : i64 + %1 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca"}: (i64) -> !llvm.ptr + llvm.store %0, %1 : !llvm.ptr, !llvm.ptr + return %0 : !llvm.ptr +} + +// ----- + +func.func private @callee() -> (!llvm.ptr) attributes { + memory = #llvm.memory_effects +} + +// CHECK: points-to-pointer sets +// CHECK: distinct[{{.+}}]<"alloca"> points to {} +func.func @func_return_stored() -> !llvm.ptr { + %0 = call @callee() {tag = "func-return"} : () -> !llvm.ptr + %c1 = arith.constant 1 : i64 + %1 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca"}: (i64) -> !llvm.ptr + llvm.store %0, %1 : !llvm.ptr, !llvm.ptr + return %0 : !llvm.ptr +} + +// ----- + + +func.func private @callee() -> (!llvm.ptr, !llvm.ptr) attributes { + memory = #llvm.memory_effects +} + +// TODO: The two results may alias, but we can't really +// differentiate them with current printing. +// CHECK: "func-return" and "func-return": MayAlias + +// CHECK: points-to-pointer sets +// CHECK-DAG: distinct[{{.+}}]<"alloca-1"> points to {distinct[{{.+}}]<"func-return-common">} +// CHECK-DAG: distinct[{{.+}}]<"alloca-2"> points to {distinct[{{.+}}]<"func-return-common">} +func.func @func_return_multiple() -> !llvm.ptr { + %0:2 = call @callee() {tag = "func-return"} : () -> (!llvm.ptr, !llvm.ptr) + %c1 = arith.constant 1 : i64 + %1 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-1"}: (i64) -> !llvm.ptr + llvm.store %0#0, %1 : !llvm.ptr, !llvm.ptr + %2 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-2"}: (i64) -> !llvm.ptr + llvm.store %0#1, %2 : !llvm.ptr, !llvm.ptr + return %0 : !llvm.ptr +} + +// ----- + +func.func private @callee() -> (!llvm.ptr {llvm.noalias}, !llvm.ptr {llvm.noalias}) attributes { + memory = #llvm.memory_effects +} + +// TODO: The two results are known not to alias, but we can't really +// differentiate them with current printing. +// CHECK: "func-return" and "func-return": NoAlias + +// CHECK: points-to-pointer sets +// CHECK-DAG: distinct[{{.+}}]<"alloca-1"> points to {distinct[{{.+}}]<"func-return0">} +// CHECK-DAG: distinct[{{.+}}]<"alloca-2"> points to {distinct[{{.+}}]<"func-return1">} +func.func @func_return_noalias() -> !llvm.ptr { + %0:2 = call @callee() {tag = "func-return"} : () -> (!llvm.ptr, !llvm.ptr) + %c1 = arith.constant 1 : i64 + %1 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-1"}: (i64) -> !llvm.ptr + llvm.store %0#0, %1 : !llvm.ptr, !llvm.ptr + %2 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-2"}: (i64) -> !llvm.ptr + llvm.store %0#1, %2 : !llvm.ptr, !llvm.ptr + return %0 : !llvm.ptr +} + +// ----- + + +func.func private @callee(!llvm.ptr, !llvm.ptr {llvm.nocapture}) -> (!llvm.ptr, !llvm.ptr) attributes { + memory = #llvm.memory_effects +} + +// CHECK: "func-return" and "func-return": MayAlias + +// Returned value can only point to the classes of captured pointers, i.e. arg0. +// However, returned value itself may alias with any argument, so pointers that +// stored the return value may point to any of the arg0, arg1 and the returned +// value itself. +// +// CHECK: points-to-pointer sets +// CHECK-DAG: distinct[{{.*}}]<"func-return-common"> points to {distinct[{{.*}}]<"arg0">} +// CHECK-DAG: distinct[{{.*}}]<"alloca-1"> points to { +// CHECK-DAG: distinct[{{.*}}]<"func-return-common"> +// CHECK-DAG: distinct[{{.*}}]<"arg0"> +// CHECK-DAG: distinct[{{.*}}]<"arg1"> +func.func @multi_operand_result(%arg0: !llvm.ptr {enzyme.tag = "arg0", llvm.noalias}, + %arg1: !llvm.ptr {enzyme.tag = "arg1", llvm.nocapture, llvm.noalias}) -> !llvm.ptr { + %0:2 = call @callee(%arg0, %arg1) {tag = "func-return"} : (!llvm.ptr, !llvm.ptr) -> (!llvm.ptr, !llvm.ptr) + %c1 = arith.constant 1 : i64 + %1 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-1"}: (i64) -> !llvm.ptr + llvm.store %0#0, %1 : !llvm.ptr, !llvm.ptr + %2 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-2"}: (i64) -> !llvm.ptr + llvm.store %0#1, %2 : !llvm.ptr, !llvm.ptr + return %0#0 : !llvm.ptr +} + +// ----- + + +func.func private @callee(!llvm.ptr, !llvm.ptr {llvm.nocapture}) -> (!llvm.ptr {llvm.noalias}, !llvm.ptr {llvm.noalias}) attributes { + memory = #llvm.memory_effects +} + +// Returned values can pointer to something that was captured, but belong to +// diferent classes and don't alias operand pointers. +// +// CHECK: points-to-pointer sets +// CHECK-DAG: distinct[{{.*}}]<"alloca-1"> points to {distinct[{{.*}}]<"func-return0">} +// CHECK-DAG: distinct[{{.*}}]<"func-return1"> points to {distinct[{{.*}}]<"arg0">} +// CHECK-DAG: distinct[{{.*}}]<"func-return0"> points to {distinct[{{.*}}]<"arg0">} +// CHECK-DAG: distinct[{{.*}}]<"alloca-2"> points to {distinct[{{.*}}]<"func-return1">} +func.func @multi_operand_result(%arg0: !llvm.ptr {enzyme.tag = "arg0", llvm.noalias}, + %arg1: !llvm.ptr {enzyme.tag = "arg1", llvm.nocapture, llvm.noalias}) -> !llvm.ptr { + %0:2 = call @callee(%arg0, %arg1) {tag = "func-return"} : (!llvm.ptr, !llvm.ptr) -> (!llvm.ptr, !llvm.ptr) + %c1 = arith.constant 1 : i64 + %1 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-1"}: (i64) -> !llvm.ptr + llvm.store %0#0, %1 : !llvm.ptr, !llvm.ptr + %2 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-2"}: (i64) -> !llvm.ptr + llvm.store %0#1, %2 : !llvm.ptr, !llvm.ptr + return %0#0 : !llvm.ptr +} + +// ----- + +func.func private @callee(!llvm.ptr, !llvm.ptr {llvm.nocapture}) + -> (!llvm.ptr, !llvm.ptr {llvm.noalias}) attributes { + memory = #llvm.memory_effects +} + +// Returned values can pointer to something that was captured, but belong to +// diferent classes and don't alias operand pointers. +// +// CHECK: points-to-pointer sets +// CHECK-DAG: distinct[{{.*}}]<"func-return-common"> points to {distinct[{{.*}}]<"arg0">} +// CHECK-DAG: distinct[{{.*}}]<"func-return1"> points to {distinct[{{.*}}]<"arg0">} +// CHECK-DAG: distinct[{{.*}}]<"alloca-1"> points to +// TODO: the current way of checking is fundamentally broken because of printing +// in hashmap order, we'd need a nested CHECK-DAG for this. +// CHECK-DAG: distinct[{{.*}}]<"alloca-2"> points to {distinct[{{.*}}]<"func-return1">} +// CHECK: #distinct +func.func @multi_operand_result(%arg0: !llvm.ptr {enzyme.tag = "arg0", llvm.noalias}, + %arg1: !llvm.ptr {enzyme.tag = "arg1", llvm.nocapture, llvm.noalias}) -> !llvm.ptr { + %0:2 = call @callee(%arg0, %arg1) {tag = "func-return"} : (!llvm.ptr, !llvm.ptr) -> (!llvm.ptr, !llvm.ptr) + %c1 = arith.constant 1 : i64 + %1 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-1"}: (i64) -> !llvm.ptr + llvm.store %0#0, %1 : !llvm.ptr, !llvm.ptr + %2 = llvm.alloca %c1 x !llvm.ptr {tag = "alloca-2"}: (i64) -> !llvm.ptr + llvm.store %0#1, %2 : !llvm.ptr, !llvm.ptr + return %0#0 : !llvm.ptr +} From 11cc0f13264671094bf41ff280ba2086f47d2f40 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 18 Jan 2024 20:15:22 -0500 Subject: [PATCH 008/131] Benchmark (#1617) * Benchmark wip * fix * prune prints --- enzyme/Enzyme/FunctionUtils.cpp | 104 +++++++++--- .../Integration/Sparse/eigen_analysis.cpp | 70 ++++----- enzyme/test/Integration/Sparse/matrix.h | 65 +++++++- enzyme/test/Integration/Sparse/ringspring.cpp | 106 +++++-------- .../Sparse/ringspring2Dextenddata.cpp | 124 +++++---------- .../Sparse/ringspring3Dextenddata.cpp | 142 ++++++++--------- .../ringspring3Dextenddatarestlengthone.cpp | 148 ++++++------------ .../Sparse/ringspring3Drestlengthone.cpp | 133 ++++++---------- enzyme/test/Integration/Sparse/sqrtspring.cpp | 104 +++++------- 9 files changed, 432 insertions(+), 564 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 0e14c20d9ef2..4c3370cf4cf9 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -7063,6 +7063,7 @@ Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, return {}; } +constexpr bool SparseDebug = false; std::shared_ptr getSparseConditions(bool &legal, Value *val, std::shared_ptr defaultFloat, @@ -7077,11 +7078,13 @@ getSparseConditions(bool &legal, Value *val, auto res = lhs->andB(rhs, ctx); assert(res); assert(ctx.seen.size() == 0); - llvm::errs() << " getSparse(and, " << *I << "), lhs(" << *I->getOperand(0) - << ") = " << *lhs << "\n"; - llvm::errs() << " getSparse(and, " << *I << "), rhs(" << *I->getOperand(1) - << ") = " << *rhs << "\n"; - llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n"; + if (SparseDebug) { + llvm::errs() << " getSparse(and, " << *I << "), lhs(" + << *I->getOperand(0) << ") = " << *lhs << "\n"; + llvm::errs() << " getSparse(and, " << *I << "), rhs(" + << *I->getOperand(1) << ") = " << *rhs << "\n"; + llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n"; + } return res; } @@ -7092,11 +7095,13 @@ getSparseConditions(bool &legal, Value *val, auto rhs = getSparseConditions(legal, I->getOperand(1), Constraints::none(), I, ctx); auto res = lhs->orB(rhs, ctx); - llvm::errs() << " getSparse(or, " << *I << "), lhs(" << *I->getOperand(0) - << ") = " << *lhs << "\n"; - llvm::errs() << " getSparse(or, " << *I << "), rhs(" << *I->getOperand(1) - << ") = " << *rhs << "\n"; - llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n"; + if (SparseDebug) { + llvm::errs() << " getSparse(or, " << *I << "), lhs(" + << *I->getOperand(0) << ") = " << *lhs << "\n"; + llvm::errs() << " getSparse(or, " << *I << "), rhs(" + << *I->getOperand(1) << ") = " << *rhs << "\n"; + llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n"; + } return res; } @@ -7108,9 +7113,12 @@ getSparseConditions(bool &legal, Value *val, getSparseConditions(legal, I->getOperand(1 - i), defaultFloat->notB(ctx), scope, ctx); auto res = pres->notB(ctx); - llvm::errs() << " getSparse(not, " << *I << "), prev (" - << *I->getOperand(0) << ") = " << *pres << "\n"; - llvm::errs() << " getSparse(not, " << *I << ") = " << *res << "\n"; + if (SparseDebug) { + llvm::errs() << " getSparse(not, " << *I << "), prev (" + << *I->getOperand(0) << ") = " << *pres << "\n"; + llvm::errs() << " getSparse(not, " << *I << ") = " << *res + << "\n"; + } return res; } } @@ -7120,8 +7128,10 @@ getSparseConditions(bool &legal, Value *val, auto L = ctx.loopToSolve; auto lhs = ctx.SE.getSCEVAtScope(icmp->getOperand(0), L); auto rhs = ctx.SE.getSCEVAtScope(icmp->getOperand(1), L); - llvm::errs() << " lhs: " << *lhs << "\n"; - llvm::errs() << " rhs: " << *rhs << "\n"; + if (SparseDebug) { + llvm::errs() << " lhs: " << *lhs << "\n"; + llvm::errs() << " rhs: " << *rhs << "\n"; + } auto sub1 = ctx.SE.getMinusSCEV(lhs, rhs); @@ -7145,8 +7155,10 @@ getSparseConditions(bool &legal, Value *val, auto res = Constraints::make_compare( div, icmp->getPredicate() == ICmpInst::ICMP_EQ, add->getLoop(), ctx); - llvm::errs() - << " getSparse(icmp, " << *I << ") = " << *res << "\n"; + if (SparseDebug) { + llvm::errs() + << " getSparse(icmp, " << *I << ") = " << *res << "\n"; + } return res; } } @@ -7172,7 +7184,9 @@ getSparseConditions(bool &legal, Value *val, // cmp x, 1.0 -> false/true if (auto fcmp = dyn_cast(I)) { auto res = defaultFloat; - llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n"; + if (SparseDebug) { + llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n"; + } return res; if (fcmp->getPredicate() == CmpInst::FCMP_OEQ || @@ -7263,13 +7277,16 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, // Full simplification while (!Q.empty()) { auto cur = Q.pop_back_val(); + /* std::set prev; for (auto v : Q) prev.insert(v); // llvm::errs() << "\n\n\n\n" << F << "\n"; llvm::errs() << "cur: " << *cur << "\n"; + */ auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL); (void)changed; + /* if (changed) { llvm::errs() << "changed: " << *changed << "\n"; @@ -7278,6 +7295,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, llvm::errs() << " + " << *I << "\n"; // llvm::errs() << F << "\n\n"; } + */ } // llvm::errs() << " post fix inner " << F << "\n"; @@ -7872,6 +7890,7 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F, args.push_back(diff); for (size_t i = argstart; i < num_args; i++) args.push_back(CI->getArgOperand(i)); + if (load_fn->getFunctionType()->getNumParams() != args.size()) { auto fnName = load_fn->getName(); auto found_numargs = load_fn->getFunctionType()->getNumParams(); @@ -7893,7 +7912,7 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F, *args[i]->getType(), " found ", load_fn->getFunctionType()->params()[i]); tocontinue = true; - break; + args[i] = UndefValue::get(args[i]->getType()); } } if (tocontinue) @@ -7902,8 +7921,18 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F, CallInst *call = B.CreateCall(load_fn, args); call->setDebugLoc(LI->getDebugLoc()); Value *tmp = call; - if (tmp->getType() != LI->getType()) - tmp = B.CreateBitCast(tmp, LI->getType()); + if (tmp->getType() != LI->getType()) { + if (CastInst::castIsValid(Instruction::BitCast, tmp, LI->getType())) + tmp = B.CreateBitCast(tmp, LI->getType()); + else { + auto fnName = load_fn->getName(); + EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, + " incorrect return type of loader function ", fnName, + " expected ", *LI->getType(), " found ", + *call->getType()); + tmp = UndefValue::get(LI->getType()); + } + } LI->replaceAllUsesWith(tmp); if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) { @@ -7927,15 +7956,44 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F, EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, " first argument of store function must be the type of " "the store found fn arg type ", - sty, " expected ", args0ty); + *sty, " expected ", *args0ty); + args[0] = UndefValue::get(sty); } } args.push_back(diff); for (size_t i = argstart; i < num_args; i++) args.push_back(CI->getArgOperand(i)); + + if (store_fn->getFunctionType()->getNumParams() != args.size()) { + auto fnName = store_fn->getName(); + auto found_numargs = store_fn->getFunctionType()->getNumParams(); + auto expected_numargs = args.size(); + EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, + " incorrect number of arguments to store function ", fnName, + " expected ", expected_numargs, " found ", found_numargs, + " - ", *store_fn->getFunctionType()); + continue; + } else { + bool tocontinue = false; + for (size_t i = 0; i < args.size(); i++) { + if (store_fn->getFunctionType()->getParamType(i) != + args[i]->getType()) { + auto fnName = store_fn->getName(); + EmitFailure("IllegalSparse", CI->getDebugLoc(), CI, + " incorrect type of argument ", i, + " to storeer function ", fnName, " expected ", + *args[i]->getType(), " found ", + store_fn->getFunctionType()->params()[i]); + tocontinue = true; + args[i] = UndefValue::get(args[i]->getType()); + } + } + if (tocontinue) + continue; + } auto call = B.CreateCall(store_fn, args); call->setDebugLoc(SI->getDebugLoc()); - if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) { + if (store_fn->hasFnAttribute(Attribute::AlwaysInline)) { InlineFunctionInfo IFI; InlineFunction(*call, IFI); } diff --git a/enzyme/test/Integration/Sparse/eigen_analysis.cpp b/enzyme/test/Integration/Sparse/eigen_analysis.cpp index d8dc311f957a..434ebcf48ddd 100644 --- a/enzyme/test/Integration/Sparse/eigen_analysis.cpp +++ b/enzyme/test/Integration/Sparse/eigen_analysis.cpp @@ -150,41 +150,6 @@ static void gradient_ip(const T *__restrict__ pos0, const size_t num_faces, cons enzyme_dup, x, out); } - -template -__attribute__((always_inline)) -static T ident_load(unsigned long long offset, size_t i) { - return (offset / sizeof(T) == i) ? T(1) : T(0); -} - - -template -__attribute__((always_inline)) -static void err_store(T val, unsigned long long offset, size_t i) { - assert(0 && "store is not legal"); -} - - -template -__attribute__((always_inline)) -static T zero_load(unsigned long long offset, size_t i, std::vector> &hess) { - return T(0); -} - - -__attribute__((enzyme_sparse_accumulate)) -void inner_store(size_t offset, size_t i, float val, std::vector> &hess) { - hess.push_back(Triple(offset, i, val)); -} - -template -__attribute__((always_inline)) -static void csr_store(T val, unsigned long long offset, size_t i, std::vector> &hess) { - if (val == 0.0) return; - offset /= sizeof(T); - inner_store(offset, i, val, hess); -} - template __attribute__((noinline)) std::vector> hessian(const T*__restrict__ pos0, size_t num_faces, const int* faces, const T*__restrict__ x, size_t x_pts) @@ -217,13 +182,20 @@ std::vector> hessian(const T*__restrict__ pos0, size_t num_faces, cons enzyme_const, pos02, enzyme_const, num_faces, enzyme_const, faces, - enzyme_dup, x2, __enzyme_todense(ident_load, err_store, i), - enzyme_dupnoneed, nullptr, __enzyme_todense(zero_load, csr_store, i, &hess)); + enzyme_dup, x2, __enzyme_todense(ident_load, ident_store, i), + enzyme_dupnoneed, nullptr, __enzyme_todense(sparse_load, sparse_store, i, &hess)); return hess; } -int main() { - const size_t x_pts = 1; +int main(int argc, char** argv) { + size_t x_pts = 8; + + if (argc >= 2) { + x_pts = atoi(argv[1]); + } + + // TODO generate data for more inputs + assert(x_pts == 8); const float x[] = {0.0, 1.0, 0.0}; @@ -233,25 +205,37 @@ int main() { const float pos0[] = {1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 3.0, 1.0, 3.0}; // Call eigenstuffM_simple + struct timeval start, end; + gettimeofday(&start, NULL); const float resultM = eigenstuffM(pos0, num_faces, faces, x); - printf("Result for eigenstuffM_simple: %f\n", resultM); + gettimeofday(&end, NULL); + printf("Result for eigenstuffM_simple: %f, runtime:%f\n", resultM, tdiff(&start, &end)); // Call eigenstuffL_simple + gettimeofday(&start, NULL); const float resultL = eigenstuffL(pos0, num_faces, faces, x); - printf("Result for eigenstuffL_simple: %f\n", resultL); + gettimeofday(&end, NULL); + printf("Result for eigenstuffL_simple: %f, runtime:%f\n", resultL, tdiff(&start, &end)); float dx[sizeof(x)/sizeof(x[0])]; for (size_t i=0; i #include +#include +float tdiff(struct timeval *start, struct timeval *end) { + return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec); +} + template struct Triple { size_t row; @@ -10,6 +15,56 @@ struct Triple { Triple(size_t row, size_t col, T val) : row(row), col(col), val(val) {} }; +__attribute__((enzyme_sparse_accumulate)) +static void inner_storeflt(int64_t row, int64_t col, float val, std::vector> &triplets) { +#ifdef BENCHMARK + if (val == 0.0) return; +#else +#warning "Compiling for debug/verfication, performance may be slowed" +#endif + triplets.emplace_back(row, col, val); +} + +__attribute__((enzyme_sparse_accumulate)) +static void inner_storedbl(int64_t row, int64_t col, double val, std::vector> &triplets) { +#ifdef BENCHMARK + if (val == 0.0) return; +#else +#warning "Compiling for debug/verfication, performance may be slowed" +#endif + triplets.emplace_back(row, col, val); +} + +template +__attribute__((always_inline)) +static void sparse_store(T val, int64_t idx, size_t i, std::vector> &triplets) { + if (val == 0.0) return; + idx /= sizeof(T); + if constexpr (sizeof(T) == 4) + inner_storeflt(i, idx, val, triplets); + else + inner_storedbl(i, idx, val, triplets); +} + +template +__attribute__((always_inline)) +static T sparse_load(int64_t idx, size_t i, std::vector> &triplets) { + return 0.0; +} + +template +__attribute__((always_inline)) +static void ident_store(T, int64_t idx, size_t i) { + assert(0 && "should never load"); +} + +template +__attribute__((always_inline)) +static T ident_load(int64_t idx, size_t i) { + idx /= sizeof(T); + return (T)(idx == i);// ? 1.0 : 0.0; +} + extern int enzyme_width; extern int enzyme_dup; extern int enzyme_dupv; @@ -17,16 +72,16 @@ extern int enzyme_const; extern int enzyme_dupnoneed; template -extern T __enzyme_autodiff(void*, Tys...); +extern T __enzyme_autodiff(void*, Tys...) noexcept; template -extern T __enzyme_fwddiff(void *, Tys...); +extern T __enzyme_fwddiff(void *, Tys...) noexcept; template -extern T __enzyme_todense(Tys...); +extern T __enzyme_todense(Tys...) noexcept; template -extern T __enzyme_post_sparse_todense(Tys...); +extern T __enzyme_post_sparse_todense(Tys...) noexcept; template __attribute__((always_inline)) @@ -200,4 +255,4 @@ static T area(const T *__restrict__ u, const T *__restrict__ v, const T *__restr T cross_product[3]; cross(cross_product, diff1, diff2); return 0.5 * norm(cross_product); -} \ No newline at end of file +} diff --git a/enzyme/test/Integration/Sparse/ringspring.cpp b/enzyme/test/Integration/Sparse/ringspring.cpp index 0ecae72bef5e..dd4242a1bcc5 100644 --- a/enzyme/test/Integration/Sparse/ringspring.cpp +++ b/enzyme/test/Integration/Sparse/ringspring.cpp @@ -17,123 +17,97 @@ #include -struct triple { - size_t row; - size_t col; - double val; - triple(triple&&) = default; - triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} -}; - - -size_t N = 8; - -extern int enzyme_dup; -extern int enzyme_dupnoneed; -extern int enzyme_out; -extern int enzyme_const; - -extern void __enzyme_autodiff(void *, ...); - -extern void __enzyme_fwddiff(void *, ...); - -extern double* __enzyme_todense(void *, ...) noexcept; - +#include "matrix.h" +template __attribute__((always_inline)) -static double f(size_t N, double* input) { +static T f(size_t N, T* input) { double out = 0; // __builtin_assume(!((N-1) == 0)); for (size_t i=0; i __attribute__((always_inline)) -static void grad_f(size_t N, double* input, double* dinput) { - __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); -} - -__attribute__((always_inline)) -static void ident_store(double , int64_t idx, size_t i) { - assert(0 && "should never load"); +static void grad_f(size_t N, T* input, T* dinput) { + __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } +template __attribute__((always_inline)) -double ident_load(int64_t idx, size_t i, size_t N) { +double ringident_load(int64_t idx, size_t i, size_t N) { idx /= sizeof(double); // return (double)( ( (idx == N) ? 0 : idx) == i); return (double)((idx != N && idx == i) || (idx == N && 0 == i)); // return (double)( idx % N == i); } - -__attribute__((enzyme_sparse_accumulate)) -void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { - printf("row=%d col=%d val=%f\n", row, col % N, val); - // assert(abs(val) > 0.00001); - triplets.emplace_back(row % N, col % N, val); -} - -__attribute__((always_inline)) -void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { - if (val == 0.0) return; - idx /= sizeof(double); - inner_store(i, idx, val, triplets); -} - +template __attribute__((always_inline)) -double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { - return 0.0; -} - -__attribute__((always_inline)) -void never_store(double val, int64_t idx, double* input, size_t N) { +void never_store(T val, int64_t idx, T* input, size_t N) { assert(0 && "this is a read only input, why are you storing here..."); } +template __attribute__((always_inline)) double mod_load(int64_t idx, double* input, size_t N) { idx /= sizeof(double); return input[idx % N]; } +template __attribute__((noinline)) -std::vector hess_f(size_t N, double* input) { - std::vector triplets; - input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); +std::vector> hess_f(size_t N, T* input) { + std::vector> triplets; + input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); __builtin_assume(N != 1); for (size_t i=0; i((void*)ringident_load, (void*)never_store, i, N); + T* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, &triplets); - __enzyme_fwddiff((void*)grad_f, + __enzyme_fwddiff((void*)grad_f, enzyme_const, N, enzyme_dup, input, d_input, - enzyme_dupnoneed, (double*)0x1, d_dinput); + enzyme_dupnoneed, (T*)0x1, d_dinput); } return triplets; } -int main() { - // size_t N = 8; - double x[N]; - for (int i=0; i= 2) { + N = atoi(argv[1]); + } + + double *x = (double*)malloc(sizeof(double) * N); + for (int i=0; i -struct triple { - size_t row; - size_t col; - double val; - triple(triple&&) = default; - triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} -}; - - -extern int enzyme_dup; -extern int enzyme_dupnoneed; -extern int enzyme_out; -extern int enzyme_const; - -extern void __enzyme_autodiff(void *, ...); - -extern void __enzyme_fwddiff(void *, ...); - -extern double* __enzyme_todense(void *, ...) noexcept; - +#include "matrix.h" +template __attribute__((always_inline)) static double f(size_t N, double* pos) { double e = 0.; @@ -52,94 +34,52 @@ static double f(size_t N, double* pos) { return e; } - +template __attribute__((always_inline)) -static void grad_f(size_t N, double* input, double* dinput) { - __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); +static void grad_f(size_t N, T* input, T* dinput) { + __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } +template __attribute__((always_inline)) -void ident_store(double , int64_t idx, size_t i) { - assert(0 && "should never load"); -} - -__attribute__((always_inline)) -double ident_load(size_t idx, size_t i, size_t N) { - idx /= sizeof(double); - return (double)(idx == i);// ? 1.0 : 0.0; -} - -__attribute__((enzyme_sparse_accumulate)) -void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { - printf("row=%d col=%d val=%f\n", row, col % N, val); - // assert(abs(val) > 0.00001); - triplets.emplace_back(row % N, col % N, val); -} - -__attribute__((always_inline)) -void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { - if (val == 0.0) return; - idx /= sizeof(double); - inner_store(i, idx, N, val, triplets); -} - -__attribute__((always_inline)) -double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { - return 0.0; -} - -__attribute__((always_inline)) -void never_store(double val, int64_t idx, double* input, size_t N) { +static void never_store(T val, int64_t idx, T* input, size_t N) { assert(0 && "this is a read only input, why are you storing here..."); } __attribute__((always_inline)) -double mod_load(int64_t idx, double* input, size_t N) { +static double mod_load(int64_t idx, double* input, size_t N) { idx /= sizeof(double); return input[idx % N]; } +template __attribute__((noinline)) -std::vector hess_f(size_t N, double* input) { - std::vector triplets; +std::vector> hess_f(size_t N, T* input) { + std::vector> triplets; // input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); for (size_t i=0; i((void*)ident_load, (void*)ident_store, i); + T* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, &triplets); - __enzyme_fwddiff((void*)grad_f, + __enzyme_fwddiff((void*)grad_f, enzyme_const, N, enzyme_dup, input, d_input, - enzyme_dupnoneed, (double*)0x1, d_dinput); + enzyme_dupnoneed, (T*)0x1, d_dinput); } return triplets; } -/* -__attribute__((noinline)) -std::vector hess_f2(size_t N, double* input) { - std::vector triplets; - input = - ((void*)mod_load, (void*)never_store, input, N); - hess_f(N, input); -} -*/ -// int argc, char** argv -int __attribute__((always_inline)) main() { - - - // if (argc != 2) { - // printf("Usage: %s \n", argv[0]); - // return 1; - // } +int main(int argc, char** argv) { + size_t N = 30; - // size_t N = atoi(argv[1]); - size_t N = 16; + if (argc >= 2) { + N = atoi(argv[1]); + } - double x[2 * N + 2]; + double *x = (double*)malloc(sizeof(double) * (2 * N + 2)); for (int i = 0; i < N; ++i) { double angle = 2 * M_PI * i / N; x[2 * i] = cos(angle) ;//+ normal(generator); @@ -147,13 +87,23 @@ int __attribute__((always_inline)) main() { } x[2 * N] = x[0]; x[2 * N + 1] = x[1]; - auto res = hess_f(N, &x[0]); - printf("%ld\n", res.size()); + + struct timeval start, end; + gettimeofday(&start, NULL); - for (auto & tup : res) - printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val); + auto res = hess_f(N, x); + + gettimeofday(&end, NULL); + + printf("Number of elements %ld\n", res.size()); - return 0; -} + printf("Runtime %0.6f\n", tdiff(&start, &end)); + if (N <= 30) { + for (auto & tup : res) + printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val); + } + + return 0; +} diff --git a/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp b/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp index 72408a73df27..6b59d27bca64 100644 --- a/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp +++ b/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp @@ -38,132 +38,114 @@ extern void __enzyme_fwddiff(void *, ...); extern double* __enzyme_todense(void *, ...) noexcept; +// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi + +#include +#include +#include +#include + +#include + +#include "matrix.h" + +template __attribute__((always_inline)) -static double f(size_t N, double* pos) { - double e = 0.; +static T f(size_t N, T* pos) { + T e = 0.; __builtin_assume(N != 0); for (size_t i = 0; i < N; i+=3) { - __builtin_assume(i < 1000000000); - double vx = pos[i]; - double vy = pos[i + 1]; - double vz = pos[i + 2]; - - double wx = pos[i + 3]; - double wy = pos[i + 4]; - double wz = pos[i + 5]; + T vx = pos[i]; + T vy = pos[i + 1]; + T vz = pos[i + 2]; + + T wx = pos[i + 3]; + T wy = pos[i + 4]; + T wz = pos[i + 5]; e += (wx - vx) * (wx - vx) + (wy - vy) * (wy - vy) + (wz - vz) * (wz - vz); } return e; } - -__attribute__((always_inline)) -static void grad_f(size_t N, double* input, double* dinput) { - __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); -} - -__attribute__((always_inline)) -static void ident_store(double , int64_t idx, size_t i) { - assert(0 && "should never load"); -} - -__attribute__((always_inline)) -static double ident_load(int64_t idx, size_t i, size_t N) { - idx /= sizeof(double); - return (double)(idx == i);// ? 1.0 : 0.0; -} - -__attribute__((enzyme_sparse_accumulate)) -static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { - printf("row=%d col=%d val=%f\n", row, col % N, val); - // assert(abs(val) > 0.00001); - triplets.emplace_back(row % N, col % N, val); -} - +template __attribute__((always_inline)) -static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { - if (val == 0.0) return; - idx /= sizeof(double); - inner_store(i, idx, N, val, triplets); +static void grad_f(size_t N, T* input, T* dinput) { + __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } +template __attribute__((always_inline)) -double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { - return 0.0; -} - -__attribute__((always_inline)) -void never_store(double val, int64_t idx, double* input, size_t N) { +static void never_store(T val, int64_t idx, T* input, size_t N) { assert(0 && "this is a read only input, why are you storing here..."); } __attribute__((always_inline)) -double mod_load(int64_t idx, double* input, size_t N) { +static double mod_load(int64_t idx, double* input, size_t N) { idx /= sizeof(double); return input[idx % N]; } +template __attribute__((noinline)) -std::vector hess_f(size_t N, double* input) { - std::vector triplets; +std::vector> hess_f(size_t N, T* input) { + std::vector> triplets; // input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); - __builtin_assume(N < 10000000000); for (size_t i=0; i((void*)ident_load, (void*)ident_store, i); + T* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, &triplets); - __enzyme_fwddiff((void*)grad_f, + __enzyme_fwddiff((void*)grad_f, enzyme_const, N, enzyme_dup, input, d_input, - enzyme_dupnoneed, (double*)0x1, d_dinput); + enzyme_dupnoneed, (T*)0x1, d_dinput); } return triplets; } -/* -__attribute__((noinline)) -std::vector hess_f2(size_t N, double* input) { - std::vector triplets; - input = - ((void*)mod_load, (void*)never_store, input, N); - hess_f(N, input); -} -*/ -// int argc, char** argv -int __attribute__((always_inline)) main() { - std::mt19937 generator(0); // Seed the random number generator - std::uniform_real_distribution normal(0, 0.05); - - - // if (argc != 2) { - // printf("Usage: %s \n", argv[0]); - // return 1; - // } - - // size_t N = atoi(argv[1]); +int main(int argc, char** argv) { size_t N = 30; - double x[3 * N + 3]; + if (argc >= 2) { + N = atoi(argv[1]); + } + + double *x = (double*)malloc(sizeof(double) * (3 * N + 3)); for (int i = 0; i < N; ++i) { double angle = 2 * M_PI * i / N; - x[3 * i] = cos(angle) + normal(generator); - x[3 * i + 1] = sin(angle) + normal(generator); - x[3 * i + 2] = normal(generator); + x[3 * i] = cos(angle) ;//+ normal(generator); + x[3 * i + 1] = sin(angle) ;//+ normal(generator); + x[3 * i + 2] = 0;//normal(generator); } x[3 * N] = x[0]; x[3 * N + 1] = x[1]; x[3 * N + 2] = x[2]; - auto res = hess_f(N, &x[0]); - - printf("%ld\n", res.size()); + struct timeval start, end; + gettimeofday(&start, NULL); + + auto res = hess_f(N, x); + + gettimeofday(&end, NULL); + + printf("Number of elements %ld\n", res.size()); + + printf("Runtime %0.6f\n", tdiff(&start, &end)); + if (N <= 30) { for (auto & tup : res) printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val); + } return 0; } diff --git a/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp b/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp index b5bb2f259135..54f9b0fbd8c1 100644 --- a/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp +++ b/enzyme/test/Integration/Sparse/ringspring3Dextenddatarestlengthone.cpp @@ -1,7 +1,5 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load -// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi @@ -9,94 +7,45 @@ // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// everything should be always inline - #include #include #include #include - #include -struct triple { - size_t row; - size_t col; - double val; - triple(triple&&) = default; - triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} -}; - - -extern int enzyme_dup; -extern int enzyme_dupnoneed; -extern int enzyme_out; -extern int enzyme_const; - -extern void __enzyme_autodiff(void *, ...); - -extern void __enzyme_fwddiff(void *, ...); - -extern double* __enzyme_todense(void *, ...) noexcept; - +#include "matrix.h" +template __attribute__((always_inline)) -static double f(size_t N, double* pos) { +static double f(size_t N, T* __restrict__ pos) { double e = 0.; - for (size_t i = 0; i < N; i += 3) { - double vx = pos[i]; - double vy = pos[i + 1]; - double vz = pos[i + 2]; - - double wx = pos[i + 3]; - double wy = pos[i + 4]; - double wz = pos[i + 5]; - double distance = (wx - vx) * (wx - vx) + (wy - vy) * (wy - vy) + (wz - vz) * (wz - vz); - double rest_len_one_dist = (sqrt(distance) - 1) * (sqrt(distance) - 1); + __builtin_assume(N != 0); + for (size_t j = 0; j < N/3; j ++) { + size_t i = 3 * j; + T vx = pos[i]; + T vy = pos[i + 1]; + T vz = pos[i + 2]; + + T wx = pos[i + 3]; + T wy = pos[i + 4]; + T wz = pos[i + 5]; + T distance = (wx - vx) * (wx - vx) + (wy - vy) * (wy - vy) + (wz - vz) * (wz - vz); + T rest_len_one_dist = (sqrt(distance) - 1) * (sqrt(distance) - 1); e += rest_len_one_dist; } return e; } - -__attribute__((always_inline)) -static void grad_f(size_t N, double* input, double* dinput) { - __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); -} - - +template __attribute__((always_inline)) -static void ident_store(double , int64_t idx, size_t i) { - assert(0 && "should never load"); +static void grad_f(size_t N, T* input, T* dinput) { + __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } +template __attribute__((always_inline)) -static double ident_load(int64_t idx, size_t i, size_t N) { - idx /= sizeof(double); - return (double)(idx == i);// ? 1.0 : 0.0; -} - -__attribute__((enzyme_sparse_accumulate)) -static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { - printf("row=%d col=%d val=%f\n", row, col % N, val); - // assert(abs(val) > 0.00001); - triplets.emplace_back(row % N, col % N, val); -} - -__attribute__((always_inline)) -static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { - if (val == 0.0) return; - idx /= sizeof(double); - inner_store(i, idx, N, val, triplets); -} - -__attribute__((always_inline)) -static double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { - return 0.0; -} - -__attribute__((always_inline)) -static void never_store(double val, int64_t idx, double* input, size_t N) { +static void never_store(T val, int64_t idx, T* input, size_t N) { assert(0 && "this is a read only input, why are you storing here..."); } @@ -106,50 +55,34 @@ static double mod_load(int64_t idx, double* input, size_t N) { return input[idx % N]; } +template __attribute__((noinline)) -std::vector hess_f(size_t N, double* input) { - std::vector triplets; +std::vector> hess_f(size_t N, T* input) { + std::vector> triplets; // input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); for (size_t i=0; i((void*)ident_load, (void*)ident_store, i); + T* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, &triplets); - __enzyme_fwddiff((void*)grad_f, + __enzyme_fwddiff((void*)grad_f, enzyme_const, N, enzyme_dup, input, d_input, - enzyme_dupnoneed, (double*)0x1, d_dinput); + enzyme_dupnoneed, (T*)0x1, d_dinput); } return triplets; } -/* -__attribute__((noinline)) -std::vector hess_f2(size_t N, double* input) { - std::vector triplets; - input = - ((void*)mod_load, (void*)never_store, input, N); - hess_f(N, input); -} -*/ - -// int argc, char** argv -int __attribute__((always_inline)) main() { - //std::mt19937 generator(0); // Seed the random number generator - //std::uniform_real_distribution normal(0, 0.05); - - - // if (argc != 2) { - // printf("Usage: %s \n", argv[0]); - // return 1; - // } - - // size_t N = atoi(argv[1]); +int main(int argc, char** argv) { size_t N = 30; - double x[3 * N + 3]; + if (argc >= 2) { + N = atoi(argv[1]); + } + + double *x = (double*)malloc(sizeof(double) * (3 * N + 3)); for (int i = 0; i < N; ++i) { double angle = 2 * M_PI * i / N; x[3 * i] = cos(angle) ;//+ normal(generator); @@ -159,14 +92,23 @@ int __attribute__((always_inline)) main() { x[3 * N] = x[0]; x[3 * N + 1] = x[1]; x[3 * N + 2] = x[2]; - auto res = hess_f(N, &x[0]); - - printf("%ld\n", res.size()); + struct timeval start, end; + gettimeofday(&start, NULL); + + auto res = hess_f(N, x); + + gettimeofday(&end, NULL); + + printf("Number of elements %ld\n", res.size()); + + printf("Runtime %0.6f\n", tdiff(&start, &end)); + if (N <= 30) { for (auto & tup : res) printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val); + } return 0; } diff --git a/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp b/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp index 49896b2cbc62..cae8bdad5708 100644 --- a/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp +++ b/enzyme/test/Integration/Sparse/ringspring3Drestlengthone.cpp @@ -7,94 +7,45 @@ // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi -// everything should be always inline - #include #include #include #include - #include -struct triple { - size_t row; - size_t col; - double val; - triple(triple&&) = default; - triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} -}; - - -extern int enzyme_dup; -extern int enzyme_dupnoneed; -extern int enzyme_out; -extern int enzyme_const; - -extern void __enzyme_autodiff(void *, ...); - -extern void __enzyme_fwddiff(void *, ...); - -extern double* __enzyme_todense(void *, ...) noexcept; - +#include "matrix.h" +template __attribute__((always_inline)) -static double f(size_t N, double* pos) { +static double f(size_t N, T* __restrict__ pos) { double e = 0.; - for (size_t i = 0; i < N; i += 3) { - double vx = pos[i]; - double vy = pos[i + 1]; - double vz = pos[i + 2]; + __builtin_assume(N != 0); + for (size_t j = 0; j < N/3; j ++) { + size_t i = 3 * j; + T vx = pos[i]; + T vy = pos[i + 1]; + T vz = pos[i + 2]; - double wx = pos[i + 3]; - double wy = pos[i + 4]; - double wz = pos[i + 5]; - double distance = (wx - vx) * (wx - vx) + (wy - vy) * (wy - vy) + (wz - vz) * (wz - vz); - double rest_len_one_dist = (sqrt(distance) - 1) * (sqrt(distance) - 1); + T wx = pos[i + 3]; + T wy = pos[i + 4]; + T wz = pos[i + 5]; + T distance = (wx - vx) * (wx - vx) + (wy - vy) * (wy - vy) + (wz - vz) * (wz - vz); + T rest_len_one_dist = (sqrt(distance) - 1) * (sqrt(distance) - 1); e += rest_len_one_dist; } return e; } - -__attribute__((always_inline)) -static void grad_f(size_t N, double* input, double* dinput) { - __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); -} - - +template __attribute__((always_inline)) -static void ident_store(double , int64_t idx, size_t i) { - assert(0 && "should never load"); +static void grad_f(size_t N, T* input, T* dinput) { + __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } +template __attribute__((always_inline)) -static double ident_load(int64_t idx, size_t i, size_t N) { - idx /= sizeof(double); - return (double)(idx == i);// ? 1.0 : 0.0; -} - -__attribute__((enzyme_sparse_accumulate)) -static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { - printf("row=%d col=%d val=%f\n", row, col % N, val); - // assert(abs(val) > 0.00001); - triplets.emplace_back(row % N, col % N, val); -} - -__attribute__((always_inline)) -static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { - if (val == 0.0) return; - idx /= sizeof(double); - inner_store(i, idx, N, val, triplets); -} - -__attribute__((always_inline)) -static double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { - return 0.0; -} - -__attribute__((always_inline)) -static void never_store(double val, int64_t idx, double* input, size_t N) { +static void never_store(T val, int64_t idx, T* input, size_t N) { assert(0 && "this is a read only input, why are you storing here..."); } @@ -104,20 +55,21 @@ static double mod_load(int64_t idx, double* input, size_t N) { return input[idx % N]; } +template __attribute__((noinline)) -std::vector hess_f(size_t N, double* input) { - std::vector triplets; +std::vector> hess_f(size_t N, T* input) { + std::vector> triplets; // input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); for (size_t i=0; i((void*)ident_load, (void*)ident_store, i); + T* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, &triplets); - __enzyme_fwddiff((void*)grad_f, + __enzyme_fwddiff((void*)grad_f, enzyme_const, N, enzyme_dup, input, d_input, - enzyme_dupnoneed, (double*)0x1, d_dinput); + enzyme_dupnoneed, (T*)0x1, d_dinput); } return triplets; @@ -134,34 +86,37 @@ std::vector hess_f2(size_t N, double* input) { */ // int argc, char** argv -int __attribute__((always_inline)) main() { - std::mt19937 generator(0); // Seed the random number generator - std::uniform_real_distribution normal(0, 0.05); - - - // if (argc != 2) { - // printf("Usage: %s \n", argv[0]); - // return 1; - // } - - // size_t N = atoi(argv[1]); +int main(int argc, char** argv) { size_t N = 30; - double x[3 * N]; + if (argc >= 2) { + N = atoi(argv[1]); + } + + double *x = (double*)malloc(sizeof(double) * 3 * N); for (int i = 0; i < N; ++i) { double angle = 2 * M_PI * i / N; x[3 * i] = cos(angle) ;//+ normal(generator); x[3 * i + 1] = sin(angle) ;//+ normal(generator); x[3 * i + 2] = 0;//normal(generator); } - auto res = hess_f(N, &x[0]); - - printf("%ld\n", res.size()); + struct timeval start, end; + gettimeofday(&start, NULL); + + auto res = hess_f(N, x); + + gettimeofday(&end, NULL); + + printf("Number of elements %ld\n", res.size()); + + printf("Runtime %0.6f\n", tdiff(&start, &end)); + if (N <= 30) { for (auto & tup : res) printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val); + } return 0; } diff --git a/enzyme/test/Integration/Sparse/sqrtspring.cpp b/enzyme/test/Integration/Sparse/sqrtspring.cpp index a9750409b37d..9645593e8fe9 100644 --- a/enzyme/test/Integration/Sparse/sqrtspring.cpp +++ b/enzyme/test/Integration/Sparse/sqrtspring.cpp @@ -15,105 +15,73 @@ #include -struct triple { - size_t row; - size_t col; - double val; - triple(triple&&) = default; - triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} -}; - -extern int enzyme_dup; -extern int enzyme_dupnoneed; -extern int enzyme_out; -extern int enzyme_const; - -extern void __enzyme_autodiff(void *, ...); - -extern void __enzyme_fwddiff(void *, ...); - -extern double* __enzyme_todense(void *, ...) noexcept; - +#include "matrix.h" +template __attribute__((always_inline)) -static double f(size_t N, double* input) { - double out = 0; +static T f(size_t N, T* input) { + T out = 0; __builtin_assume(!((N-1) == 0)); for (size_t i=0; i __attribute__((always_inline)) -static void grad_f(size_t N, double* input, double* dinput) { - __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); -} - - -__attribute__((always_inline)) -static void ident_store(double , int64_t idx, size_t i) { - assert(0 && "should never load"); -} - -__attribute__((always_inline)) -static double ident_load(int64_t idx, size_t i, size_t N) { - idx /= sizeof(double); - return (double)(idx == i);// ? 1.0 : 0.0; -} - -__attribute__((enzyme_sparse_accumulate)) -static void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { - printf("row=%d col=%d val=%f\n", row, col, val); - assert(abs(val) > 0.00001); - triplets.emplace_back(row, col, val); -} - -__attribute__((always_inline)) -static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { - if (val == 0.0) return; - idx /= sizeof(double); - inner_store(i, idx, val, triplets); -} - -__attribute__((always_inline)) -static double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { - return 0.0; +static void grad_f(size_t N, T* input, T* dinput) { + __enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput); } +template __attribute__((noinline)) -std::vector hess_f(size_t N, double* input) { - std::vector triplets; +std::vector> hess_f(size_t N, T* input) { + std::vector> triplets; __builtin_assume(N > 0); for (size_t i=0; i((void*)ident_load, (void*)ident_store, i); + T* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, &triplets); - __enzyme_fwddiff((void*)grad_f, + __enzyme_fwddiff((void*)grad_f, enzyme_const, N, enzyme_dup, input, d_input, - enzyme_dupnoneed, (double*)0x1, d_dinput); - + enzyme_dupnoneed, (T*)0x1, d_dinput); } return triplets; } -int main() { - size_t N = 8; - double x[N]; - for (int i=0; i= 2) { + N = atoi(argv[1]); + } + + double *x = (double*)malloc(sizeof(double) * N); + for (int i=0; i Date: Sat, 20 Jan 2024 13:19:35 -0500 Subject: [PATCH 009/131] Update benchmark.yml to use openstack22 (#1543) * Update benchmark.yml to use openstack22 * Update benchmark.yml * Update benchmark.yml * Update benchmark.yml --- .github/workflows/benchmark.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 68b610ffaa5d..2dadff967d32 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -21,17 +21,17 @@ jobs: strategy: fail-fast: false matrix: - llvm: ["11", "12", "13", "14", "15", "16"] + llvm: ["13", "14", "15", "16"] build: ["Release", "Debug"] # "RelWithDebInfo" - os: [openstack18] + os: [openstack22] timeout-minutes: 120 steps: - name: add llvm run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true - sudo apt-get install -y python-pip autoconf cmake gcc g++ libtool gfortran libblas-dev llvm-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libeigen3-dev libboost-dev - sudo pip install lit pathlib + sudo apt-get install -y python3-pip autoconf cmake gcc g++ libtool gfortran libblas-dev llvm-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libeigen3-dev libboost-dev + sudo python3 -m pip install lit pathlib sudo touch /usr/lib/llvm-${{ matrix.llvm }}/bin/yaml-bench if [[ '${{ matrix.llvm }}' == '13' ]]; then sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake From 0d8c77cad4f0888d4dea4e9b32143c7332f65d5e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 22 Jan 2024 11:48:54 -0500 Subject: [PATCH 010/131] Try github actions for nightly bazel build (#1620) * Try github actions for nightly bazel build * don't install ninja --- .github/workflows/enzyme-bazel.yml | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/.github/workflows/enzyme-bazel.yml b/.github/workflows/enzyme-bazel.yml index 57ec22251ec2..273bb9b08082 100644 --- a/.github/workflows/enzyme-bazel.yml +++ b/.github/workflows/enzyme-bazel.yml @@ -21,31 +21,18 @@ jobs: matrix: build: ["Release"] llbuild: ["Release"] - os: [openstack22] + os: [ubuntu-latest] timeout-minutes: 500 steps: - - name: add llvm - run: | - sudo rm -f /etc/apt/sources.list.d/*llvm*.list - sudo apt-get update - sudo apt-get install -y ninja-build git autoconf cmake gcc g++ libtool python3 python3-dev - - uses: actions/checkout@v3 - - uses: actions/checkout@v3 + + - uses: actions/checkout@v4 + - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' path: 'llvm-project' - submodules: true - - - name: Install bazelisk - run: | - curl -LO "https://github.com/bazelbuild/bazelisk/releases/download/v1.1.0/bazelisk-linux-amd64" - mkdir -p "${GITHUB_WORKSPACE}/bin/" - mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel" - chmod +x "${GITHUB_WORKSPACE}/bin/bazel" - - name: cmake run: | cd enzyme - "${GITHUB_WORKSPACE}/bin/bazel" build :EnzymeStatic + bazel build :EnzymeStatic From 304e21bc0e4ccf3efa3a62aefe94df2eaaeba232 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 23 Jan 2024 23:12:06 -0500 Subject: [PATCH 011/131] MLIR bazel build and test (#1618) * MLIR bazel build and test * Fix MLIR memory bug * fix * print errs * fix --- .github/workflows/enzyme-bazel.yml | 12 +- .github/workflows/enzyme-mlir.yml | 2 +- enzyme/BUILD | 257 ++++++++++++++++++ .../LinalgAutoDiffOpInterfaceImpl.cpp | 4 +- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 28 +- enzyme/WORKSPACE | 18 +- .../test/Integration/BatchMode/char-ptr.cpp | 2 +- enzyme/test/Integration/BatchMode/sqaure.cpp | 2 +- .../test/Integration/BatchMode/test_utils.h | 32 --- enzyme/test/Integration/ForwardMode/binops.c | 2 +- .../test/Integration/ForwardMode/customfwd.c | 2 +- enzyme/test/Integration/ForwardMode/eigen.cpp | 2 +- .../Integration/ForwardMode/fwdandrev.cpp | 2 +- enzyme/test/Integration/ForwardMode/loops.c | 2 +- .../Integration/ForwardMode/loopsdouble.c | 2 +- .../Integration/ForwardMode/loopstriple.c | 2 +- .../Integration/ForwardMode/mpi/mpi_bcast.c | 2 +- .../Integration/ForwardMode/mpi/mpi_reduce.c | 2 +- enzyme/test/Integration/ForwardMode/nofn.cpp | 2 +- .../Integration/ForwardMode/rosenbrock.cpp | 2 +- enzyme/test/Integration/ForwardMode/rwrloop.c | 2 +- enzyme/test/Integration/ForwardMode/sumtil.c | 2 +- enzyme/test/Integration/ForwardMode/sumtil2.c | 2 +- .../ForwardModeVector/test_utils.h | 32 --- .../Integration/ReverseMode/allocatedtape.c | 2 +- .../ReverseMode/allocatedtape_err.c | 2 +- .../test/Integration/ReverseMode/blas_gemm.c | 2 +- .../test/Integration/ReverseMode/blas_gemm2.c | 2 +- .../test/Integration/ReverseMode/boundissue.c | 2 +- .../test/Integration/ReverseMode/cachefwd.c | 2 +- enzyme/test/Integration/ReverseMode/cerr.cpp | 2 +- enzyme/test/Integration/ReverseMode/cin.cpp | 2 +- enzyme/test/Integration/ReverseMode/cmplx.cpp | 2 +- enzyme/test/Integration/ReverseMode/cout.cpp | 2 +- .../Integration/ReverseMode/customalloc.c | 2 +- .../Integration/ReverseMode/customcombined.c | 2 +- .../Integration/ReverseMode/customglob.cpp | 2 +- .../Integration/ReverseMode/customlog1p.c | 2 +- enzyme/test/Integration/ReverseMode/dbginfo.c | 2 +- .../ReverseMode/differential_pointer_return.c | 2 +- .../Integration/ReverseMode/eigensumsq.cpp | 2 +- .../ReverseMode/eigensumsqdyn-notmp.cpp | 2 +- .../Integration/ReverseMode/eigensumsqdyn.cpp | 2 +- .../Integration/ReverseMode/eigentensor.cpp | 2 +- .../ReverseMode/eigentensorfull.cpp | 2 +- .../ReverseMode/eigentensorreal.cpp | 2 +- enzyme/test/Integration/ReverseMode/fbuff.cpp | 2 +- .../test/Integration/ReverseMode/forrealloc.c | 2 +- enzyme/test/Integration/ReverseMode/frexp.c | 2 +- .../test/Integration/ReverseMode/fwdsolve.c | 2 +- .../ReverseMode/gradient-struct-return.c | 2 +- .../Integration/ReverseMode/headerremat.c | 2 +- .../test/Integration/ReverseMode/inactivefn.c | 2 +- .../Integration/ReverseMode/insertsort_sum.c | 2 +- .../ReverseMode/insertsort_sum_alt.c | 2 +- .../ReverseMode/insertsort_sum_min.c | 2 +- .../ReverseMode/integrateconst.cpp | 2 +- .../Integration/ReverseMode/integrateexp.cpp | 2 +- enzyme/test/Integration/ReverseMode/invsqrt.c | 2 +- enzyme/test/Integration/ReverseMode/loops.c | 2 +- .../Integration/ReverseMode/loopsdouble.c | 2 +- .../Integration/ReverseMode/loopsnested.c | 2 +- .../Integration/ReverseMode/loopstriple.c | 2 +- enzyme/test/Integration/ReverseMode/manydiv.c | 2 +- enzyme/test/Integration/ReverseMode/manymax.c | 2 +- enzyme/test/Integration/ReverseMode/map.cpp | 2 +- .../test/Integration/ReverseMode/metamalloc.c | 2 +- enzyme/test/Integration/ReverseMode/metarwr.c | 2 +- .../ReverseMode/mixedstruct1-old.c | 2 +- .../ReverseMode/mixedstruct1-simple.c | 2 +- .../ReverseMode/mixedstruct1-simplefda.c | 2 +- .../ReverseMode/mixedstruct1-simpleps.c | 2 +- .../ReverseMode/mixedstruct1-simpler.c | 2 +- .../ReverseMode/mixedstruct1-simplest.c | 2 +- .../Integration/ReverseMode/mixedstruct1-sp.c | 2 +- .../Integration/ReverseMode/mixedstruct1.c | 2 +- .../test/Integration/ReverseMode/mpi_bcast.c | 2 +- .../test/Integration/ReverseMode/mpi_reduce.c | 2 +- .../Integration/ReverseMode/multivecmax.cpp | 2 +- .../Integration/ReverseMode/multivecmaxC.c | 2 +- enzyme/test/Integration/ReverseMode/mycos.c | 4 +- enzyme/test/Integration/ReverseMode/omp.c | 2 +- enzyme/test/Integration/ReverseMode/omp2.c | 2 +- enzyme/test/Integration/ReverseMode/omp3.c | 2 +- enzyme/test/Integration/ReverseMode/omp6.c | 2 +- .../test/Integration/ReverseMode/omp_crit.c | 2 +- .../ReverseMode/omp_firstprivate.c | 2 +- enzyme/test/Integration/ReverseMode/omp_two.c | 2 +- .../test/Integration/ReverseMode/ompbound.c | 2 +- .../Integration/ReverseMode/perturbation.cpp | 2 +- .../Integration/ReverseMode/posix_memalign.c | 2 +- .../ReverseMode/posix_memalignfor.c | 2 +- .../Integration/ReverseMode/readwriteread.c | 2 +- enzyme/test/Integration/ReverseMode/recurse.c | 2 +- enzyme/test/Integration/ReverseMode/remat.c | 2 +- .../test/Integration/ReverseMode/remat2.cpp | 2 +- .../Integration/ReverseMode/rematSimple.c | 2 +- enzyme/test/Integration/ReverseMode/rwrloop.c | 2 +- enzyme/test/Integration/ReverseMode/rwrmeta.c | 2 +- .../ReverseMode/simpleeigen-made-part.cpp | 2 +- .../ReverseMode/simpleeigen-made.cpp | 2 +- .../Integration/ReverseMode/simpleeigen.cpp | 2 +- .../simpleeigenstatic-made-odd.cpp | 2 +- .../ReverseMode/simpleeigenstatic-made.cpp | 2 +- .../ReverseMode/simpleeigenstatic-sum.cpp | 2 +- .../ReverseMode/simpleeigenstatic-sumsq.cpp | 2 +- .../ReverseMode/simpleeigenstatic-vec.cpp | 2 +- .../ReverseMode/simpleeigenstatic.cpp | 2 +- .../Integration/ReverseMode/smallrealloc.c | 2 +- enzyme/test/Integration/ReverseMode/sret.cpp | 2 +- .../Integration/ReverseMode/subdoublestore.c | 2 +- enzyme/test/Integration/ReverseMode/sumtil.c | 2 +- enzyme/test/Integration/ReverseMode/sumtil2.c | 2 +- .../test/Integration/ReverseMode/taylorlog.c | 2 +- .../test/Integration/ReverseMode/test_utils.h | 39 --- enzyme/test/Integration/ReverseMode/vecmax.c | 2 +- .../test/Integration/ReverseMode/vecmax.cpp | 2 +- .../Integration/ReverseMode/virtualshadow.cpp | 2 +- .../ReverseMode/virtualshadow2.cpp | 2 +- .../ReverseMode/virtualshadow3.cpp | 2 +- enzyme/test/Integration/ReverseMode/wcout.cpp | 2 +- enzyme/test/Integration/Sparse/test_utils.h | 38 --- .../{ForwardMode => }/test_utils.h | 0 enzyme/test/lit.cfg.py | 2 + enzyme/test/lit.site.cfg.py.in | 36 ++- 125 files changed, 451 insertions(+), 275 deletions(-) delete mode 100644 enzyme/test/Integration/BatchMode/test_utils.h delete mode 100644 enzyme/test/Integration/ForwardModeVector/test_utils.h delete mode 100644 enzyme/test/Integration/ReverseMode/test_utils.h delete mode 100644 enzyme/test/Integration/Sparse/test_utils.h rename enzyme/test/Integration/{ForwardMode => }/test_utils.h (100%) diff --git a/.github/workflows/enzyme-bazel.yml b/.github/workflows/enzyme-bazel.yml index 273bb9b08082..38e27e4edcc1 100644 --- a/.github/workflows/enzyme-bazel.yml +++ b/.github/workflows/enzyme-bazel.yml @@ -26,13 +26,21 @@ jobs: timeout-minutes: 500 steps: + - name: Prep + run: | + python -m pip install lit + - uses: actions/checkout@v4 - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' path: 'llvm-project' - - name: cmake + - name: Build + run: | + cd enzyme + bazel build :EnzymeStatic :enzymemlir-opt + - name: Test run: | cd enzyme - bazel build :EnzymeStatic + bazel test --test_output=errors ... diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index ecf909bb5b32..b4c581a3a720 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/checkout@v3 with: repository: 'llvm/llvm-project' - ref: '5ed11e767c0c39a3bc8e035588e7a383849d46a8' + ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb' path: 'llvm-project' - name: Get MLIR commit hash diff --git a/enzyme/BUILD b/enzyme/BUILD index 3c97327ad1b5..42903efaef60 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -1,4 +1,7 @@ load("@llvm-project//llvm:tblgen.bzl", "gentbl") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") licenses(["notice"]) @@ -191,6 +194,7 @@ cc_library( "@llvm-project//llvm:TransformUtils", "@llvm-project//llvm:config", ], + alwayslink = 1 ) cc_binary( @@ -213,3 +217,256 @@ genrule( cmd = "cp $< $@", output_to_bindir = 1, ) + +td_library( + name = "EnzymeDialectTdFiles", + srcs = [ + "Enzyme/MLIR/Dialect/Dialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + ] +) + +gentbl_cc_library( + name = "EnzymeOpsIncGen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "Enzyme/MLIR/Dialect/EnzymeOps.h.inc", + ), + ( + ["-gen-op-defs"], + "Enzyme/MLIR/Dialect/EnzymeOps.cpp.inc", + ), + ( + [ + "-gen-dialect-decls", + "-dialect=enzyme", + ], + "Enzyme/MLIR/Dialect/EnzymeOpsDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=enzyme", + ], + "Enzyme/MLIR/Dialect/EnzymeOpsDialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td", + deps = [":EnzymeDialectTdFiles"], +) + +td_library( + name = "EnzymePassesTdFiles", + srcs = [ + ], + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ] +) + +gentbl_cc_library( + name = "EnzymePassesIncGen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=enzyme", + ], + "Enzyme/MLIR/Passes/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Passes/Passes.td", + deps = [":EnzymePassesTdFiles"], +) + +gentbl_cc_library( + name = "EnzymeTypesIncGen", + tbl_outs = [ + ( + ["-gen-typedef-decls"], + "Enzyme/MLIR/Dialect/EnzymeOpsTypes.h.inc", + ), + ( + ["-gen-typedef-defs"], + "Enzyme/MLIR/Dialect/EnzymeOpsTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td", + deps = [":EnzymeDialectTdFiles"], +) + +gentbl_cc_library( + name = "EnzymeEnumsIncGen", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "Enzyme/MLIR/Dialect/EnzymeEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "Enzyme/MLIR/Dialect/EnzymeEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td", + deps = [":EnzymeDialectTdFiles"], +) + +gentbl_cc_library( + name = "EnzymeAttributesIncGen", + tbl_outs = [ + ( + ["-gen-attrdef-decls"], + "Enzyme/MLIR/Dialect/EnzymeAttributes.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "Enzyme/MLIR/Dialect/EnzymeAttributes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td", + deps = [":EnzymeDialectTdFiles"], +) + + +gentbl_cc_library( + name = "EnzymeTypeInterfacesIncGen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td", + deps = [":EnzymeDialectTdFiles"], +) + +gentbl_cc_library( + name = "EnzymeOpInterfacesIncGen", + tbl_outs = [ + ( + ["--gen-op-interface-decls"], + "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h.inc", + ), + ( + ["--gen-op-interface-defs"], + "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td", + deps = [":EnzymeDialectTdFiles"], +) + +cc_library( + name = "EnzymeMLIR", + srcs = glob([ + "Enzyme/MLIR/Dialect/*.cpp", + "Enzyme/MLIR/Passes/*.cpp", + "Enzyme/MLIR/Interfaces/*.cpp", + "Enzyme/MLIR/Analysis/*.cpp", + "Enzyme/MLIR/Implementations/*.cpp", + ]), + hdrs = glob([ + "Enzyme/MLIR/Dialect/*.h", + "Enzyme/MLIR/Passes/*.h", + "Enzyme/MLIR/Interfaces/*.h", + "Enzyme/MLIR/Analysis/*.h", + "Enzyme/MLIR/Implementations/*.h", + "Enzyme/Utils.h", + "Enzyme/TypeAnalysis/*.h" + ]), + includes = ["Enzyme/MLIR", "Enzyme"], + visibility = ["//visibility:public"], + deps = [ + ":EnzymeOpsIncGen", + ":EnzymePassesIncGen", + ":EnzymeTypesIncGen", + ":EnzymeEnumsIncGen", + ":EnzymeAttributesIncGen", + ":EnzymeTypeInterfacesIncGen", + ":EnzymeOpInterfacesIncGen", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + ], +) + +cc_binary( + name = "enzymemlir-opt", + srcs = ["Enzyme/MLIR/enzymemlir-opt.cpp"], + visibility = ["//visibility:public"], + includes = ["Enzyme/MLIR"], + deps = [ + ":EnzymeMLIR", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:AllPassesAndDialects", + ], +) + +# Generates lit config input file by applying path placeholder substitutions +# similar to the configure_lit_site_cfg CMake macro. +expand_template( + name = "lit_site_cfg_py", + testonly = True, + out = "test/lit.site.cfg.py", + substitutions = { + "@LLVM_VERSION_MAJOR@": "18", + "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", + "@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@ENZYME_SOURCE_DIR@": "", + "@ENZYME_BINARY_DIR@": "", + "@TARGET_TRIPLE@": "", + "@TARGETS_TO_BUILD@": "ALL", + "@LLVM_SHLIBEXT@": ".so", + }, + template = "test/lit.site.cfg.py.in", + visibility = ["//visibility:private"], +) + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + ":test/lit.cfg.py", + ":test/lit.site.cfg.py", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:not", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:opt", + "@llvm-project//clang:builtin_headers_gen", + ":enzyme-clang", + ":enzyme-clang++", + ":enzymemlir-opt" + ] + glob(["test/**/*.h"]) + ) + for src in glob(["test/**/*.mlir"]) +] diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 237e66837513..93488a07fd6e 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -73,7 +73,7 @@ struct GenericOpInterfaceReverse MGradientUtilsReverse *gutils, SmallVector caches) const { auto linalgOp = cast(op); - assert(linalgOp.hasBufferSemantics() && + assert(linalgOp.hasPureBufferSemantics() && "Linalg op with tensor semantics not yet supported"); linalg::LinalgOp newOp = @@ -278,4 +278,4 @@ void mlir::enzyme::registerLinalgDialectAutoDiffInterface( #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(context); }); -} \ No newline at end of file +} diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index d36abbf2dc53..e4af7b760ce6 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -222,11 +222,12 @@ void MEnzymeLogic::handlePredecessors( } else { SmallVector blocks; SmallVector indices; - SmallVector arguments; + SmallVector> arguments; SmallVector defaultArguments; - Block *defaultBlock; - int i = 1; - for (Block *predecessor : oBB->getPredecessors()) { + Block *defaultBlock = nullptr; + for (auto pair : llvm::enumerate(oBB->getPredecessors())) { + auto predecessor = pair.value(); + auto idx = pair.index(); Block *predecessorRevMode = gutils->mapReverseModeBlocks.lookupOrNull(predecessor); @@ -250,10 +251,10 @@ void MEnzymeLogic::handlePredecessors( } } } - if (predecessor != *(oBB->getPredecessors().begin())) { + if (idx != 0) { blocks.push_back(predecessorRevMode); - indices.push_back(APInt(32, i++)); - arguments.push_back(operands); + indices.push_back(APInt(32, idx - 1)); + arguments.emplace_back(std::move(operands)); } else { defaultBlock = predecessorRevMode; defaultArguments = operands; @@ -275,15 +276,19 @@ void MEnzymeLogic::handlePredecessors( oBB->getPredecessors().end()) { // If there is only one block we can directly create a branch for // simplicity sake - revBuilder.create(loc, defaultBlock, defaultArguments); + auto bop = + revBuilder.create(loc, defaultBlock, defaultArguments); } else { Value cache = gutils->insertInit(gutils->getIndexCacheType()); Value flag = revBuilder.create(loc, gutils->getIndexType(), cache); - revBuilder.create( + SmallVector argumentRanges; + for (const auto &a : arguments) + argumentRanges.emplace_back(a); + auto bop = revBuilder.create( loc, flag, defaultBlock, defaultArguments, ArrayRef(indices), - ArrayRef(blocks), ArrayRef(arguments)); + ArrayRef(blocks), argumentRanges); Value origin = newBB->addArgument(gutils->getIndexType(), loc); @@ -356,7 +361,6 @@ void MEnzymeLogic::differentiate( Block *oBB = *it; Block *newBB = gutils->getNewFromOriginal(oBB); Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB); - mapInvertArguments(oBB, reverseBB, gutils); handleReturns(oBB, newBB, reverseBB, gutils, parentRegion); visitChildren(oBB, reverseBB, gutils); @@ -401,4 +405,4 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( delete gutils; return nf; -} \ No newline at end of file +} diff --git a/enzyme/WORKSPACE b/enzyme/WORKSPACE index 4ec0b10d4007..7fabb1dea266 100644 --- a/enzyme/WORKSPACE +++ b/enzyme/WORKSPACE @@ -7,7 +7,7 @@ new_local_repository( load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") -llvm_configure(name = "llvm-project", targets = ["X86"]) +llvm_configure(name = "llvm-project", targets = ["X86", "NVPTX"]) load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") @@ -44,3 +44,19 @@ maybe( "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" ], ) + +PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949" +PYRULES_SHA256 = "cfa6957832ae0e0c7ee2ccf455a888a291e8419ed8faf45f4420dd7414d5dd96" + +http_archive( + name = "rules_python", + sha256 = PYRULES_SHA256, + strip_prefix = "rules_python-" + PYRULES_COMMIT, + urls = ["https://github.com/bazelbuild/rules_python/archive/{commit}.tar.gz".format(commit = PYRULES_COMMIT)] +) + + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() + diff --git a/enzyme/test/Integration/BatchMode/char-ptr.cpp b/enzyme/test/Integration/BatchMode/char-ptr.cpp index 6f5236e6602c..0d727b601073 100644 --- a/enzyme/test/Integration/BatchMode/char-ptr.cpp +++ b/enzyme/test/Integration/BatchMode/char-ptr.cpp @@ -10,7 +10,7 @@ #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_batch(void *, int, int, char *, char *, char *, char *); extern int enzyme_dup; diff --git a/enzyme/test/Integration/BatchMode/sqaure.cpp b/enzyme/test/Integration/BatchMode/sqaure.cpp index f9c97ed54eaa..acd73ccbe973 100644 --- a/enzyme/test/Integration/BatchMode/sqaure.cpp +++ b/enzyme/test/Integration/BatchMode/sqaure.cpp @@ -8,7 +8,7 @@ // RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" #include struct Vector { diff --git a/enzyme/test/Integration/BatchMode/test_utils.h b/enzyme/test/Integration/BatchMode/test_utils.h deleted file mode 100644 index 3f96e761fb2b..000000000000 --- a/enzyme/test/Integration/BatchMode/test_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include -#include -#include - -/* -#ifdef __cplusplus -extern "C" { -#endif -static inline bool approx_fp_equality_float(float f1, float f2, double -threshold) { if (fabs(f1-f2) > threshold) return false; return true; -} - -static inline bool approx_fp_equality_double(double f1, double f2, double -threshold) { if (fabs(f1-f2) > threshold) return false; return true; -} -#ifdef __cplusplus -} -#endif -*/ - -#define APPROX_EQ(LHS, RHS, THRES) \ - { \ - if (__builtin_fabs(LHS - RHS) > THRES) { \ - fprintf(stderr, \ - "Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d " \ - "(%s)\n", \ - #LHS, LHS, #RHS, RHS, THRES, __FILE__, __LINE__, \ - __PRETTY_FUNCTION__); \ - abort(); \ - } \ - }; diff --git a/enzyme/test/Integration/ForwardMode/binops.c b/enzyme/test/Integration/ForwardMode/binops.c index 11dfbe01bc09..936c433eca73 100644 --- a/enzyme/test/Integration/ForwardMode/binops.c +++ b/enzyme/test/Integration/ForwardMode/binops.c @@ -7,7 +7,7 @@ // RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" extern double __enzyme_fwddiff(double (*)(double, double), double,...); diff --git a/enzyme/test/Integration/ForwardMode/customfwd.c b/enzyme/test/Integration/ForwardMode/customfwd.c index 78f8673add1d..e9d1f5d0c130 100644 --- a/enzyme/test/Integration/ForwardMode/customfwd.c +++ b/enzyme/test/Integration/ForwardMode/customfwd.c @@ -17,7 +17,7 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/eigen.cpp b/enzyme/test/Integration/ForwardMode/eigen.cpp index 0eff17398bf6..89504b26303d 100644 --- a/enzyme/test/Integration/ForwardMode/eigen.cpp +++ b/enzyme/test/Integration/ForwardMode/eigen.cpp @@ -5,7 +5,7 @@ // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ForwardMode/fwdandrev.cpp b/enzyme/test/Integration/ForwardMode/fwdandrev.cpp index 997b08d80a67..8da8b91472a9 100644 --- a/enzyme/test/Integration/ForwardMode/fwdandrev.cpp +++ b/enzyme/test/Integration/ForwardMode/fwdandrev.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -std=c++14 -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" template struct tensor; diff --git a/enzyme/test/Integration/ForwardMode/loops.c b/enzyme/test/Integration/ForwardMode/loops.c index 1c498b137bd7..612839248cc2 100644 --- a/enzyme/test/Integration/ForwardMode/loops.c +++ b/enzyme/test/Integration/ForwardMode/loops.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopsdouble.c b/enzyme/test/Integration/ForwardMode/loopsdouble.c index 549c60a4da4c..8c34740c80d5 100644 --- a/enzyme/test/Integration/ForwardMode/loopsdouble.c +++ b/enzyme/test/Integration/ForwardMode/loopsdouble.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopstriple.c b/enzyme/test/Integration/ForwardMode/loopstriple.c index e25a81d09e38..060e74d009f7 100644 --- a/enzyme/test/Integration/ForwardMode/loopstriple.c +++ b/enzyme/test/Integration/ForwardMode/loopstriple.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/mpi/mpi_bcast.c b/enzyme/test/Integration/ForwardMode/mpi/mpi_bcast.c index ec354f8ee285..33118c32942d 100644 --- a/enzyme/test/Integration/ForwardMode/mpi/mpi_bcast.c +++ b/enzyme/test/Integration/ForwardMode/mpi/mpi_bcast.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #include diff --git a/enzyme/test/Integration/ForwardMode/mpi/mpi_reduce.c b/enzyme/test/Integration/ForwardMode/mpi/mpi_reduce.c index 019441d65f66..394e64a96d6f 100644 --- a/enzyme/test/Integration/ForwardMode/mpi/mpi_reduce.c +++ b/enzyme/test/Integration/ForwardMode/mpi/mpi_reduce.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #include diff --git a/enzyme/test/Integration/ForwardMode/nofn.cpp b/enzyme/test/Integration/ForwardMode/nofn.cpp index 8092bb5f5446..6c0716d4e31a 100644 --- a/enzyme/test/Integration/ForwardMode/nofn.cpp +++ b/enzyme/test/Integration/ForwardMode/nofn.cpp @@ -1,6 +1,6 @@ // RUN: if [ %llvmver -ge 10 ]; then %clang -g -O0 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify; fi -#include "test_utils.h" +#include "../test_utils.h" #include diff --git a/enzyme/test/Integration/ForwardMode/rosenbrock.cpp b/enzyme/test/Integration/ForwardMode/rosenbrock.cpp index 434ae89af89d..8b52d2ff31c1 100644 --- a/enzyme/test/Integration/ForwardMode/rosenbrock.cpp +++ b/enzyme/test/Integration/ForwardMode/rosenbrock.cpp @@ -7,7 +7,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang++ -std=c++11 -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ForwardMode/rwrloop.c b/enzyme/test/Integration/ForwardMode/rwrloop.c index d1b0a1a10b4b..dc71b0eead20 100644 --- a/enzyme/test/Integration/ForwardMode/rwrloop.c +++ b/enzyme/test/Integration/ForwardMode/rwrloop.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/sumtil.c b/enzyme/test/Integration/ForwardMode/sumtil.c index 20ce00a4b735..9e9286f97157 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil.c +++ b/enzyme/test/Integration/ForwardMode/sumtil.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardMode/sumtil2.c b/enzyme/test/Integration/ForwardMode/sumtil2.c index c8ae8741bef1..32d289703e69 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil2.c +++ b/enzyme/test/Integration/ForwardMode/sumtil2.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardModeVector/test_utils.h b/enzyme/test/Integration/ForwardModeVector/test_utils.h deleted file mode 100644 index 3f96e761fb2b..000000000000 --- a/enzyme/test/Integration/ForwardModeVector/test_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include -#include -#include - -/* -#ifdef __cplusplus -extern "C" { -#endif -static inline bool approx_fp_equality_float(float f1, float f2, double -threshold) { if (fabs(f1-f2) > threshold) return false; return true; -} - -static inline bool approx_fp_equality_double(double f1, double f2, double -threshold) { if (fabs(f1-f2) > threshold) return false; return true; -} -#ifdef __cplusplus -} -#endif -*/ - -#define APPROX_EQ(LHS, RHS, THRES) \ - { \ - if (__builtin_fabs(LHS - RHS) > THRES) { \ - fprintf(stderr, \ - "Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d " \ - "(%s)\n", \ - #LHS, LHS, #RHS, RHS, THRES, __FILE__, __LINE__, \ - __PRETTY_FUNCTION__); \ - abort(); \ - } \ - }; diff --git a/enzyme/test/Integration/ReverseMode/allocatedtape.c b/enzyme/test/Integration/ReverseMode/allocatedtape.c index 0fd02521e0dd..15bce1ed71cb 100644 --- a/enzyme/test/Integration/ReverseMode/allocatedtape.c +++ b/enzyme/test/Integration/ReverseMode/allocatedtape.c @@ -17,7 +17,7 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include "test_utils.h" +#include "../test_utils.h" void __enzyme_autodiff(void*, ...); void* __enzyme_augmentfwd(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c index f3ee53a65d09..1236a86388e7 100644 --- a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c +++ b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c @@ -8,7 +8,7 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi #include -#include "test_utils.h" +#include "../test_utils.h" void __enzyme_reverse(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/blas_gemm.c b/enzyme/test/Integration/ReverseMode/blas_gemm.c index 195b7efea775..bdebc0a180e9 100644 --- a/enzyme/test/Integration/ReverseMode/blas_gemm.c +++ b/enzyme/test/Integration/ReverseMode/blas_gemm.c @@ -8,7 +8,7 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi -#include "test_utils.h" +#include "../test_utils.h" #include "../blas_inline.h" #include diff --git a/enzyme/test/Integration/ReverseMode/blas_gemm2.c b/enzyme/test/Integration/ReverseMode/blas_gemm2.c index e1fd7947308e..a417e0418cd9 100644 --- a/enzyme/test/Integration/ReverseMode/blas_gemm2.c +++ b/enzyme/test/Integration/ReverseMode/blas_gemm2.c @@ -11,7 +11,7 @@ #include #include #include -#include "test_utils.h" +#include "../test_utils.h" #include "../blas_inline.h" extern int enzyme_dup; diff --git a/enzyme/test/Integration/ReverseMode/boundissue.c b/enzyme/test/Integration/ReverseMode/boundissue.c index 5df279890c7b..f297ebbdda97 100644 --- a/enzyme/test/Integration/ReverseMode/boundissue.c +++ b/enzyme/test/Integration/ReverseMode/boundissue.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/cachefwd.c b/enzyme/test/Integration/ReverseMode/cachefwd.c index 65050b3b32e8..ab56e3dd0853 100644 --- a/enzyme/test/Integration/ReverseMode/cachefwd.c +++ b/enzyme/test/Integration/ReverseMode/cachefwd.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); /*double max(double x, double y) { diff --git a/enzyme/test/Integration/ReverseMode/cerr.cpp b/enzyme/test/Integration/ReverseMode/cerr.cpp index e85bed4ff4df..b4bc9febd3d6 100644 --- a/enzyme/test/Integration/ReverseMode/cerr.cpp +++ b/enzyme/test/Integration/ReverseMode/cerr.cpp @@ -7,7 +7,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -#include "test_utils.h" +#include "../test_utils.h" #include #include #include diff --git a/enzyme/test/Integration/ReverseMode/cin.cpp b/enzyme/test/Integration/ReverseMode/cin.cpp index 6a59655838c5..e357708bdae1 100644 --- a/enzyme/test/Integration/ReverseMode/cin.cpp +++ b/enzyme/test/Integration/ReverseMode/cin.cpp @@ -7,7 +7,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -#include "test_utils.h" +#include "../test_utils.h" #include #include #include diff --git a/enzyme/test/Integration/ReverseMode/cmplx.cpp b/enzyme/test/Integration/ReverseMode/cmplx.cpp index 5970d4cf94ae..73559beb94fb 100644 --- a/enzyme/test/Integration/ReverseMode/cmplx.cpp +++ b/enzyme/test/Integration/ReverseMode/cmplx.cpp @@ -9,7 +9,7 @@ //#include -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ReverseMode/cout.cpp b/enzyme/test/Integration/ReverseMode/cout.cpp index 9c0a611e1224..fadb93f17b23 100644 --- a/enzyme/test/Integration/ReverseMode/cout.cpp +++ b/enzyme/test/Integration/ReverseMode/cout.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S #include -#include "test_utils.h" +#include "../test_utils.h" extern double __enzyme_autodiff(void*, double); diff --git a/enzyme/test/Integration/ReverseMode/customalloc.c b/enzyme/test/Integration/ReverseMode/customalloc.c index 4bd8f88bcf4e..eca4deac5ee2 100644 --- a/enzyme/test/Integration/ReverseMode/customalloc.c +++ b/enzyme/test/Integration/ReverseMode/customalloc.c @@ -22,7 +22,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/customcombined.c b/enzyme/test/Integration/ReverseMode/customcombined.c index e5dca475da5d..5ce2c1df2b07 100644 --- a/enzyme/test/Integration/ReverseMode/customcombined.c +++ b/enzyme/test/Integration/ReverseMode/customcombined.c @@ -17,7 +17,7 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, double); diff --git a/enzyme/test/Integration/ReverseMode/customglob.cpp b/enzyme/test/Integration/ReverseMode/customglob.cpp index e323b40aebce..c3e494cdfbfc 100644 --- a/enzyme/test/Integration/ReverseMode/customglob.cpp +++ b/enzyme/test/Integration/ReverseMode/customglob.cpp @@ -15,7 +15,7 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang++ -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/customlog1p.c b/enzyme/test/Integration/ReverseMode/customlog1p.c index 995cd2c051aa..95c0697fff77 100644 --- a/enzyme/test/Integration/ReverseMode/customlog1p.c +++ b/enzyme/test/Integration/ReverseMode/customlog1p.c @@ -21,7 +21,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/dbginfo.c b/enzyme/test/Integration/ReverseMode/dbginfo.c index dbf9201d1ccd..06767e624849 100644 --- a/enzyme/test/Integration/ReverseMode/dbginfo.c +++ b/enzyme/test/Integration/ReverseMode/dbginfo.c @@ -9,7 +9,7 @@ //#include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c index ffafffe58962..b05a7264ae3c 100644 --- a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c +++ b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/eigensumsq.cpp b/enzyme/test/Integration/ReverseMode/eigensumsq.cpp index f601255fb273..ee4a8e9c76f6 100644 --- a/enzyme/test/Integration/ReverseMode/eigensumsq.cpp +++ b/enzyme/test/Integration/ReverseMode/eigensumsq.cpp @@ -13,7 +13,7 @@ #define EIGEN_UNROLLING_LIMIT 0 #define EIGEN_DONT_VECTORIZE 1 -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ReverseMode/eigensumsqdyn-notmp.cpp b/enzyme/test/Integration/ReverseMode/eigensumsqdyn-notmp.cpp index d5f806d7d62b..7ed4132deac0 100644 --- a/enzyme/test/Integration/ReverseMode/eigensumsqdyn-notmp.cpp +++ b/enzyme/test/Integration/ReverseMode/eigensumsqdyn-notmp.cpp @@ -14,7 +14,7 @@ #define EIGEN_UNROLLING_LIMIT 0 #define EIGEN_DONT_VECTORIZE 1 -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ReverseMode/eigensumsqdyn.cpp b/enzyme/test/Integration/ReverseMode/eigensumsqdyn.cpp index e382f07a4a13..4076978d1af6 100644 --- a/enzyme/test/Integration/ReverseMode/eigensumsqdyn.cpp +++ b/enzyme/test/Integration/ReverseMode/eigensumsqdyn.cpp @@ -13,7 +13,7 @@ #define EIGEN_UNROLLING_LIMIT 0 #define EIGEN_DONT_VECTORIZE 1 -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ReverseMode/eigentensor.cpp b/enzyme/test/Integration/ReverseMode/eigentensor.cpp index f813a8381ab4..28feb236f4db 100644 --- a/enzyme/test/Integration/ReverseMode/eigentensor.cpp +++ b/enzyme/test/Integration/ReverseMode/eigentensor.cpp @@ -14,7 +14,7 @@ #define EIGEN_UNROLLING_LIMIT 0 #define EIGEN_DONT_VECTORIZE 1 -#include "test_utils.h" +#include "../test_utils.h" void memcpy(float* __restrict dst, float* __restrict src, size_t count) { diff --git a/enzyme/test/Integration/ReverseMode/eigentensorfull.cpp b/enzyme/test/Integration/ReverseMode/eigentensorfull.cpp index d6d4c0dd0f2f..b61aebd47ef4 100644 --- a/enzyme/test/Integration/ReverseMode/eigentensorfull.cpp +++ b/enzyme/test/Integration/ReverseMode/eigentensorfull.cpp @@ -14,7 +14,7 @@ #define EIGEN_UNROLLING_LIMIT 0 #define EIGEN_DONT_VECTORIZE 1 -#include "test_utils.h" +#include "../test_utils.h" /* void memcpy(float* __restrict dst, float* __restrict src, size_t count) { diff --git a/enzyme/test/Integration/ReverseMode/eigentensorreal.cpp b/enzyme/test/Integration/ReverseMode/eigentensorreal.cpp index f2ce5fbaa591..a5780d03c9b6 100644 --- a/enzyme/test/Integration/ReverseMode/eigentensorreal.cpp +++ b/enzyme/test/Integration/ReverseMode/eigentensorreal.cpp @@ -14,7 +14,7 @@ #define EIGEN_UNROLLING_LIMIT 0 #define EIGEN_DONT_VECTORIZE 1 -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ReverseMode/fbuff.cpp b/enzyme/test/Integration/ReverseMode/fbuff.cpp index bb1d701fa0a1..2c9e31584aa8 100644 --- a/enzyme/test/Integration/ReverseMode/fbuff.cpp +++ b/enzyme/test/Integration/ReverseMode/fbuff.cpp @@ -7,7 +7,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -#include "test_utils.h" +#include "../test_utils.h" #include #include #include diff --git a/enzyme/test/Integration/ReverseMode/forrealloc.c b/enzyme/test/Integration/ReverseMode/forrealloc.c index f40f82cb54b5..9fc9ec5798d7 100644 --- a/enzyme/test/Integration/ReverseMode/forrealloc.c +++ b/enzyme/test/Integration/ReverseMode/forrealloc.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" float __enzyme_autodiff(void*, float, int); diff --git a/enzyme/test/Integration/ReverseMode/frexp.c b/enzyme/test/Integration/ReverseMode/frexp.c index 27480621b26f..4b917e7e282f 100644 --- a/enzyme/test/Integration/ReverseMode/frexp.c +++ b/enzyme/test/Integration/ReverseMode/frexp.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double f(double x) { int exp; diff --git a/enzyme/test/Integration/ReverseMode/fwdsolve.c b/enzyme/test/Integration/ReverseMode/fwdsolve.c index 0b9fe7db3ca2..c7a6ad63040e 100644 --- a/enzyme/test/Integration/ReverseMode/fwdsolve.c +++ b/enzyme/test/Integration/ReverseMode/fwdsolve.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/gradient-struct-return.c b/enzyme/test/Integration/ReverseMode/gradient-struct-return.c index 718f2adcd32a..605d6bb5045f 100644 --- a/enzyme/test/Integration/ReverseMode/gradient-struct-return.c +++ b/enzyme/test/Integration/ReverseMode/gradient-struct-return.c @@ -8,7 +8,7 @@ // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" typedef struct { double dx,dy; diff --git a/enzyme/test/Integration/ReverseMode/headerremat.c b/enzyme/test/Integration/ReverseMode/headerremat.c index 4eb9f673a991..b3397ad8bead 100644 --- a/enzyme/test/Integration/ReverseMode/headerremat.c +++ b/enzyme/test/Integration/ReverseMode/headerremat.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #include diff --git a/enzyme/test/Integration/ReverseMode/inactivefn.c b/enzyme/test/Integration/ReverseMode/inactivefn.c index e27f1069f863..e48edd3486ed 100644 --- a/enzyme/test/Integration/ReverseMode/inactivefn.c +++ b/enzyme/test/Integration/ReverseMode/inactivefn.c @@ -21,7 +21,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum.c b/enzyme/test/Integration/ReverseMode/insertsort_sum.c index 65740a389ec5..7e7d6b0a3ae4 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c index 0a486cc0b83d..3a643492d30f 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c index fa168b4eb1d3..b62150353581 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c @@ -7,7 +7,7 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" #include #include #include diff --git a/enzyme/test/Integration/ReverseMode/integrateconst.cpp b/enzyme/test/Integration/ReverseMode/integrateconst.cpp index 0f6ce4056818..2086ab8b525d 100644 --- a/enzyme/test/Integration/ReverseMode/integrateconst.cpp +++ b/enzyme/test/Integration/ReverseMode/integrateconst.cpp @@ -10,7 +10,7 @@ // RUN: %clang++ -fno-use-cxa-atexit -ffast-math -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang++ -fno-use-cxa-atexit -ffast-math -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" #define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS #define BOOST_NO_EXCEPTIONS diff --git a/enzyme/test/Integration/ReverseMode/integrateexp.cpp b/enzyme/test/Integration/ReverseMode/integrateexp.cpp index ffa23597a1ec..b0af87422627 100644 --- a/enzyme/test/Integration/ReverseMode/integrateexp.cpp +++ b/enzyme/test/Integration/ReverseMode/integrateexp.cpp @@ -12,7 +12,7 @@ // RUN: %clang++ -fno-use-cxa-atexit -ffast-math -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // TODO: %clang++ -fno-use-cxa-atexit -ffast-math -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" #define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS #define BOOST_NO_EXCEPTIONS diff --git a/enzyme/test/Integration/ReverseMode/invsqrt.c b/enzyme/test/Integration/ReverseMode/invsqrt.c index a9bad760745d..c14ce79ab88d 100644 --- a/enzyme/test/Integration/ReverseMode/invsqrt.c +++ b/enzyme/test/Integration/ReverseMode/invsqrt.c @@ -23,7 +23,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" // Fast inverse sqrt // Code taken from https://en.wikipedia.org/wiki/Fast_inverse_square_root diff --git a/enzyme/test/Integration/ReverseMode/loops.c b/enzyme/test/Integration/ReverseMode/loops.c index 482f00eabc24..3a2794963eb5 100644 --- a/enzyme/test/Integration/ReverseMode/loops.c +++ b/enzyme/test/Integration/ReverseMode/loops.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loopsdouble.c b/enzyme/test/Integration/ReverseMode/loopsdouble.c index 9bd5796eb57d..8108a9813c63 100644 --- a/enzyme/test/Integration/ReverseMode/loopsdouble.c +++ b/enzyme/test/Integration/ReverseMode/loopsdouble.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/loopsnested.c b/enzyme/test/Integration/ReverseMode/loopsnested.c index 73790f6f3e02..8e1eb7c3fa9a 100644 --- a/enzyme/test/Integration/ReverseMode/loopsnested.c +++ b/enzyme/test/Integration/ReverseMode/loopsnested.c @@ -7,7 +7,7 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include "test_utils.h" +#include "../test_utils.h" __attribute__((always_inline)) inline void doA(double x[2]) diff --git a/enzyme/test/Integration/ReverseMode/loopstriple.c b/enzyme/test/Integration/ReverseMode/loopstriple.c index 6cc3d168fd56..490502145810 100644 --- a/enzyme/test/Integration/ReverseMode/loopstriple.c +++ b/enzyme/test/Integration/ReverseMode/loopstriple.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/manydiv.c b/enzyme/test/Integration/ReverseMode/manydiv.c index 642de1211e99..a0da1f6499b3 100644 --- a/enzyme/test/Integration/ReverseMode/manydiv.c +++ b/enzyme/test/Integration/ReverseMode/manydiv.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/manymax.c b/enzyme/test/Integration/ReverseMode/manymax.c index 2bb0c1a0a7ee..90b61041ec1e 100644 --- a/enzyme/test/Integration/ReverseMode/manymax.c +++ b/enzyme/test/Integration/ReverseMode/manymax.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/map.cpp b/enzyme/test/Integration/ReverseMode/map.cpp index 36357c293860..71cba7a0f8ef 100644 --- a/enzyme/test/Integration/ReverseMode/map.cpp +++ b/enzyme/test/Integration/ReverseMode/map.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/ReverseMode/metamalloc.c b/enzyme/test/Integration/ReverseMode/metamalloc.c index 81ebdb799b47..96a848fd9698 100644 --- a/enzyme/test/Integration/ReverseMode/metamalloc.c +++ b/enzyme/test/Integration/ReverseMode/metamalloc.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/metarwr.c b/enzyme/test/Integration/ReverseMode/metarwr.c index 923a710dd6b8..4efcd475f68f 100644 --- a/enzyme/test/Integration/ReverseMode/metarwr.c +++ b/enzyme/test/Integration/ReverseMode/metarwr.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-old.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-old.c index 57f2b42c5656..b016f796e255 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-old.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-old.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c index 916784faf4e8..5dea68e793da 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c index 53b4cfeb446a..e2b3f9788fca 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c index 6bf81d36ceb2..b4857429b6fe 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c index c0dc02d63e56..61b81d0a57bf 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c index 4a1de376bbb7..bcf3ea958485 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c index 6321f42b5169..eb783098cf02 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1.c b/enzyme/test/Integration/ReverseMode/mixedstruct1.c index 85e315a0ce93..3ab79b584ef9 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/mpi_bcast.c b/enzyme/test/Integration/ReverseMode/mpi_bcast.c index 81fdedf1b3ca..d1edb2796236 100644 --- a/enzyme/test/Integration/ReverseMode/mpi_bcast.c +++ b/enzyme/test/Integration/ReverseMode/mpi_bcast.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #include diff --git a/enzyme/test/Integration/ReverseMode/mpi_reduce.c b/enzyme/test/Integration/ReverseMode/mpi_reduce.c index c87e6e0c85e5..d40d191bb7d1 100644 --- a/enzyme/test/Integration/ReverseMode/mpi_reduce.c +++ b/enzyme/test/Integration/ReverseMode/mpi_reduce.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #include diff --git a/enzyme/test/Integration/ReverseMode/multivecmax.cpp b/enzyme/test/Integration/ReverseMode/multivecmax.cpp index 1d24a81bd023..372f8841a059 100644 --- a/enzyme/test/Integration/ReverseMode/multivecmax.cpp +++ b/enzyme/test/Integration/ReverseMode/multivecmax.cpp @@ -15,7 +15,7 @@ #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); /*double max(double x, double y) { diff --git a/enzyme/test/Integration/ReverseMode/multivecmaxC.c b/enzyme/test/Integration/ReverseMode/multivecmaxC.c index 3153059573f2..c39559adb2de 100644 --- a/enzyme/test/Integration/ReverseMode/multivecmaxC.c +++ b/enzyme/test/Integration/ReverseMode/multivecmaxC.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); /*double max(double x, double y) { diff --git a/enzyme/test/Integration/ReverseMode/mycos.c b/enzyme/test/Integration/ReverseMode/mycos.c index f690d32defe6..1d7d4c8e85f3 100644 --- a/enzyme/test/Integration/ReverseMode/mycos.c +++ b/enzyme/test/Integration/ReverseMode/mycos.c @@ -1,4 +1,4 @@ -// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then ls && %clang -std=c11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi @@ -19,7 +19,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" __attribute__((noinline)) uint64_t factorial(uint64_t x) { diff --git a/enzyme/test/Integration/ReverseMode/omp.c b/enzyme/test/Integration/ReverseMode/omp.c index b51e2abe6149..4d6f0bd164c5 100644 --- a/enzyme/test/Integration/ReverseMode/omp.c +++ b/enzyme/test/Integration/ReverseMode/omp.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp2.c b/enzyme/test/Integration/ReverseMode/omp2.c index 7e49df928e37..8cdd06545fd1 100644 --- a/enzyme/test/Integration/ReverseMode/omp2.c +++ b/enzyme/test/Integration/ReverseMode/omp2.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp3.c b/enzyme/test/Integration/ReverseMode/omp3.c index d6431f3e3953..59bdca34cf9a 100644 --- a/enzyme/test/Integration/ReverseMode/omp3.c +++ b/enzyme/test/Integration/ReverseMode/omp3.c @@ -11,7 +11,7 @@ # include # include -#include "test_utils.h" +#include "../test_utils.h" void msg(double* in, int *len, unsigned int slen) { if (slen != 0) { diff --git a/enzyme/test/Integration/ReverseMode/omp6.c b/enzyme/test/Integration/ReverseMode/omp6.c index 9c74e1b6ae96..8a6f34fe235f 100644 --- a/enzyme/test/Integration/ReverseMode/omp6.c +++ b/enzyme/test/Integration/ReverseMode/omp6.c @@ -13,7 +13,7 @@ # include #include -#include "test_utils.h" +#include "../test_utils.h" __attribute__((noinline)) void set(double *a, double x){ diff --git a/enzyme/test/Integration/ReverseMode/omp_crit.c b/enzyme/test/Integration/ReverseMode/omp_crit.c index a633cf37551f..794040c35bef 100644 --- a/enzyme/test/Integration/ReverseMode/omp_crit.c +++ b/enzyme/test/Integration/ReverseMode/omp_crit.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp_firstprivate.c b/enzyme/test/Integration/ReverseMode/omp_firstprivate.c index 4d5a25fb869e..fe284a2cdbbd 100644 --- a/enzyme/test/Integration/ReverseMode/omp_firstprivate.c +++ b/enzyme/test/Integration/ReverseMode/omp_firstprivate.c @@ -15,7 +15,7 @@ extern int omp_get_max_threads(); #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp_two.c b/enzyme/test/Integration/ReverseMode/omp_two.c index 6e9a70411574..1f95e93f2520 100644 --- a/enzyme/test/Integration/ReverseMode/omp_two.c +++ b/enzyme/test/Integration/ReverseMode/omp_two.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/ompbound.c b/enzyme/test/Integration/ReverseMode/ompbound.c index b51e2abe6149..4d6f0bd164c5 100644 --- a/enzyme/test/Integration/ReverseMode/ompbound.c +++ b/enzyme/test/Integration/ReverseMode/ompbound.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/perturbation.cpp b/enzyme/test/Integration/ReverseMode/perturbation.cpp index 34639d614499..1bbb8ea70f98 100644 --- a/enzyme/test/Integration/ReverseMode/perturbation.cpp +++ b/enzyme/test/Integration/ReverseMode/perturbation.cpp @@ -9,7 +9,7 @@ //#include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff1(void*, int, double, double); double __enzyme_autodiff2(void*, double); diff --git a/enzyme/test/Integration/ReverseMode/posix_memalign.c b/enzyme/test/Integration/ReverseMode/posix_memalign.c index 4217a69366dd..48aab315b95a 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalign.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalign.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" int posix_memalign(void **memptr, size_t alignment, size_t size); diff --git a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c index e938ebe44065..0a01ef095142 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" int posix_memalign(void **memptr, size_t alignment, size_t size); diff --git a/enzyme/test/Integration/ReverseMode/readwriteread.c b/enzyme/test/Integration/ReverseMode/readwriteread.c index 84ca7955a1fc..ecfdce54d27e 100644 --- a/enzyme/test/Integration/ReverseMode/readwriteread.c +++ b/enzyme/test/Integration/ReverseMode/readwriteread.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/recurse.c b/enzyme/test/Integration/ReverseMode/recurse.c index 491d96f0c896..a53041e77188 100644 --- a/enzyme/test/Integration/ReverseMode/recurse.c +++ b/enzyme/test/Integration/ReverseMode/recurse.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/remat.c b/enzyme/test/Integration/ReverseMode/remat.c index aec85a951f9b..b228bcc13c99 100644 --- a/enzyme/test/Integration/ReverseMode/remat.c +++ b/enzyme/test/Integration/ReverseMode/remat.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); void square(double** p_delv, double** p_e, int ** idx, int numReg, int numElemReg) { diff --git a/enzyme/test/Integration/ReverseMode/remat2.cpp b/enzyme/test/Integration/ReverseMode/remat2.cpp index a0b38526cdb1..67e720f6f159 100644 --- a/enzyme/test/Integration/ReverseMode/remat2.cpp +++ b/enzyme/test/Integration/ReverseMode/remat2.cpp @@ -21,7 +21,7 @@ #include -#include "test_utils.h" +#include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/rematSimple.c b/enzyme/test/Integration/ReverseMode/rematSimple.c index 6ab78f2252cd..7b95f9cb9506 100644 --- a/enzyme/test/Integration/ReverseMode/rematSimple.c +++ b/enzyme/test/Integration/ReverseMode/rematSimple.c @@ -7,7 +7,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); void square(double* __restrict__ delv, double* __restrict__ e, unsigned long long numReg) { diff --git a/enzyme/test/Integration/ReverseMode/rwrloop.c b/enzyme/test/Integration/ReverseMode/rwrloop.c index 547d990a04c4..cdf9e3774553 100644 --- a/enzyme/test/Integration/ReverseMode/rwrloop.c +++ b/enzyme/test/Integration/ReverseMode/rwrloop.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rwrmeta.c b/enzyme/test/Integration/ReverseMode/rwrmeta.c index 7ccf2114e930..34a15d5c93d4 100644 --- a/enzyme/test/Integration/ReverseMode/rwrmeta.c +++ b/enzyme/test/Integration/ReverseMode/rwrmeta.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/simpleeigen-made-part.cpp b/enzyme/test/Integration/ReverseMode/simpleeigen-made-part.cpp index 9b1ca58070ad..eda0d4d6f715 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigen-made-part.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigen-made-part.cpp @@ -8,7 +8,7 @@ // TODO: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; using Eigen::VectorXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigen-made.cpp b/enzyme/test/Integration/ReverseMode/simpleeigen-made.cpp index 211662ca14d9..4ffd79ba41a5 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigen-made.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigen-made.cpp @@ -10,7 +10,7 @@ // RUN: if [ %llvmver -lt 16 ]; then %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - ; fi #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; using Eigen::VectorXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigen.cpp b/enzyme/test/Integration/ReverseMode/simpleeigen.cpp index cd098649d6b2..bcf3639a200e 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigen.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigen.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; using Eigen::VectorXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made-odd.cpp b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made-odd.cpp index acf5b53ba0c0..2065c72bdf12 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made-odd.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made-odd.cpp @@ -8,7 +8,7 @@ // TODO: %clang++ -Xclang -new-struct-path-tbaa -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made.cpp b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made.cpp index aecd4e7b73af..ceb60f239d50 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-made.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sum.cpp b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sum.cpp index 247892c628e7..cff685397281 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sum.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sum.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sumsq.cpp b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sumsq.cpp index ea5b310c4c6f..df01ab911767 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sumsq.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-sumsq.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-vec.cpp b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-vec.cpp index c30372e31135..b6aa27e3c6ed 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigenstatic-vec.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigenstatic-vec.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; diff --git a/enzyme/test/Integration/ReverseMode/simpleeigenstatic.cpp b/enzyme/test/Integration/ReverseMode/simpleeigenstatic.cpp index 618a86f353a7..16cef553a8d6 100644 --- a/enzyme/test/Integration/ReverseMode/simpleeigenstatic.cpp +++ b/enzyme/test/Integration/ReverseMode/simpleeigenstatic.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" using Eigen::MatrixXd; diff --git a/enzyme/test/Integration/ReverseMode/smallrealloc.c b/enzyme/test/Integration/ReverseMode/smallrealloc.c index 9b65f134d1d5..51e29ccdfce5 100644 --- a/enzyme/test/Integration/ReverseMode/smallrealloc.c +++ b/enzyme/test/Integration/ReverseMode/smallrealloc.c @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" float __enzyme_autodiff(void*, float); diff --git a/enzyme/test/Integration/ReverseMode/sret.cpp b/enzyme/test/Integration/ReverseMode/sret.cpp index 0240e396ecb7..2b5315d5d35c 100644 --- a/enzyme/test/Integration/ReverseMode/sret.cpp +++ b/enzyme/test/Integration/ReverseMode/sret.cpp @@ -7,7 +7,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -#include "test_utils.h" +#include "../test_utils.h" #include #include #include diff --git a/enzyme/test/Integration/ReverseMode/subdoublestore.c b/enzyme/test/Integration/ReverseMode/subdoublestore.c index 4f9915284175..f411d98f3e7f 100644 --- a/enzyme/test/Integration/ReverseMode/subdoublestore.c +++ b/enzyme/test/Integration/ReverseMode/subdoublestore.c @@ -13,7 +13,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/sumtil.c b/enzyme/test/Integration/ReverseMode/sumtil.c index 4f6bc31c82ee..0a2b0502c2bc 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil.c +++ b/enzyme/test/Integration/ReverseMode/sumtil.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); /*double max(double x, double y) { diff --git a/enzyme/test/Integration/ReverseMode/sumtil2.c b/enzyme/test/Integration/ReverseMode/sumtil2.c index d4af27b3e15f..aac316c7c4ea 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil2.c +++ b/enzyme/test/Integration/ReverseMode/sumtil2.c @@ -11,7 +11,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); /*double max(double x, double y) { diff --git a/enzyme/test/Integration/ReverseMode/taylorlog.c b/enzyme/test/Integration/ReverseMode/taylorlog.c index 56c97e006c39..649dd4fff243 100644 --- a/enzyme/test/Integration/ReverseMode/taylorlog.c +++ b/enzyme/test/Integration/ReverseMode/taylorlog.c @@ -9,7 +9,7 @@ //#include -#include "test_utils.h" +#include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/test_utils.h b/enzyme/test/Integration/ReverseMode/test_utils.h deleted file mode 100644 index 8d15bfe10719..000000000000 --- a/enzyme/test/Integration/ReverseMode/test_utils.h +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include -#include -#include - -extern -#ifdef __cplusplus -"C" -#endif -int enzyme_allocated, enzyme_const, enzyme_dup, enzyme_dupnoneed, enzyme_out, - enzyme_tape; - -/* -#ifdef __cplusplus -extern "C" { -#endif -static inline bool approx_fp_equality_float(float f1, float f2, double threshold) { - if (fabs(f1-f2) > threshold) return false; - return true; -} - -static inline bool approx_fp_equality_double(double f1, double f2, double threshold) { - if (fabs(f1-f2) > threshold) return false; - return true; -} -#ifdef __cplusplus -} -#endif -*/ - -#define APPROX_EQ(LHS, RHS, THRES) \ - { \ - if (__builtin_fabs((LHS) - (RHS)) > THRES) { \ - fprintf(stderr, "Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d (%s)\n", \ - #LHS, (double)(LHS), #RHS, (double)(RHS), THRES, \ - __FILE__, __LINE__, __PRETTY_FUNCTION__); \ - abort(); \ - } \ - }; diff --git a/enzyme/test/Integration/ReverseMode/vecmax.c b/enzyme/test/Integration/ReverseMode/vecmax.c index 4900a7995739..24ae2793dbc9 100644 --- a/enzyme/test/Integration/ReverseMode/vecmax.c +++ b/enzyme/test/Integration/ReverseMode/vecmax.c @@ -7,7 +7,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); /*double max(double x, double y) { diff --git a/enzyme/test/Integration/ReverseMode/vecmax.cpp b/enzyme/test/Integration/ReverseMode/vecmax.cpp index 586fae3fe0f7..eda1af0acf9e 100644 --- a/enzyme/test/Integration/ReverseMode/vecmax.cpp +++ b/enzyme/test/Integration/ReverseMode/vecmax.cpp @@ -12,7 +12,7 @@ #include #include -#include "test_utils.h" +#include "../test_utils.h" extern void __enzyme_autodiff(void*, std::vector*, std::vector*); /*double max(double x, double y) { diff --git a/enzyme/test/Integration/ReverseMode/virtualshadow.cpp b/enzyme/test/Integration/ReverseMode/virtualshadow.cpp index 24bfa680cfe7..3511f839c8df 100644 --- a/enzyme/test/Integration/ReverseMode/virtualshadow.cpp +++ b/enzyme/test/Integration/ReverseMode/virtualshadow.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" struct S { double (*fn)(double); double val; diff --git a/enzyme/test/Integration/ReverseMode/virtualshadow2.cpp b/enzyme/test/Integration/ReverseMode/virtualshadow2.cpp index 74b7fa88c474..0852b3d082dd 100644 --- a/enzyme/test/Integration/ReverseMode/virtualshadow2.cpp +++ b/enzyme/test/Integration/ReverseMode/virtualshadow2.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" struct S { double (*fn)(double); diff --git a/enzyme/test/Integration/ReverseMode/virtualshadow3.cpp b/enzyme/test/Integration/ReverseMode/virtualshadow3.cpp index 96366a6a3b98..53972e1851e0 100644 --- a/enzyme/test/Integration/ReverseMode/virtualshadow3.cpp +++ b/enzyme/test/Integration/ReverseMode/virtualshadow3.cpp @@ -4,7 +4,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include -#include "test_utils.h" +#include "../test_utils.h" void __enzyme_autodiff(...); diff --git a/enzyme/test/Integration/ReverseMode/wcout.cpp b/enzyme/test/Integration/ReverseMode/wcout.cpp index 7252b4dccae0..bc47cadb4dc8 100644 --- a/enzyme/test/Integration/ReverseMode/wcout.cpp +++ b/enzyme/test/Integration/ReverseMode/wcout.cpp @@ -8,7 +8,7 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S #include -#include "test_utils.h" +#include "../test_utils.h" #include #include diff --git a/enzyme/test/Integration/Sparse/test_utils.h b/enzyme/test/Integration/Sparse/test_utils.h deleted file mode 100644 index afcf87d4471b..000000000000 --- a/enzyme/test/Integration/Sparse/test_utils.h +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include -#include -#include - -extern -#ifdef __cplusplus -"C" -#endif -int enzyme_allocated, enzyme_const, enzyme_dup, enzyme_dupnoneed, enzyme_out, - enzyme_tape; - -/* -#ifdef __cplusplus -extern "C" { -#endif -static inline bool approx_fp_equality_float(float f1, float f2, double threshold) { - if (fabs(f1-f2) > threshold) return false; - return true; -} - -static inline bool approx_fp_equality_double(double f1, double f2, double threshold) { - if (fabs(f1-f2) > threshold) return false; - return true; -} -#ifdef __cplusplus -} -#endif -*/ - -#define APPROX_EQ(LHS, RHS, THRES) \ - { \ - if (__builtin_fabs((LHS) - (RHS)) > THRES) { \ - fprintf(stderr, "Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d (%s)\n", #LHS, (double)(LHS), #RHS, (double)(RHS), THRES, \ - __FILE__, __LINE__, __PRETTY_FUNCTION__); \ - abort(); \ - } \ - }; diff --git a/enzyme/test/Integration/ForwardMode/test_utils.h b/enzyme/test/Integration/test_utils.h similarity index 100% rename from enzyme/test/Integration/ForwardMode/test_utils.h rename to enzyme/test/Integration/test_utils.h diff --git a/enzyme/test/lit.cfg.py b/enzyme/test/lit.cfg.py index a9e799701eed..527dc9ce8e44 100644 --- a/enzyme/test/lit.cfg.py +++ b/enzyme/test/lit.cfg.py @@ -45,6 +45,7 @@ #llvm_config.add_tool_substitutions(tools, config.llvm_tools_dir) # opt knows whether it is compiled with -DNDEBUG. +""" import subprocess try: opt_cmd = subprocess.Popen([os.path.join(config.llvm_tools_dir, 'opt'), '-version'], @@ -68,3 +69,4 @@ except OSError: print("Could not find llvm-config in " + config.llvm_tools_dir) exit(42) +""" diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 3f4439e09584..bb6e6e7ac0c2 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -1,9 +1,13 @@ @LIT_SITE_CFG_IN_HEADER@ +import os + #config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_ver = "@LLVM_VERSION_MAJOR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_BINARY_DIR@" +if len("@ENZYME_BINARY_DIR@") == 0: + config.llvm_tools_dir = os.getcwd() + "/" + config.llvm_tools_dir config.llvm_libs_dir = "@LLVM_LIBS_DIR@" config.enzyme_obj_root = "@ENZYME_BINARY_DIR@" config.target_triple = "@TARGET_TRIPLE@" @@ -40,10 +44,22 @@ config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) config.substitutions.append(('%lli', config.llvm_tools_dir + "/lli" + (" --jit-kind=mcjit" if int(config.llvm_ver) >= 13 else "") )) config.substitutions.append(('%opt', config.llvm_tools_dir + "/opt")) -config.substitutions.append(('%eopt', config.enzyme_obj_root + "/Enzyme/MLIR/enzymemlir-opt")) + +eopt = config.enzyme_obj_root + "/Enzyme/MLIR/enzymemlir-opt" +if len("@ENZYME_BINARY_DIR@") == 0: + eopt = os.path.dirname(os.path.abspath(__file__)) + "/../enzymemlir-opt" + +eclang = config.llvm_tools_dir + "/clang" +if len("@ENZYME_BINARY_DIR@") == 0: + eclang = os.path.dirname(os.path.abspath(__file__)) + "/../enzyme-clang" + resource = config.llvm_tools_dir + "/../clang/staging" + eclang += " -resource-dir " + resource + " " + eclang += "-I " + os.path.dirname(os.path.abspath(__file__)) + "/Integration" + +config.substitutions.append(('%eopt', eopt)) config.substitutions.append(('%llvmver', config.llvm_ver)) config.substitutions.append(('%FileCheck', config.llvm_tools_dir + "/FileCheck")) -config.substitutions.append(('%clang', config.llvm_tools_dir + "/clang")) +config.substitutions.append(('%clang', eclang)) config.substitutions.append(('%O0TBAA', "-O1 -Xclang -disable-llvm-passes")) oldPM = ((" --enable-new-pm=0" if int(config.llvm_ver) >= 13 else "") @@ -53,6 +69,12 @@ newPM = ((" --enable-new-pm=1" if int(config.llvm_ver) in (12,13) else "") + ' -load-pass-plugin=@ENZYME_BINARY_DIR@/Enzyme/LLVMEnzyme-' + config.llvm_ver + config.llvm_shlib_ext + ' -load=@ENZYME_BINARY_DIR@/Enzyme/LLVMEnzyme-' + config.llvm_ver + config.llvm_shlib_ext + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "")) +if len("@ENZYME_BINARY_DIR@") == 0: + oldPM = ((" --enable-new-pm=0" if int(config.llvm_ver) >= 13 else "") + + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "")) + newPM = ((" --enable-new-pm=1" if int(config.llvm_ver) in (12,13) else "") + + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "")) + oldPMOP = oldPM newPMOP = newPM if int(config.llvm_ver) >= 16: @@ -77,8 +99,16 @@ oldPM = (((" -fno-experimental-new-pass-manager" if int(config.llvm_ver) < 14 el newPM = ((" -fexperimental-new-pass-manager" if int(config.llvm_ver) < 13 else "") + ' -fpass-plugin=@ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext + ' -Xclang -load -Xclang @ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext) + +if len("@ENZYME_BINARY_DIR@") == 0: + oldPM = ((" -fno-experimental-new-pass-manager" if int(config.llvm_ver) < 14 else "-flegacy-pass-manager") if int(config.llvm_ver) >= 13 else "") + newPM = (" -fexperimental-new-pass-manager" if int(config.llvm_ver) < 13 else "") + config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < 15 else newPM)) config.substitutions.append(('%newLoadClangEnzyme', newPM)) # Let the main config do the real work. -lit_config.load_config(config, "@ENZYME_SOURCE_DIR@/test/lit.cfg.py") +cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" +if len("@ENZYME_SOURCE_DIR@") == 0: + cfgfile = os.path.dirname(os.path.abspath(__file__)) + "/lit.cfg.py" +lit_config.load_config(config, cfgfile) From ff91980c5d4726b4dd7fc78d4d80dae0a0744569 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 24 Jan 2024 23:41:42 -0500 Subject: [PATCH 012/131] Bazel CI: cancel in progress --- .github/workflows/enzyme-bazel.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/enzyme-bazel.yml b/.github/workflows/enzyme-bazel.yml index 38e27e4edcc1..52260fa8b1d6 100644 --- a/.github/workflows/enzyme-bazel.yml +++ b/.github/workflows/enzyme-bazel.yml @@ -11,6 +11,10 @@ on: - main merge_group: +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: build-linux: name: Bazel ${{ matrix.build }} ${{ matrix.os }} From 8157f52200406deaf480838ae72b21aa3daed4ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 24 Jan 2024 23:42:20 -0500 Subject: [PATCH 013/131] Re-disable benchmark.yml (#1626) --- .github/workflows/benchmark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 2dadff967d32..46fa9edb1bed 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -23,7 +23,7 @@ jobs: matrix: llvm: ["13", "14", "15", "16"] build: ["Release", "Debug"] # "RelWithDebInfo" - os: [openstack22] + os: [openstack18] timeout-minutes: 120 steps: - name: add llvm From 22b693cb467a2f7a66812a36fb10e0dc43b03d77 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 24 Jan 2024 23:43:17 -0500 Subject: [PATCH 014/131] Invertpointer nicer error for global (#1625) --- enzyme/Enzyme/AdjointGenerator.h | 26 ++++++++++++++++++++++---- enzyme/Enzyme/EnzymeLogic.cpp | 18 +++++++----------- enzyme/Enzyme/GradientUtils.cpp | 22 ++++++++++++++-------- enzyme/Enzyme/Utils.cpp | 11 ++++++++--- enzyme/Enzyme/Utils.h | 13 +++++++++++++ 5 files changed, 64 insertions(+), 26 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 7dc91bec08d5..c6d558cf5ba0 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -5695,11 +5695,29 @@ class AdjointGenerator auto callval = call.getCalledOperand(); if (gutils->isConstantValue(callval)) { - llvm::errs() << *gutils->newFunc->getParent() << "\n"; - llvm::errs() << " orig: " << call << " callval: " << *callval << "\n"; + std::string s; + llvm::raw_string_ostream ss(s); + ss << *gutils->oldFunc << "\n"; + ss << "in Mode: " << to_string(Mode) << "\n"; + ss << " orig: " << call << " callval: " << *callval << "\n"; + ss << " constant function being called, but active call instruction\n"; + if (CustomErrorHandler) { + auto val = unwrap(CustomErrorHandler(ss.str().c_str(), wrap(&call), + ErrorType::NoDerivative, gutils, + nullptr, wrap(&Builder2))); + if (val) + newcalled = val; + else + newcalled = + UndefValue::get(gutils->getShadowType(callval->getType())); + } else { + EmitFailure("NoDerivative", call.getDebugLoc(), &call, ss.str()); + newcalled = + UndefValue::get(gutils->getShadowType(callval->getType())); + } + } else { + newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2); } - assert(!gutils->isConstantValue(callval)); - newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2); auto ft = call.getFunctionType(); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 5000ed34bcaa..2810b5370ac4 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3252,22 +3252,18 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, } } if (!PNfloatType) { + std::string str; + raw_string_ostream ss(str); + ss << "Cannot deduce type of phi " << *orig; if (CustomErrorHandler) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of phi " << *orig; CustomErrorHandler(str.c_str(), wrap(orig), ErrorType::NoType, &gutils->TR.analyzer, nullptr, wrap(&Builder)); continue; } else { - llvm::errs() << *gutils->oldFunc->getParent() << "\n"; - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() - << " for orig " << *orig << " saw " - << gutils->TR.intType(size, orig, /*necessary*/ false).str() - << " - " - << "\n"; - gutils->TR.intType(size, orig, /*necessary*/ true); + ss << "\n"; + gutils->TR.dump(ss); + EmitFailure("CannotDeduceType", orig->getDebugLoc(), orig, ss.str()); + continue; } } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index fbe293beaefe..423b6401e64f 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5423,14 +5423,20 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, return shadow; } - llvm::errs() << *oldFunc->getParent() << "\n"; - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *arg << "\n"; - assert(0 && "cannot compute with global variable that doesn't have " - "marked shadow global"); - report_fatal_error("cannot compute with global variable that doesn't " - "have marked shadow global"); + std::string s; + llvm::raw_string_ostream ss(s); + ss << "cannot compute with global variable that doesn't have marked " + "shadow global\n"; + ss << *arg << "\n"; + if (CustomErrorHandler) { + return unwrap(CustomErrorHandler(ss.str().c_str(), wrap(arg), + ErrorType::NoShadow, this, nullptr, + wrap(&BuilderM))); + } else { + EmitFailure("InvertGlobal", BuilderM.getCurrentDebugLocation(), oldFunc, + ss.str()); + } + return UndefValue::get(getShadowType(arg->getType())); } auto md = arg->getMetadata("enzyme_shadow"); if (!isa(md)) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index a177dfc7fad4..25994fb2bd06 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -439,8 +439,12 @@ CallInst *CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree) { EnzymeFailure::EnzymeFailure(const llvm::Twine &RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion) - : DiagnosticInfoUnsupported(*CodeRegion->getParent()->getParent(), - RemarkName, Loc) {} + : EnzymeFailure(RemarkName, Loc, CodeRegion->getParent()->getParent()) {} + +EnzymeFailure::EnzymeFailure(const llvm::Twine &RemarkName, + const llvm::DiagnosticLocation &Loc, + const llvm::Function *CodeRegion) + : DiagnosticInfoUnsupported(*CodeRegion, RemarkName, Loc) {} /// Convert a floating type to a string static inline std::string tofltstr(Type *T) { @@ -2627,7 +2631,8 @@ llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V, bool cublas) { CustomErrorHandler(ss.str().c_str(), nullptr, ErrorType::NoDerivative, nullptr, nullptr, nullptr); } else { - EmitFailure("unknown trans blas value", nullptr, nullptr, ss.str()); + EmitFailure("unknown trans blas value", B.getCurrentDebugLocation(), + B.GetInsertBlock()->getParent(), ss.str()); } return V; } diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 90043bdd3db3..7b46fd9a2e83 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -168,6 +168,8 @@ class EnzymeFailure final : public llvm::DiagnosticInfoUnsupported { public: EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion); + EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, + const llvm::Function *CodeRegion); }; template @@ -181,6 +183,17 @@ void EmitFailure(llvm::StringRef RemarkName, (EnzymeFailure("Enzyme: " + ss.str(), Loc, CodeRegion))); } +template +void EmitFailure(llvm::StringRef RemarkName, + const llvm::DiagnosticLocation &Loc, + const llvm::Function *CodeRegion, Args &...args) { + std::string *str = new std::string(); + llvm::raw_string_ostream ss(*str); + (ss << ... << args); + CodeRegion->getContext().diagnose( + (EnzymeFailure("Enzyme: " + ss.str(), Loc, CodeRegion))); +} + static inline llvm::Function *isCalledFunction(llvm::Value *val) { if (llvm::CallInst *CI = llvm::dyn_cast(val)) { return CI->getCalledFunction(); From 2189e8948f39a39096a0238568e7333846675825 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 24 Jan 2024 23:46:36 -0500 Subject: [PATCH 015/131] [MLIR] Add tablegen derivative infrastructure (#1623) --- enzyme/BUILD | 15 + .../ArithAutoDiffOpInterfaceImpl.cpp | 141 +-- .../MLIR/Implementations/ArithDerivatives.td | 68 ++ .../MLIR/Implementations/CMakeLists.txt | 6 + .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 5 +- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 3 + enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 1 + .../MLIR/Interfaces/GradientUtilsReverse.cpp | 25 +- .../MLIR/Interfaces/GradientUtilsReverse.h | 5 + enzyme/test/MLIR/ReverseMode/pow.mlir | 4 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 1012 ++++++++++------- enzyme/tools/enzyme-tblgen/enzyme-tblgen.h | 1 + 12 files changed, 757 insertions(+), 529 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td diff --git a/enzyme/BUILD b/enzyme/BUILD index 42903efaef60..6278ec604b7e 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -375,6 +375,20 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) +gentbl( + name = "arith-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/ArithDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -396,6 +410,7 @@ cc_library( includes = ["Enzyme/MLIR", "Enzyme"], visibility = ["//visibility:public"], deps = [ + ":arith-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp index b89b7a868e72..b6ceeb13e35d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp @@ -27,149 +27,12 @@ using namespace mlir; using namespace mlir::enzyme; namespace { -struct MulFOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - // Derivative of r = a * b -> dr = a * db + da * b - auto mulOp = cast(op); - if (!gutils->isConstantValue(mulOp)) { - mlir::Value res = nullptr; - for (int i = 0; i < 2; i++) { - if (!gutils->isConstantValue(mulOp.getOperand(i))) { - mlir::Value tmp = builder.create( - mulOp.getLoc(), - gutils->invertPointerM(mulOp.getOperand(i), builder), - gutils->getNewFromOriginal(mulOp.getOperand(1 - i))); - if (res == nullptr) - res = tmp; - else - res = builder.create(mulOp.getLoc(), res, tmp); - } - } - gutils->setDiffe(mulOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - -struct AddFOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - // Derivative of r = a + b -> dr = da + db - auto addOp = cast(op); - if (!gutils->isConstantValue(addOp)) { - mlir::Value res = nullptr; - for (int i = 0; i < 2; i++) { - if (!gutils->isConstantValue(addOp.getOperand(i))) { - mlir::Value tmp = - gutils->invertPointerM(addOp.getOperand(i), builder); - if (res == nullptr) - res = tmp; - else - res = builder.create(addOp.getLoc(), res, tmp); - } - } - gutils->setDiffe(addOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - -void addToGradient(Value oldGradient, Value addedGradient, OpBuilder &builder, - MGradientUtilsReverse *gutils) { - Value gradient = addedGradient; - if (gutils->hasInvertPointer(oldGradient)) { - Value operandGradient = gutils->invertPointerM(oldGradient, builder); - gradient = builder.create(oldGradient.getLoc(), - operandGradient, addedGradient); - } - gutils->mapInvertPointer(oldGradient, gradient, builder); -} - -struct AddFOpInterfaceReverse - : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { - // Derivative of r = a + b -> dr = da + db - auto addOp = cast(op); - - if (gutils->hasInvertPointer(addOp)) { - Value addedGradient = gutils->invertPointerM(addOp, builder); - addToGradient(addOp.getLhs(), addedGradient, builder, gutils); - addToGradient(addOp.getRhs(), addedGradient, builder, gutils); - } - } - - SmallVector cacheValues(Operation *op, - MGradientUtilsReverse *gutils) const { - return SmallVector(); - } - - void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const {} -}; - -struct MulFOpInterfaceReverse - : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { - auto mulOp = cast(op); - - if (gutils->hasInvertPointer(mulOp)) { - Value own_gradient = gutils->invertPointerM(mulOp, builder); - for (int i = 0; i < 2; i++) { - if (!gutils->isConstantValue(mulOp.getOperand(i))) { - Value cache = caches[i]; - Value retrievedValue = gutils->popCache(cache, builder); - Value addedGradient = builder.create( - mulOp.getLoc(), own_gradient, retrievedValue); - - addToGradient(mulOp.getOperand(i), addedGradient, builder, gutils); - } - } - } - } - - SmallVector cacheValues(Operation *op, - MGradientUtilsReverse *gutils) const { - auto mulOp = cast(op); - if (gutils->hasInvertPointer(mulOp)) { - OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); - SmallVector caches; - for (int i = 0; i < 2; i++) { - Value otherOperand = mulOp.getOperand((i + 1) % 2); - Value cache = gutils->initAndPushCache( - gutils->getNewFromOriginal(otherOperand), cacheBuilder); - caches.push_back(cache); - } - return caches; - } - return SmallVector(); - } - - void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const {} -}; - +#include "Implementations/ArithDerivatives.inc" } // namespace void mlir::enzyme::registerArithDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, arith::ArithDialect *) { - arith::AddFOp::attachInterface(*context); - arith::MulFOp::attachInterface(*context); - - arith::AddFOp::attachInterface(*context); - arith::MulFOp::attachInterface(*context); + registerInterfaces(context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td new file mode 100644 index 000000000000..bb713ef61799 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -0,0 +1,68 @@ +class MLIRDerivative resultOps> { + string dialect = dialect_; + string opName = opName_; + dag PatternToMatch = patternToMatch; + list ArgDerivatives = resultOps; +} + +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} + +class DiffeRetIndex indices_> { + list indices = indices_; +} +def DiffeRet : DiffeRetIndex<[-1]>; + +class Inst : Operation { + string name = mnemonic; + string dialect = dialect_; +} +class ArithInst : Inst; + +def AddF : ArithInst<"arith::AddFOp">; +def SubF : ArithInst<"arith::SubFOp">; +def NegF : ArithInst<"arith::NegFOp">; +def MulF : ArithInst<"arith::MulFOp">; +def DivF : ArithInst<"arith::DivFOp">; +def RemF : ArithInst<"arith::RemFOp">; + +def CheckedMulF : ArithInst<"arith::MulFOp">; +def CheckedDivF : ArithInst<"arith::DivFOp">; + +def Op { +} + +def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y), + [ + (DiffeRet), + (DiffeRet), + ] + >; + +def : MLIRDerivative<"arith", "SubFOp", (Op $x, $y), + [ + (DiffeRet), + (NegF (DiffeRet)), + ] + >; +def : MLIRDerivative<"arith", "NegFOp", (Op $x), + [ + (NegF (DiffeRet)), + ] + >; +def : MLIRDerivative<"arith", "MulFOp", (Op $x, $y), + [ + (CheckedMulF (DiffeRet), $y), + (CheckedMulF (DiffeRet), $x) + ] + >; +def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y), + [ + (CheckedDivF (DiffeRet), $y), + (NegF (MulF (CheckedDivF (DiffeRet), $y), (DivF $x, $y))) + ] + // (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y)) + >; diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index fde10de2b02b..a41ee2133c68 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -1,3 +1,8 @@ + +set(LLVM_TARGET_DEFINITIONS ArithDerivatives.td) +enzyme_tablegen(ArithDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(ArithDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations ArithAutoDiffOpInterfaceImpl.cpp LLVMAutoDiffOpInterfaceImpl.cpp @@ -8,6 +13,7 @@ add_mlir_library(MLIREnzymeImplementations DEPENDS MLIRAutoDiffOpInterfaceIncGen + ArithDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index e4af7b760ce6..b9cbfe3e6913 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -276,8 +276,7 @@ void MEnzymeLogic::handlePredecessors( oBB->getPredecessors().end()) { // If there is only one block we can directly create a branch for // simplicity sake - auto bop = - revBuilder.create(loc, defaultBlock, defaultArguments); + revBuilder.create(loc, defaultBlock, defaultArguments); } else { Value cache = gutils->insertInit(gutils->getIndexCacheType()); Value flag = @@ -286,7 +285,7 @@ void MEnzymeLogic::handlePredecessors( SmallVector argumentRanges; for (const auto &a : arguments) argumentRanges.emplace_back(a); - auto bop = revBuilder.create( + revBuilder.create( loc, flag, defaultBlock, defaultArguments, ArrayRef(indices), ArrayRef(blocks), argumentRanges); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index fb29a2b7e83a..286456b2d039 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -140,6 +140,9 @@ Operation *mlir::enzyme::MGradientUtils::cloneWithNewOperands(OpBuilder &B, return B.clone(*op, map); } +bool mlir::enzyme::MGradientUtils::isConstantInstruction(Operation *op) const { + return activityAnalyzer->isConstantOperation(TR, op); +} bool mlir::enzyme::MGradientUtils::isConstantValue(Value v) const { return activityAnalyzer->isConstantValue(TR, v); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 8ab7a36d0066..9b9509b3ec22 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -57,6 +57,7 @@ class MGradientUtils { void eraseIfUnused(Operation *op, bool erase = true, bool check = true) { // TODO } + bool isConstantInstruction(mlir::Operation *v) const; bool isConstantValue(mlir::Value v) const; mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2); void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index c19dfc6f3cae..02a3d3d956d6 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -182,6 +182,11 @@ mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, return B.clone(*op, map); } +bool mlir::enzyme::MGradientUtilsReverse::isConstantInstruction( + Operation *op) const { + return false; +} + bool mlir::enzyme::MGradientUtilsReverse::isConstantValue(Value v) const { if (isa(v.getType())) return true; @@ -202,6 +207,24 @@ bool mlir::enzyme::MGradientUtilsReverse::requiresShadow(Type t) { return false; } +void mlir::enzyme::MGradientUtilsReverse::addToDiffe(Value oldGradient, + Value addedGradient, + OpBuilder &builder) { + // TODO + Value gradient = addedGradient; + if (hasInvertPointer(oldGradient)) { + Value operandGradient = invertPointerM(oldGradient, builder); + auto iface = cast(addedGradient.getType()); + gradient = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, + addedGradient); + } + mapInvertPointer(oldGradient, gradient, builder); +} + +Value mlir::enzyme::MGradientUtilsReverse::diffe(Value v, OpBuilder &builder) { + return invertPointerM(v, builder); +} + /* The value v must have an invert pointer */ @@ -420,4 +443,4 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( Logic, newFunc, todiff, TA, invertedPointers, constant_values, nonconstant_values, retType, constant_args, originalToNew, originalToNewOps, mode_, width, symbolTable_); -} \ No newline at end of file +} diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index b88090481aae..474badb034c9 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -66,8 +66,13 @@ class MGradientUtilsReverse { // TODO } bool isConstantValue(mlir::Value v) const; + bool isConstantInstruction(mlir::Operation *v) const; bool hasInvertPointer(mlir::Value v); mlir::Value invertPointerM(mlir::Value v, OpBuilder &builder); + mlir::Value diffe(mlir::Value v, OpBuilder &builder); + + void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, + OpBuilder &builder); void mapInvertPointer(mlir::Value v, mlir::Value invertValue, OpBuilder &builder); diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index 2db45ce570b3..5c5596ec389e 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -24,8 +24,8 @@ module { // Make sure the right values are being cached in the primal // CHECK: %[[one:.+]] = arith.constant 1.0 // CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) // CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]]) +// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) // Ensure the right value is yielded in the adjoint // CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]]) @@ -33,10 +33,10 @@ module { // CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]]) // CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]]) // CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) +// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) // CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]]) // CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] // CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]]) -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) // CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] // CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) : // CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index d1733b52fa72..11384a2284e8 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -62,6 +62,8 @@ static cl::opt "Generate binaryoperator derivative")), cl::values(clEnumValN(InstDerivatives, "gen-inst-derivatives", "Generate instruction derivative")), + cl::values(clEnumValN(MLIRDerivatives, "gen-mlir-derivatives", + "Generate MLIR derivative")), cl::values(clEnumValN(CallDerivatives, "gen-call-derivatives", "Generate call derivative"))); @@ -195,15 +197,15 @@ struct VariableSetting { bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, Record *pattern, Init *resultTree, StringRef builder, VariableSetting &nameToOrdinal, bool lookup, - ArrayRef retidx, StringRef origName, - bool newFromOriginal); + ArrayRef retidx, StringRef origName, bool newFromOriginal, + ActionType intrinsic); SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, const Twine &argName, Record *pattern, DagInit *resultRoot, StringRef builder, VariableSetting &nameToOrdinal, bool lookup, ArrayRef retidx, StringRef origName, - bool newFromOriginal) { + bool newFromOriginal, ActionType intrinsic) { SmallVector vectorValued; size_t idx = 0; @@ -215,25 +217,32 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, auto [ord, vecValue] = nameToOrdinal.lookup(names->getValue(), pattern, resultRoot); if (!vecValue && !startsWith(ord, "local")) { - if (lookup) + if (lookup && intrinsic != MLIRDerivatives) os << "lookup("; - if (newFromOriginal) + if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) os << "gutils->getNewFromOriginal("; } - os << ord; + if (lookup && !vecValue && !startsWith(ord, "local") && + intrinsic == MLIRDerivatives) { + auto start = ord.find('(') + 1; + auto end = ord.find(')'); + os << "operands[" << ord.substr(start, end - start) << "]"; + } else { + os << ord; + } if (!vecValue && !startsWith(ord, "local")) { - if (newFromOriginal) + if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) os << ")"; - if (lookup) + if (lookup && intrinsic != MLIRDerivatives) os << ", " << builder << ")"; } os << ";\n"; vectorValued.push_back(vecValue); continue; } - vectorValued.push_back(handle(curIndent, argName + "_" + Twine(idx), os, - pattern, args, builder, nameToOrdinal, lookup, - retidx, origName, newFromOriginal)); + vectorValued.push_back(handle( + curIndent, argName + "_" + Twine(idx), os, pattern, args, builder, + nameToOrdinal, lookup, retidx, origName, newFromOriginal, intrinsic)); os << ";\n"; if (names) { auto name = names->getAsUnquotedString(); @@ -249,8 +258,8 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, Record *pattern, Init *resultTree, StringRef builder, VariableSetting &nameToOrdinal, bool lookup, - ArrayRef retidx, StringRef origName, - bool newFromOriginal) { + ArrayRef retidx, StringRef origName, bool newFromOriginal, + ActionType intrinsic) { if (DagInit *resultRoot = dyn_cast(resultTree)) { auto opName = resultRoot->getOperator()->getAsString(); auto Def = cast(resultRoot->getOperator())->getDef(); @@ -345,7 +354,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, } else handle(curIndent + INDENT, argPattern + "_vs", os, pattern, resultRoot->getArg(0), builder, nameToOrdinal, lookup, retidx, - origName, newFromOriginal); + origName, newFromOriginal, intrinsic); os << ")"; os << "->getElementCount()"; @@ -388,7 +397,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, vector = handle(curIndent + INDENT + INDENT, argPattern + "_sia_" + Twine(i), os, pattern, resultRoot->getArg(i), builder, nameToOrdinal, lookup, - retidx, origName, newFromOriginal); + retidx, origName, newFromOriginal, intrinsic); os << ";\n"; if (!vector) { @@ -591,7 +600,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, SmallVector vectorValued = prepareArgs( curIndent + INDENT, os, argPattern, pattern, resultRoot, builder, - nameToOrdinal, lookup, retidx, origName, newFromOriginal); + nameToOrdinal, lookup, retidx, origName, newFromOriginal, intrinsic); bool anyVector = false; for (auto b : vectorValued) anyVector |= b; @@ -674,7 +683,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << curIndent << INDENT << "// Computing subroutine " << opName << "\n"; SmallVector vectorValued = prepareArgs( curIndent + INDENT, os, argPattern, pattern, resultRoot, builder, - nameToOrdinal, lookup, retidx, origName, newFromOriginal); + nameToOrdinal, lookup, retidx, origName, newFromOriginal, intrinsic); bool anyVector = false; for (auto b : vectorValued) anyVector |= b; @@ -769,7 +778,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, bool anyVector2 = handle(curIndent + INDENT, argPattern + "_sr", os, pattern, insts, builder, nnameToOrdinal, /*lookup*/ false, nretidx, - "", /*newFromOriginal*/ false); + "", /*newFromOriginal*/ false, intrinsic); (void)anyVector2; assert(anyVector == anyVector2); os << ";\n"; @@ -782,7 +791,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << curIndent << INDENT << "// Computing " << opName << "\n"; SmallVector vectorValued = prepareArgs( curIndent + INDENT, os, argPattern, pattern, resultRoot, builder, - nameToOrdinal, lookup, retidx, origName, newFromOriginal); + nameToOrdinal, lookup, retidx, origName, newFromOriginal, intrinsic); bool anyVector = false; for (auto b : vectorValued) anyVector |= b; @@ -795,18 +804,18 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, Def->getValueInit("func"), origName); } - if (anyVector) { + if (anyVector && intrinsic != MLIRDerivatives) { os << curIndent << INDENT << "Value *res = nullptr;\n"; os << curIndent << INDENT << "for(unsigned int idx=0, W=gutils->getWidth(); idx(op.getLoc(), "; } else { os << builder << ".Create" << opName << "("; } for (size_t i = 0; i < vectorValued.size(); i++) { if (i > 0) os << ", "; - if (vectorValued[i]) + if (vectorValued[i] && intrinsic != MLIRDerivatives) os << "(gutils->getWidth() == 1) ? " << argPattern << "_" << i << " : gutils->extractMeta(" << builder << ", " << argPattern << "_" << i << ", idx)"; @@ -857,53 +868,56 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, if (isCall) { os << curIndent << INDENT; - if (anyVector) + if (anyVector && intrinsic != MLIRDerivatives) os << INDENT; - os << "V->setDebugLoc(gutils->getNewFromOriginal(" << origName - << ".getDebugLoc()));" - "\n"; - os << curIndent << INDENT; - if (anyVector) - os << INDENT; - os << "V->setCallingConv(cconv);\n"; - for (auto *attr : *cast(Def->getValueAsListInit("fnattrs"))) { - auto attrDef = cast(attr)->getDef(); - auto attrName = attrDef->getValueAsString("name"); - if (attrName == "ReadNone") { - os << "#if LLVM_VERSION_MAJOR >= 16\n"; + if (intrinsic != MLIRDerivatives) { + os << "V->setDebugLoc(gutils->getNewFromOriginal(" << origName + << ".getDebugLoc()));" + "\n"; + os << curIndent << INDENT; + if (anyVector) + os << INDENT; + os << "V->setCallingConv(cconv);\n"; + for (auto *attr : + *cast(Def->getValueAsListInit("fnattrs"))) { + auto attrDef = cast(attr)->getDef(); + auto attrName = attrDef->getValueAsString("name"); + if (attrName == "ReadNone") { + os << "#if LLVM_VERSION_MAJOR >= 16\n"; + os << curIndent << INDENT; + if (anyVector) + os << INDENT; + os << "V->setOnlyReadsMemory();\n"; + os << "V->setOnlyWritesMemory();\n"; + os << "#elif LLVM_VERSION_MAJOR >= 14\n"; + } else if (attrName == "ReadOnly") { + os << "#if LLVM_VERSION_MAJOR >= 16\n"; + os << curIndent << INDENT; + if (anyVector) + os << INDENT; + os << "V->setOnlyReadsMemory();\n"; + os << "#elif LLVM_VERSION_MAJOR >= 14\n"; + } else + os << "#if LLVM_VERSION_MAJOR >= 14\n"; os << curIndent << INDENT; if (anyVector) os << INDENT; - os << "V->setOnlyReadsMemory();\n"; - os << "V->setOnlyWritesMemory();\n"; - os << "#elif LLVM_VERSION_MAJOR >= 14\n"; - } else if (attrName == "ReadOnly") { - os << "#if LLVM_VERSION_MAJOR >= 16\n"; + os << "V->addAttributeAtIndex(AttributeList::FunctionIndex, " + "Attribute::" + << attrName << ");\n"; + os << "#else \n"; + os << curIndent << INDENT; if (anyVector) os << INDENT; - os << "V->setOnlyReadsMemory();\n"; - os << "#elif LLVM_VERSION_MAJOR >= 14\n"; - } else - os << "#if LLVM_VERSION_MAJOR >= 14\n"; - os << curIndent << INDENT; - if (anyVector) - os << INDENT; - os << "V->addAttributeAtIndex(AttributeList::FunctionIndex, " - "Attribute::" - << attrName << ");\n"; - os << "#else \n"; - - os << curIndent << INDENT; - if (anyVector) - os << INDENT; - os << "V->addAttribute(AttributeList::FunctionIndex, " - "Attribute::" - << attrName << ");\n"; - os << "#endif \n"; + os << "V->addAttribute(AttributeList::FunctionIndex, " + "Attribute::" + << attrName << ");\n"; + os << "#endif \n"; + } } } - if (anyVector) { + if (anyVector && intrinsic != MLIRDerivatives) { os << curIndent << INDENT << INDENT << "if (gutils->getWidth() == 1) res = " "V;\n"; @@ -929,11 +943,259 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("unknown operation")); } +void handleUse( + DagInit *root, DagInit *resultTree, std::string &foundPrimalUse, + std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, + DagInit *tree, + StringMap> &varNameToCondition) { + auto opName = resultTree->getOperator()->getAsString(); + auto Def = cast(resultTree->getOperator())->getDef(); + if (opName == "DiffeRetIndex" || Def->isSubClassOf("DiffeRetIndex")) { + foundDiffRet = true; + return; + } + assert(Def->isSubClassOf("Operation")); + bool usesPrimal = Def->getValueAsBit("usesPrimal"); + bool usesShadow = Def->getValueAsBit("usesShadow"); + bool usesCustom = Def->getValueAsBit("usesCustom"); + + // We don't handle any custom primal/shadow + (void)usesCustom; + assert(!usesCustom); + + for (auto argEn : llvm::enumerate(resultTree->getArgs())) { + auto name = resultTree->getArgNameStr(argEn.index()); + + auto arg2 = dyn_cast(argEn.value()); + + if (arg2) { + // Recursive use of shadow is unhandled + assert(!usesShadow); + + std::string foundPrimalUse2 = ""; + std::string foundShadowUse2 = ""; + + bool foundDiffRet2 = false; + // We set precondition to be false (aka "") if we do not need the + // primal, since we are now only recurring to set variables + // correctly. + if (name.size() || usesPrimal) + handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, + name.size() ? foundShadowUse2 : foundShadowUse, + name.size() ? foundDiffRet2 : foundDiffRet, + usesPrimal ? precondition : "", tree, varNameToCondition); + + if (name.size()) { + if (foundPrimalUse2.size() && + !(startsWith(foundPrimalUse, foundPrimalUse2) || + endsWith(foundPrimalUse, foundPrimalUse2))) { + if (foundPrimalUse.size() == 0) + foundPrimalUse = foundPrimalUse2; + else + foundPrimalUse += " || " + foundPrimalUse2; + } + if (foundShadowUse2.size() && + !(startsWith(foundShadowUse, foundShadowUse2) || + endsWith(foundShadowUse, foundShadowUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundShadowUse2; + else + foundShadowUse += " || " + foundShadowUse2; + } + foundDiffRet |= foundDiffRet2; + + varNameToCondition[name] = + std::make_tuple(foundPrimalUse2, foundShadowUse2, foundDiffRet2); + } + } else { + assert(name.size()); + + if (name.size()) { + auto found = varNameToCondition.find(name); + if (found == varNameToCondition.end()) { + llvm::errs() << "tree scope: " << *tree << "\n"; + llvm::errs() << "root scope: " << *root << "\n"; + llvm::errs() << "could not find var name: " << name << "\n"; + } + assert(found != varNameToCondition.end()); + } + + if (precondition.size()) { + auto [foundPrimalUse2, foundShadowUse2, foundDiffRet2] = + varNameToCondition[name]; + if (precondition != "true") { + if (foundPrimalUse2.size()) { + foundPrimalUse2 = + "((" + foundPrimalUse2 + ")&&(" + precondition + ")"; + } + if (foundShadowUse2.size()) { + foundShadowUse2 = + "((" + foundShadowUse2 + ")&&(" + precondition + ")"; + } + } + if (usesPrimal) { + if (foundPrimalUse2.size() && + !(startsWith(foundPrimalUse, foundPrimalUse2) || + endsWith(foundPrimalUse, foundPrimalUse2))) { + if (foundPrimalUse.size() == 0) + foundPrimalUse = foundPrimalUse2; + else + foundPrimalUse += " || " + foundPrimalUse2; + } + if (foundShadowUse2.size() && + !(startsWith(foundShadowUse, foundShadowUse2) || + endsWith(foundShadowUse, foundShadowUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundShadowUse2; + else + foundShadowUse += " || " + foundShadowUse2; + } + foundDiffRet |= foundDiffRet2; + } + if (usesShadow) { + if (foundPrimalUse2.size() && + !(startsWith(foundShadowUse, foundPrimalUse2) || + endsWith(foundShadowUse, foundPrimalUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundPrimalUse2; + else + foundShadowUse += " || " + foundPrimalUse2; + } + assert(!foundDiffRet2); + assert(foundShadowUse2 == ""); + } + } + } + } +} + +void printDiffUse( + raw_ostream &os, Twine prefix, ListInit *argOps, StringRef origName, + ActionType intrinsic, DagInit *tree, + StringMap> &varNameToCondition) { + os << prefix << " // Rule " << *tree << "\n"; + + for (auto argOpEn : enumerate(*argOps)) { + size_t argIdx = argOpEn.index(); + if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + continue; + } + } + + // The condition necessary to require the use of the arg + std::string foundPrimalUse = ""; + std::string foundShadowUse = ""; + bool foundDiffRet = false; + + DagInit *resultTree = cast(argOpEn.value()); + + // hasDiffeRet(resultTree) + handleUse(resultTree, resultTree, foundPrimalUse, foundShadowUse, + foundDiffRet, /*precondition*/ "true", tree, varNameToCondition); + + os << prefix << " // Arg " << argIdx << " : " << *resultTree << "\n"; + + if (foundPrimalUse != "") { + + if (intrinsic == MLIRDerivatives) + os << prefix << " if (!gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))"; + else + os << prefix + << " if (!shadow && !gutils->isConstantValue(const_cast(" + << origName << "->getOperand(" << argIdx << ")))"; + + if (foundDiffRet) { + if (intrinsic == MLIRDerivatives) + os << " && !gutils->isConstantValue(" << origName + << "->getResult(0))"; + else + os << " && !gutils->isConstantValue(const_cast((const Value*)" + << origName << "))"; + } else { + if (intrinsic == MLIRDerivatives) + os << " && !gutils->isConstantInstruction(" << origName << ")"; + else + os << " && !gutils->isConstantInstruction(const_cast( " + << origName << "))"; + } + + os << ") {\n"; + os << prefix << " if (" << foundPrimalUse << ") {\n"; + if (intrinsic == MLIRDerivatives) + os << prefix << " used = true;\n"; + else { + os << prefix << " if (EnzymePrintDiffUse)\n"; + os << prefix + << " llvm::errs() << \"Need direct primal of \" << *val << "; + os << "\"in reverse from \" << *user << \" from condition " + << foundPrimalUse; + os << "\";\n"; + os << prefix << " return true;\n"; + } + os << prefix << " }\n"; + + os << prefix << " }\n"; + } + + if (intrinsic != MLIRDerivatives) { + os << prefix << " if (shadow && !gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))"; + + if (foundDiffRet) { + os << " && !gutils->isConstantValue(const_cast((const Value*)" + << origName << "))"; + } else { + os << " && !gutils->isConstantInstruction(const_cast( " + << origName << "))"; + } + + os << ") {\n"; + + os << prefix + << " if (qtype == QueryType::Shadow && (mode == " + "DerivativeMode::ForwardMode || mode == " + "DerivativeMode::ForwardModeSplit)) {\n"; + os << prefix + << " if (EnzymePrintDiffUse) llvm::errs() << \"Need forward " + "shadow of \" << *val << \" from condition \" << *user << " + "\"\\n\";\n"; + os << prefix << " return true;\n"; + os << prefix << " }\n"; + + if (foundShadowUse != "") { + os << prefix << " if (" << foundShadowUse << ") {\n"; + os << prefix << " if (EnzymePrintDiffUse)\n"; + os << " llvm::errs() << \"Need direct shadow of \" << *val " + "<< "; + os << "\"in reverse from \" << *user << \" from condition " + << foundShadowUse; + os << "\";\n"; + os << prefix << " return true;\n"; + os << prefix << " }\n"; + } + + os << prefix << " }\n"; + } + } + + if (intrinsic != MLIRDerivatives) { + os << prefix << " return false;\n"; + os << prefix << "}\n"; + } +} + static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); const char *patternNames = ""; switch (intrinsic) { + case MLIRDerivatives: + patternNames = "MLIRDerivative"; + break; case CallDerivatives: patternNames = "CallPattern"; break; @@ -957,7 +1219,9 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, for (Record *pattern : patterns) { DagInit *tree = pattern->getValueAsDag("PatternToMatch"); - DagInit *duals = pattern->getValueAsDag("ArgDuals"); + DagInit *duals = nullptr; + if (intrinsic != MLIRDerivatives) + duals = pattern->getValueAsDag("ArgDuals"); // Emit RewritePattern for Pattern. ListInit *argOps = pattern->getValueAsListInit("ArgDerivatives"); @@ -977,6 +1241,18 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasTA: case GenBlasDiffUse: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); + case MLIRDerivatives: { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "FwdDerivative : \n"; + os << " public AutoDiffOpInterface::ExternalModel<" + << opName << "FwdDerivative, " << dialect << "::" << opName << "> {\n"; + os << " LogicalResult createForwardModeTangent(Operation *op0, " + "OpBuilder &builder, MGradientUtils *gutils) const {\n"; + os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; + origName = "op"; + break; + } case CallDerivatives: { os << " if (("; bool prev = false; @@ -1147,7 +1423,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, nameToOrdinal.insert(tree->getNameStr(), (Twine("(&") + origName + ")").str(), false); - if (intrinsic != BinopDerivatives && intrinsic != InstDerivatives) { + if (intrinsic != BinopDerivatives && intrinsic != InstDerivatives && + intrinsic != MLIRDerivatives) { os << " if (gutils->knownRecomputeHeuristic.find(&" << origName << ") !=\n"; os << " gutils->knownRecomputeHeuristic.end()) {\n"; @@ -1160,30 +1437,46 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " }\n"; os << " }\n"; } - os << " eraseIfUnused(" << origName << ");\n"; - os << " if (gutils->isConstantInstruction(&" << origName << "))\n"; - if (intrinsic == IntrDerivatives || intrinsic == CallDerivatives) - os << " return true;\n"; + if (intrinsic != MLIRDerivatives) + os << " eraseIfUnused(" << origName << ");\n"; else - os << " return;\n"; + os << " gutils->eraseIfUnused(" << origName << ");\n"; + + if (intrinsic == MLIRDerivatives) { + os << " if (gutils->isConstantInstruction(op))\n"; + os << " return success();\n"; + } else { + os << " if (gutils->isConstantInstruction(&" << origName << "))\n"; + if (intrinsic == IntrDerivatives || intrinsic == CallDerivatives) + os << " return true;\n"; + else + os << " return;\n"; - os << " switch (Mode) {\n"; - os << " case DerivativeMode::ForwardModeSplit:\n"; - os << " case DerivativeMode::ForwardMode:{\n"; - os << " IRBuilder<> Builder2(&" << origName << ");\n"; - os << " getForwardBuilder(Builder2);\n"; + os << " switch (Mode) {\n"; + os << " case DerivativeMode::ForwardModeSplit:\n"; + os << " case DerivativeMode::ForwardMode:{\n"; + os << " IRBuilder<> Builder2(&" << origName << ");\n"; + os << " getForwardBuilder(Builder2);\n"; + } // TODO - if (duals->getOperator()->getAsString() == + if (!duals || + duals->getOperator()->getAsString() == "ForwardFromSummedReverseInternal" || cast(duals->getOperator()) ->getDef() ->isSubClassOf("ForwardFromSummedReverseInternal")) { - os << " Value *res = Constant::getNullValue(gutils->getShadowType(" - << origName - << "." - "getType()));\n"; + + if (intrinsic == MLIRDerivatives) { + os << " mlir::Value res = nullptr;\n"; + } else { + os << " Value *res = " + "Constant::getNullValue(gutils->getShadowType(" + << origName + << "." + "getType()));\n"; + } for (auto argOpEn : enumerate(*argOps)) { size_t argIdx = argOpEn.index(); @@ -1201,12 +1494,19 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } } - os << curIndent << "if (!gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; - os << curIndent << INDENT << "Value *dif = diffe(" << origName - << ".getOperand(" << argIdx << "), Builder2);\n"; - os << curIndent << INDENT - << "Value *arg_diff_tmp = UndefValue::get(res->getType());\n"; + if (intrinsic == MLIRDerivatives) { + os << curIndent << "if (!gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + os << curIndent << INDENT << "auto dif = gutils->invertPointerM(" + << origName << "->getOperand(" << argIdx << "), builder);\n"; + } else { + os << curIndent << "if (!gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + os << curIndent << INDENT << "Value *dif = diffe(" << origName + << ".getOperand(" << argIdx << "), Builder2);\n"; + os << curIndent << INDENT + << "Value *arg_diff_tmp = UndefValue::get(res->getType());\n"; + } initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); @@ -1225,29 +1525,47 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, return; } os << curIndent << INDENT << "{\n"; - os << curIndent << INDENT << INDENT << "Value *itmp = "; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << INDENT << "mlir::Value itmp = "; + else + os << curIndent << INDENT << INDENT << "Value *itmp = "; ArrayRef retidx{}; bool vectorValued = handle( Twine(curIndent) + INDENT + INDENT, "fwdarg", os, pattern, - resultTree, "Builder2", nameToOrdinal, /*lookup*/ false, - retidx, origName, /*newFromOriginal*/ true); + resultTree, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", + nameToOrdinal, /*lookup*/ false, retidx, origName, + /*newFromOriginal*/ true, intrinsic); os << ";\n"; (void)vectorValued; assert(vectorValued); - os << curIndent << INDENT << INDENT - << "arg_diff_tmp = GradientUtils::recursiveFAdd(Builder2,"; - os << "res, itmp, {"; - { - bool seen = false; - for (auto i : idx) { - if (seen) - os << ", "; - os << i; - seen = true; + if (intrinsic == MLIRDerivatives) { + os << curIndent << INDENT << INDENT + << "if (!res) res = itmp;\n"; + os << curIndent << INDENT << INDENT << "else {\n"; + os << curIndent << INDENT << INDENT << INDENT + << "auto operandType = " + "cast(res.getType());\n"; + os << curIndent << INDENT << INDENT << INDENT + << "res = operandType.createAddOp(builder, op.getLoc(), " + "res, itmp);\n"; + os << curIndent << INDENT << INDENT << "}\n"; + } else { + os << curIndent << INDENT << INDENT + << "arg_diff_tmp = GradientUtils::recursiveFAdd(Builder2,"; + os << "res, itmp, {"; + { + bool seen = false; + for (auto i : idx) { + if (seen) + os << ", "; + os << i; + seen = true; + } } - } - os << "}, {}, arg_diff_tmp, gutils->getWidth() != 1);\n"; + os << "}, {}, arg_diff_tmp, gutils->getWidth() != 1);\n"; + } os << curIndent << INDENT << "}\n"; } else if (ListInit *lst = dyn_cast(ival)) { unsigned i = 0; @@ -1262,7 +1580,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, Twine("Unknown subinitialization")); }; fwdres({}, argOpEn.value()); - os << curIndent << INDENT << "res = arg_diff_tmp;\n"; + if (intrinsic != MLIRDerivatives) + os << curIndent << INDENT << "res = arg_diff_tmp;\n"; os << " }\n"; } } else { @@ -1272,23 +1591,105 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, bool vectorValued = handle(" ", "fwdnsrarg", os, pattern, duals, "Builder2", nameToOrdinal, /*lookup*/ false, retidx, origName, - /*newFromOriginal*/ true); + /*newFromOriginal*/ true, intrinsic); (void)vectorValued; assert(vectorValued); os << ";\n"; } os << " assert(res);\n"; - os << " setDiffe(&" << origName << ", res, Builder2);\n"; - os << " break;\n"; + if (intrinsic == MLIRDerivatives) { + os << " gutils->setDiffe(" << origName + << "->getResult(0), res, builder);\n"; + os << " return success();\n"; + } else { + os << " setDiffe(&" << origName << ", res, Builder2);\n"; + os << " break;\n"; + } os << " }\n"; - os << " case DerivativeMode::ReverseModeGradient:\n"; - os << " case DerivativeMode::ReverseModeCombined:{\n"; - os << " IRBuilder<> Builder2(&" << origName << ");\n"; - os << " getReverseBuilder(Builder2);\n"; + if (intrinsic != MLIRDerivatives) { + os << " case DerivativeMode::ReverseModeGradient:\n"; + os << " case DerivativeMode::ReverseModeCombined:{\n"; + os << " IRBuilder<> Builder2(&" << origName << ");\n"; + os << " getReverseBuilder(Builder2);\n"; + os << " Value *dif = nullptr;\n"; + } else { + os << "};\n"; + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "RevDerivative : \n"; + os << " public " + "ReverseAutoDiffOpInterface::ExternalModel<" + << opName << "RevDerivative, " << dialect << "::" << opName << "> {\n"; + os << " SmallVector cachedArguments(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " SmallVector toret(op->getNumOperands(), false);\n"; + StringMap> varNameToCondition; + + std::function)> insert = + [&](DagInit *ptree, ArrayRef prev) { + for (auto treeEn : llvm::enumerate(ptree->getArgs())) { + auto tree = treeEn.value(); + auto name = ptree->getArgNameStr(treeEn.index()); + SmallVector next(prev.begin(), prev.end()); + next.push_back(treeEn.index()); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (name.size()) { + varNameToCondition[name] = std::make_tuple( + "idx == " + std::to_string(treeEn.index()), "", false); + } + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + varNameToCondition[tree->getNameStr()] = + std::make_tuple("ILLEGAL", "ILLEGAL", false); + + os << " for (size_t idx=0; idxgetNumOperands(); idx++) {\n"; + os << " bool used = false;\n"; + printDiffUse(os, " ", argOps, origName, intrinsic, tree, + varNameToCondition); + os << " toret[idx] = used;\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + + os << " SmallVector cacheValues(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " if (gutils->isConstantInstruction(op) || " + "gutils->isConstantValue(op->getResult(0))) return {};\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " SmallVector toret;\n"; + os << " OpBuilder builder(gutils->getNewFromOriginal(op));\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " Value cache = " + "gutils->initAndPushCache(gutils->getNewFromOriginal(op->" + "getOperand(en.index())), builder);\n"; + os << " toret.push_back(cache);\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + os << "\n"; + os << " void createShadowValues(Operation *op, OpBuilder &builder,\n"; + os << " MGradientUtilsReverse *gutils) const " + "{}\n"; + + os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " + "&builder,\n"; + os << " MGradientUtilsReverse *gutils,\n"; + os << " SmallVector caches) const {\n"; + os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; + os << " mlir::Value dif = nullptr;\n"; + } // TODO vector - os << " Value *dif = nullptr;\n"; bool seen = false; for (auto argOpEn : enumerate(*argOps)) { size_t argIdx = argOpEn.index(); @@ -1308,22 +1709,44 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, if (seen) os << "} else "; seen = true; - os << "if (!dif && !gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; + if (intrinsic == MLIRDerivatives) { + os << "if (!dif && !gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + } else { + os << "if (!dif && !gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + } DagInit *resultTree = cast(argOpEn.value()); if (hasDiffeRet(resultTree)) { - os << " dif = diffe(&" << origName << ", Builder2);\n"; - os << " setDiffe(&" << origName - << ", " - "Constant::getNullValue(gutils->getShadowType(" - << origName - << ".getType())), " - "Builder2);\n"; + if (intrinsic == MLIRDerivatives) { + os << " dif = gutils->diffe(" << origName << ", builder);\n"; + os << " gutils->clearValue(" << origName << ", builder);\n"; + } else { + os << " dif = diffe(&" << origName << ", Builder2);\n"; + os << " setDiffe(&" << origName + << ", " + "Constant::getNullValue(gutils->getShadowType(" + << origName + << ".getType())), " + "Builder2);\n"; + } } } if (seen) os << " }\n"; + if (intrinsic == MLIRDerivatives) { + os << " SmallVector operands(op->getNumOperands(), nullptr);\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " size_t count = 0;\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " operands[en.index()] = " + "gutils->popCache(caches[count], builder);\n"; + os << " count++;\n"; + os << " }\n"; + } + std::function, Init *)> revres = [&](size_t argIdx, ArrayRef idx, Init *ival) { if (DagInit *resultTree = dyn_cast(ival)) { @@ -1340,40 +1763,47 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } const char *curIndent = " "; os << curIndent << "{\n"; - os << curIndent << INDENT << "Value *tmp = "; - bool vectorValued = - handle(Twine(curIndent) + INDENT, "revarg", os, pattern, - resultTree, "Builder2", nameToOrdinal, /*lookup*/ true, - idx, origName, /*newFromOriginal*/ true); + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value tmp = "; + else + os << curIndent << INDENT << "Value *tmp = "; + bool vectorValued = handle( + Twine(curIndent) + INDENT, "revarg", os, pattern, resultTree, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", + nameToOrdinal, /*lookup*/ true, idx, origName, + /*newFromOriginal*/ true, intrinsic); os << ";\n"; - os << curIndent << INDENT - << "Value *out = " - "UndefValue::get(gutils->getShadowType(" - << origName << ".getOperand(" << argIdx << ")->getType()));\n"; - - os << curIndent << INDENT - << "for(unsigned int idx=0, W=gutils->getWidth(); " - "idxgetWidth() == " - "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " - "nullptr;\n"; - os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; - if (vectorValued) - os << curIndent << INDENT << INDENT - << "if (gutils->getWidth() > 1) next = " - "gutils->extractMeta(Builder2, next, idx);\n"; - os << curIndent << INDENT << INDENT - << "if (prev) next = Builder2.CreateFAdd(prev, " - "next);\n"; - os << curIndent << INDENT << INDENT - << "out = (gutils->getWidth() > 1) ? " - "Builder2.CreateInsertValue(out, next, idx) : next;\n"; - os << curIndent << INDENT << "}\n"; - os << curIndent << INDENT << "toadd = out;\n"; + if (intrinsic == MLIRDerivatives) { + os << "assert(toadd == nullptr); toadd = tmp;\n"; + } else { + os << curIndent << INDENT + << "Value *out = " + "UndefValue::get(gutils->getShadowType(" + << origName << ".getOperand(" << argIdx << ")->getType()));\n"; + os << curIndent << INDENT + << "for(unsigned int idx=0, W=gutils->getWidth(); " + "idxgetWidth() == " + "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " + "nullptr;\n"; + os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; + if (vectorValued) + os << curIndent << INDENT << INDENT + << "if (gutils->getWidth() > 1) next = " + "gutils->extractMeta(Builder2, next, idx);\n"; + os << curIndent << INDENT << INDENT + << "if (prev) next = Builder2.CreateFAdd(prev, " + "next);\n"; + os << curIndent << INDENT << INDENT + << "out = (gutils->getWidth() > 1) ? " + "Builder2.CreateInsertValue(out, next, idx) : next;\n"; + os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "toadd = out;\n"; + } os << curIndent << "}\n"; } else if (ListInit *lst = dyn_cast(ival)) { @@ -1400,32 +1830,61 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } const char *curIndent = " "; - os << curIndent << "if (!gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << "if (!gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + else + os << curIndent << "if (!gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); - os << curIndent << INDENT << "Value *toadd = nullptr;\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value toadd = nullptr;\n"; + else + os << curIndent << INDENT << "Value *toadd = nullptr;\n"; revres(argIdx, {}, argOpEn.value()); - os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName - << ".getOperand(" << argIdx << "), toadd"; - os << ", Builder2, " << origName << ".getOperand(" << argIdx - << ")->getType());\n"; + if (intrinsic == MLIRDerivatives) { + os << curIndent << INDENT << "if (toadd) gutils->addToDiffe(" + << origName << "->getOperand(" << argIdx << "), toadd, builder);\n"; + } else { + os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName + << ".getOperand(" << argIdx << "), toadd"; + os << ", Builder2, " << origName << ".getOperand(" << argIdx + << ")->getType());\n"; + } os << curIndent << "}\n"; } - os << " break;\n"; - os << " }\n"; + if (intrinsic != MLIRDerivatives) { + os << " break;\n"; + os << " }\n"; - os << " case DerivativeMode::ReverseModePrimal:{\n"; - // TODO - os << " break;\n"; - os << " }\n"; - os << " }\n"; + os << " case DerivativeMode::ReverseModePrimal:{\n"; + // TODO + os << " break;\n"; + os << " }\n"; + os << " }\n"; + } if (intrinsic == IntrDerivatives || intrinsic == CallDerivatives) os << " return true;\n }\n"; else os << " return;\n }\n"; + if (intrinsic == MLIRDerivatives) + os << "};\n\n"; + } + + if (intrinsic == MLIRDerivatives) { + os << "void registerInterfaces(MLIRContext* context) {\n"; + for (Record *pattern : patterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "FwdDerivative>(*context);\n"; + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "RevDerivative>(*context);\n"; + } + os << "}\n"; } } @@ -1433,6 +1892,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { const char *patternNames; switch (intrinsic) { + case MLIRDerivatives: case GenBlasDerivatives: case UpdateBlasDecl: case UpdateBlasTA: @@ -1470,6 +1930,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, std::string origName; std::string prefix; switch (intrinsic) { + case MLIRDerivatives: case GenBlasDerivatives: case UpdateBlasDecl: case UpdateBlasTA: @@ -1592,228 +2053,8 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, varNameToCondition[tree->getNameStr()] = std::make_tuple("ILLEGAL", "ILLEGAL", false); - std::function - handleUse = [&](DagInit *root, DagInit *resultTree, - StringTy &foundPrimalUse, StringTy &foundShadowUse, - bool &foundDiffRet, std::string precondition) { - auto opName = resultTree->getOperator()->getAsString(); - auto Def = cast(resultTree->getOperator())->getDef(); - if (opName == "DiffeRetIndex" || Def->isSubClassOf("DiffeRetIndex")) { - foundDiffRet = true; - return; - } - assert(Def->isSubClassOf("Operation")); - bool usesPrimal = Def->getValueAsBit("usesPrimal"); - bool usesShadow = Def->getValueAsBit("usesShadow"); - bool usesCustom = Def->getValueAsBit("usesCustom"); - - // We don't handle any custom primal/shadow - (void)usesCustom; - assert(!usesCustom); - - for (auto argEn : llvm::enumerate(resultTree->getArgs())) { - auto name = resultTree->getArgNameStr(argEn.index()); - - auto arg2 = dyn_cast(argEn.value()); - - if (arg2) { - // Recursive use of shadow is unhandled - assert(!usesShadow); - - StringTy foundPrimalUse2 = ""; - StringTy foundShadowUse2 = ""; - - bool foundDiffRet2 = false; - // We set precondition to be false (aka "") if we do not need the - // primal, since we are now only recurring to set variables - // correctly. - if (name.size() || usesPrimal) - handleUse(root, arg2, - name.size() ? foundPrimalUse2 : foundPrimalUse, - name.size() ? foundShadowUse2 : foundShadowUse, - name.size() ? foundDiffRet2 : foundDiffRet, - usesPrimal ? precondition : ""); - - if (name.size()) { - if (foundPrimalUse2.size() && - !(startsWith(foundPrimalUse, foundPrimalUse2) || - endsWith(foundPrimalUse, foundPrimalUse2))) { - if (foundPrimalUse.size() == 0) - foundPrimalUse = foundPrimalUse2; - else - foundPrimalUse += " || " + foundPrimalUse2; - } - if (foundShadowUse2.size() && - !(startsWith(foundShadowUse, foundShadowUse2) || - endsWith(foundShadowUse, foundShadowUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundShadowUse2; - else - foundShadowUse += " || " + foundShadowUse2; - } - foundDiffRet |= foundDiffRet2; - - varNameToCondition[name] = std::make_tuple( - foundPrimalUse2, foundShadowUse2, foundDiffRet2); - } - } else { - assert(name.size()); - - if (name.size()) { - auto found = varNameToCondition.find(name); - if (found == varNameToCondition.end()) { - llvm::errs() << "tree scope: " << *tree << "\n"; - llvm::errs() << "root scope: " << *root << "\n"; - llvm::errs() << "could not find var name: " << name << "\n"; - } - assert(found != varNameToCondition.end()); - } - - if (precondition.size()) { - auto [foundPrimalUse2, foundShadowUse2, foundDiffRet2] = - varNameToCondition[name]; - if (precondition != "true") { - if (foundPrimalUse2.size()) { - foundPrimalUse2 = - "((" + foundPrimalUse2 + ")&&(" + precondition + ")"; - } - if (foundShadowUse2.size()) { - foundShadowUse2 = - "((" + foundShadowUse2 + ")&&(" + precondition + ")"; - } - } - if (usesPrimal) { - if (foundPrimalUse2.size() && - !(startsWith(foundPrimalUse, foundPrimalUse2) || - endsWith(foundPrimalUse, foundPrimalUse2))) { - if (foundPrimalUse.size() == 0) - foundPrimalUse = foundPrimalUse2; - else - foundPrimalUse += " || " + foundPrimalUse2; - } - if (foundShadowUse2.size() && - !(startsWith(foundShadowUse, foundShadowUse2) || - endsWith(foundShadowUse, foundShadowUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundShadowUse2; - else - foundShadowUse += " || " + foundShadowUse2; - } - foundDiffRet |= foundDiffRet2; - } - if (usesShadow) { - if (foundPrimalUse2.size() && - !(startsWith(foundShadowUse, foundPrimalUse2) || - endsWith(foundShadowUse, foundPrimalUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundPrimalUse2; - else - foundShadowUse += " || " + foundPrimalUse2; - } - assert(!foundDiffRet2); - assert(foundShadowUse2 == ""); - } - } - } - } - }; - - os << prefix << " // Rule " << *tree << "\n"; - - for (auto argOpEn : enumerate(*argOps)) { - size_t argIdx = argOpEn.index(); - if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { - auto opName = resultRoot->getOperator()->getAsString(); - auto Def = cast(resultRoot->getOperator())->getDef(); - if (opName == "InactiveArgSpec" || - Def->isSubClassOf("InactiveArgSpec")) { - continue; - } - } - - // The condition necessary to require the use of the arg - StringTy foundPrimalUse = ""; - StringTy foundShadowUse = ""; - bool foundDiffRet = false; - - DagInit *resultTree = cast(argOpEn.value()); - - // hasDiffeRet(resultTree) - handleUse(resultTree, resultTree, foundPrimalUse, foundShadowUse, - foundDiffRet, /*precondition*/ "true"); - - os << prefix << " // Arg " << argIdx << " : " << *resultTree << "\n"; - - if (foundPrimalUse != "") { - - os << prefix - << " if (!shadow && !gutils->isConstantValue(const_cast(" - << origName << "->getOperand(" << argIdx << ")))"; - - if (foundDiffRet) { - os << " && !gutils->isConstantValue(const_cast((const Value*)" - << origName << "))"; - } else { - os << " && !gutils->isConstantInstruction(const_cast( " - << origName << "))"; - } - - os << ") {\n"; - os << prefix << " if (" << foundPrimalUse << ") {\n"; - os << prefix << " if (EnzymePrintDiffUse)\n"; - os << prefix - << " llvm::errs() << \"Need direct primal of \" << *val << "; - os << "\"in reverse from \" << *user << \" from condition " - << foundPrimalUse; - os << "\";\n"; - os << prefix << " return true;\n"; - os << prefix << " }\n"; - - os << prefix << " }\n"; - } - - os << prefix << " if (shadow && !gutils->isConstantValue(" << origName - << "->getOperand(" << argIdx << "))"; - - if (foundDiffRet) { - os << " && !gutils->isConstantValue(const_cast((const Value*)" - << origName << "))"; - } else { - os << " && !gutils->isConstantInstruction(const_cast( " - << origName << "))"; - } - - os << ") {\n"; - - os << prefix - << " if (qtype == QueryType::Shadow && (mode == " - "DerivativeMode::ForwardMode || mode == " - "DerivativeMode::ForwardModeSplit)) {\n"; - os << prefix - << " if (EnzymePrintDiffUse) llvm::errs() << \"Need forward " - "shadow of \" << *val << \" from condition \" << *user << " - "\"\\n\";\n"; - os << prefix << " return true;\n"; - os << prefix << " }\n"; - - if (foundShadowUse != "") { - os << prefix << " if (" << foundShadowUse << ") {\n"; - os << prefix << " if (EnzymePrintDiffUse)\n"; - os << " llvm::errs() << \"Need direct shadow of \" << *val " - "<< "; - os << "\"in reverse from \" << *user << \" from condition " - << foundShadowUse; - os << "\";\n"; - os << prefix << " return true;\n"; - os << prefix << " }\n"; - } - - os << prefix << " }\n"; - } - - os << prefix << " return false;\n"; - os << prefix << "}\n"; + printDiffUse(os, prefix, argOps, origName, intrinsic, tree, + varNameToCondition); } } @@ -1821,8 +2062,11 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, #include "blasDiffUseUpdater.h" #include "blasTAUpdater.h" +void emitMLIRDerivatives(RecordKeeper &records, raw_ostream &os); + static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { switch (action) { + case MLIRDerivatives: case CallDerivatives: case InstDerivatives: case IntrDerivatives: diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h index dd3484d7b151..368644ba0b5d 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h @@ -15,6 +15,7 @@ #include enum ActionType { + MLIRDerivatives, CallDerivatives, InstDerivatives, BinopDerivatives, From 5caf99ae20f9ceb3d2fd76e874f76866bae07d66 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 01:01:52 -0500 Subject: [PATCH 016/131] [TypeAnalysisPrinter] support opaque pointers on 16 (#1629) --- .../TypeAnalysis/TypeAnalysisPrinter.cpp | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp index 00bf133ef1b7..38cde29e38d5 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp @@ -82,12 +82,18 @@ bool printTypeAnalyses(llvm::Function &F) { dt = ConcreteType(a.getType()->getScalarType()); } else if (a.getType()->isPointerTy()) { #if LLVM_VERSION_MAJOR < 17 - auto et = cast(a.getType())->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR >= 15 + if (F.getContext().supportsTypedPointers()) { +#endif + auto et = cast(a.getType())->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 15 } +#endif #endif dt.insert({}, BaseType::Pointer); } else if (a.getType()->isIntOrIntVectorTy()) { @@ -106,12 +112,18 @@ bool printTypeAnalyses(llvm::Function &F) { dt = ConcreteType(F.getReturnType()->getScalarType()); } else if (F.getReturnType()->isPointerTy()) { #if LLVM_VERSION_MAJOR < 17 - auto et = cast(F.getReturnType())->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR >= 15 + if (F.getContext().supportsTypedPointers()) { +#endif + auto et = cast(F.getReturnType())->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 15 } +#endif #endif dt.insert({}, BaseType::Pointer); } else if (F.getReturnType()->isIntOrIntVectorTy()) { From b2151f8f6dc9c3a807a7216fdfab612be12ab2cc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 01:19:48 -0500 Subject: [PATCH 017/131] Fix struct containing attributes (#1627) --- enzyme/Enzyme/Clang/EnzymeClang.cpp | 22 ++++++- .../Integration/ReverseMode/inactiveglob.cpp | 62 +++++++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 enzyme/test/Integration/ReverseMode/inactiveglob.cpp diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index ee5794cb13e2..ed01f1bf5739 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -240,6 +240,11 @@ struct EnzymeFunctionLikeAttrInfo : public ParsedAttrInfo { // if (FD->isLateTemplateParsed()) return; auto &AST = S.getASTContext(); DeclContext *declCtx = FD->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = FD->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) @@ -369,6 +374,11 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo { auto &AST = S.getASTContext(); DeclContext *declCtx = D->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = D->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) @@ -425,7 +435,6 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo { return AttributeNotApplied; } V->setInit(expr); - V->dump(); S.MarkVariableReferenced(loc, V); S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); return AttributeApplied; @@ -479,6 +488,11 @@ struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo { auto &AST = S.getASTContext(); DeclContext *declCtx = D->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = D->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) @@ -534,7 +548,6 @@ struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo { return AttributeNotApplied; } V->setInit(expr); - V->dump(); S.MarkVariableReferenced(loc, V); S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); return AttributeApplied; @@ -584,6 +597,11 @@ struct EnzymeSparseAccumulateAttrInfo : public ParsedAttrInfo { auto &AST = S.getASTContext(); DeclContext *declCtx = D->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = D->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) diff --git a/enzyme/test/Integration/ReverseMode/inactiveglob.cpp b/enzyme/test/Integration/ReverseMode/inactiveglob.cpp new file mode 100644 index 000000000000..118ab257acd2 --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/inactiveglob.cpp @@ -0,0 +1,62 @@ +// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi + +#include +#include +#include + +#include "../test_utils.h" + +double __enzyme_autodiff(void*, ...); + +struct Temp { +private: + __attribute__((enzyme_inactive)) + static double tmp; +public: + +__attribute__((noinline)) +static void f(bool cond, double a, double *c) { + if (cond) + tmp *= a; + else + *c *= a; +} + +static double get() { return tmp; } +static void set(double v) { tmp = v; } + +}; + +double Temp::tmp = 0; + +double test(bool cond, double a) { + double dat = a; + Temp::f(cond, a, &dat); + return dat; +} + +int main(int argc, char** argv) { + Temp::set(5.5); + double out = __enzyme_autodiff((void*)test, false, 3.0); + printf("out=%f\n", out); + APPROX_EQ(out, 6.0, 1e-10); + APPROX_EQ(Temp::get(), 5.5, 1e-10); + return 0; +} From 1b8aff2b7492110add8e5cbe9d8e576b089b9c8f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 01:19:57 -0500 Subject: [PATCH 018/131] Mark more as nofree/inactive (#1628) --- enzyme/Enzyme/ActivityAnalysis.cpp | 3 ++- enzyme/Enzyme/EnzymeLogic.cpp | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 245be4a2958c..d751ee2b565e 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -409,7 +409,8 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { "std::__1::discard_block_engine", "std::__1::independent_bits_engine", "std::__1::shuffle_order_engine", - + "std::__1::basic_streambuf", + "std::__1::basic_stringbuf", "std::__detail::_Prime_rehash_policy", "std::__detail::_Hash_code_base", diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 2810b5370ac4..960692445c50 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5846,6 +5846,9 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { } if (F->empty()) { + if (EnzymeAssumeUnknownNoFree) { + return F; + } if (EnzymeEmptyFnInactive) { return F; } From b7e88c9892c09e946fdd6558ba39355ab3de9fdc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 01:50:34 -0500 Subject: [PATCH 019/131] Fix integration CI bug (#1630) --- enzyme/Enzyme/EnzymeLogic.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 960692445c50..3b1670522fc2 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3234,7 +3234,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, for (size_t i = 1; i < size; i++) { if (!PNtypeT[{(int)i}].isFloat()) continue; - PNtypeT[{(int)i}].checkedOrIn(PNtype, /*pointerIntSame*/ true, legal); + PNtype.checkedOrIn(PNtypeT[{(int)i}], /*pointerIntSame*/ true, legal); if (!legal) { break; } @@ -3254,7 +3254,8 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, if (!PNfloatType) { std::string str; raw_string_ostream ss(str); - ss << "Cannot deduce type of phi " << *orig; + ss << "Cannot deduce type of phi " << *orig << PNtypeT.str() + << " sz: " << size << "\n"; if (CustomErrorHandler) { CustomErrorHandler(str.c_str(), wrap(orig), ErrorType::NoType, &gutils->TR.analyzer, nullptr, wrap(&Builder)); From 4b67823aaf0e5b9715690ce198a51e97b7b22229 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 12:13:19 -0500 Subject: [PATCH 020/131] [Bazel] run integration tests (#1631) --- enzyme/BUILD | 15 ++++++++++++--- enzyme/Enzyme/eopt.cpp | 17 +++++++++++++++++ enzyme/test/lit.site.cfg.py.in | 13 +++++++++---- 3 files changed, 38 insertions(+), 7 deletions(-) create mode 100644 enzyme/Enzyme/eopt.cpp diff --git a/enzyme/BUILD b/enzyme/BUILD index 6278ec604b7e..c9ae1c3cdb71 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -149,7 +149,7 @@ cc_library( "Enzyme/*.cpp", "Enzyme/TypeAnalysis/*.cpp", "Enzyme/Clang/EnzymeClang.cpp", - ]), + ], exclude=["Enzyme/eopt.cpp"]), hdrs = glob([ "Enzyme/*.h", "Enzyme/TypeAnalysis/*.h", @@ -218,6 +218,15 @@ genrule( output_to_bindir = 1, ) +cc_binary( + name = "enzyme-opt", + srcs = ["Enzyme/eopt.cpp"], + deps = [ + ":EnzymeStatic", + "@llvm-project//llvm:opt-driver", + ], +) + td_library( name = "EnzymeDialectTdFiles", srcs = [ @@ -476,12 +485,12 @@ expand_template( "@llvm-project//llvm:count", "@llvm-project//llvm:not", "@llvm-project//llvm:lli", - "@llvm-project//llvm:opt", + ":enzyme-opt", "@llvm-project//clang:builtin_headers_gen", ":enzyme-clang", ":enzyme-clang++", ":enzymemlir-opt" ] + glob(["test/**/*.h"]) ) - for src in glob(["test/**/*.mlir"]) + for src in glob(["test/**/*.mlir", "test/Integration/**/*.c", "test/Integration/**/.cpp"], exclude=["test/**/*omp*.c"]) ] diff --git a/enzyme/Enzyme/eopt.cpp b/enzyme/Enzyme/eopt.cpp new file mode 100644 index 000000000000..ab8792d34996 --- /dev/null +++ b/enzyme/Enzyme/eopt.cpp @@ -0,0 +1,17 @@ +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Passes/PassBuilder.h" + +#include + +using namespace llvm; + +void registerEnzyme(llvm::PassBuilder &PB); + +extern "C" int optMain(int argc, char **argv, + llvm::ArrayRef> + PassBuilderCallbacks); + +int main(int argc, char **argv) { + std::function plugins[] = {registerEnzyme}; + return optMain(argc, argv, plugins); +} diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index bb6e6e7ac0c2..481fce924bcc 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -43,11 +43,16 @@ config.excludes = ['Inputs'] config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) config.substitutions.append(('%lli', config.llvm_tools_dir + "/lli" + (" --jit-kind=mcjit" if int(config.llvm_ver) >= 13 else "") )) -config.substitutions.append(('%opt', config.llvm_tools_dir + "/opt")) -eopt = config.enzyme_obj_root + "/Enzyme/MLIR/enzymemlir-opt" +opt = config.llvm_tools_dir + "/opt" if len("@ENZYME_BINARY_DIR@") == 0: - eopt = os.path.dirname(os.path.abspath(__file__)) + "/../enzymemlir-opt" + opt = os.path.dirname(os.path.abspath(__file__)) + "/../enzyme-opt" + +config.substitutions.append(('%opt', opt)) + +emopt = config.enzyme_obj_root + "/Enzyme/MLIR/enzymemlir-opt" +if len("@ENZYME_BINARY_DIR@") == 0: + emopt = os.path.dirname(os.path.abspath(__file__)) + "/../enzymemlir-opt" eclang = config.llvm_tools_dir + "/clang" if len("@ENZYME_BINARY_DIR@") == 0: @@ -56,7 +61,7 @@ if len("@ENZYME_BINARY_DIR@") == 0: eclang += " -resource-dir " + resource + " " eclang += "-I " + os.path.dirname(os.path.abspath(__file__)) + "/Integration" -config.substitutions.append(('%eopt', eopt)) +config.substitutions.append(('%eopt', emopt)) config.substitutions.append(('%llvmver', config.llvm_ver)) config.substitutions.append(('%FileCheck', config.llvm_tools_dir + "/FileCheck")) config.substitutions.append(('%clang', eclang)) From e44812319a67d566da1c05ffcca8a23429ec143c Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 13:04:56 -0500 Subject: [PATCH 021/131] Handle invert pointer of add expr (#1632) --- enzyme/Enzyme/GradientUtils.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 423b6401e64f..034bf00142a7 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5505,6 +5505,26 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, return shadow; } else if (auto arg = dyn_cast(oval)) { IRBuilder<> bb(inversionAllocs); + if (arg->getOpcode() == Instruction::Add) { + if (isa(arg->getOperand(0))) { + auto rule = [&bb, &arg](Value *ip) { + Constant *invops[2] = {arg->getOperand(0), cast(ip)}; + return arg->getWithOperands(invops); + }; + + auto ip = invertPointerM(arg->getOperand(1), bb, nullShadow); + return applyChainRule(arg->getType(), bb, rule, ip); + } + if (isa(arg->getOperand(1))) { + auto rule = [&bb, &arg](Value *ip) { + Constant *invops[2] = {cast(ip), arg->getOperand(1)}; + return arg->getWithOperands(invops); + }; + + auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow); + return applyChainRule(arg->getType(), bb, rule, ip); + } + } auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow); if (arg->isCast()) { From 5d8d29464601f81dcd0b4cd6d2b49af6f283aa49 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 14:20:09 -0500 Subject: [PATCH 022/131] Use specified name for known inactive function insts (#1633) --- enzyme/Enzyme/ActivityAnalysis.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index d751ee2b565e..fa4dd06bb68c 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -742,10 +742,10 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, InsertConstantInstruction(TR, I); return true; } - if (KnownInactiveFunctionInsts.count(called->getName())) { - InsertConstantInstruction(TR, I); - return true; - } + } + if (KnownInactiveFunctionInsts.count(getFuncNameFromCall(CI))) { + InsertConstantInstruction(TR, I); + return true; } } From 3c559ca31796b593eda67e5ed921f61de2a81d55 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 18:01:54 -0500 Subject: [PATCH 023/131] [Julia] permit caches of addresspace-10 even if potentially indirect (#1635) --- enzyme/Enzyme/EnzymeLogic.h | 1 + enzyme/Enzyme/GradientUtils.cpp | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 4ce25e8ae465..562aa12459a8 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -51,6 +51,7 @@ extern "C" { extern llvm::cl::opt EnzymePrint; +extern llvm::cl::opt EnzymeJuliaAddrLoad; } enum class AugmentedStruct { Tape, Return, DifferentialReturn }; diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 034bf00142a7..6adfc1bd6b78 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9325,6 +9325,13 @@ bool GradientUtils::needsCacheWholeAllocation( if (found == knownRecomputeHeuristic.end()) continue; + // If caching a julia base object, this is fine as + // GC will deal with any issues with. + if (auto PT = dyn_cast(cur->getType())) + if (PT->getAddressSpace() == 10) + if (EnzymeJuliaAddrLoad) + continue; + // If caching this user, it cannot be a gep/cast of original if (!found->second) { llvm::errs() << " mod: " << *oldFunc->getParent() << "\n"; From 0c98e26433ae768578f75c658839ca1fdaec4e66 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Fri, 26 Jan 2024 09:27:06 +0900 Subject: [PATCH 024/131] Add functions to truncate and expand fp values (#1615) * Precompute to/from types * Value truncation * non opaque --- enzyme/Enzyme/Enzyme.cpp | 64 +++++++++-- enzyme/Enzyme/EnzymeLogic.cpp | 134 +++++++++++++++++------ enzyme/Enzyme/EnzymeLogic.h | 9 +- enzyme/test/Enzyme/Truncate/cmp.ll | 4 +- enzyme/test/Enzyme/Truncate/intrinsic.ll | 4 +- enzyme/test/Enzyme/Truncate/select.ll | 4 +- enzyme/test/Enzyme/Truncate/simple.ll | 4 +- enzyme/test/Enzyme/Truncate/value.ll | 37 +++++++ 8 files changed, 202 insertions(+), 58 deletions(-) create mode 100644 enzyme/test/Enzyme/Truncate/value.ll diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index a514a826351b..e402c1656248 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1314,14 +1314,14 @@ class EnzymeBase { return type_args; } - bool HandleTruncate(CallInst *CI) { + bool HandleTruncateFunc(CallInst *CI) { IRBuilder<> Builder(CI); Function *F = parseFunctionParameter(CI); if (!F) return false; if (CI->arg_size() != 3) { EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, - "Had incorrect number of args to __enzyme_truncate", *CI, + "Had incorrect number of args to __enzyme_truncate_func", *CI, " - expected 3"); return false; } @@ -1330,7 +1330,7 @@ class EnzymeBase { auto Cto = cast(CI->getArgOperand(2)); assert(Cto); RequestContext context(CI, &Builder); - llvm::Value *res = Logic.CreateTruncate( + llvm::Value *res = Logic.CreateTruncateFunc( context, F, (unsigned)Cfrom->getValue().getZExtValue(), (unsigned)Cto->getValue().getZExtValue()); if (!res) @@ -1341,6 +1341,28 @@ class EnzymeBase { return true; } + bool HandleTruncateValue(CallInst *CI, bool isTruncate) { + IRBuilder<> Builder(CI); + if (CI->arg_size() != 3) { + EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, + "Had incorrect number of args to __enzyme_truncate_value", + *CI, " - expected 3"); + return false; + } + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto = cast(CI->getArgOperand(2)); + assert(Cto); + auto Addr = CI->getArgOperand(0); + RequestContext context(CI, &Builder); + bool res = Logic.CreateTruncateValue( + context, Addr, (unsigned)Cfrom->getValue().getZExtValue(), + (unsigned)Cto->getValue().getZExtValue(), isTruncate); + if (!res) + return false; + return true; + } + bool HandleBatch(CallInst *CI) { unsigned width = 1; unsigned truei = 0; @@ -2088,7 +2110,9 @@ class EnzymeBase { MapVector toVirtual; MapVector toSize; SmallVector toBatch; - SmallVector toTruncate; + SmallVector toTruncateFunc; + SmallVector toTruncateValue; + SmallVector toExpandValue; MapVector toProbProg; SetVector InactiveCalls; SetVector IterCalls; @@ -2398,7 +2422,9 @@ class EnzymeBase { bool virtualCall = false; bool sizeOnly = false; bool batch = false; - bool truncate = false; + bool truncateFunc = false; + bool truncateValue = false; + bool expandValue = false; bool probProg = false; DerivativeMode derivativeMode; ProbProgMode probProgMode; @@ -2428,9 +2454,15 @@ class EnzymeBase { } else if (Fn->getName().contains("__enzyme_batch")) { enableEnzyme = true; batch = true; - } else if (Fn->getName().contains("__enzyme_truncate")) { + } else if (Fn->getName().contains("__enzyme_truncate_func")) { enableEnzyme = true; - truncate = true; + truncateFunc = true; + } else if (Fn->getName().contains("__enzyme_truncate_value")) { + enableEnzyme = true; + truncateValue = true; + } else if (Fn->getName().contains("__enzyme_expand_value")) { + enableEnzyme = true; + expandValue = true; } else if (Fn->getName().contains("__enzyme_likelihood")) { enableEnzyme = true; probProgMode = ProbProgMode::Likelihood; @@ -2488,8 +2520,12 @@ class EnzymeBase { toSize[CI] = derivativeMode; else if (batch) toBatch.push_back(CI); - else if (truncate) - toTruncate.push_back(CI); + else if (truncateFunc) + toTruncateFunc.push_back(CI); + else if (truncateValue) + toTruncateValue.push_back(CI); + else if (expandValue) + toExpandValue.push_back(CI); else if (probProg) { toProbProg[CI] = probProgMode; } else @@ -2583,8 +2619,14 @@ class EnzymeBase { for (auto call : toBatch) { HandleBatch(call); } - for (auto call : toTruncate) { - HandleTruncate(call); + for (auto call : toTruncateFunc) { + HandleTruncateFunc(call); + } + for (auto call : toTruncateValue) { + HandleTruncateValue(call, true); + } + for (auto call : toExpandValue) { + HandleTruncateValue(call, false); } for (auto &&[call, mode] : toProbProg) { diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 3b1670522fc2..9ad444e4bd3a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4813,11 +4813,53 @@ Function *EnzymeLogic::CreateForwardDiff( return nf; } +static Type *getTypeForWidth(LLVMContext &ctx, unsigned width) { + switch (width) { + default: + return llvm::Type::getIntNTy(ctx, width); + case 64: + return llvm::Type::getDoubleTy(ctx); + case 32: + return llvm::Type::getFloatTy(ctx); + case 16: + return llvm::Type::getHalfTy(ctx); + } +} + +static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, + unsigned fromwidth, unsigned towidth) { + Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); + Type *toTy = getTypeForWidth(B.getContext(), towidth); + if (!tmpBlock) + tmpBlock = B.CreateAlloca(fromTy); + B.CreateStore( + v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad( + toTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(toTy))); +} + +static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, + unsigned fromwidth, unsigned towidth) { + Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); + if (!tmpBlock) + tmpBlock = B.CreateAlloca(fromTy); + auto c0 = + Constant::getNullValue(llvm::Type::getIntNTy(B.getContext(), fromwidth)); + B.CreateStore( + c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); + B.CreateStore( + v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad( + fromTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(fromTy))); +} + class TruncateGenerator : public llvm::InstVisitor { private: ValueToValueMapTy &originalToNewFn; unsigned fromwidth; unsigned towidth; + Type *fromType; + Type *toType; Function *oldFunc; Function *newFunc; AllocaInst *tmpBlock; @@ -4830,7 +4872,11 @@ class TruncateGenerator : public llvm::InstVisitor { : originalToNewFn(originalToNewFn), fromwidth(fromwidth), towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - tmpBlock = B.CreateAlloca(getTypeForWidth(fromwidth)); + + fromType = getTypeForWidth(B.getContext(), fromwidth); + toType = getTypeForWidth(B.getContext(), towidth); + + tmpBlock = B.CreateAlloca(fromType); } void visitInstruction(llvm::Instruction &inst) { @@ -4848,42 +4894,16 @@ class TruncateGenerator : public llvm::InstVisitor { todo(inst); } - Type *getTypeForWidth(unsigned width) { - switch (width) { - default: - return llvm::Type::getIntNTy(oldFunc->getContext(), width); - case 64: - return llvm::Type::getDoubleTy(oldFunc->getContext()); - case 32: - return llvm::Type::getFloatTy(oldFunc->getContext()); - case 16: - return llvm::Type::getHalfTy(oldFunc->getContext()); - } - } + Type *getFromType() { return fromType; } - Type *getFromType() { return getTypeForWidth(fromwidth); } - - Type *getToType() { return getTypeForWidth(towidth); } + Type *getToType() { return toType; } Value *truncate(IRBuilder<> &B, Value *v) { - Type *nextType = getTypeForWidth(towidth); - B.CreateStore( - v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); - return B.CreateLoad( - nextType, - B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType))); + return floatTruncate(B, v, tmpBlock, fromwidth, towidth); } Value *expand(IRBuilder<> &B, Value *v) { - Type *origT = getFromType(); - auto c0 = Constant::getNullValue( - llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth)); - B.CreateStore(c0, B.CreatePointerCast( - tmpBlock, PointerType::getUnqual(c0->getType()))); - B.CreateStore( - v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); - return B.CreateLoad( - origT, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(origT))); + return floatExpand(B, v, tmpBlock, fromwidth, towidth); } void todo(llvm::Instruction &I) { @@ -5180,7 +5200,7 @@ class TruncateGenerator : public llvm::InstVisitor { Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) - return Logic.CreateTruncate(ctx, F, fromwidth, towidth); + return Logic.CreateTruncateFunc(ctx, F, fromwidth, towidth); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; @@ -5203,10 +5223,52 @@ class TruncateGenerator : public llvm::InstVisitor { } }; -llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, - llvm::Function *totrunc, - unsigned fromwidth, - unsigned towidth) { +bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, + unsigned fromwidth, unsigned towidth, + bool isTruncate) { + assert(context.req && context.ip); + + if (fromwidth == towidth) { + context.req->eraseFromParent(); + return true; + } + + if (fromwidth < towidth) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Cannot truncate into a large width\n"; + if (context.req) { + ss << " at context: " << *context.req; + EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, + ss.str()); + return false; + } + llvm_unreachable("failed to truncate value"); + } + + IRBuilderBase &B = *context.ip; + Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); + Type *toTy = getTypeForWidth(B.getContext(), towidth); + + Value *converted = nullptr; + if (isTruncate) + converted = + floatExpand(B, B.CreateFPTrunc(v, toTy), nullptr, fromwidth, towidth); + else + converted = + B.CreateFPExt(floatTruncate(B, v, nullptr, fromwidth, towidth), fromTy); + assert(converted); + + context.req->replaceAllUsesWith(converted); + context.req->eraseFromParent(); + + return true; +} + +llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, + llvm::Function *totrunc, + unsigned fromwidth, + unsigned towidth) { if (fromwidth == towidth) return totrunc; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 562aa12459a8..2f1ac9fde496 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -513,9 +513,12 @@ class EnzymeLogic { using TruncateCacheKey = std::tuple; std::map TruncateCachedFunctions; - llvm::Function *CreateTruncate(RequestContext context, - llvm::Function *tobatch, unsigned fromwidth, - unsigned towidth); + llvm::Function *CreateTruncateFunc(RequestContext context, + llvm::Function *tobatch, + unsigned fromwidth, unsigned towidth); + bool CreateTruncateValue(RequestContext context, llvm::Value *addr, + unsigned fromwidth, unsigned towidth, + bool isTruncate); /// Create a traced version of a function /// \p context the instruction which requested this trace (or null). diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index 3c2cffec9979..e9c61ebd773e 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -6,11 +6,11 @@ define i1 @f(double %x, double %y) { ret i1 %res } -declare i1 (double, double)* @__enzyme_truncate(...) +declare i1 (double, double)* @__enzyme_truncate_func(...) define i1 @tester(double %x, double %y) { entry: - %ptr = call i1 (double, double)* (...) @__enzyme_truncate(i1 (double, double)* @f, i64 64, i64 32) + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_func(i1 (double, double)* @f, i64 64, i64 32) %res = call i1 %ptr(double %x, double %y) ret i1 %res } diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index ea92f5d96bbc..da4457492ce2 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -13,11 +13,11 @@ define double @f(double %x, double %y) { ret double %res } -declare double (double, double)* @__enzyme_truncate(...) +declare double (double, double)* @__enzyme_truncate_func(...) define double @tester(double %x, double %y) { entry: - %ptr = call double (double, double)* (...) @__enzyme_truncate(double (double, double)* @f, i64 64, i64 32) + %ptr = call double (double, double)* (...) @__enzyme_truncate_func(double (double, double)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y) ret double %res } diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index ae539469b9a2..58b4a58ef91b 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -6,11 +6,11 @@ define double @f(double %x, double %y, i1 %cond) { ret double %res } -declare double (double, double, i1)* @__enzyme_truncate(...) +declare double (double, double, i1)* @__enzyme_truncate_func(...) define double @tester(double %x, double %y, i1 %cond) { entry: - %ptr = call double (double, double, i1)* (...) @__enzyme_truncate(double (double, double, i1)* @f, i64 64, i64 32) + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_func(double (double, double, i1)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y, i1 %cond) ret double %res } diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 69990236a29e..0f346a26f0d2 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -8,11 +8,11 @@ define void @f(double* %x) { ret void } -declare void (double*)* @__enzyme_truncate(...) +declare void (double*)* @__enzyme_truncate_func(...) define void @tester(double* %data) { entry: - %ptr = call void (double*)* (...) @__enzyme_truncate(void (double*)* @f, i64 64, i64 32) + %ptr = call void (double*)* (...) @__enzyme_truncate_func(void (double*)* @f, i64 64, i64 32) call void %ptr(double* %data) ret void } diff --git a/enzyme/test/Enzyme/Truncate/value.ll b/enzyme/test/Enzyme/Truncate/value.ll new file mode 100644 index 000000000000..51f00401078d --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/value.ll @@ -0,0 +1,37 @@ +; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi +; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi + +declare double @__enzyme_truncate_value(double, i64, i64) +declare double @__enzyme_expand_value(double, i64, i64) + +define double @expand_tester(double %a, double * %c) { +entry: + %b = call double @__enzyme_expand_value(double %a, i64 64, i64 32) + ret double %b +} + +define double @truncate_tester(double %a) { +entry: + %b = call double @__enzyme_truncate_value(double %a, i64 64, i64 32) + ret double %b +} + +; CHECK: define double @expand_tester(double %a, double* %c) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = alloca double, align 8 +; CHECK-NEXT: store double %a, double* %0, align 8 +; CHECK-NEXT: %1 = bitcast double* %0 to float* +; CHECK-NEXT: %2 = load float, float* %1, align 4 +; CHECK-NEXT: %3 = fpext float %2 to double +; CHECK-NEXT: ret double %3 + +; CHECK: define double @truncate_tester(double %a) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fptrunc double %a to float +; CHECK-NEXT: %1 = alloca double, align 8 +; CHECK-NEXT: %2 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %2, align 4 +; CHECK-NEXT: %3 = bitcast double* %1 to float* +; CHECK-NEXT: store float %0, float* %3, align 4 +; CHECK-NEXT: %4 = load double, double* %1, align 8 +; CHECK-NEXT: ret double %4 From af10cc97ecfca36fec8c185ddf66d2c35ffa0059 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 20:15:31 -0500 Subject: [PATCH 025/131] Allow disabling memmove warning (#1637) --- enzyme/Enzyme/Utils.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 25994fb2bd06..ff44cbaa715d 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -87,6 +87,9 @@ llvm::cl::opt EnzymeStrongZero("enzyme-strong-zero", cl::init(false), cl::Hidden, cl::desc("Use additional checks to ensure correct " "behavior when handling functions with inf")); +llvm::cl::opt EnzymeMemmoveWarning( + "enzyme-memmove-warning", cl::init(true), cl::Hidden, + cl::desc("Warn if using memmove implementation as a fallback for memmove")); } void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, @@ -1240,8 +1243,10 @@ Function * getOrInsertDifferentialFloatMemmove(Module &M, Type *T, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth) { - llvm::errs() << "warning: didn't implement memmove, using memcpy as fallback " - "which can result in errors\n"; + if (EnzymeMemmoveWarning) + llvm::errs() + << "warning: didn't implement memmove, using memcpy as fallback " + "which can result in errors\n"; return getOrInsertDifferentialFloatMemcpy(M, T, dstalign, srcalign, dstaddr, srcaddr, bitwidth); } From cce389e7e71d01243ce3899e959c725b50d87bad Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jan 2024 12:11:14 -0500 Subject: [PATCH 026/131] Fix julia 16 llvm (#1638) --- .packaging/build_tarballs.jl | 5 +++-- enzyme/BCLoad/CMakeLists.txt | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.packaging/build_tarballs.jl b/.packaging/build_tarballs.jl index 3c60696102d1..67d214bf2c86 100644 --- a/.packaging/build_tarballs.jl +++ b/.packaging/build_tarballs.jl @@ -117,11 +117,12 @@ for llvm_version in llvm_versions, llvm_assertions in (false, true) for platform in platforms augmented_platform = deepcopy(platform) augmented_platform[LLVM.platform_name] = LLVM.platform(llvm_version, llvm_assertions) - + gcc_version = version > v"15" ? v"10" : v"8" should_build_platform(triplet(augmented_platform)) || continue push!(builds, (; dependencies, products, platforms=[augmented_platform], + gcc_version, )) end end @@ -137,6 +138,6 @@ for (i,build) in enumerate(builds) build_tarballs(i == lastindex(builds) ? non_platform_ARGS : non_reg_ARGS, name, version, sources, script, build.platforms, build.products, build.dependencies; - preferred_gcc_version=v"8", julia_compat="1.6", + preferred_gcc_version=build.gcc_version, julia_compat="1.6", augment_platform_block, lazy_artifacts=true) # drop when julia_compat >= 1.7 end diff --git a/enzyme/BCLoad/CMakeLists.txt b/enzyme/BCLoad/CMakeLists.txt index 0bac59a04dfa..f2b04238fb67 100644 --- a/enzyme/BCLoad/CMakeLists.txt +++ b/enzyme/BCLoad/CMakeLists.txt @@ -40,7 +40,7 @@ ExternalProject_Add(gsl64 BUILD_IN_SOURCE 1 INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/gsl64/install CONFIGURE_COMMAND sh -c ${CMAKE_CURRENT_SOURCE_DIR}/fixgsl64.sh - BUILD_COMMAND sh -c "$ cblas/*.c ${BC_LOAD_FLAGS2} -I . -I .. -S -emit-llvm -O1" + BUILD_COMMAND sh -c "rm cblas/xerbla.c && $ cblas/*.c ${BC_LOAD_FLAGS2} -I . -I .. -S -emit-llvm -O1" INSTALL_COMMAND "" UPDATE_COMMAND "" TEST_COMMAND "" From 878c34a59dac0daec2ebc16d4bcbb9240e177a41 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jan 2024 13:17:32 -0500 Subject: [PATCH 027/131] Fix linking issue (#1639) --- enzyme/BCLoad/BCLoader.cpp | 2 -- enzyme/Enzyme/CMakeLists.txt | 8 ++++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index cec3317d0fd2..c27377d213ba 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -28,7 +28,6 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { std::vector todo; bool seen32 = false; bool seen64 = false; - bool seenGemm = false; for (auto &F : M) { if (!F.empty()) continue; @@ -52,7 +51,6 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { seen32 = true; if (index == 2) seen64 = true; - if (endsWith(str, "gemm")) seenGemm = true; break; } index++; diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index d932f9d21dd3..1cd6e84c5be1 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -44,6 +44,7 @@ set(LLVM_LINK_COMPONENTS Demangle) file(GLOB ENZYME_SRC CONFIGURE_DEPENDS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp" ) +list(REMOVE_ITEM ENZYME_SRC "eopt.cpp") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -54,6 +55,7 @@ list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp T if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB) add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} + PARTIAL_SOURCES_INTENDED MODULE DEPENDS intrinsics_gen @@ -64,6 +66,7 @@ if (${Clang_FOUND}) add_llvm_library( ClangEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymeClang.cpp Clang/EnzymePassLoader.cpp + PARTIAL_SOURCES_INTENDED MODULE DEPENDS intrinsics_gen @@ -74,6 +77,7 @@ target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPA endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp + PARTIAL_SOURCES_INTENDED MODULE DEPENDS intrinsics_gen @@ -84,6 +88,7 @@ target_compile_definitions(LLDEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS else() add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} + PARTIAL_SOURCES_INTENDED MODULE DEPENDS intrinsics_gen @@ -94,6 +99,7 @@ if (${Clang_FOUND}) add_llvm_library( ClangEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymeClang.cpp Clang/EnzymePassLoader.cpp + PARTIAL_SOURCES_INTENDED MODULE DEPENDS intrinsics_gen @@ -104,6 +110,7 @@ target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPA endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp + PARTIAL_SOURCES_INTENDED MODULE DEPENDS intrinsics_gen @@ -116,6 +123,7 @@ endif() if (${ENZYME_STATIC_LIB}) add_llvm_library( EnzymeStatic-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} + PARTIAL_SOURCES_INTENDED STATIC DEPENDS intrinsics_gen From 5a7d8468628bdc4b33cb8f3aa2b81a2942691079 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jan 2024 21:32:38 -0500 Subject: [PATCH 028/131] Fix bcload on llvm16 (#1642) --- enzyme/BCLoad/BCLoader.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index c27377d213ba..7f7cf6cfe2c4 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -69,9 +69,15 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { SMDiagnostic Err; MemoryBufferRef buf(mod, StringRef("bcloader")); +#if LLVM_VERSION_MAJOR >= 16 + auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef, StringRef) { + return Optional(M.getDataLayout().getStringRepresentation()); + }); +#else auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef) { return Optional(M.getDataLayout().getStringRepresentation()); }); +#endif if (!BC) Err.print("bcloader", llvm::errs()); From 050f2b2e6bbacbaec28b4948342b5f9b2dafb42e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jan 2024 23:10:58 -0500 Subject: [PATCH 029/131] BCLoad 16 CI (#1643) --- enzyme/BCLoad/BCLoader.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index 7f7cf6cfe2c4..ee045e7b596a 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -71,7 +71,7 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { #if LLVM_VERSION_MAJOR >= 16 auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef, StringRef) { - return Optional(M.getDataLayout().getStringRepresentation()); + return std::optional(M.getDataLayout().getStringRepresentation()); }); #else auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef) { From 900789e6c36d5b9efe43b5c5c839f1d6d43dbb0f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 27 Jan 2024 00:26:56 -0500 Subject: [PATCH 030/131] More bc load attempts (#1644) * BCLoad 16 CI * trying again --- enzyme/BCLoad/BCLoader.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index ee045e7b596a..6aad07eb6739 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -70,9 +70,11 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { MemoryBufferRef buf(mod, StringRef("bcloader")); #if LLVM_VERSION_MAJOR >= 16 - auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef, StringRef) { - return std::optional(M.getDataLayout().getStringRepresentation()); - }); + auto BC = llvm::parseIR(buf, Err, M.getContext(), + llvm::ParserCallbacks([&](StringRef, StringRef) { + return std::optional( + M.getDataLayout().getStringRepresentation()); + })); #else auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef) { return Optional(M.getDataLayout().getStringRepresentation()); From ee0c6783adc01e3d61d67c7a4b660e014e9fb224 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 29 Jan 2024 13:08:38 -0500 Subject: [PATCH 031/131] Add nv_ceil handler (#1647) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 381cd0729114..f2a3dc56cdf8 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -168,6 +168,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"tgamma", Intrinsic::not_intrinsic}, {"lgamma", Intrinsic::not_intrinsic}, {"ceil", Intrinsic::ceil}, + {"__nv_ceil", Intrinsic::ceil}, {"floor", Intrinsic::floor}, {"fmod", Intrinsic::not_intrinsic}, {"trunc", Intrinsic::trunc}, From ff15cd8eec441dfd6ea563bc4f4eb7a2bd1774db Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 30 Jan 2024 13:58:10 -0500 Subject: [PATCH 032/131] [MLIR][ActivityAnalysis] create activity interface (#1648) * [MLIR][ActivityAnalysis] create activity interface * wip * fixup --- enzyme/BUILD | 32 +++- .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 164 ++++++------------ .../Analysis/DataFlowActivityAnalysis.cpp | 19 ++ .../MLIR/Implementations/ArithDerivatives.td | 25 +-- .../MLIR/Implementations/CMakeLists.txt | 11 ++ enzyme/Enzyme/MLIR/Implementations/Common.td | 30 ++++ .../CoreDialectsAutoDiffImplementations.h | 1 + .../LLVMAutoDiffOpInterfaceImpl.cpp | 18 ++ .../MLIR/Implementations/LLVMDerivatives.td | 17 ++ .../NVVMAutoDiffOpInterfaceImpl.cpp | 34 ++++ .../MLIR/Implementations/NVVMDerivatives.td | 4 + .../SCFAutoDiffOpInterfaceImpl.cpp | 16 ++ .../MLIR/Interfaces/AutoDiffOpInterface.td | 21 +++ enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 19 ++ 15 files changed, 272 insertions(+), 140 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/Common.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td diff --git a/enzyme/BUILD b/enzyme/BUILD index c9ae1c3cdb71..8954d5433e0a 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -392,7 +392,35 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"], + td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "llvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/LLVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "nvvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/NVVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], deps = [ ":enzyme-tblgen", ], @@ -420,6 +448,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":arith-derivatives", + ":llvm-derivatives", + ":nvvm-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index c1a471277915..bb8c8f2ccc9c 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -3,7 +3,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" @@ -19,6 +18,8 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/Support/ModRef.h" +#include "Interfaces/AutoDiffOpInterface.h" + const char *KnownInactiveFunctionsStartingWith[] = { "f90io", "$ss5print", @@ -467,9 +468,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (isa(I)) return true; + if (auto ifaceOp = dyn_cast(I)) { + if (ifaceOp.isInactive()) { + return true; + } + } + // Branch, unreachable, and previously computed constants are inactive - if (isa(I) /*|| isa(I)*/ || - ConstantOperations.contains(I)) { + if (/*|| isa(I)*/ ConstantOperations.contains(I)) { return true; } @@ -592,12 +598,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // *I // << "\n"; - if (isa(I)) { - InsertConstantOperation(TR, I); - } - // if (auto II = dyn_cast(I)) { // switch (II->getIntrinsicID()) { // case Intrinsic::nvvm_barrier0: @@ -1121,13 +1121,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (Val.getType().isa()) return true; - // All function pointers are considered active in case an augmented primal - // or reverse is needed - if (Val.getDefiningOp() && - isa(Val.getDefiningOp())) { - return false; - } - /// If we've already shown this value to be inactive if (ConstantValues.find(Val) != ConstantValues.end()) { return true; @@ -1142,44 +1135,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (matchPattern(Val, m_Constant())) return true; - // if (auto CD = dyn_cast(Val)) { - // // inductively assume inactive - // ConstantValues.insert(CD); - // for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { - // if (!isConstantValue(TR, CD->getElementAsConstant(i))) { - // ConstantValues.erase(CD); - // ActiveValues.insert(CD); - // return false; - // } - // } - // return true; - // } - // if (auto CD = dyn_cast(Val)) { - // // inductively assume inactive - // ConstantValues.insert(CD); - // for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - // if (!isConstantValue(TR, CD->getOperand(i))) { - // ConstantValues.erase(CD); - // ActiveValues.insert(CD); - // return false; - // } - // } - // return true; - // } - if (Operation *definingOp = Val.getDefiningOp()) { - // Undef and non-global constants are inactive. - if (isa(definingOp)) { - return true; - } - - // Ops derived from intrinsics. - // NOTE: this was written with the assumption that Value is-a Operation, - // which is not the case in MLIR. - if (isa(definingOp)) { - return true; + if (auto ifaceOp = dyn_cast(definingOp)) { + if (ifaceOp.isInactive()) { + return true; + } } } @@ -1494,6 +1454,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } } + if (auto op = TmpOrig.getDefiningOp()) + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { + InsertConstantValue(TR, Val); + if (TmpOrig != Val) { + InsertConstantValue(TR, TmpOrig); + } + return true; + } + } + UpHypothesis = std::shared_ptr( new mlir::enzyme::ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantValues.insert(Val); @@ -1828,16 +1799,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (notForAnalysis.count(op->getBlock())) return false; - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("exit") || - iasm.getAsmString().contains("cpuid")) - return false; - } - if (isa(op)) { - return true; - } + if (auto op = TmpOrig.getDefiningOp()) + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { + return false; + } + } // If this is a malloc or free, this doesn't impact the activity if (auto CI = dyn_cast(op)) { @@ -2537,21 +2504,16 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return false; } - // if (EnzymePrintActivity) - // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << - // "\n"; - - // cpuid is explicitly an inactive instruction - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("cpuid")) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from known cpuid instruction - // " - // << *inst << "\n"; + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { return true; } } + // if (EnzymePrintActivity) + // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << + // "\n"; + if (auto store = dyn_cast(op)) { if (isConstantValue(TR, store.getValue()) || isConstantValue(TR, store.getAddr())) { @@ -2643,15 +2605,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return true; } } - // Intrinsics known always to be inactive - if (isa(op)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - // << *inst << "\n"; - return true; - } if (auto gep = dyn_cast(op)) { // A gep's only args that could make it active is the pointer operand @@ -2731,13 +2684,7 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return false; } - if (isa(op)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-fpcst:" << - // *inst - // << "\n"; - return true; - } else { + { bool seenuse = false; //! TODO does not consider reading from global memory that is active and not //! an argument @@ -2871,6 +2818,13 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // } } + if (UA != UseActivity::AllStores) { + if (auto ifaceOp = dyn_cast(a)) { + if (ifaceOp.isArgInactive(parent)) + return true; + } + } + // if (EnzymePrintActivity) // llvm::errs() << " considering use of " << *val << " - " << *a // << "\n"; @@ -3078,14 +3032,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( continue; } - if (isa( - a)) { - // if (EnzymePrintActivity) - // llvm::errs() << "found constant(" << (int)directions - // << ") si-fp use:" << *val << " user " << *a << "\n"; - continue; - } - // // TODO: this should not happen in valid MLIR... // @@ -3367,10 +3313,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( LLVM::SExtOp, LLVM::ZExtOp, LLVM::TruncOp, - LLVM::SIToFPOp, - LLVM::UIToFPOp, - LLVM::FPToSIOp, - LLVM::FPToUIOp, LLVM::FPExtOp, LLVM::FPTruncOp // clang-format on @@ -3451,6 +3393,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( continue; } + if (auto ifaceOp = dyn_cast(a)) { + if (ifaceOp.isArgInactive(val)) + return true; + } + if (isa(a)) { if (ActiveReturns == DIFFE_TYPE::CONSTANT) continue; @@ -3509,19 +3456,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // TODO: in MLIR, users are always operations // if (Operation *inst = a) { - auto mayWriteToMemory = [](Operation *op) { - auto iface = dyn_cast(op); - if (!iface) - return true; - - SmallVector effects; - iface.getEffects(effects); - return llvm::any_of( - effects, [](const MemoryEffects::EffectInstance &effect) { - return isa(effect.getEffect()); - }); - }; - if (!mayWriteToMemory(inst) /*|| (isa(inst) && AA.onlyReadsMemory(cast(inst)))*/) { // // if not written to memory and returning a known constant, this diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 608bf75f40a5..4f922d6eb5d5 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -43,6 +43,8 @@ #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" +#include "Interfaces/AutoDiffOpInterface.h" + using namespace mlir; using namespace mlir::dataflow; using enzyme::AliasClassLattice; @@ -508,6 +510,15 @@ class DenseForwardActivityAnalysis ForwardMemoryActivity *after) override { join(after, before); ChangeResult result = ChangeResult::NoChange; + + // TODO If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // propagateIfChanged(after, result); + // return; + // } + // } + auto memory = dyn_cast(op); // If we can't reason about the memory effects, then conservatively assume // we can't deduce anything about activity via side-effects. @@ -657,6 +668,14 @@ class DenseBackwardActivityAnalysis void visitOperation(Operation *op, const BackwardMemoryActivity &after, BackwardMemoryActivity *before) override { + + // TODO: If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // return; + // } + // } + // Initialize the return activity of arguments. if (op->hasTrait() && op->getParentOp() == parentOp) { for (const auto &[arg, argActivity] : diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index bb713ef61799..fb7f113f16fe 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -1,25 +1,5 @@ -class MLIRDerivative resultOps> { - string dialect = dialect_; - string opName = opName_; - dag PatternToMatch = patternToMatch; - list ArgDerivatives = resultOps; -} +include "Common.td" -class Operation { - bit usesPrimal = usesPrimal_; - bit usesShadow = usesShadow_; - bit usesCustom = usesCustom_; -} - -class DiffeRetIndex indices_> { - list indices = indices_; -} -def DiffeRet : DiffeRetIndex<[-1]>; - -class Inst : Operation { - string name = mnemonic; - string dialect = dialect_; -} class ArithInst : Inst; def AddF : ArithInst<"arith::AddFOp">; @@ -32,9 +12,6 @@ def RemF : ArithInst<"arith::RemFOp">; def CheckedMulF : ArithInst<"arith::MulFOp">; def CheckedDivF : ArithInst<"arith::DivFOp">; -def Op { -} - def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y), [ (DiffeRet), diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index a41ee2133c68..521ba76c22bc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -3,9 +3,18 @@ set(LLVM_TARGET_DEFINITIONS ArithDerivatives.td) enzyme_tablegen(ArithDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(ArithDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS LLVMDerivatives.td) +enzyme_tablegen(LLVMDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(LLVMDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS NVVMDerivatives.td) +enzyme_tablegen(NVVMDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(NVVMDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations ArithAutoDiffOpInterfaceImpl.cpp LLVMAutoDiffOpInterfaceImpl.cpp + NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -14,6 +23,8 @@ add_mlir_library(MLIREnzymeImplementations DEPENDS MLIRAutoDiffOpInterfaceIncGen ArithDerivativesIncGen + LLVMDerivativesIncGen + NVVMDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td new file mode 100644 index 000000000000..3909405320d1 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -0,0 +1,30 @@ +class InactiveOp { + string dialect = dialect_; + string opName = opName_; +} + +class MLIRDerivative resultOps> { + string dialect = dialect_; + string opName = opName_; + dag PatternToMatch = patternToMatch; + list ArgDerivatives = resultOps; +} + +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} + +class DiffeRetIndex indices_> { + list indices = indices_; +} +def DiffeRet : DiffeRetIndex<[-1]>; + +class Inst : Operation { + string name = mnemonic; + string dialect = dialect_; +} + +def Op { +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 669b028998c6..56af04b30133 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -19,6 +19,7 @@ namespace enzyme { void registerArithDialectAutoDiffInterface(DialectRegistry ®istry); void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry); void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry); +void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 3c60fd421c7a..079dd1cb64e9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -23,6 +23,23 @@ using namespace mlir; using namespace mlir::enzyme; namespace { +#include "Implementations/LLVMDerivatives.inc" +} // namespace + +namespace { +struct InlineAsmActivityInterface + : public ActivityOpInterface::ExternalModel { + bool isInactive(Operation *op) const { + auto asmOp = cast(op); + auto str = asmOp.getAsmString(); + return str.contains("cpuid") || str.contains("exit"); + } + bool isArgInactive(Operation *op, mlir::Value) const { + return isInactive(op); + } +}; + struct LoadOpInterface : public AutoDiffOpInterface::ExternalModel { LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, @@ -102,5 +119,6 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface( LLVM::StoreOp::attachInterface(*context); LLVM::AllocaOp::attachInterface(*context); LLVM::LLVMPointerType::attachInterface(*context); + registerInterfaces(context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td new file mode 100644 index 000000000000..9e5f28e41665 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -0,0 +1,17 @@ +include "Common.td" + +def : InactiveOp<"LLVM", "SIToFPOp">; +def : InactiveOp<"LLVM", "UIToFPOp">; +def : InactiveOp<"LLVM", "FPToSIOp">; +def : InactiveOp<"LLVM", "FPToUIOp">; +def : InactiveOp<"LLVM", "AssumeOp">; +def : InactiveOp<"LLVM", "StackSaveOp">; +def : InactiveOp<"LLVM", "StackRestoreOp">; +def : InactiveOp<"LLVM", "LifetimeStartOp">; +def : InactiveOp<"LLVM", "LifetimeEndOp">; +def : InactiveOp<"LLVM", "Prefetch">; +def : InactiveOp<"LLVM", "MemsetOp">; + +def : InactiveOp<"LLVM", "UndefOp">; +def : InactiveOp<"LLVM", "ConstantOp">; +def : InactiveOp<"LLVM", "UnreachableOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..4d8116ce011b --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,34 @@ +//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/NVVMDerivatives.inc" +} // namespace + +void mlir::enzyme::registerNVVMDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, NVVM::NVVMDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td new file mode 100644 index 000000000000..f34dfb564cbc --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td @@ -0,0 +1,4 @@ +include "Common.td" + +// TODO in reverse replicate in reverse pass +def : InactiveOp<"NVVM", "Barrier0Op">; diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 52f48ce7e2d1..71d3db0c2905 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -32,6 +32,8 @@ struct ForOpInterface MGradientUtils *gutils) const { auto forOp = cast(op); auto nFor = cast(gutils->getNewFromOriginal(op)); + // Get a list of all the return types, which is the original return types + // alongside any shadow return types SmallVector nTypes; for (auto r : forOp->getResults()) { // TODO only if used @@ -43,6 +45,8 @@ struct ForOpInterface nTypes.push_back(adTypeIface.getShadowType()); } } + + // Get a list of all args, which is original args, and any shadows SmallVector nArgs; for (const auto &[initVal, iterArg] : llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs())) { @@ -51,12 +55,16 @@ struct ForOpInterface if (!gutils->isConstantValue(iterArg)) nArgs.push_back(gutils->invertPointerM(initVal, builder)); } + + // Create the new modified for loop auto repFor = builder.create( forOp.getLoc(), gutils->getNewFromOriginal(forOp.getLowerBound()), gutils->getNewFromOriginal(forOp.getUpperBound()), gutils->getNewFromOriginal(forOp.getStep()), nArgs); repFor.getRegion().takeBody(nFor.getRegion()); + // Inject the mapping for the new results into GradientUtil's shadow + // table SmallVector reps; size_t idx = 0; for (Value r : forOp.getResults()) { @@ -72,13 +80,19 @@ struct ForOpInterface idx++; } } + + // Replace all uses of original results nFor.replaceAllUsesWith(reps); gutils->erase(nFor); + + // differentiate body for (Operation &o : llvm::make_early_inc_range(forOp.getBody()->without_terminator())) { if (failed(gutils->visitChild(&o))) return failure(); } + + // Fix terminator (yield) operations Operation *oldYield = repFor.getBody()->getTerminator(); builder.setInsertionPointToEnd(repFor.getBody()); SmallVector nYields; @@ -93,6 +107,8 @@ struct ForOpInterface Operation *newYield = builder.clone(*oldYield); newYield->setOperands(nYields); gutils->erase(oldYield); + + // Done return success(); } }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 61099c316cd2..f45b641f9cb8 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -76,4 +76,25 @@ def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { ]; } +def ActivityOpInterface + : OpInterface<"ActivityOpInterface"> { + let cppNamespace = "::mlir::enzyme"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + }], + /*retTy=*/"bool", + /*methodName=*/"isInactive" + >, + InterfaceMethod< + /*desc=*/[{ + }], + /*retTy=*/"bool", + /*methodName=*/"isArgInactive", + /*args=*/(ins "::mlir::Value":$val) + > + ]; +} + #endif // ENZYME_MLIR_INTERFACES_AUTODIFFOPINTERFACES diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 589ffa610a28..b6bb33c51df9 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -100,6 +100,7 @@ int main(int argc, char **argv) { enzyme::registerArithDialectAutoDiffInterface(registry); enzyme::registerBuiltinDialectAutoDiffInterface(registry); enzyme::registerLLVMDialectAutoDiffInterface(registry); + enzyme::registerNVVMDialectAutoDiffInterface(registry); enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 11384a2284e8..1f4a3dc1b892 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1875,6 +1875,19 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } if (intrinsic == MLIRDerivatives) { + const auto &actpatterns = + recordKeeper.getAllDerivedDefinitions("InactiveOp"); + for (auto &pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "Activity : \n"; + os << " public ActivityOpInterface::ExternalModel<" + << opName << "Activity, " << dialect << "::" << opName << "> {\n"; + os << " bool isInactive(mlir::Operation*) const { return true; }\n"; + os << " bool isArgInactive(mlir::Operation*, mlir::Value) const { " + "return true; }\n"; + os << "};\n"; + } os << "void registerInterfaces(MLIRContext* context) {\n"; for (Record *pattern : patterns) { auto opName = pattern->getValueAsString("opName"); @@ -1884,6 +1897,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "RevDerivative>(*context);\n"; } + for (Record *pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "Activity>(*context);\n"; + } os << "}\n"; } } From 53a0bd8aef1dd49808420f19ef16847ad2d0ef6f Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 31 Jan 2024 13:10:00 -0500 Subject: [PATCH 033/131] [WIP] Simplify MLIR (#1646) --- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 4 +- .../MLIR/Interfaces/GradientUtilsReverse.cpp | 64 ++----------------- .../MLIR/Interfaces/GradientUtilsReverse.h | 29 +-------- 3 files changed, 10 insertions(+), 87 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 9b9509b3ec22..094228232346 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -74,14 +74,14 @@ class MDiffeGradientUtils : public MGradientUtils { FunctionOpInterface oldFunc_, MTypeAnalysis &TA, MTypeResults TR, IRMapping &invertedPointers_, const SmallPtrSetImpl &constantvalues_, - const SmallPtrSetImpl &returnvals_, + const SmallPtrSetImpl &activevals_, DIFFE_TYPE ActiveReturn, ArrayRef constant_values, IRMapping &origToNew_, std::map &origToNewOps_, DerivativeMode mode, unsigned width, bool omp) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, - constantvalues_, returnvals_, ActiveReturn, + constantvalues_, activevals_, ActiveReturn, constant_values, origToNew_, origToNewOps_, mode, width, omp) {} diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 02a3d3d956d6..65d0349f1476 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -36,10 +36,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_) - : newFunc(newFunc_), oldFunc(oldFunc_), Logic(Logic), mode(mode_), - originalToNewFn(originalToNewFn_), - originalToNewFnOps(originalToNewFnOps_), TA(TA_), width(width), - ArgDiffeTypes(ArgDiffeTypes_), symbolTable(symbolTable_) { + : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, + invertedPointers_, constantvalues_, activevals_, + ReturnActivity, ArgDiffeTypes_, originalToNewFn_, + originalToNewFnOps_, mode_, width, /*omp*/ false), + symbolTable(symbolTable_) { initInitializationBlock(invertedPointers_, ArgDiffeTypes_); } @@ -136,43 +137,6 @@ Value mlir::enzyme::MGradientUtilsReverse::insertInitShadowedGradient( return gradient; } -Value mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - const mlir::Value originst) const { - if (!originalToNewFn.contains(originst)) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - llvm::errs() << originst << "\n"; - llvm_unreachable("Could not get new val from original"); - } - return originalToNewFn.lookupOrNull(originst); -} - -Block *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - mlir::Block *originst) const { - if (!originalToNewFn.contains(originst)) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - llvm::errs() << originst << "\n"; - llvm_unreachable("Could not get new blk from original"); - } - return originalToNewFn.lookupOrNull(originst); -} - -Operation *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - Operation *originst) const { - auto found = originalToNewFnOps.find(originst); - if (found == originalToNewFnOps.end()) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - for (auto &pair : originalToNewFnOps) { - llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n"; - } - llvm::errs() << originst << " - " << *originst << "\n"; - llvm_unreachable("Could not get new op from original"); - } - return found->second; -} - Operation * mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, Operation *op) { @@ -182,24 +146,6 @@ mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, return B.clone(*op, map); } -bool mlir::enzyme::MGradientUtilsReverse::isConstantInstruction( - Operation *op) const { - return false; -} - -bool mlir::enzyme::MGradientUtilsReverse::isConstantValue(Value v) const { - if (isa(v.getType())) - return true; - if (isa(v.getType())) - return true; - - if (matchPattern(v, m_Constant())) - return true; - - // TODO - return false; -} - bool mlir::enzyme::MGradientUtilsReverse::requiresShadow(Type t) { if (auto iface = dyn_cast(t)) { return iface.requiresShadow(); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index 474badb034c9..3d5d4d147269 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -15,10 +15,12 @@ #include +#include "GradientUtils.h" + namespace mlir { namespace enzyme { -class MGradientUtilsReverse { +class MGradientUtilsReverse : public MDiffeGradientUtils { public: MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, @@ -32,13 +34,6 @@ class MGradientUtilsReverse { DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_); - // From CacheUtility - FunctionOpInterface newFunc; - FunctionOpInterface oldFunc; - - MEnzymeLogic &Logic; - bool AtomicAdd; - DerivativeMode mode; IRMapping invertedPointersGlobal; IRMapping invertedPointersShadow; IRMapping shadowValues; @@ -47,26 +42,8 @@ class MGradientUtilsReverse { IRMapping mapReverseModeBlocks; DenseMap>> mapBlockArguments; - IRMapping originalToNewFn; - std::map originalToNewFnOps; - - MTypeAnalysis &TA; - - unsigned width; - ArrayRef ArgDiffeTypes; - SymbolTableCollection &symbolTable; - mlir::Value getNewFromOriginal(const mlir::Value originst) const; - mlir::Block *getNewFromOriginal(mlir::Block *originst) const; - Operation *getNewFromOriginal(Operation *originst) const; - - void erase(Operation *op) { op->erase(); } - void eraseIfUnused(Operation *op, bool erase = true, bool check = true) { - // TODO - } - bool isConstantValue(mlir::Value v) const; - bool isConstantInstruction(mlir::Operation *v) const; bool hasInvertPointer(mlir::Value v); mlir::Value invertPointerM(mlir::Value v, OpBuilder &builder); mlir::Value diffe(mlir::Value v, OpBuilder &builder); From 724756e5d770f3cb7cee6dc17ea292443a55eb05 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 31 Jan 2024 19:55:18 +0100 Subject: [PATCH 034/131] Reimplement MLIR tangent for scf.for generically (#1649) * Reimplement MLIR tangent for scf.for generically Introduce an interface for ops with region control flow and use that interface inside the tangent implementation for scf.for. This interface can also work for other kinds of control flow ops as long as they don't transfer control flow between regions (e.g. like scf.while does). Further cleanup is necessary to clearly delegate implementation of AutoDiffOpInterface to this new interface for any compatible op via traits. * Generalize * fixup * Decrease the amount of tablegen'd code * Add some documentation and use better types * add forward diff support for affine.for and affine.if * bazel * add forward diff rule for scf.while Fix a minor bug discovered in the process: the interface was designed for constant folding and expects a list of null attributes of the appropriate size... * add executeop * add cf * fixup --------- Co-authored-by: William S. Moses --- enzyme/BUILD | 32 +++ .../AffineAutoDiffOpInterfaceImpl.cpp | 64 +++++ .../MLIR/Implementations/AffineDerivatives.td | 24 ++ .../CFAutoDiffOpInterfaceImpl.cpp | 41 +++ .../MLIR/Implementations/CFDerivatives.td | 4 + .../MLIR/Implementations/CMakeLists.txt | 17 ++ enzyme/Enzyme/MLIR/Implementations/Common.td | 16 ++ .../CoreDialectsAutoDiffImplementations.cpp | 246 ++++++++++++++++++ .../CoreDialectsAutoDiffImplementations.h | 87 +++++++ .../SCFAutoDiffOpInterfaceImpl.cpp | 94 +------ .../MLIR/Implementations/SCFDerivatives.td | 50 ++++ .../MLIR/Interfaces/AutoDiffOpInterface.td | 36 +++ enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 37 +-- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 4 +- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 8 + enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 2 + enzyme/test/MLIR/ForwardMode/affine.mlir | 104 ++++++++ .../branch-self-recursive.mlir | 0 .../test/MLIR/{ => ForwardMode}/branch.mlir | 0 enzyme/test/MLIR/ForwardMode/executeop.mlir | 60 +++++ enzyme/test/MLIR/{ => ForwardMode}/for.mlir | 0 enzyme/test/MLIR/ForwardMode/for2.mlir | 30 +++ enzyme/test/MLIR/ForwardMode/if1.mlir | 41 +++ .../test/MLIR/{ => ForwardMode}/inactive.mlir | 0 .../test/MLIR/{ => ForwardMode}/invalid.mlir | 0 enzyme/test/MLIR/{ => ForwardMode}/llvm.mlir | 0 .../test/MLIR/{ => ForwardMode}/memref.mlir | 0 enzyme/test/MLIR/{ => ForwardMode}/test.mlir | 0 enzyme/test/MLIR/ForwardMode/while.mlir | 44 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 39 +++ 30 files changed, 957 insertions(+), 123 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td create mode 100644 enzyme/test/MLIR/ForwardMode/affine.mlir rename enzyme/test/MLIR/{ => ForwardMode}/branch-self-recursive.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/branch.mlir (100%) create mode 100644 enzyme/test/MLIR/ForwardMode/executeop.mlir rename enzyme/test/MLIR/{ => ForwardMode}/for.mlir (100%) create mode 100644 enzyme/test/MLIR/ForwardMode/for2.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/if1.mlir rename enzyme/test/MLIR/{ => ForwardMode}/inactive.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/invalid.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/llvm.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/memref.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/test.mlir (100%) create mode 100644 enzyme/test/MLIR/ForwardMode/while.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index 8954d5433e0a..f091b06bd8f5 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -384,6 +384,22 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) +gentbl( + name = "affine-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/AffineDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/AffineDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/AffineDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + gentbl( name = "arith-derivatives", tbl_outs = [( @@ -426,6 +442,20 @@ gentbl( ], ) +gentbl( + name = "scf-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/SCFDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/SCFDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/SCFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -447,9 +477,11 @@ cc_library( includes = ["Enzyme/MLIR", "Enzyme"], visibility = ["//visibility:public"], deps = [ + ":affine-derivatives", ":arith-derivatives", ":llvm-derivatives", ":nvvm-derivatives", + ":scf-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..c27f0d60d129 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,64 @@ +//===- AffineAutoDiffOpInterfaceImpl.cpp - Interface external model -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR Affine dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/IntegerSet.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +affine::AffineForOp +createAffineForWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, TypeRange rettys) { + affine::AffineForOpAdaptor adaptor(remappedOperands, + cast(original)); + auto repFor = builder.create( + original->getLoc(), adaptor.getLowerBoundOperands(), + adaptor.getLowerBoundMap(), adaptor.getUpperBoundOperands(), + adaptor.getUpperBoundMap(), adaptor.getStep().getZExtValue(), + // This dance is necessary because the adaptor accessors are based on the + // internal attribute containing the number of operands associated with + // each named operand group. This attribute is carried over from the + // original operation and does not account for the shadow-related iter + // args. Instead, assume lower/upper bound operands must not have shadows + // since they are integer-typed and take the result of operands as iter + // args. + remappedOperands.drop_front(adaptor.getLowerBoundOperands().size() + + adaptor.getUpperBoundOperands().size())); + return repFor; +} + +affine::AffineIfOp createAffineIfWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, + affine::AffineIfOp original, + ValueRange remappedOperands, + TypeRange rettys) { + affine::AffineIfOpAdaptor adaptor(remappedOperands, original); + return builder.create( + original->getLoc(), rettys, original.getIntegerSet(), + adaptor.getOperands(), !original.getElseRegion().empty()); +} + +#include "Implementations/AffineDerivatives.inc" +} // namespace + +void mlir::enzyme::registerAffineDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, affine::AffineDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td new file mode 100644 index 000000000000..d6e866f3089e --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td @@ -0,0 +1,24 @@ +include "Common.td" + +def : ControlFlowOp<"affine", "AffineForOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return createAffineForWithShadows(op, builder, gutils, original, + remappedOperands, rettys); + } +}]>; + +def : ControlFlowOp<"affine", "AffineIfOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return createAffineIfWithShadows(op, builder, gutils, + cast(original), + remappedOperands, rettys); + } +}]>; + +def : RegionTerminatorOp<"affine", "AffineYieldOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..8f40db9d834d --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,41 @@ +//===- SCFAutoDiffOpInterfaceImpl.cpp - Interface external model ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR SCF dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/EnzymeLogic.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/CFDerivatives.inc" +} // namespace + +void mlir::enzyme::registerCFDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, cf::ControlFlowDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td new file mode 100644 index 000000000000..8b4f41696a9d --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td @@ -0,0 +1,4 @@ +include "Common.td" + +def : BranchOp<"cf", "BranchOp">; +def : BranchOp<"cf", "SwitchOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index 521ba76c22bc..8e62f5849be1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -1,3 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS AffineDerivatives.td) +enzyme_tablegen(AffineDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(AffineDerivativesIncGen) set(LLVM_TARGET_DEFINITIONS ArithDerivatives.td) enzyme_tablegen(ArithDerivatives.inc -gen-mlir-derivatives) @@ -11,20 +14,34 @@ set(LLVM_TARGET_DEFINITIONS NVVMDerivatives.td) enzyme_tablegen(NVVMDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(NVVMDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS SCFDerivatives.td) +enzyme_tablegen(SCFDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(SCFDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS CFDerivatives.td) +enzyme_tablegen(CFDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(CFDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations + AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp + CoreDialectsAutoDiffImplementations.cpp LLVMAutoDiffOpInterfaceImpl.cpp NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp + CFAutoDiffOpInterfaceImpl.cpp DEPENDS MLIRAutoDiffOpInterfaceIncGen + AffineDerivativesIncGen ArithDerivativesIncGen LLVMDerivativesIncGen NVVMDerivativesIncGen + SCFDerivativesIncGen + CFDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 3909405320d1..c09cac32f734 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -3,6 +3,22 @@ class InactiveOp { string opName = opName_; } +class ControlFlowOp { + string dialect = dialect_; + string opName = opName_; + string impl = impl_; +} + +class BranchOp { + string dialect = dialect_; + string opName = opName_; +} + +class RegionTerminatorOp { + string dialect = dialect_; + string opName = opName_; +} + class MLIRDerivative resultOps> { string dialect = dialect_; string opName = opName_; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp new file mode 100644 index 000000000000..2a50d1c481c0 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -0,0 +1,246 @@ +//===- CoreDialectsAutoDiffImplementations.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains common utilities for the external model implementation of +// the automatic differentiation op interfaces for upstream MLIR dialects. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" + +using namespace mlir; +using namespace mlir::enzyme; + +void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, + OpBuilder &builder, + MGradientUtils *gutils) { + auto newInst = gutils->getNewFromOriginal(inst); + + auto binst = cast(inst); + + // TODO generalize to cloneWithNewBlockArgs interface + SmallVector newVals; + + SmallVector segSizes; + // Keep non-differentiated, non-forwarded operands + size_t non_forwarded = 0; + for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { + auto ops = binst.getSuccessorOperands(i).getForwardedOperands(); + if (ops.empty()) + continue; + non_forwarded = ops.getBeginOperandIndex(); + break; + } + + for (size_t i = 0; i < non_forwarded; i++) + newVals.push_back(gutils->getNewFromOriginal(binst->getOperand(i))); + + segSizes.push_back(newVals.size()); + for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { + size_t cur = newVals.size(); + auto ops = binst.getSuccessorOperands(i).getForwardedOperands(); + for (auto &&[idx, op] : llvm::enumerate(ops)) { + auto arg = + *binst.getSuccessorBlockArgument(ops.getBeginOperandIndex() + idx); + newVals.push_back(gutils->getNewFromOriginal(op)); + if (!gutils->isConstantValue(arg)) { + if (!gutils->isConstantValue(op)) { + newVals.push_back(gutils->invertPointerM(op, builder)); + } else { + Type retTy = + arg.getType().cast().getShadowType(); + auto toret = retTy.cast().createNullValue( + builder, op.getLoc()); + newVals.push_back(toret); + } + } + } + cur = newVals.size() - cur; + segSizes.push_back(cur); + } + + SmallVector attrs(newInst->getAttrs()); + bool has_cases = false; + for (auto &attr : attrs) { + if (attr.getName() == "case_operand_segments") { + has_cases = true; + } + } + for (auto &attr : attrs) { + if (attr.getName() == "operandSegmentSizes") { + if (!has_cases) { + attr.setValue(builder.getDenseI32ArrayAttr(segSizes)); + } else { + SmallVector segSlices2(segSizes.begin(), segSizes.begin() + 2); + segSlices2.push_back(0); + for (size_t i = 2; i < segSizes.size(); i++) + segSlices2[2] += segSizes[i]; + attr.setValue(builder.getDenseI32ArrayAttr(segSlices2)); + } + } + if (attr.getName() == "case_operand_segments") { + SmallVector segSlices2(segSizes.begin() + 2, segSizes.end()); + attr.setValue(builder.getDenseI32ArrayAttr(segSlices2)); + } + } + + gutils->getNewFromOriginal(inst->getBlock()) + ->push_back( + newInst->create(newInst->getLoc(), newInst->getName(), TypeRange(), + newVals, attrs, OpaqueProperties(nullptr), + newInst->getSuccessors(), newInst->getNumRegions())); + gutils->erase(newInst); + return; +} + +void mlir::enzyme::detail::regionTerminatorForwardHandler( + Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { + auto termIface = cast(origTerminator); + auto parentOp = termIface->getParentOp(); + + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + llvm::SmallDenseSet operandsToShadow; + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[i, target] : llvm::enumerate(targetValues)) { + if (!gutils->isConstantValue(target)) + operandsToShadow.insert(operandRange.getBeginOperandIndex() + i); + } + } + SmallVector newOperands; + newOperands.reserve(termIface->getNumOperands() + operandsToShadow.size()); + for (OpOperand &operand : termIface->getOpOperands()) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + if (operandsToShadow.contains(operand.getOperandNumber())) + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + + // Assuming shadows following the originals are fine. + // TODO: consider extending to have a ShadowableTerminatorOpInterface + Operation *replTerminator = gutils->getNewFromOriginal(origTerminator); + Operation *newTerminator = builder.clone(*replTerminator); + newTerminator->setOperands(newOperands); + gutils->erase(replTerminator); +} + +LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils) { + // For all active results, add shadow types. + // For now, assuming all results are relevant. + Operation *newOp = gutils->getNewFromOriginal(op); + SmallVector newOpResultTypes; + newOpResultTypes.reserve(op->getNumResults() * 2); + for (Value result : op->getResults()) { + // TODO only if used (can we DCE the primal after having done the + // derivative). + newOpResultTypes.push_back(result.getType()); + if (gutils->isConstantValue(result)) + continue; + auto typeIface = dyn_cast(result.getType()); + if (!typeIface) + return failure(); + newOpResultTypes.push_back(typeIface.getShadowType()); + } + + // For all operands that are forwarded to the body, if they are active, also + // add the shadow as operand. + auto regionBranchOp = dyn_cast(op); + if (!regionBranchOp) + return failure(); + + SmallVector successors; + // TODO: we may need to record, for every successor, which of its inputs + // need a shadow to recreate the body correctly. + llvm::SmallDenseSet operandPositionsToShadow; + regionBranchOp.getEntrySuccessorRegions( + SmallVector(op->getNumOperands(), Attribute()), successors); + for (const RegionSuccessor &successor : successors) { + if (!successor.isParent() && successor.getSuccessor()->empty()) + continue; + + OperandRange operandRange = + regionBranchOp.getEntrySuccessorOperands(successor); + + // Need to know which of the arguments are being forwarded to from + // operands. + for (auto &&[i, regionValue, operand] : + llvm::enumerate(successor.getSuccessorInputs(), operandRange)) { + if (gutils->isConstantValue(regionValue)) + continue; + operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); + } + } + SmallVector newOperands; + newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size()); + for (OpOperand &operand : op->getOpOperands()) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + if (operandPositionsToShadow.contains(operand.getOperandNumber())) + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + // We are assuming the op can forward additional operands, listed + // immediately after the original operands, to the same regions. + // ^^ + // Our interface guarantees this. + // We also assume that the region-holding op returns all of the values + // yielded by terminators, and only those values. + + auto iface = dyn_cast(op); + if (!iface) + return failure(); + Operation *replacement = iface.createWithShadows( + builder, gutils, op, newOperands, newOpResultTypes); + for (auto &&[region, replacementRegion] : + llvm::zip(newOp->getRegions(), replacement->getRegions())) { + replacementRegion.takeBody(region); + } + + // Inject the mapping for the new results into GradientUtil's shadow + // table. + SmallVector reps; + size_t idx = 0; + for (Value r : op->getResults()) { + // TODO only if used + reps.push_back(replacement->getResult(idx)); + idx++; + if (!gutils->isConstantValue(r)) { + auto inverted = gutils->invertedPointers.lookupOrNull(r); + assert(inverted); + gutils->invertedPointers.map(r, replacement->getResult(idx)); + inverted.replaceAllUsesWith(replacement->getResult(idx)); + gutils->erase(inverted.getDefiningOp()); + idx++; + } + } + + // Differentiate body. + for (auto &origRegion : op->getRegions()) { + for (auto &origBlock : origRegion) { + for (Operation &o : origBlock) { + if (failed(gutils->visitChild(&o))) + return failure(); + } + } + } + + // Replace all uses of original results + gutils->replaceOrigOpWith(op, reps); + gutils->erase(newOp); + + return success(); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 56af04b30133..b98a6ef10c39 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -12,16 +12,103 @@ // //===----------------------------------------------------------------------===// +#include "Interfaces/AutoDiffOpInterface.h" +#include "mlir/Support/LogicalResult.h" + namespace mlir { class DialectRegistry; +class Operation; +class OpBuilder; namespace enzyme { +class MGradientUtils; + +namespace detail { +// Non-template implementation of +// AutoDiffUsingControlFlow::createForwardModeTangent. +LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements forward-mode differentiation of branching operations. +// Assumes that successive shadows are legal +void branchingForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements forward-mode differentiation of region-terminator operations. +// Assumes that successive shadows are legal +void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their control flow and differentiating the +// nested operations. +template +class AutoDiffUsingControlFlow + : public AutoDiffOpInterface::ExternalModel, + OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + return controlFlowForwardHandler(op, builder, gutils); + } +}; + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their branching properties. +template +class AutoDiffUsingBranch + : public AutoDiffOpInterface::ExternalModel, + OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + branchingForwardHandler(op, builder, gutils); + return success(); + } +}; + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their region terminator properties. +template +class AutoDiffUsingRegionTerminator + : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingRegionTerminator, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + regionTerminatorForwardHandler(op, builder, gutils); + return success(); + } +}; +} // namespace detail + +// Registers AutoDiffUsingControlFlow for the given op. +template +void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); +} +// Registers AutoDiffUsingBranch for the given op. +template +void registerAutoDiffUsingBranchInterface(MLIRContext &context) { + OpTy::template attachInterface>(context); +} +// Registers AutoDiffUsingRegionTerminator for the given op. +template +void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); +} + +// Interface registration hooks for individual upstream dialects. +void registerAffineDialectAutoDiffInterface(DialectRegistry ®istry); void registerArithDialectAutoDiffInterface(DialectRegistry ®istry); void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry); void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); +void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 71d3db0c2905..bab415ff247f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -19,99 +19,18 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include using namespace mlir; using namespace mlir::enzyme; namespace { -struct ForOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto forOp = cast(op); - auto nFor = cast(gutils->getNewFromOriginal(op)); - // Get a list of all the return types, which is the original return types - // alongside any shadow return types - SmallVector nTypes; - for (auto r : forOp->getResults()) { - // TODO only if used - nTypes.push_back(r.getType()); - if (!gutils->isConstantValue(r)) { - auto adTypeIface = r.getType().dyn_cast(); - if (!adTypeIface) - return failure(); - nTypes.push_back(adTypeIface.getShadowType()); - } - } - - // Get a list of all args, which is original args, and any shadows - SmallVector nArgs; - for (const auto &[initVal, iterArg] : - llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs())) { - // TODO only if used - nArgs.push_back(gutils->getNewFromOriginal(initVal)); - if (!gutils->isConstantValue(iterArg)) - nArgs.push_back(gutils->invertPointerM(initVal, builder)); - } - - // Create the new modified for loop - auto repFor = builder.create( - forOp.getLoc(), gutils->getNewFromOriginal(forOp.getLowerBound()), - gutils->getNewFromOriginal(forOp.getUpperBound()), - gutils->getNewFromOriginal(forOp.getStep()), nArgs); - repFor.getRegion().takeBody(nFor.getRegion()); - - // Inject the mapping for the new results into GradientUtil's shadow - // table - SmallVector reps; - size_t idx = 0; - for (Value r : forOp.getResults()) { - // TODO only if used - reps.push_back(repFor.getResult(idx)); - idx++; - if (!gutils->isConstantValue(r)) { - auto inverted = gutils->invertedPointers.lookupOrNull(r); - assert(inverted); - gutils->invertedPointers.map(r, repFor.getResult(idx)); - inverted.replaceAllUsesWith(repFor.getResult(idx)); - gutils->erase(inverted.getDefiningOp()); - idx++; - } - } - - // Replace all uses of original results - nFor.replaceAllUsesWith(reps); - gutils->erase(nFor); - - // differentiate body - for (Operation &o : - llvm::make_early_inc_range(forOp.getBody()->without_terminator())) { - if (failed(gutils->visitChild(&o))) - return failure(); - } - - // Fix terminator (yield) operations - Operation *oldYield = repFor.getBody()->getTerminator(); - builder.setInsertionPointToEnd(repFor.getBody()); - SmallVector nYields; - for (const auto &[result, yieldOperand] : - llvm::zip(forOp.getResults(), - forOp.getBody()->getTerminator()->getOperands())) { - // TODO only if used - nYields.push_back(gutils->getNewFromOriginal(yieldOperand)); - if (!gutils->isConstantValue(result)) - nYields.push_back(gutils->invertPointerM(yieldOperand, builder)); - } - Operation *newYield = builder.clone(*oldYield); - newYield->setOperands(nYields); - gutils->erase(oldYield); - - // Done - return success(); - } -}; +#include "Implementations/SCFDerivatives.inc" struct ForOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel(*context); - + registerInterfaces(context); scf::ForOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td new file mode 100644 index 000000000000..4c9ee09abcd2 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td @@ -0,0 +1,50 @@ +include "Common.td" + +def : ControlFlowOp<"scf", "ForOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + scf::ForOpAdaptor adaptor(remappedOperands); + auto repFor = builder.create( + op->getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), adaptor.getInitArgs()); + return repFor; + } +}]>; + +def : ControlFlowOp<"scf", "IfOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + scf::IfOpAdaptor adaptor(remappedOperands); + auto repIf = builder.create( + op->getLoc(), rettys, adaptor.getCondition()); + return repIf; + } +}]>; + +def : ControlFlowOp<"scf", "WhileOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return builder.create(original->getLoc(), rettys, + remappedOperands, original->getAttrs()); + } +}]>; + +def : ControlFlowOp<"scf", "ExecuteRegionOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + auto repIf = builder.create( + op->getLoc(), rettys); + return repIf; + } +}]>; + +def : RegionTerminatorOp<"scf", "YieldOp">; +def : RegionTerminatorOp<"scf", "ConditionOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index f45b641f9cb8..5d764c8dc7bb 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -40,6 +40,42 @@ def AutoDiffOpInterface : OpInterface<"AutoDiffOpInterface"> { ]; } +def ControlFlowAutoDiffOpInterface + : OpInterface<"ControlFlowAutoDiffOpInterface"> { + let description = [{ + A differentiable MLIR operation whose forward differentiation rules are + driven by how control flows through the operation. + + There are two general assumptions: + - the operation can communicate additional values along the control flow + edges, which is used to put shadow values immediately after the primal + values; + - all values returned by the operation are yielded by all region-exiting + terminators. + }]; + let cppNamespace = "::mlir::enzyme"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Creates a copy of this operation additionally carrying required shadow + values along control flow edges using the given builder. The `original` + is the operation in the original primal code prior to differentiation, + and this method is supposed to be called on the operation in the cloned + function being constructed. Remapped operands contains a flat list of + operands usable in the cloned function and can be fed to the Adaptor + constructor. + }], + /*retTy=*/"::mlir::Operation *", + /*methodName=*/"createWithShadows", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::enzyme::MGradientUtils *":$gutils, + "::mlir::Operation *":$original, + "::mlir::ValueRange":$remappedOperands, + "::mlir::TypeRange":$returnTypes) + > + ]; +} + def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { let description = [{ A differentiable MLIR operation that is able to emit reverse mode adjoints for itself. diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index a92a5f3386ff..1da6904893f8 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -1,4 +1,5 @@ #include "Dialect/Ops.h" +#include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" @@ -21,7 +22,7 @@ using namespace mlir; using namespace mlir::enzyme; -void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB, +void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, DIFFE_TYPE retType, ReturnType retVal) { auto inst = oBB->getTerminator(); @@ -33,39 +34,7 @@ void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB, nBuilder.setInsertionPointToEnd(nBB); if (auto binst = dyn_cast(inst)) { - // TODO generalize to cloneWithNewBlockArgs interface - SmallVector newVals; - - SmallVector segSizes; - for (size_t i = 0, len = binst.getSuccessorOperands(0) - .getForwardedOperands() - .getBeginOperandIndex(); - i < len; i++) - newVals.push_back(gutils->getNewFromOriginal(binst->getOperand(i))); - segSizes.push_back(newVals.size()); - for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { - size_t cur = newVals.size(); - for (auto op : binst.getSuccessorOperands(i).getForwardedOperands()) { - newVals.push_back(gutils->getNewFromOriginal(op)); - if (!gutils->isConstantValue(op)) { - newVals.push_back(gutils->invertPointerM(op, nBuilder)); - } - } - cur = newVals.size() - cur; - segSizes.push_back(cur); - } - - SmallVector attrs(newInst->getAttrs()); - for (auto &attr : attrs) { - if (attr.getName() == "operandSegmentSizes") - attr.setValue(nBuilder.getDenseI32ArrayAttr(segSizes)); - } - - nBB->push_back( - newInst->create(newInst->getLoc(), newInst->getName(), TypeRange(), - newVals, attrs, OpaqueProperties(nullptr), - newInst->getSuccessors(), newInst->getNumRegions())); - gutils->erase(newInst); + mlir::enzyme::detail::branchingForwardHandler(inst, nBuilder, gutils); return; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 286456b2d039..20ed156d312d 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -306,7 +306,9 @@ LogicalResult MGradientUtils::visitChild(Operation *op) { // In absence of a proper activity analysis, approximate it by treating any // side effect-free operation producing constants as inactive. // if (auto iface = dyn_cast(op)) { - if (llvm::all_of(op->getResults(), + if (!isa(op) && + !isa(op) && + llvm::all_of(op->getResults(), [this](Value v) { return isConstantValue(v); }) && /*iface.hasNoEffect()*/ activityAnalyzer->isConstantOperation(TR, op)) { return success(); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 094228232346..7ef99a1014fd 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -54,6 +54,14 @@ class MGradientUtils { std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp); void erase(Operation *op) { op->erase(); } + void replaceOrigOpWith(Operation *op, ValueRange vals) { + for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) { + originalToNewFn.map(res, rep); + } + auto newOp = getNewFromOriginal(op); + newOp->replaceAllUsesWith(vals); + originalToNewFnOps.erase(op); + } void eraseIfUnused(Operation *op, bool erase = true, bool check = true) { // TODO } diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index b6bb33c51df9..e7af0b01267f 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -97,12 +97,14 @@ int main(int argc, char **argv) { }); // Register the autodiff interface implementations for upstream dialects. + enzyme::registerAffineDialectAutoDiffInterface(registry); enzyme::registerArithDialectAutoDiffInterface(registry); enzyme::registerBuiltinDialectAutoDiffInterface(registry); enzyme::registerLLVMDialectAutoDiffInterface(registry); enzyme::registerNVVMDialectAutoDiffInterface(registry); enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); + enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir new file mode 100644 index 000000000000..a091d6d774ac --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -0,0 +1,104 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @loop(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %r = affine.for %arg1 = 0 to 10 step 1 iter_args(%arg2 = %cst) -> (f64) { + %n = arith.addf %arg2, %x : f64 + affine.yield %n : f64 + } + return %r : f64 + } + func.func @dloop(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @loop(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeloop + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, f64) { + // CHECK: %[[v1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 + // CHECK: %[[v2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 + // CHECK: affine.yield %[[v2]], %[[v1]] : f64, f64 + // CHECK: } + // CHECK: return %[[r0]]#1 : f64 + + func.func @if_then_else(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dif_then_else(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @if_then_else(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeif_then_else + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) + // CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, f64, f64) { + // CHECK: %[[v3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v5:.+]] = arith.addf %[[v3]], %[[v4]] : f64 + // CHECK: %[[v6:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v6]], %[[v5]], %[[cst]] : f64, f64, f64 + // CHECK: } else { + // CHECK: %[[v3:.+]] = arith.addf %[[arg1]], %[[arg1]] : f64 + // CHECK: %[[v4:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v4]], %[[v3]], %[[cst_0]] : f64, f64, f64 + // CHECK: } + // CHECK: %[[v1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 + // CHECK: %[[v2:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 + // CHECK: return %[[v1]] : f64 + + func.func @if_then(%x : f64, %c : i1) -> f64 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %mem = memref.alloc() : memref<1xf64> + memref.store %c2, %mem[%c0] : memref<1xf64> + scf.if %c { + %mul = arith.mulf %x, %x : f64 + memref.store %mul, %mem[%c0] : memref<1xf64> + } + %r = memref.load %mem[%c0] : memref<1xf64> + %res = arith.mulf %c2, %r : f64 + return %res : f64 + } + func.func @dif_then(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @if_then(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeif_then + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { + // CHECK: %[[c0:.+]] = arith.constant 0 : index + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[cst_1:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[alloc:.+]] = memref.alloc() : memref<1xf64> + // CHECK: %[[alloc_2:.+]] = memref.alloc() : memref<1xf64> + // CHECK: memref.store %[[cst]], %[[alloc]][%[[c0]]] : memref<1xf64> + // CHECK: memref.store %[[cst_0]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: scf.if %[[arg2]] { + // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v5:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v6:.+]] = arith.addf %[[v4]], %[[v5]] : f64 + // CHECK: %[[v7:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 + // CHECK: memref.store %[[v6]], %[[alloc]][%[[c0]]] : memref<1xf64> + // CHECK: memref.store %[[v7]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: } + // CHECK: %[[v0:.+]] = memref.load %[[alloc]][%[[c0]]] : memref<1xf64> + // CHECK: %[[v1:.+]] = memref.load %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst_0]] : f64 + // CHECK: %[[v3:.+]] = arith.mulf %[[cst_0]], %[[v1]] : f64 + // CHECK: return %[[v2]] : f64 +} diff --git a/enzyme/test/MLIR/branch-self-recursive.mlir b/enzyme/test/MLIR/ForwardMode/branch-self-recursive.mlir similarity index 100% rename from enzyme/test/MLIR/branch-self-recursive.mlir rename to enzyme/test/MLIR/ForwardMode/branch-self-recursive.mlir diff --git a/enzyme/test/MLIR/branch.mlir b/enzyme/test/MLIR/ForwardMode/branch.mlir similarity index 100% rename from enzyme/test/MLIR/branch.mlir rename to enzyme/test/MLIR/ForwardMode/branch.mlir diff --git a/enzyme/test/MLIR/ForwardMode/executeop.mlir b/enzyme/test/MLIR/ForwardMode/executeop.mlir new file mode 100644 index 000000000000..696b07b5f7ea --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/executeop.mlir @@ -0,0 +1,60 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i32) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.execute_region -> (f64, f64) { + cf.switch %c : i32, [ + default: ^bb5, + 104: ^bb3, + 113: ^bb4(%c10 : f64) + ] + ^bb4(%z : f64): // pred: ^bb2 + %x2 = arith.mulf %x, %x : f64 + scf.yield %x2, %z : f64, f64 + ^bb3: + %x3 = arith.addf %x, %x : f64 + scf.yield %x3, %c2 : f64, f64 + ^bb5: + cf.br ^bb4(%x : f64) + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : f64, %c : i32) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i32) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[x:.+]]: f64, %[[dx:.+]]: f64, %[[c:.+]]: i32) -> f64 { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-DAG: %[[cst0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[r0:.+]]:4 = scf.execute_region -> (f64, f64, f64, f64) { +// CHECK-NEXT: %[[cst02:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: cf.switch %[[c]] : i32, [ +// CHECK-NEXT: default: ^bb3, +// CHECK-NEXT: 104: ^bb2, +// CHECK-NEXT: 113: ^bb1(%[[cst10]], %[[cst02]] : f64, f64) +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1(%[[a3:.+]]: f64, %[[da3:.+]]: f64): // 2 preds: ^bb0, ^bb3 +// CHECK-NEXT: %[[a4:.+]] = arith.mulf %[[dx]], %[[x]] : f64 +// CHECK-NEXT: %[[a5:.+]] = arith.mulf %[[dx]], %[[x]] : f64 +// CHECK-NEXT: %[[a6:.+]] = arith.addf %[[a4]], %[[a5]] : f64 +// CHECK-NEXT: %[[a7:.+]] = arith.mulf %[[x]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[a7]], %[[a6]], %[[a3]], %[[da3]] : f64, f64, f64, f64 +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: %[[b8:.+]] = arith.addf %[[dx]], %[[dx]] : f64 +// CHECK-NEXT: %[[b9:.+]] = arith.addf %[[x]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[b9]], %[[b8]], %[[cst2]], %[[cst0]] : f64, f64, f64, f64 +// CHECK-NEXT: ^bb3: // pred: ^bb0 +// CHECK-NEXT: cf.br ^bb1(%[[x]], %[[dx]] : f64, f64) +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#3, %[[r0]]#0 : f64 +// CHECK-NEXT: %[[r3:.+]] = arith.addf %[[r1]], %[[r2]] : f64 +// CHECK-NEXT: %[[r4:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r3]] : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/for.mlir b/enzyme/test/MLIR/ForwardMode/for.mlir similarity index 100% rename from enzyme/test/MLIR/for.mlir rename to enzyme/test/MLIR/ForwardMode/for.mlir diff --git a/enzyme/test/MLIR/ForwardMode/for2.mlir b/enzyme/test/MLIR/ForwardMode/for2.mlir new file mode 100644 index 000000000000..7d4e7608b98e --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/for2.mlir @@ -0,0 +1,30 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %x) -> (f64) { + %n = arith.addf %arg2, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[arg0]], %[[arg4:.+]] = %[[arg1]]) -> (f64, f64) { +// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[i0]]#1 : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/if1.mlir b/enzyme/test/MLIR/ForwardMode/if1.mlir new file mode 100644 index 000000000000..3187bac56fa7 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/if1.mlir @@ -0,0 +1,41 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } +} + + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, f64, f64) { +// CHECK-NEXT: %[[t3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[t4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[t5:.+]] = arith.addf %[[t3]], %[[t4]] : f64 +// CHECK-NEXT: %[[t6:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[t6]], %[[t5]], %[[cst2]] : f64, f64, f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[e3:.+]] = arith.addf %arg1, %arg1 : f64 +// CHECK-NEXT: %[[e4:.+]] = arith.addf %arg0, %arg0 : f64 +// CHECK-NEXT: scf.yield %[[e4]], %[[e3]], %[[cst10]] : f64, f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r1]] : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/inactive.mlir b/enzyme/test/MLIR/ForwardMode/inactive.mlir similarity index 100% rename from enzyme/test/MLIR/inactive.mlir rename to enzyme/test/MLIR/ForwardMode/inactive.mlir diff --git a/enzyme/test/MLIR/invalid.mlir b/enzyme/test/MLIR/ForwardMode/invalid.mlir similarity index 100% rename from enzyme/test/MLIR/invalid.mlir rename to enzyme/test/MLIR/ForwardMode/invalid.mlir diff --git a/enzyme/test/MLIR/llvm.mlir b/enzyme/test/MLIR/ForwardMode/llvm.mlir similarity index 100% rename from enzyme/test/MLIR/llvm.mlir rename to enzyme/test/MLIR/ForwardMode/llvm.mlir diff --git a/enzyme/test/MLIR/memref.mlir b/enzyme/test/MLIR/ForwardMode/memref.mlir similarity index 100% rename from enzyme/test/MLIR/memref.mlir rename to enzyme/test/MLIR/ForwardMode/memref.mlir diff --git a/enzyme/test/MLIR/test.mlir b/enzyme/test/MLIR/ForwardMode/test.mlir similarity index 100% rename from enzyme/test/MLIR/test.mlir rename to enzyme/test/MLIR/ForwardMode/test.mlir diff --git a/enzyme/test/MLIR/ForwardMode/while.mlir b/enzyme/test/MLIR/ForwardMode/while.mlir new file mode 100644 index 000000000000..cbe7da34769b --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/while.mlir @@ -0,0 +1,44 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @while(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %r:2 = scf.while (%arg1 = %c0, %arg2 = %cst) : (index, f64) -> (index, f64) { + %1 = arith.cmpi slt, %arg1, %c10 : index + scf.condition(%1) %arg1, %arg2 : index, f64 + } do { + ^bb0(%arg1: index, %arg2: f64): + %1 = arith.addi %arg1, %c1 : index + %2 = arith.addf %arg2, %x : f64 + scf.yield %1, %2 : index, f64 + } + return %r#1 : f64 + } + func.func @dwhile(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @while(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffewhile + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[c0:.+]] = arith.constant 0 : index + // CHECK: %[[c1:.+]] = arith.constant 1 : index + // CHECK: %[[c10:.+]] = arith.constant 10 : index + // CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) : (index, f64, f64) -> (index, f64, f64) { + // CHECK: %[[v1:.+]] = arith.cmpi slt, %[[arg2]], %[[c10]] : index + // CHECK: scf.condition(%[[v1]]) %[[arg2]], %[[arg3]], %[[arg4]] : index, f64, f64 + // CHECK: } do { + // CHECK: ^bb0(%[[arg2:.+]]: index, %[[arg3:.+]]: f64, %[[arg4:.+]]: f64): + // CHECK: %[[v1:.+]] = arith.addi %[[arg2]], %[[c1]] : index + // CHECK: %[[v2:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 + // CHECK: %[[v3:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v1]], %[[v3]], %[[v2]] : index, f64, f64 + // CHECK: } + // CHECK: return %[[r0]]#2 : f64 + // CHECK: } +} diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 1f4a3dc1b892..58802117bf5b 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1888,6 +1888,25 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, "return true; }\n"; os << "};\n"; } + const auto &cfpatterns = + recordKeeper.getAllDerivedDefinitions("ControlFlowOp"); + for (auto &pattern : cfpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + auto impl = pattern->getValueAsString("impl"); + os << "struct " << opName << "CF : \n"; + os << " public " + "ControlFlowAutoDiffOpInterface::ExternalModel<" + << opName << "CF, " << dialect << "::" << opName << "> {\n"; + os << impl << "\n"; + os << "};\n"; + } + + const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp"); + + const auto ®tpatterns = + recordKeeper.getAllDerivedDefinitions("RegionTerminatorOp"); + os << "void registerInterfaces(MLIRContext* context) {\n"; for (Record *pattern : patterns) { auto opName = pattern->getValueAsString("opName"); @@ -1903,6 +1922,26 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "Activity>(*context);\n"; } + for (Record *pattern : cfpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "CF>(*context);\n"; + os << " registerAutoDiffUsingControlFlowInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : brpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingBranchInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : regtpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } os << "}\n"; } } From c8b777ef79a4151685bc0bbd148b149c49ceacfc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 1 Feb 2024 00:26:45 -0500 Subject: [PATCH 035/131] [MLIR] Add read interface fwd (#1652) --- enzyme/BUILD | 30 ++++++++++++++++ .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 24 ++++++++++--- .../MLIR/Implementations/AffineDerivatives.td | 2 ++ .../MLIR/Implementations/CMakeLists.txt | 5 +++ enzyme/Enzyme/MLIR/Implementations/Common.td | 6 ++++ .../CoreDialectsAutoDiffImplementations.cpp | 31 +++++++++++++++++ .../CoreDialectsAutoDiffImplementations.h | 24 +++++++++++++ .../LLVMAutoDiffOpInterfaceImpl.cpp | 23 +------------ .../MLIR/Implementations/LLVMDerivatives.td | 8 +++++ .../MemRefAutoDiffOpInterfaceImpl.cpp | 22 ++---------- .../MLIR/Implementations/MemRefDerivatives.td | 13 +++++++ .../MLIR/Interfaces/AutoDiffOpInterface.td | 2 +- enzyme/test/MLIR/ForwardMode/affine.mlir | 6 ++-- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 34 ++++++++++++++++++- 14 files changed, 179 insertions(+), 51 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td diff --git a/enzyme/BUILD b/enzyme/BUILD index f091b06bd8f5..7645c3a7ecf9 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -456,6 +456,34 @@ gentbl( ], ) +gentbl( + name = "cf-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/CFDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/CFDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/CFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "memref-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/MemRefDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/MemRefDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/MemRefDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -482,6 +510,8 @@ cc_library( ":llvm-derivatives", ":nvvm-derivatives", ":scf-derivatives", + ":cf-derivatives", + ":memref-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index bb8c8f2ccc9c..5f3cd22db72d 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -2820,8 +2820,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (UA != UseActivity::AllStores) { if (auto ifaceOp = dyn_cast(a)) { - if (ifaceOp.isArgInactive(parent)) - return true; + bool allInactive = true; + for (OpOperand &operand : a->getOpOperands()) { + if (parent == operand.get() && + !ifaceOp.isArgInactive(operand.getOperandNumber())) { + allInactive = false; + break; + } + } + if (allInactive) + continue; } } @@ -3394,8 +3402,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( } if (auto ifaceOp = dyn_cast(a)) { - if (ifaceOp.isArgInactive(val)) - return true; + bool allInactive = true; + for (OpOperand &operand : a->getOpOperands()) { + if (operand.get() == val && + !ifaceOp.isArgInactive(operand.getOperandNumber())) { + allInactive = false; + break; + } + } + if (allInactive) + continue; } if (isa(a)) { diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td index d6e866f3089e..aaad67ff367a 100644 --- a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td @@ -22,3 +22,5 @@ def : ControlFlowOp<"affine", "AffineIfOp", [{ }]>; def : RegionTerminatorOp<"affine", "AffineYieldOp">; +def : ReadOnlyIdentityOp<"affine", "AffineLoadOp", [0]>; +def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index 8e62f5849be1..464740d48d4b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -22,6 +22,10 @@ set(LLVM_TARGET_DEFINITIONS CFDerivatives.td) enzyme_tablegen(CFDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(CFDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS MemRefDerivatives.td) +enzyme_tablegen(MemRefDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(MemRefDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp @@ -42,6 +46,7 @@ add_mlir_library(MLIREnzymeImplementations NVVMDerivativesIncGen SCFDerivativesIncGen CFDerivativesIncGen + MemRefDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index c09cac32f734..4eff61d7380f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -9,6 +9,12 @@ class ControlFlowOp { string impl = impl_; } +class ReadOnlyIdentityOp diffargs_> { + string dialect = dialect_; + string opName = opName_; + list diffargs = diffargs_; +} + class BranchOp { string dialect = dialect_; string opName = opName_; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 2a50d1c481c0..0ee33c44f01c 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -101,6 +101,37 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, return; } +LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils) { + + auto iface = cast(orig); + + SmallVector newOperands; + newOperands.reserve(orig->getNumOperands()); + for (OpOperand &operand : orig->getOpOperands()) { + if (iface.isArgInactive(operand.getOperandNumber())) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + } else { + if (gutils->isConstantValue(operand.get())) + return failure(); + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + } + + // Assuming shadows following the originals are fine. + // TODO: consider extending to have a ShadowableTerminatorOpInterface + Operation *primal = gutils->getNewFromOriginal(orig); + Operation *shadow = builder.clone(*primal); + shadow->setOperands(newOperands); + for (auto &&[oval, sval] : + llvm::zip(orig->getResults(), shadow->getResults())) { + gutils->setDiffe(oval, sval, builder); + } + llvm::errs() << " shadow load: " << *shadow << "\n"; + + return success(); +} + void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { auto termIface = cast(origTerminator); diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index b98a6ef10c39..4c54baa24ed2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -39,6 +39,11 @@ void branchingForwardHandler(Operation *op, OpBuilder &builder, void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); +// Implements forward-mode differentiation of read-only (including read-none) +// operations which do not perform computatoin +LogicalResult readOnlyIdentityForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + // Implements the forward autodiff interface for operations whose derivatives // are can be inferred by analyzing their control flow and differentiating the // nested operations. @@ -80,6 +85,19 @@ class AutoDiffUsingRegionTerminator return success(); } }; + +// Implements the forward autodiff interface for operations which are +// read only and identity like (aka not computing sin of mem read). +template +class AutoDiffUsingReadOnlyIdentity + : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingReadOnlyIdentity, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + return readOnlyIdentityForwardHandler(op, builder, gutils); + } +}; } // namespace detail // Registers AutoDiffUsingControlFlow for the given op. @@ -99,6 +117,12 @@ void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { OpTy::template attachInterface>( context); } +// Registers AutoDiffUsingRegionTerminator for the given op. +template +void registerAutoDiffUsingReadOnlyIdentityInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); +} // Interface registration hooks for individual upstream dialects. void registerAffineDialectAutoDiffInterface(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 079dd1cb64e9..d7884fcc305e 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -35,27 +35,7 @@ struct InlineAsmActivityInterface auto str = asmOp.getAsmString(); return str.contains("cpuid") || str.contains("exit"); } - bool isArgInactive(Operation *op, mlir::Value) const { - return isInactive(op); - } -}; - -struct LoadOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto loadOp = cast(op); - if (!gutils->isConstantValue(loadOp)) { - Type shadowType = - cast(loadOp.getType()).getShadowType(); - mlir::Value res = builder.create( - loadOp.getLoc(), shadowType, - gutils->invertPointerM(loadOp.getAddr(), builder)); - gutils->setDiffe(loadOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } + bool isArgInactive(Operation *op, size_t) const { return isInactive(op); } }; struct StoreOpInterface @@ -115,7 +95,6 @@ class PointerTypeInterface void mlir::enzyme::registerLLVMDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) { - LLVM::LoadOp::attachInterface(*context); LLVM::StoreOp::attachInterface(*context); LLVM::AllocaOp::attachInterface(*context); LLVM::LLVMPointerType::attachInterface(*context); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td index 9e5f28e41665..c9b8fe76ca56 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -15,3 +15,11 @@ def : InactiveOp<"LLVM", "MemsetOp">; def : InactiveOp<"LLVM", "UndefOp">; def : InactiveOp<"LLVM", "ConstantOp">; def : InactiveOp<"LLVM", "UnreachableOp">; + + +def : ReadOnlyIdentityOp<"LLVM", "LoadOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "AddrSpaceCastOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "BitcastOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "GEPOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 010f9d997005..04f8c3a60e33 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -28,25 +28,7 @@ using namespace mlir; using namespace mlir::enzyme; namespace { -struct LoadOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto loadOp = cast(op); - if (!gutils->isConstantValue(loadOp)) { - SmallVector inds; - for (auto ind : loadOp.getIndices()) - inds.push_back(gutils->getNewFromOriginal(ind)); - mlir::Value res = builder.create( - loadOp.getLoc(), gutils->invertPointerM(loadOp.getMemref(), builder), - inds); - gutils->setDiffe(loadOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; +#include "Implementations/MemRefDerivatives.inc" struct StoreOpInterface : public AutoDiffOpInterface::ExternalModel(*context); + registerInterfaces(context); memref::StoreOp::attachInterface(*context); memref::AllocOp::attachInterface(*context); MemRefType::attachInterface(*context); diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td new file mode 100644 index 000000000000..7bf269199ca4 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td @@ -0,0 +1,13 @@ +include "Common.td" + +def : ReadOnlyIdentityOp<"memref", "LoadOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "CastOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "CollapseShapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ExpandShapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ReinterpretCastOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ReshapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "TransposeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ViewOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "SubViewOp", [0]>; + +def : InactiveOp<"memref", "DimOp">; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 5d764c8dc7bb..9123dfcaa22e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -128,7 +128,7 @@ def ActivityOpInterface }], /*retTy=*/"bool", /*methodName=*/"isArgInactive", - /*args=*/(ins "::mlir::Value":$val) + /*args=*/(ins "size_t":$opidx) > ]; } diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir index a091d6d774ac..2dcb61441251 100644 --- a/enzyme/test/MLIR/ForwardMode/affine.mlir +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -70,7 +70,7 @@ module { %mul = arith.mulf %x, %x : f64 memref.store %mul, %mem[%c0] : memref<1xf64> } - %r = memref.load %mem[%c0] : memref<1xf64> + %r = affine.load %mem[0] : memref<1xf64> %res = arith.mulf %c2, %r : f64 return %res : f64 } @@ -96,8 +96,8 @@ module { // CHECK: memref.store %[[v6]], %[[alloc]][%[[c0]]] : memref<1xf64> // CHECK: memref.store %[[v7]], %[[alloc_2]][%[[c0]]] : memref<1xf64> // CHECK: } - // CHECK: %[[v0:.+]] = memref.load %[[alloc]][%[[c0]]] : memref<1xf64> - // CHECK: %[[v1:.+]] = memref.load %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: %[[v0:.+]] = affine.load %[[alloc]][0] : memref<1xf64> + // CHECK: %[[v1:.+]] = affine.load %[[alloc_2]][0] : memref<1xf64> // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst_0]] : f64 // CHECK: %[[v3:.+]] = arith.mulf %[[cst_0]], %[[v1]] : f64 // CHECK: return %[[v2]] : f64 diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 58802117bf5b..f3f8f5940d27 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1884,12 +1884,16 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " public ActivityOpInterface::ExternalModel<" << opName << "Activity, " << dialect << "::" << opName << "> {\n"; os << " bool isInactive(mlir::Operation*) const { return true; }\n"; - os << " bool isArgInactive(mlir::Operation*, mlir::Value) const { " + os << " bool isArgInactive(mlir::Operation*, size_t) const { " "return true; }\n"; os << "};\n"; } const auto &cfpatterns = recordKeeper.getAllDerivedDefinitions("ControlFlowOp"); + + const auto &ropatterns = + recordKeeper.getAllDerivedDefinitions("ReadOnlyIdentityOp"); + for (auto &pattern : cfpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); @@ -1902,6 +1906,26 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << "};\n"; } + for (auto &pattern : ropatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + auto diffargs = pattern->getValueAsListOfInts("diffargs"); + os << "struct " << opName << "ROActivity : \n"; + os << " public ActivityOpInterface::ExternalModel<" << opName + << "ROActivity, " << dialect << "::" << opName << "> {\n"; + os << " bool isInactive(mlir::Operation* op) const {\n"; + os << " for (size_t i=0, len=op->getNumOperands(); i(*context);\n"; } + for (Record *pattern : ropatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "ROActivity>(*context);\n"; + os << " registerAutoDiffUsingReadOnlyIdentityInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } for (Record *pattern : brpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); From 9ff56bc24fa1d483ac1abb08bfaaddfa6272d366 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 2 Feb 2024 23:26:01 -0500 Subject: [PATCH 036/131] [MLIR] General store / allocation / tensor / math interfaces (#1657) * Generalize mem to include stores * alloca * Add tensor/math --- enzyme/BUILD | 15 ++++ .../MLIR/Implementations/AffineDerivatives.td | 4 +- .../MLIR/Implementations/ArithDerivatives.td | 12 --- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 37 ++++++++ .../MLIR/Implementations/CMakeLists.txt | 6 ++ enzyme/Enzyme/MLIR/Implementations/Common.td | 30 ++++++- .../CoreDialectsAutoDiffImplementations.cpp | 59 +++++++++++- .../CoreDialectsAutoDiffImplementations.h | 73 ++++++++++++--- .../LLVMAutoDiffOpInterfaceImpl.cpp | 41 ++------- .../MLIR/Implementations/LLVMDerivatives.td | 5 +- .../MathAutoDiffOpInterfaceImpl.cpp | 38 ++++++++ .../MLIR/Implementations/MathDerivatives.td | 17 ++++ .../MemRefAutoDiffOpInterfaceImpl.cpp | 90 ++++--------------- .../MLIR/Implementations/MemRefDerivatives.td | 5 +- .../MLIR/Interfaces/AutoDiffTypeInterface.h | 1 + .../MLIR/Interfaces/AutoDiffTypeInterface.td | 8 ++ .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 7 +- enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + enzyme/test/MLIR/ForwardMode/affine.mlir | 24 +++-- enzyme/test/MLIR/ForwardMode/tensorsin.mlir | 19 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 43 ++++++--- 21 files changed, 370 insertions(+), 165 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td create mode 100644 enzyme/test/MLIR/ForwardMode/tensorsin.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index 7645c3a7ecf9..3b4d03c001e0 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -484,6 +484,20 @@ gentbl( ], ) +gentbl( + name = "math-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/MathDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/MathDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/MathDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -512,6 +526,7 @@ cc_library( ":scf-derivatives", ":cf-derivatives", ":memref-derivatives", + ":math-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td index aaad67ff367a..9f22d00cecdb 100644 --- a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td @@ -23,4 +23,6 @@ def : ControlFlowOp<"affine", "AffineIfOp", [{ def : RegionTerminatorOp<"affine", "AffineYieldOp">; def : ReadOnlyIdentityOp<"affine", "AffineLoadOp", [0]>; -def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>; \ No newline at end of file +def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>; +def : MemoryIdentityOp<"affine", "AffineStoreOp", [1], [0]>; +def : MemoryIdentityOp<"affine", "AffineVectorStoreOp", [1], [0]>; diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index fb7f113f16fe..3ed038fa4cb6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -1,17 +1,5 @@ include "Common.td" -class ArithInst : Inst; - -def AddF : ArithInst<"arith::AddFOp">; -def SubF : ArithInst<"arith::SubFOp">; -def NegF : ArithInst<"arith::NegFOp">; -def MulF : ArithInst<"arith::MulFOp">; -def DivF : ArithInst<"arith::DivFOp">; -def RemF : ArithInst<"arith::RemFOp">; - -def CheckedMulF : ArithInst<"arith::MulFOp">; -def CheckedDivF : ArithInst<"arith::DivFOp">; - def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y), [ (DiffeRet), diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 30feca7ba82d..613e0ef2ad7e 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -45,6 +45,37 @@ class FloatTypeInterface } bool requiresShadow(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } +}; + +class TensorTypeInterface + : public AutoDiffTypeInterface::ExternalModel { +public: + Value createNullValue(Type self, OpBuilder &builder, Location loc) const { + auto tenType = self.cast(); + auto attr = DenseElementsAttr::get(tenType, 0); + return builder.create(loc, tenType, attr); + } + + Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a, + Value b) const { + return builder.create(loc, a, b); + } + + Type getShadowType(Type self, unsigned width) const { + assert(width == 1 && "unsupported width != 1"); + return self; + } + + bool requiresShadow(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } }; template @@ -69,6 +100,10 @@ class IntegerTypeInterface } bool requiresShadow(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } }; } // namespace @@ -81,5 +116,7 @@ void mlir::enzyme::registerBuiltinDialectAutoDiffInterface( Float64Type::attachInterface(*context); IntegerType::attachInterface>(*context); IndexType::attachInterface>(*context); + UnrankedTensorType::attachInterface(*context); + RankedTensorType::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index 464740d48d4b..ab156fa1b36f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -26,6 +26,10 @@ set(LLVM_TARGET_DEFINITIONS MemRefDerivatives.td) enzyme_tablegen(MemRefDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(MemRefDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS MathDerivatives.td) +enzyme_tablegen(MathDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(MathDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp @@ -37,6 +41,7 @@ add_mlir_library(MLIREnzymeImplementations BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp CFAutoDiffOpInterfaceImpl.cpp + MathAutoDiffOpInterfaceImpl.cpp DEPENDS MLIRAutoDiffOpInterfaceIncGen @@ -47,6 +52,7 @@ add_mlir_library(MLIREnzymeImplementations SCFDerivativesIncGen CFDerivativesIncGen MemRefDerivativesIncGen + MathDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 4eff61d7380f..1451d99f17ca 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -3,18 +3,26 @@ class InactiveOp { string opName = opName_; } +class AllocationOp { + string dialect = dialect_; + string opName = opName_; +} + class ControlFlowOp { string dialect = dialect_; string opName = opName_; string impl = impl_; } -class ReadOnlyIdentityOp diffargs_> { +class MemoryIdentityOp ptrargs_, list storedargs_ = []> { string dialect = dialect_; string opName = opName_; - list diffargs = diffargs_; + list ptrargs = ptrargs_; + list storedargs = storedargs_; } +class ReadOnlyIdentityOp ptrargs_> : MemoryIdentityOp; + class BranchOp { string dialect = dialect_; string opName = opName_; @@ -50,3 +58,21 @@ class Inst : Operation : Inst; +class MathInst : Inst; + +def AddF : ArithInst<"arith::AddFOp">; +def SubF : ArithInst<"arith::SubFOp">; +def NegF : ArithInst<"arith::NegFOp">; +def MulF : ArithInst<"arith::MulFOp">; +def DivF : ArithInst<"arith::DivFOp">; +def RemF : ArithInst<"arith::RemFOp">; + +def CheckedMulF : ArithInst<"arith::MulFOp">; +def CheckedDivF : ArithInst<"arith::DivFOp">; + +def CosF : MathInst<"math::CosOp">; +def SinF : MathInst<"math::SinOp">; +def ExpF : MathInst<"math::ExpOp">; +def SqrtF : MathInst<"math::SqrtOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 0ee33c44f01c..857784472211 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -101,9 +101,18 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, return; } -LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( - Operation *orig, OpBuilder &builder, MGradientUtils *gutils) { +static bool contains(ArrayRef ar, int v) { + for (auto a : ar) { + if (a == v) { + return true; + } + } + return false; +} +LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils, + ArrayRef storedVals) { auto iface = cast(orig); SmallVector newOperands; @@ -112,8 +121,27 @@ LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( if (iface.isArgInactive(operand.getOperandNumber())) { newOperands.push_back(gutils->getNewFromOriginal(operand.get())); } else { - if (gutils->isConstantValue(operand.get())) + if (gutils->isConstantValue(operand.get())) { + + if (contains(storedVals, operand.getOperandNumber())) { + if (auto iface = + dyn_cast(operand.get().getType())) { + if (!iface.requiresShadow()) { + // TODO only do if mutable + Type retTy = iface.getShadowType(); + auto toret = retTy.cast().createNullValue( + builder, operand.get().getLoc()); + newOperands.push_back(toret); + continue; + } + } + } + orig->emitWarning() + << "Unsupported constant arg to memory identity forward " + "handler(opidx=" + << operand.getOperandNumber() << ", op=" << operand.get() << ")\n"; return failure(); + } newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); } } @@ -127,11 +155,34 @@ LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( llvm::zip(orig->getResults(), shadow->getResults())) { gutils->setDiffe(oval, sval, builder); } - llvm::errs() << " shadow load: " << *shadow << "\n"; return success(); } +LogicalResult mlir::enzyme::detail::allocationForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils, bool zero) { + + Operation *primal = gutils->getNewFromOriginal(orig); + Operation *shadow = builder.clone(*primal); + + Value shadowRes = shadow->getResult(0); + + gutils->setDiffe(orig->getResult(0), shadowRes, builder); + gutils->eraseIfUnused(orig); + + if (zero) { + // Fill with zeros + if (auto iface = dyn_cast(shadowRes.getType())) { + return iface.zeroInPlace(builder, orig->getLoc(), shadowRes); + } else { + orig->emitWarning() << "memref.alloc element type does not implement " + "AutoDiffTypeInterface"; + return failure(); + } + } + return success(); +} + void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { auto termIface = cast(origTerminator); diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 4c54baa24ed2..5a1fbc136708 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -22,6 +22,7 @@ class OpBuilder; namespace enzyme { class MGradientUtils; +class MGradientUtilsReverse; namespace detail { // Non-template implementation of @@ -40,9 +41,14 @@ void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); // Implements forward-mode differentiation of read-only (including read-none) -// operations which do not perform computatoin -LogicalResult readOnlyIdentityForwardHandler(Operation *op, OpBuilder &builder, - MGradientUtils *gutils); +// operations which do not perform computation +LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, + ArrayRef storedVals); + +// Implements shadow initialization differentiation of allocation +LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, bool zero); // Implements the forward autodiff interface for operations whose derivatives // are can be inferred by analyzing their control flow and differentiating the @@ -88,14 +94,52 @@ class AutoDiffUsingRegionTerminator // Implements the forward autodiff interface for operations which are // read only and identity like (aka not computing sin of mem read). -template -class AutoDiffUsingReadOnlyIdentity +template +class AutoDiffUsingMemoryIdentity : public AutoDiffOpInterface::ExternalModel< - AutoDiffUsingReadOnlyIdentity, OpTy> { + AutoDiffUsingMemoryIdentity, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + + return memoryIdentityForwardHandler( + op, builder, gutils, std::initializer_list{storedvals...}); + } +}; + +// Implements the forward autodiff interface for operations which are +// allocation like +template +class AutoDiffUsingAllocationFwd : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingAllocationFwd, OpTy> { public: LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const { - return readOnlyIdentityForwardHandler(op, builder, gutils); + + return allocationForwardHandler(op, builder, gutils, /*zero*/ false); + } +}; + +// Implements the reverse autodiff interface for operations which are +// allocation like +template +class AutoDiffUsingAllocationRev + : public ReverseAutoDiffOpInterface::ExternalModel< + AutoDiffUsingAllocationRev, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const {} + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const { + (void)allocationForwardHandler(op, builder, (MGradientUtils *)gutils, + /*zero*/ true); } }; } // namespace detail @@ -117,10 +161,18 @@ void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { OpTy::template attachInterface>( context); } -// Registers AutoDiffUsingRegionTerminator for the given op. +// Registers AutoDiffUsingMemoryIdentity for the given op. +template +void registerAutoDiffUsingMemoryIdentityInterface(MLIRContext &context) { + OpTy::template attachInterface< + detail::AutoDiffUsingMemoryIdentity>(context); +} +// Registers AutoDiffUsingAllocation for the given op. template -void registerAutoDiffUsingReadOnlyIdentityInterface(MLIRContext &context) { - OpTy::template attachInterface>( +void registerAutoDiffUsingAllocationInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); + OpTy::template attachInterface>( context); } @@ -134,5 +186,6 @@ void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); +void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index d7884fcc305e..472cd0a43d61 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -24,9 +24,7 @@ using namespace mlir::enzyme; namespace { #include "Implementations/LLVMDerivatives.inc" -} // namespace -namespace { struct InlineAsmActivityInterface : public ActivityOpInterface::ExternalModel { @@ -38,37 +36,6 @@ struct InlineAsmActivityInterface bool isArgInactive(Operation *op, size_t) const { return isInactive(op); } }; -struct StoreOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto storeOp = cast(op); - if (!gutils->isConstantValue(storeOp.getAddr())) { - builder.create( - storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), - gutils->invertPointerM(storeOp.getAddr(), builder)); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - -struct AllocaOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto allocOp = cast(op); - if (!gutils->isConstantValue(allocOp)) { - Operation *nop = gutils->cloneWithNewOperands(builder, op); - gutils->setDiffe(allocOp, nop->getResult(0), builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - class PointerTypeInterface : public AutoDiffTypeInterface::ExternalModel { @@ -89,14 +56,18 @@ class PointerTypeInterface } bool requiresShadow(Type self) const { return true; } + + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + // TODO inspect val and memset corresponding size + return failure(); + } }; } // namespace void mlir::enzyme::registerLLVMDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) { - LLVM::StoreOp::attachInterface(*context); - LLVM::AllocaOp::attachInterface(*context); LLVM::LLVMPointerType::attachInterface(*context); registerInterfaces(context); }); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td index c9b8fe76ca56..e77e88aea47f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -1,5 +1,6 @@ include "Common.td" +def : MemoryIdentityOp<"LLVM", "StoreOp", [1], [0]>; def : InactiveOp<"LLVM", "SIToFPOp">; def : InactiveOp<"LLVM", "UIToFPOp">; def : InactiveOp<"LLVM", "FPToSIOp">; @@ -22,4 +23,6 @@ def : ReadOnlyIdentityOp<"LLVM", "AddrSpaceCastOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "BitcastOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "GEPOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>; -def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; \ No newline at end of file +def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; + +def : AllocationOp<"LLVM", "AllocaOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..2833eeb44726 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,38 @@ +//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/Ops.h" +#include "mlir/IR/TypeSupport.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/MathDerivatives.inc" +} // namespace + +void mlir::enzyme::registerMathDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, math::MathDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td new file mode 100644 index 000000000000..71db1a8574ac --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td @@ -0,0 +1,17 @@ +include "Common.td" + +def : MLIRDerivative<"math", "CosOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (NegF (SinF $x))) + ] + >; +def : MLIRDerivative<"math", "ExpOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (ExpF $x)) + ] + >; +def : MLIRDerivative<"math", "SinOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (CosF $x)) + ] + >; diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 04f8c3a60e33..1b811caf1b81 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -17,38 +17,20 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" -// TODO: We need a way to zero out a memref (which linalg.fill does), but -// ideally we wouldn't depend on the linalg dialect. -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" +// TODO: We need a way to zero out a memref (which linalg.fill does), but +// ideally we wouldn't depend on the linalg dialect. +#include "mlir/Dialect/Linalg/IR/Linalg.h" + using namespace mlir; using namespace mlir::enzyme; namespace { #include "Implementations/MemRefDerivatives.inc" -struct StoreOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto storeOp = cast(op); - if (!gutils->isConstantValue(storeOp.getMemref())) { - SmallVector inds; - for (auto ind : storeOp.getIndices()) - inds.push_back(gutils->getNewFromOriginal(ind)); - builder.create( - storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), - gutils->invertPointerM(storeOp.getMemref(), builder), inds); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - struct LoadOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { @@ -177,38 +159,6 @@ struct StoreOpInterfaceReverse } }; -struct AllocOpInterfaceReverse - : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const {} - - SmallVector cacheValues(Operation *op, - MGradientUtilsReverse *gutils) const { - return SmallVector(); - } - - void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const { - auto allocOp = cast(op); - auto newAllocOp = cast(gutils->getNewFromOriginal(op)); - - Value shadow = builder.create( - op->getLoc(), newAllocOp.getType(), newAllocOp.getDynamicSizes()); - // Fill with zeros - if (auto iface = dyn_cast( - allocOp.getType().getElementType())) { - Value zero = iface.createNullValue(builder, op->getLoc()); - builder.create(op->getLoc(), zero, shadow); - } else { - op->emitWarning() << "memref.alloc element type does not implement " - "AutoDiffTypeInterface"; - } - gutils->mapShadowValue(allocOp, shadow, builder); - } -}; - struct SubViewOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel< SubViewOpInterfaceReverse, memref::SubViewOp> { @@ -236,21 +186,6 @@ struct SubViewOpInterfaceReverse } }; -struct AllocOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto allocOp = cast(op); - if (!gutils->isConstantValue(allocOp)) { - Operation *nop = gutils->cloneWithNewOperands(builder, op); - gutils->setDiffe(allocOp, nop->getResult(0), builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - class MemRefTypeInterface : public AutoDiffTypeInterface::ExternalModel { @@ -271,6 +206,20 @@ class MemRefTypeInterface } bool requiresShadow(Type self) const { return true; } + + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + auto MT = cast(self); + if (auto iface = dyn_cast(MT.getElementType())) { + if (!iface.requiresShadow()) { + Value zero = iface.createNullValue(builder, loc); + builder.create(loc, zero, val); + } + } else { + return failure(); + } + return success(); + } }; } // namespace @@ -278,13 +227,10 @@ void mlir::enzyme::registerMemRefDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) { registerInterfaces(context); - memref::StoreOp::attachInterface(*context); - memref::AllocOp::attachInterface(*context); MemRefType::attachInterface(*context); memref::LoadOp::attachInterface(*context); memref::StoreOp::attachInterface(*context); - memref::AllocOp::attachInterface(*context); memref::SubViewOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td index 7bf269199ca4..173a22a6e2f1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td @@ -1,5 +1,6 @@ include "Common.td" +def : MemoryIdentityOp<"memref", "StoreOp", [1], [0]>; def : ReadOnlyIdentityOp<"memref", "LoadOp", [0]>; def : ReadOnlyIdentityOp<"memref", "CastOp", [0]>; def : ReadOnlyIdentityOp<"memref", "CollapseShapeOp", [0]>; @@ -10,4 +11,6 @@ def : ReadOnlyIdentityOp<"memref", "TransposeOp", [0]>; def : ReadOnlyIdentityOp<"memref", "ViewOp", [0]>; def : ReadOnlyIdentityOp<"memref", "SubViewOp", [0]>; -def : InactiveOp<"memref", "DimOp">; \ No newline at end of file +def : InactiveOp<"memref", "DimOp">; +def : AllocationOp<"memref", "AllocOp">; +def : AllocationOp<"memref", "AllocaOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h index 7c405bde0a99..04d0186c4b1e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h @@ -16,6 +16,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" namespace mlir { class OpBuilder; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td index 5fec2643a98c..956f616751dc 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td @@ -41,6 +41,14 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { /*methodName=*/"createAddOp", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, "::mlir::Value":$a, "::mlir::Value":$b) >, + InterfaceMethod< + /*desc=*/[{ + Zero the operation in place + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"zeroInPlace", + /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, "::mlir::Value":$val) + >, InterfaceMethod< /*desc=*/[{ Returns the type that can contain the adjoint value for this type. If diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 84b8b8b5d3b3..2b2760dfa8ff 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -4,7 +4,10 @@ using namespace mlir; using namespace mlir::enzyme; Type getShadowType(Type type, unsigned width) { - return type.cast().getShadowType(width); + if (auto iface = type.dyn_cast()) + return iface.getShadowType(width); + llvm::errs() << " type does not have autodifftypeinterface: " << type << "\n"; + exit(1); } mlir::FunctionType getFunctionTypeForClone( @@ -263,4 +266,4 @@ FunctionOpInterface CloneFunctionWithReturns( } return NewF; -} \ No newline at end of file +} diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index e7af0b01267f..90038c1e3236 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -102,6 +102,7 @@ int main(int argc, char **argv) { enzyme::registerBuiltinDialectAutoDiffInterface(registry); enzyme::registerLLVMDialectAutoDiffInterface(registry); enzyme::registerNVVMDialectAutoDiffInterface(registry); + enzyme::registerMathDialectAutoDiffInterface(registry); enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); enzyme::registerCFDialectAutoDiffInterface(registry); diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir index 2dcb61441251..d0e587409713 100644 --- a/enzyme/test/MLIR/ForwardMode/affine.mlir +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -61,14 +61,13 @@ module { // CHECK: return %[[v1]] : f64 func.func @if_then(%x : f64, %c : i1) -> f64 { - %c0 = arith.constant 0 : index %c2 = arith.constant 2.000000e+00 : f64 %c10 = arith.constant 10.000000e+00 : f64 %mem = memref.alloc() : memref<1xf64> - memref.store %c2, %mem[%c0] : memref<1xf64> + affine.store %c2, %mem[0] : memref<1xf64> scf.if %c { %mul = arith.mulf %x, %x : f64 - memref.store %mul, %mem[%c0] : memref<1xf64> + affine.store %mul, %mem[0] : memref<1xf64> } %r = affine.load %mem[0] : memref<1xf64> %res = arith.mulf %c2, %r : f64 @@ -80,25 +79,24 @@ module { } // CHECK: @fwddiffeif_then // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { - // CHECK: %[[c0:.+]] = arith.constant 0 : index - // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 - // CHECK: %[[cst_0:.+]] = arith.constant 2.000000e+00 : f64 - // CHECK: %[[cst_1:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK-DAG: %[[cst1:.+]] = arith.constant 1.000000e+01 : f64 // CHECK: %[[alloc:.+]] = memref.alloc() : memref<1xf64> // CHECK: %[[alloc_2:.+]] = memref.alloc() : memref<1xf64> - // CHECK: memref.store %[[cst]], %[[alloc]][%[[c0]]] : memref<1xf64> - // CHECK: memref.store %[[cst_0]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK-DAG: %[[cst0:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: affine.store %[[cst0]], %[[alloc]][0] : memref<1xf64> + // CHECK: affine.store %[[cst2]], %[[alloc_2]][0] : memref<1xf64> // CHECK: scf.if %[[arg2]] { // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK: %[[v5:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK: %[[v6:.+]] = arith.addf %[[v4]], %[[v5]] : f64 // CHECK: %[[v7:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 - // CHECK: memref.store %[[v6]], %[[alloc]][%[[c0]]] : memref<1xf64> - // CHECK: memref.store %[[v7]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: affine.store %[[v6]], %[[alloc]][0] : memref<1xf64> + // CHECK: affine.store %[[v7]], %[[alloc_2]][0] : memref<1xf64> // CHECK: } // CHECK: %[[v0:.+]] = affine.load %[[alloc]][0] : memref<1xf64> // CHECK: %[[v1:.+]] = affine.load %[[alloc_2]][0] : memref<1xf64> - // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst_0]] : f64 - // CHECK: %[[v3:.+]] = arith.mulf %[[cst_0]], %[[v1]] : f64 + // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst2]] : f64 + // CHECK: %[[v3:.+]] = arith.mulf %[[cst2]], %[[v1]] : f64 // CHECK: return %[[v2]] : f64 } diff --git a/enzyme/test/MLIR/ForwardMode/tensorsin.mlir b/enzyme/test/MLIR/ForwardMode/tensorsin.mlir new file mode 100644 index 000000000000..5ab208712b73 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/tensorsin.mlir @@ -0,0 +1,19 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<2xf64>) -> tensor<2xf64> { + %y = math.sin %x : tensor<2xf64> + return %y : tensor<2xf64> + } + func.func @dsq(%x : tensor<2xf64>, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (tensor<2xf64>, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffesquare(%arg0: tensor<2xf64>, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[a0:.+]] = math.cos %arg0 : tensor<2xf64> +// CHECK-NEXT: %[[a1:.+]] = arith.mulf %arg1, %[[a0]] : tensor<2xf64> +// CHECK-NEXT: %[[a2:.+]] = math.sin %arg0 : tensor<2xf64> +// CHECK-NEXT: return %[[a1]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index f3f8f5940d27..c30f8870063e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1398,8 +1398,11 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, insert(dg, next); if (ptree->getArgNameStr(i).size()) { - auto op = - (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + std::string op; + if (intrinsic != MLIRDerivatives) + op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + else + op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); if (prev.size() > 0) { op = "gutils->extractMeta(Builder2, " + op + ", ArrayRef({"; @@ -1891,8 +1894,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, const auto &cfpatterns = recordKeeper.getAllDerivedDefinitions("ControlFlowOp"); - const auto &ropatterns = - recordKeeper.getAllDerivedDefinitions("ReadOnlyIdentityOp"); + const auto &mempatterns = + recordKeeper.getAllDerivedDefinitions("MemoryIdentityOp"); for (auto &pattern : cfpatterns) { auto opName = pattern->getValueAsString("opName"); @@ -1906,13 +1909,14 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << "};\n"; } - for (auto &pattern : ropatterns) { + for (auto &pattern : mempatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); - auto diffargs = pattern->getValueAsListOfInts("diffargs"); - os << "struct " << opName << "ROActivity : \n"; + auto diffargs = pattern->getValueAsListOfInts("ptrargs"); + auto storedargs = pattern->getValueAsListOfInts("storedargs"); + os << "struct " << opName << "MemActivity : \n"; os << " public ActivityOpInterface::ExternalModel<" << opName - << "ROActivity, " << dialect << "::" << opName << "> {\n"; + << "MemActivity, " << dialect << "::" << opName << "> {\n"; os << " bool isInactive(mlir::Operation* op) const {\n"; os << " for (size_t i=0, len=op->getNumOperands(); igetValueAsString("opName"); @@ -1954,13 +1964,16 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " registerAutoDiffUsingControlFlowInterface<" << dialect << "::" << opName << ">(*context);\n"; } - for (Record *pattern : ropatterns) { + for (Record *pattern : mempatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " " << dialect << "::" << opName << "::attachInterface<" << opName - << "ROActivity>(*context);\n"; - os << " registerAutoDiffUsingReadOnlyIdentityInterface<" << dialect - << "::" << opName << ">(*context);\n"; + << "MemActivity>(*context);\n"; + os << " registerAutoDiffUsingMemoryIdentityInterface<" << dialect + << "::" << opName; + for (auto storedarg : pattern->getValueAsListOfInts("storedargs")) + os << ", " << storedarg; + os << ">(*context);\n"; } for (Record *pattern : brpatterns) { auto opName = pattern->getValueAsString("opName"); @@ -1974,6 +1987,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect << "::" << opName << ">(*context);\n"; } + for (Record *pattern : allocpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingAllocationInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } os << "}\n"; } } From d8ec31545895507d2e04e3650c5a912d2f20a52e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 5 Feb 2024 22:00:34 +0100 Subject: [PATCH 037/131] Fix differential use analysis of ivi (#1661) --- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 41 +++++++++++++++++++ .../test/Enzyme/ForwardMode/insertdiffuse.ll | 25 +++++++++++ enzyme/test/lit.site.cfg.py.in | 2 +- 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 4a2cff5d79a9..3be0cdfc2895 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -293,6 +293,47 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( return false; } + if (shadow) { + if (auto IVI = dyn_cast(user)) + if (isa(IVI) || isa(IVI)) { + if (IVI->getOperand(1) == val) { + SmallVector todo; + todo.push_back(IVI); + SmallVector, 1> + done; + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + done.emplace_back(cast(u), cur); + } + } + for (auto &pair : done) { + + bool direct = is_use_directly_needed_in_reverse( + gutils, pair.second, mode, pair.first, oldUnreachable, + QueryType::Shadow, recursiveUse); + if (direct) { + + if (EnzymePrintDiffUse) + llvm::errs() + << " Need (partial) direct " << to_string(qtype) << " of " + << *val << " in reverse from insertelem " << *user + << " via " << *pair.second << " in " << *pair.first << "\n"; + return true; + } + } + } + } + } + if (!shadow) if (auto IEI = dyn_cast(user)) { // Only need the index in the reverse, so if the value is not diff --git a/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll b/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll new file mode 100644 index 000000000000..44f246cdfbe0 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll @@ -0,0 +1,25 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +define { double, i64 } @julia_logabsgamma_3264_inner.1(double %x, i64 %z) { +entry: + %iadd = add i64 %z, 1 + %.fca.0.insert = insertvalue { double, i64 } undef, double %x, 0 + %.fca.1.insert = insertvalue { double, i64 } %.fca.0.insert, i64 %iadd, 1 + ret { double, i64 } %.fca.1.insert +} + +declare { double, i64 } @__enzyme_fwddiff(...) + +define { double, i64 } @ad(double %x, double %dx) { + %m = call { double, i64 } (...) @__enzyme_fwddiff({ double, i64 } (double, i64)* @julia_logabsgamma_3264_inner.1, double %x, double %dx, i64 1) + ret { double, i64 } %m +} + +; CHECK: define internal { double, i64 } @fwddiffejulia_logabsgamma_3264_inner.1(double %x, double %"x'", i64 %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %iadd = add i64 %z, 1 +; CHECK-NEXT: %".fca.0.insert'ipiv" = insertvalue { double, i64 } zeroinitializer, double %"x'", 0 +; CHECK-NEXT: %".fca.1.insert'ipiv" = insertvalue { double, i64 } %".fca.0.insert'ipiv", i64 %iadd, 1 +; CHECK-NEXT: ret { double, i64 } %".fca.1.insert'ipiv" +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 481fce924bcc..0b8a0f831d6e 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -82,7 +82,7 @@ if len("@ENZYME_BINARY_DIR@") == 0: oldPMOP = oldPM newPMOP = newPM -if int(config.llvm_ver) >= 16: +if int(config.llvm_ver) == 16: newPM += " -opaque-pointers=0" oldPM += " -opaque-pointers=0" From 5b6b4ac44a343676aff768dcf0c50ee0a6055fd5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 02:08:17 +0000 Subject: [PATCH 038/131] Fix integration tests on main (#1662) --- enzyme/Enzyme/AdjointGenerator.h | 24 +++++++++++++++++++ .../test/Integration/ReverseMode/forrealloc.c | 8 +++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index c6d558cf5ba0..80d6c851e2a2 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -3353,6 +3353,29 @@ class AdjointGenerator } } } + // If the type is known, but outside of the known range + // (but the memcpy size is a variable), attempt to use + // the first type out of range as the memcpy type. + if (size == 1 && !isa(new_size)) { + for (auto ptr : {orig_dst, orig_src}) { + vd = TR.query(ptr).Data0().ShiftIndices(DL, 0, -1, 0); + if (vd.isKnownPastPointer()) { + ConcreteType mv(BaseType::Unknown); + size_t minInt = 0xFFFFFFFF; + for (const auto &pair : vd.getMapping()) { + if (pair.first.size() != 1) + continue; + if (minInt < (size_t)pair.first[0]) + continue; + minInt = pair.first[0]; + mv = pair.second; + } + assert(mv != BaseType::Unknown); + vd.insert({0}, mv); + goto known; + } + } + } if (errorIfNoType) EmitWarning("CannotDeduceType", MTI, "failed to deduce type of copy ", MTI); @@ -3368,6 +3391,7 @@ class AdjointGenerator &TR.analyzer, nullptr, wrap(&BuilderZ)); } else { ss << "\n"; + ss << *gutils->oldFunc << "\n"; TR.dump(ss); EmitFailure("CannotDeduceType", MTI.getDebugLoc(), &MTI, ss.str()); } diff --git a/enzyme/test/Integration/ReverseMode/forrealloc.c b/enzyme/test/Integration/ReverseMode/forrealloc.c index 9fc9ec5798d7..9034ff68f0bb 100644 --- a/enzyme/test/Integration/ReverseMode/forrealloc.c +++ b/enzyme/test/Integration/ReverseMode/forrealloc.c @@ -1,11 +1,11 @@ // RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - // RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - #include #include From dd02ac51a56d8761d87d4dd3ebef6f730d88e37f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 11:01:05 +0000 Subject: [PATCH 039/131] Fix DT error (#1665) --- enzyme/Enzyme/FunctionUtils.cpp | 28 +++++++++++------ enzyme/test/Enzyme/ReverseMode/logabsgamma.ll | 31 +++++++++++++++++++ .../test/Integration/ReverseMode/forrealloc.c | 16 +++++----- .../ReverseMode/integrateconst.cpp | 1 - enzyme/test/Integration/ReverseMode/sret.cpp | 3 -- 5 files changed, 57 insertions(+), 22 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/logabsgamma.ll diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 4c3370cf4cf9..71ceb9688eea 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2495,19 +2495,23 @@ void ReplaceFunctionImplementation(Module &M) { } void PreProcessCache::optimizeIntermediate(Function *F) { - PromotePass().run(*F, FAM); + PreservedAnalyses PA; + PA = PromotePass().run(*F, FAM); + FAM.invalidate(*F, PA); #if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - GVNPass().run(*F, FAM); + PA = GVNPass().run(*F, FAM); #else - GVN().run(*F, FAM); + PA = GVN().run(*F, FAM); #endif + FAM.invalidate(*F, PA); #if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG) - SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM); + PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM); #elif LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - SROAPass().run(*F, FAM); + PA = SROAPass().run(*F, FAM); #else - SROA().run(*F, FAM); + PA = SROA().run(*F, FAM); #endif + FAM.invalidate(*F, PA); if (EnzymeSelectOpt) { #if LLVM_VERSION_MAJOR >= 12 @@ -2518,8 +2522,10 @@ void PreProcessCache::optimizeIntermediate(Function *F) { /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr); #endif - SimplifyCFGPass(scfgo).run(*F, FAM); - CorrelatedValuePropagationPass().run(*F, FAM); + PA = SimplifyCFGPass(scfgo).run(*F, FAM); + FAM.invalidate(*F, PA); + PA = CorrelatedValuePropagationPass().run(*F, FAM); + FAM.invalidate(*F, PA); SelectOptimization(F); } // EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM); @@ -2529,8 +2535,10 @@ void PreProcessCache::optimizeIntermediate(Function *F) { ReplaceFunctionImplementation(*F->getParent()); - PreservedAnalyses PA; - FAM.invalidate(*F, PA); + { + PreservedAnalyses PA; + FAM.invalidate(*F, PA); + } #if LLVM_VERSION_MAJOR < 14 using OptimizationLevel = llvm::PassBuilder::OptimizationLevel; diff --git a/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll new file mode 100644 index 000000000000..02e0f40cd5bc --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s + +; XFAIL: * + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %a = call { double, i64 } @logabsgamma(double %x) + %b = extractvalue { double, i64 } %a, 0 + ret double %b +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +declare { double, i64 } @logabsgamma(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @cosh(double %x) +; CHECK-NEXT: %1 = fmul fast double %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0 +; CHECK-NEXT: ret { double } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Integration/ReverseMode/forrealloc.c b/enzyme/test/Integration/ReverseMode/forrealloc.c index 9034ff68f0bb..924f7f04a9f1 100644 --- a/enzyme/test/Integration/ReverseMode/forrealloc.c +++ b/enzyme/test/Integration/ReverseMode/forrealloc.c @@ -1,11 +1,11 @@ -// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - -// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %loadClangEnzyme %s -S -emit-llvm -o - | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %loadClangEnzyme %s -S -emit-llvm -o - | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi #include #include diff --git a/enzyme/test/Integration/ReverseMode/integrateconst.cpp b/enzyme/test/Integration/ReverseMode/integrateconst.cpp index 2086ab8b525d..c55247d65e55 100644 --- a/enzyme/test/Integration/ReverseMode/integrateconst.cpp +++ b/enzyme/test/Integration/ReverseMode/integrateconst.cpp @@ -14,7 +14,6 @@ #define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS #define BOOST_NO_EXCEPTIONS -#include #include #include diff --git a/enzyme/test/Integration/ReverseMode/sret.cpp b/enzyme/test/Integration/ReverseMode/sret.cpp index 2b5315d5d35c..16751fe3e318 100644 --- a/enzyme/test/Integration/ReverseMode/sret.cpp +++ b/enzyme/test/Integration/ReverseMode/sret.cpp @@ -8,9 +8,6 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S #include "../test_utils.h" -#include -#include -#include typedef struct { double df[3]; From cf728efae5bab49fead2bd74c891361f73f408ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 11:29:56 +0000 Subject: [PATCH 040/131] Fixups for external bazel integration tests (#1668) --- enzyme/BUILD | 184 +++++++++++++----- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp | 28 +-- .../Analysis/DataFlowActivityAnalysis.cpp | 27 +-- .../LinalgAutoDiffOpInterfaceImpl.cpp | 1 - .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 10 +- .../Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp | 2 +- enzyme/Enzyme/TypeAnalysis/BaseType.h | 8 +- enzyme/test/Integration/ForwardMode/loops.c | 4 - .../Integration/ForwardMode/loopsdouble.c | 4 - .../Integration/ForwardMode/loopstriple.c | 4 - enzyme/test/Integration/ForwardMode/rwrloop.c | 4 - enzyme/test/Integration/ForwardMode/sumtil.c | 4 - enzyme/test/Integration/ForwardMode/sumtil2.c | 4 - .../Integration/ForwardModeVector/binops.c | 17 +- .../ReverseMode/allocatedtape_err.c | 5 +- .../test/Integration/ReverseMode/boundissue.c | 4 - .../test/Integration/ReverseMode/cachefwd.c | 4 - .../Integration/ReverseMode/customcombined.c | 2 +- enzyme/test/Integration/ReverseMode/dbginfo.c | 2 - .../ReverseMode/differential_pointer_return.c | 5 - .../test/Integration/ReverseMode/forrealloc.c | 6 - enzyme/test/Integration/ReverseMode/frexp.c | 6 +- .../test/Integration/ReverseMode/fwdsolve.c | 9 +- .../ReverseMode/gradient-struct-return.c | 1 - .../Integration/ReverseMode/headerremat.c | 7 - .../Integration/ReverseMode/insertsort_sum.c | 4 - .../ReverseMode/insertsort_sum_alt.c | 5 - .../ReverseMode/insertsort_sum_min.c | 4 - enzyme/test/Integration/ReverseMode/loops.c | 4 - .../Integration/ReverseMode/loopsdouble.c | 4 - .../Integration/ReverseMode/loopstriple.c | 4 - enzyme/test/Integration/ReverseMode/manydiv.c | 4 - enzyme/test/Integration/ReverseMode/manymax.c | 4 - .../test/Integration/ReverseMode/metamalloc.c | 10 +- enzyme/test/Integration/ReverseMode/metarwr.c | 6 +- .../ReverseMode/mixedstruct1-old.c | 5 - .../ReverseMode/mixedstruct1-simple.c | 7 +- .../ReverseMode/mixedstruct1-simplefda.c | 5 - .../ReverseMode/mixedstruct1-simpleps.c | 5 - .../ReverseMode/mixedstruct1-simpler.c | 5 - .../ReverseMode/mixedstruct1-simplest.c | 5 - .../Integration/ReverseMode/mixedstruct1-sp.c | 5 - .../Integration/ReverseMode/mixedstruct1.c | 5 - .../Integration/ReverseMode/multivecmaxC.c | 5 +- .../Integration/ReverseMode/posix_memalign.c | 8 +- .../ReverseMode/posix_memalignfor.c | 8 +- .../Integration/ReverseMode/readwriteread.c | 5 - enzyme/test/Integration/ReverseMode/recurse.c | 4 - enzyme/test/Integration/ReverseMode/remat.c | 4 - .../Integration/ReverseMode/rematSimple.c | 4 - enzyme/test/Integration/ReverseMode/rwrloop.c | 5 - enzyme/test/Integration/ReverseMode/rwrmeta.c | 4 - .../Integration/ReverseMode/smallrealloc.c | 5 - .../Integration/ReverseMode/subdoublestore.c | 6 - enzyme/test/Integration/ReverseMode/sumtil.c | 4 - enzyme/test/Integration/ReverseMode/sumtil2.c | 4 - .../test/Integration/ReverseMode/taylorlog.c | 2 - enzyme/test/Integration/blas_inline.h | 3 + enzyme/test/Integration/test_utils.h | 23 ++- enzyme/test/MLIR/ForwardMode/inactive.mlir | 8 +- 60 files changed, 212 insertions(+), 333 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index 3b4d03c001e0..0ad691e5aae0 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -1,6 +1,6 @@ +load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@llvm-project//llvm:tblgen.bzl", "gentbl") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@bazel_skylib//rules:expand_template.bzl", "expand_template") licenses(["notice"]) @@ -145,11 +145,14 @@ gentbl( cc_library( name = "EnzymeStatic", - srcs = glob([ - "Enzyme/*.cpp", - "Enzyme/TypeAnalysis/*.cpp", - "Enzyme/Clang/EnzymeClang.cpp", - ], exclude=["Enzyme/eopt.cpp"]), + srcs = glob( + [ + "Enzyme/*.cpp", + "Enzyme/TypeAnalysis/*.cpp", + "Enzyme/Clang/EnzymeClang.cpp", + ], + exclude = ["Enzyme/eopt.cpp"], + ), hdrs = glob([ "Enzyme/*.h", "Enzyme/TypeAnalysis/*.h", @@ -194,7 +197,7 @@ cc_library( "@llvm-project//llvm:TransformUtils", "@llvm-project//llvm:config", ], - alwayslink = 1 + alwayslink = 1, ) cc_binary( @@ -223,6 +226,8 @@ cc_binary( srcs = ["Enzyme/eopt.cpp"], deps = [ ":EnzymeStatic", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", "@llvm-project//llvm:opt-driver", ], ) @@ -230,16 +235,16 @@ cc_binary( td_library( name = "EnzymeDialectTdFiles", srcs = [ - "Enzyme/MLIR/Dialect/Dialect.td", + "Enzyme/MLIR/Dialect/Dialect.td", ], - deps = [ + deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", "@llvm-project//mlir:ViewLikeInterfaceTdFiles", - "@llvm-project//mlir:FunctionInterfacesTdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:LoopLikeInterfaceTdFiles", - ] + ], ) gentbl_cc_library( @@ -277,9 +282,9 @@ td_library( name = "EnzymePassesTdFiles", srcs = [ ], - deps = [ + deps = [ "@llvm-project//mlir:PassBaseTdFiles", - ] + ], ) gentbl_cc_library( @@ -349,7 +354,6 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) - gentbl_cc_library( name = "EnzymeTypeInterfacesIncGen", tbl_outs = [ @@ -394,7 +398,8 @@ gentbl( td_file = "Enzyme/MLIR/Implementations/AffineDerivatives.td", td_srcs = [ "Enzyme/MLIR/Implementations/AffineDerivatives.td", - "Enzyme/MLIR/Implementations/Common.td"], + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -408,7 +413,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/ArithDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -422,7 +430,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/LLVMDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -436,7 +447,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/NVVMDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -450,7 +464,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/SCFDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/SCFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/SCFDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -464,7 +481,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/CFDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/CFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/CFDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -478,7 +498,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/MemRefDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/MemRefDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/MemRefDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -492,7 +515,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/MathDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/MathDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/MathDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -514,48 +540,95 @@ cc_library( "Enzyme/MLIR/Analysis/*.h", "Enzyme/MLIR/Implementations/*.h", "Enzyme/Utils.h", - "Enzyme/TypeAnalysis/*.h" + "Enzyme/TypeAnalysis/*.h", ]), - includes = ["Enzyme/MLIR", "Enzyme"], + includes = [ + "Enzyme", + "Enzyme/MLIR", + ], visibility = ["//visibility:public"], deps = [ + ":EnzymeAttributesIncGen", + ":EnzymeEnumsIncGen", + ":EnzymeOpInterfacesIncGen", + ":EnzymeOpsIncGen", + ":EnzymePassesIncGen", + ":EnzymeTypeInterfacesIncGen", + ":EnzymeTypesIncGen", ":affine-derivatives", ":arith-derivatives", + ":cf-derivatives", ":llvm-derivatives", + ":math-derivatives", + ":memref-derivatives", ":nvvm-derivatives", ":scf-derivatives", - ":cf-derivatives", - ":memref-derivatives", - ":math-derivatives", - ":EnzymeOpsIncGen", - ":EnzymePassesIncGen", - ":EnzymeTypesIncGen", - ":EnzymeEnumsIncGen", - ":EnzymeAttributesIncGen", - ":EnzymeTypeInterfacesIncGen", - ":EnzymeOpInterfacesIncGen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//llvm:config", "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:ConversionPasses", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:CastInterfaces", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgStructuredOpsIncGen", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", ], ) cc_binary( name = "enzymemlir-opt", srcs = ["Enzyme/MLIR/enzymemlir-opt.cpp"], - visibility = ["//visibility:public"], includes = ["Enzyme/MLIR"], + visibility = ["//visibility:public"], deps = [ ":EnzymeMLIR", - "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Transforms", ], ) @@ -586,18 +659,25 @@ expand_template( name = "%s.test" % src, srcs = [src], data = [ + ":enzyme-clang", + ":enzyme-clang++", + ":enzyme-opt", + ":enzymemlir-opt", ":test/lit.cfg.py", ":test/lit.site.cfg.py", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:count", - "@llvm-project//llvm:not", - "@llvm-project//llvm:lli", - ":enzyme-opt", "@llvm-project//clang:builtin_headers_gen", - ":enzyme-clang", - ":enzyme-clang++", - ":enzymemlir-opt" - ] + glob(["test/**/*.h"]) + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["test/**/*.h"]), + ) + for src in glob( + [ + "test/**/*.mlir", + "test/Integration/**/*.c", + "test/Integration/**/.cpp", + ], + exclude = ["test/**/*omp*.c"], ) - for src in glob(["test/**/*.mlir", "test/Integration/**/*.c", "test/Integration/**/.cpp"], exclude=["test/**/*omp*.c"]) ] diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index b584e8e64a1c..767a60a3ed2b 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -211,7 +211,7 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, // TODO: consider a stricter check that we only replace unknown // values or a value with itself, currently blocked by memalign. AliasClassSet valuesCopy(values); - valuesCopy.join(it->getSecond()); + (void)valuesCopy.join(it->getSecond()); values.print(llvm::errs()); llvm::errs() << "\n"; it->getSecond().print(llvm::errs()); @@ -572,7 +572,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( if (funcMayReadOther) { // If a function may read from other, it may be storing pointers from // unknown alias sets into any writable pointer. - functionMayCapture.markUnknown(); + (void)functionMayCapture.markUnknown(); } else { for (int pointerAsData : pointerLikeOperands) { // If not captured, it cannot be stored in anything. @@ -583,7 +583,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( const auto *srcClasses = getOrCreateFor( call, call.getArgOperands()[pointerAsData]); - functionMayCapture.join(srcClasses->getAliasClassesObject()); + (void)functionMayCapture.join(srcClasses->getAliasClassesObject()); } } @@ -598,16 +598,17 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // If the argument cannot be stored into, just preserve it as is. if (!mayWriteArg(callee, pointerOperand, argModRef)) { - nonWritableOperandClasses.join(destClasses->getAliasClassesObject()); + (void)nonWritableOperandClasses.join( + destClasses->getAliasClassesObject()); continue; } - writableClasses.join(destClasses->getAliasClassesObject()); + (void)writableClasses.join(destClasses->getAliasClassesObject()); // If the destination class is unknown, mark all known classes // pessimistic (alias classes that have not beed analyzed and thus are // absent from pointsTo are treated as "undefined" at this point). if (destClasses->isUnknown()) { - writableClasses.markUnknown(); + (void)writableClasses.markUnknown(); changed |= after->markAllPointToUnknown(); break; } @@ -675,15 +676,15 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( AliasClassSet resultWithoutNonWritableOperands = AliasClassSet::getUndefined(); if (destClasses->isUnknown() || nonWritableOperandClasses.isUnknown()) { - resultWithoutNonWritableOperands.markUnknown(); + (void)resultWithoutNonWritableOperands.markUnknown(); } else if (!destClasses->isUndefined() && !nonWritableOperandClasses.isUndefined()) { DenseSet nonOperandClasses = llvm::set_difference(destClasses->getAliasClasses(), nonWritableOperandClasses.getAliasClasses()); - resultWithoutNonWritableOperands.insert(nonOperandClasses); + (void)resultWithoutNonWritableOperands.insert(nonOperandClasses); } else { - resultWithoutNonWritableOperands.join( + (void)resultWithoutNonWritableOperands.join( destClasses->getAliasClassesObject()); } @@ -973,7 +974,7 @@ void enzyme::AliasAnalysis::visitOperation( if (!isPointerLike(result.getType())) continue; - results[result.getResultNumber()]->markUnknown(); + (void)results[result.getResultNumber()]->markUnknown(); } return; } @@ -1027,7 +1028,7 @@ void enzyme::AliasAnalysis::visitExternalCall( continue; const AliasClassLattice *srcClasses = operands[operandNo]; - operandAliasClasses.join(srcClasses->getAliasClassesObject()); + (void)operandAliasClasses.join(srcClasses->getAliasClassesObject()); if (!mayReadArg(callee, operandNo, argModRef)) continue; @@ -1035,13 +1036,14 @@ void enzyme::AliasAnalysis::visitExternalCall( // If can read from argument, collect the alias classes that can this // argument may be pointing to. const auto *pointsToLattice = getOrCreateFor(call, call); - srcClasses->getAliasClassesObject().foreachClass( + (void)srcClasses->getAliasClassesObject().foreachClass( [&](DistinctAttr srcClass, AliasClassSet::State state) { // Nothing to do in top/bottom case. In the top case, we have already // set `operandAliasClasses` to top above. if (srcClass == nullptr) return ChangeResult::NoChange; - operandAliasClasses.join(pointsToLattice->getPointsTo(srcClass)); + (void)operandAliasClasses.join( + pointsToLattice->getPointsTo(srcClass)); return ChangeResult::NoChange; }); } diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 4f922d6eb5d5..9c540064a49a 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -330,7 +330,7 @@ class MemoryActivity : public AbstractDenseLattice { const MemoryActivityState *rhsActivity = isKnownInRHS ? &rhsIt->getSecond() : &rhs.otherMemoryActivity; MemoryActivityState updatedActivity(*lhsActivity); - updatedActivity.merge(*rhsActivity); + (void)updatedActivity.merge(*rhsActivity); if ((lhsIt != activityStates.end() && updatedActivity != lhsIt->getSecond()) || (lhsIt == activityStates.end() && @@ -490,7 +490,7 @@ std::optional getCopySource(Operation *op) { /// If the classes are undefined, the callback will not be called at all. void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass, function_ref forEachFn) { - ptrAliasClass->getAliasClassesObject().foreachClass( + (void)ptrAliasClass->getAliasClassesObject().foreachClass( [&](DistinctAttr alloc, enzyme::AliasClassSet::State state) { if (state != enzyme::AliasClassSet::State::Undefined) forEachFn(alloc); @@ -854,15 +854,16 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, std::deque frontier; DenseSet visited; auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) { - aliasClasses.foreachClass([&](DistinctAttr neighbor, - enzyme::AliasClassSet::State state) { - assert(neighbor && "unhandled undefined/unknown case before visit"); - if (!visited.contains(neighbor)) { - visited.insert(neighbor); - frontier.push_back(neighbor); - } - return ChangeResult::NoChange; - }); + (void)aliasClasses.foreachClass( + [&](DistinctAttr neighbor, enzyme::AliasClassSet::State state) { + assert(neighbor && + "unhandled undefined/unknown case before visit"); + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + frontier.push_back(neighbor); + } + return ChangeResult::NoChange; + }); }; // If this triggers, investigate why the alias classes weren't computed. @@ -1071,7 +1072,7 @@ void enzyme::runDataFlowActivityAnalysis( // analyses, enzyme_const is the default. if (activity == enzyme::Activity::enzyme_out) { auto *argLattice = solver.getOrCreateState(arg); - argLattice->join(ValueActivity::getActiveVal()); + (void)argLattice->join(ValueActivity::getActiveVal()); } } @@ -1086,7 +1087,7 @@ void enzyme::runDataFlowActivityAnalysis( solver.getOrCreateState(operand); // Very basic type inference of the type if (isa(operand.getType())) { - returnLattice->meet(ValueActivity::getActiveVal()); + (void)returnLattice->meet(ValueActivity::getActiveVal()); } } } diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 93488a07fd6e..334d7325d7ec 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -28,7 +28,6 @@ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 20ed156d312d..467a8f59ec6e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -36,14 +36,12 @@ mlir::enzyme::MGradientUtils::MGradientUtils( ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp) - : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), TA(TA_), - TR(TR_), omp(omp), blocksNotForAnalysis(), + : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), + invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_), + originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - width(width), ArgDiffeTypes(ArgDiffeTypes_), - originalToNewFn(originalToNewFn_), - originalToNewFnOps(originalToNewFnOps_), - invertedPointers(invertedPointers_) { + TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) { /* for (BasicBlock &BB : *oldFunc) { diff --git a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp index 99a41f804593..c2680ae6db72 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp @@ -91,7 +91,7 @@ struct PrintAliasAnalysisPass continue; // TODO(zinenko): this has been overriding the argument... // Use an array attr instead (will break syntactic tests). - state->getAliasClassesObject().foreachClass( + (void)state->getAliasClassesObject().foreachClass( [&](DistinctAttr aliasClass, enzyme::AliasClassSet::State state) { if (state == enzyme::AliasClassSet::State::Undefined) funcOp.setArgAttr( diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index 9ba5bd2bec5c..f73948a703ad 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -25,8 +25,6 @@ #ifndef ENZYME_TYPE_ANALYSIS_BASE_TYPE_H #define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1 -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/ErrorHandling.h" #include /// Categories of potential types @@ -57,11 +55,11 @@ static inline std::string to_string(BaseType t) { case BaseType::Unknown: return "Unknown"; } - llvm_unreachable("unknown inttype"); + assert(0 && "unknown inttype"); } /// Convert string to BaseType -static inline BaseType parseBaseType(llvm::StringRef str) { +template static inline BaseType parseBaseType(T str) { if (str == "Integer") return BaseType::Integer; if (str == "Float") @@ -72,6 +70,6 @@ static inline BaseType parseBaseType(llvm::StringRef str) { return BaseType::Anything; if (str == "Unknown") return BaseType::Unknown; - llvm_unreachable("Unknown BaseType string"); + assert(0 && "Unknown BaseType string"); } #endif diff --git a/enzyme/test/Integration/ForwardMode/loops.c b/enzyme/test/Integration/ForwardMode/loops.c index 612839248cc2..33bef07a2d16 100644 --- a/enzyme/test/Integration/ForwardMode/loops.c +++ b/enzyme/test/Integration/ForwardMode/loops.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopsdouble.c b/enzyme/test/Integration/ForwardMode/loopsdouble.c index 8c34740c80d5..4faa0cad74f3 100644 --- a/enzyme/test/Integration/ForwardMode/loopsdouble.c +++ b/enzyme/test/Integration/ForwardMode/loopsdouble.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopstriple.c b/enzyme/test/Integration/ForwardMode/loopstriple.c index 060e74d009f7..84e98ab1694e 100644 --- a/enzyme/test/Integration/ForwardMode/loopstriple.c +++ b/enzyme/test/Integration/ForwardMode/loopstriple.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/rwrloop.c b/enzyme/test/Integration/ForwardMode/rwrloop.c index dc71b0eead20..32e548029f6d 100644 --- a/enzyme/test/Integration/ForwardMode/rwrloop.c +++ b/enzyme/test/Integration/ForwardMode/rwrloop.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/sumtil.c b/enzyme/test/Integration/ForwardMode/sumtil.c index 9e9286f97157..5d6369df4909 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil.c +++ b/enzyme/test/Integration/ForwardMode/sumtil.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardMode/sumtil2.c b/enzyme/test/Integration/ForwardMode/sumtil2.c index 32d289703e69..18428ca371d6 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil2.c +++ b/enzyme/test/Integration/ForwardMode/sumtil2.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardModeVector/binops.c b/enzyme/test/Integration/ForwardModeVector/binops.c index 27c786a86980..a3b66f27ea31 100644 --- a/enzyme/test/Integration/ForwardModeVector/binops.c +++ b/enzyme/test/Integration/ForwardModeVector/binops.c @@ -7,10 +7,7 @@ // RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include +#include "../test_utils.h" /* #ifdef __cplusplus @@ -28,18 +25,6 @@ threshold) { if (fabs(f1-f2) > threshold) return false; return true; #endif */ -#define APPROX_EQ(LHS, RHS, THRES) \ - { \ - if (__builtin_fabs(LHS - RHS) > THRES) { \ - fprintf(stderr, \ - "Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d " \ - "(%s)\n", \ - #LHS, LHS, #RHS, RHS, THRES, __FILE__, __LINE__, \ - __PRETTY_FUNCTION__); \ - abort(); \ - } \ - }; - typedef struct { double dx, dy; diff --git a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c index 1236a86388e7..27bd89f36e40 100644 --- a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c +++ b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c @@ -7,8 +7,9 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi -#include -#include "../test_utils.h" +extern int enzyme_allocated; +extern int enzyme_tape; +double sin(double); void __enzyme_reverse(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/boundissue.c b/enzyme/test/Integration/ReverseMode/boundissue.c index f297ebbdda97..4d41f2bf482a 100644 --- a/enzyme/test/Integration/ReverseMode/boundissue.c +++ b/enzyme/test/Integration/ReverseMode/boundissue.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/cachefwd.c b/enzyme/test/Integration/ReverseMode/cachefwd.c index ab56e3dd0853..6e1b058a7004 100644 --- a/enzyme/test/Integration/ReverseMode/cachefwd.c +++ b/enzyme/test/Integration/ReverseMode/cachefwd.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/customcombined.c b/enzyme/test/Integration/ReverseMode/customcombined.c index 5ce2c1df2b07..cbd86d0452ef 100644 --- a/enzyme/test/Integration/ReverseMode/customcombined.c +++ b/enzyme/test/Integration/ReverseMode/customcombined.c @@ -32,7 +32,7 @@ void* augment_square_(const double* src, const double *d_src, double* dest, doub // intentionally incorrect for debugging *dest = 7.0; *d_dest = 11.0; - return NULL; + return (void*)0; } int gradient = 0; diff --git a/enzyme/test/Integration/ReverseMode/dbginfo.c b/enzyme/test/Integration/ReverseMode/dbginfo.c index 06767e624849..31d9bafe4439 100644 --- a/enzyme/test/Integration/ReverseMode/dbginfo.c +++ b/enzyme/test/Integration/ReverseMode/dbginfo.c @@ -7,8 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - -g | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - -g | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -//#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c index b05a7264ae3c..42daecbaa7ed 100644 --- a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c +++ b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/forrealloc.c b/enzyme/test/Integration/ReverseMode/forrealloc.c index 924f7f04a9f1..5daf309c6c6a 100644 --- a/enzyme/test/Integration/ReverseMode/forrealloc.c +++ b/enzyme/test/Integration/ReverseMode/forrealloc.c @@ -7,14 +7,8 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi -#include -#include -#include -#include - #include "../test_utils.h" - float __enzyme_autodiff(void*, float, int); float foo(float inp, int n) { diff --git a/enzyme/test/Integration/ReverseMode/frexp.c b/enzyme/test/Integration/ReverseMode/frexp.c index 4b917e7e282f..08858d51f796 100644 --- a/enzyme/test/Integration/ReverseMode/frexp.c +++ b/enzyme/test/Integration/ReverseMode/frexp.c @@ -7,12 +7,10 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" +extern double frexp ( double num, int* exp ); + double f(double x) { int exp; return frexp(x, &exp); diff --git a/enzyme/test/Integration/ReverseMode/fwdsolve.c b/enzyme/test/Integration/ReverseMode/fwdsolve.c index c7a6ad63040e..5a7ab4cc67e8 100644 --- a/enzyme/test/Integration/ReverseMode/fwdsolve.c +++ b/enzyme/test/Integration/ReverseMode/fwdsolve.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" @@ -25,10 +20,10 @@ void forward_sub(int N, double* __restrict__ L, double * __restrict__ b, double b must be a vector of the same leading dimension as L """ */ - for (size_t i=0; i 1) - for (size_t j=0; j #include "../test_utils.h" typedef struct { diff --git a/enzyme/test/Integration/ReverseMode/headerremat.c b/enzyme/test/Integration/ReverseMode/headerremat.c index b3397ad8bead..a324a519557f 100644 --- a/enzyme/test/Integration/ReverseMode/headerremat.c +++ b/enzyme/test/Integration/ReverseMode/headerremat.c @@ -7,15 +7,8 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" - -#include - __attribute__((noinline)) int evaluate_integrand(const int nr, const int dtheta) diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum.c b/enzyme/test/Integration/ReverseMode/insertsort_sum.c index 7e7d6b0a3ae4..7b42f7613890 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum.c @@ -6,10 +6,6 @@ // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c index 3a643492d30f..321438af8b9a 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c index b62150353581..f6bfd1a4c5c0 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c @@ -8,10 +8,6 @@ // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include "../test_utils.h" -#include -#include -#include -#include #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loops.c b/enzyme/test/Integration/ReverseMode/loops.c index 3a2794963eb5..57746125be68 100644 --- a/enzyme/test/Integration/ReverseMode/loops.c +++ b/enzyme/test/Integration/ReverseMode/loops.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loopsdouble.c b/enzyme/test/Integration/ReverseMode/loopsdouble.c index 8108a9813c63..c9ffc8b6bde6 100644 --- a/enzyme/test/Integration/ReverseMode/loopsdouble.c +++ b/enzyme/test/Integration/ReverseMode/loopsdouble.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loopstriple.c b/enzyme/test/Integration/ReverseMode/loopstriple.c index 490502145810..a1aa76444893 100644 --- a/enzyme/test/Integration/ReverseMode/loopstriple.c +++ b/enzyme/test/Integration/ReverseMode/loopstriple.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/manydiv.c b/enzyme/test/Integration/ReverseMode/manydiv.c index a0da1f6499b3..e3937c6bb956 100644 --- a/enzyme/test/Integration/ReverseMode/manydiv.c +++ b/enzyme/test/Integration/ReverseMode/manydiv.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/manymax.c b/enzyme/test/Integration/ReverseMode/manymax.c index 90b61041ec1e..4cf31bf1cfa7 100644 --- a/enzyme/test/Integration/ReverseMode/manymax.c +++ b/enzyme/test/Integration/ReverseMode/manymax.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/metamalloc.c b/enzyme/test/Integration/ReverseMode/metamalloc.c index 96a848fd9698..3765a0c803fd 100644 --- a/enzyme/test/Integration/ReverseMode/metamalloc.c +++ b/enzyme/test/Integration/ReverseMode/metamalloc.c @@ -7,20 +7,16 @@ // RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); struct { int count; -void* (*allocfn)(size_t); +void* (*allocfn)(unsigned long); } tup = {0, malloc}; __attribute__((noinline)) -void* metamalloc(size_t size) { +void* metamalloc(unsigned long size) { void* ret = tup.allocfn(size); //if (ret != 0) // tup.count++; @@ -38,7 +34,7 @@ double alldiv(double x) { } -static void* (*sallocfn)(size_t) = malloc; +static void* (*sallocfn)(unsigned long) = malloc; __attribute__((noinline)) void* smetamalloc(int size) { return sallocfn(size); diff --git a/enzyme/test/Integration/ReverseMode/metarwr.c b/enzyme/test/Integration/ReverseMode/metarwr.c index 4efcd475f68f..3e3310c634bd 100644 --- a/enzyme/test/Integration/ReverseMode/metarwr.c +++ b/enzyme/test/Integration/ReverseMode/metarwr.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); @@ -19,7 +15,7 @@ void call(double* __restrict__ a, long** data) { long* segment = data[0]; long size = segment[1] - segment[0]; printf("seg[1]=%d seg[0]=%d\n", segment[1], segment[0]); - for (size_t i=0; i -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c index 5dea68e793da..40444dd8306f 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c @@ -5,12 +5,7 @@ // RUN: %clang -std=c11 %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - - -#include -#include -#include -#include +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - q #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c index e2b3f9788fca..22e6fc4dedc7 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c index b4857429b6fe..7de76117e765 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c index 61b81d0a57bf..6c602b32da3a 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c index bcf3ea958485..5417fce6b8ea 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c index eb783098cf02..d64e9a70da9a 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1.c b/enzyme/test/Integration/ReverseMode/mixedstruct1.c index 3ab79b584ef9..e1218dbc0541 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/multivecmaxC.c b/enzyme/test/Integration/ReverseMode/multivecmaxC.c index c39559adb2de..460036659ffd 100644 --- a/enzyme/test/Integration/ReverseMode/multivecmaxC.c +++ b/enzyme/test/Integration/ReverseMode/multivecmaxC.c @@ -10,9 +10,6 @@ // RUN: %clang++ -ffast-math -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -fno-exceptions %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang++ -ffast-math -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -fno-exceptions %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); @@ -21,7 +18,7 @@ extern void __enzyme_autodiff(void*, double*, double*, int); }*/ double reduce_max(double* vec, int size) { - double ret = -INFINITY; + double ret = -10000000; double *maxes = (double*)malloc(sizeof(double)*size); int count = 0; for (int i = 0; i < size; i++) { diff --git a/enzyme/test/Integration/ReverseMode/posix_memalign.c b/enzyme/test/Integration/ReverseMode/posix_memalign.c index 48aab315b95a..b50a405dd70c 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalign.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalign.c @@ -7,15 +7,9 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" -int posix_memalign(void **memptr, size_t alignment, size_t size); +int posix_memalign(void **memptr, unsigned long alignment, unsigned long size); float __enzyme_autodiff(void*, float, int); diff --git a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c index 0a01ef095142..7336444421a8 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c @@ -7,15 +7,9 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" -int posix_memalign(void **memptr, size_t alignment, size_t size); +int posix_memalign(void **memptr, unsigned long alignment, unsigned long size); float __enzyme_autodiff(void*, float, int); diff --git a/enzyme/test/Integration/ReverseMode/readwriteread.c b/enzyme/test/Integration/ReverseMode/readwriteread.c index ecfdce54d27e..adb5afca594a 100644 --- a/enzyme/test/Integration/ReverseMode/readwriteread.c +++ b/enzyme/test/Integration/ReverseMode/readwriteread.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/recurse.c b/enzyme/test/Integration/ReverseMode/recurse.c index a53041e77188..ca06c745411f 100644 --- a/enzyme/test/Integration/ReverseMode/recurse.c +++ b/enzyme/test/Integration/ReverseMode/recurse.c @@ -9,10 +9,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/remat.c b/enzyme/test/Integration/ReverseMode/remat.c index b228bcc13c99..c9b771b17d1b 100644 --- a/enzyme/test/Integration/ReverseMode/remat.c +++ b/enzyme/test/Integration/ReverseMode/remat.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -// test.c -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rematSimple.c b/enzyme/test/Integration/ReverseMode/rematSimple.c index 7b95f9cb9506..788e0c258c8e 100644 --- a/enzyme/test/Integration/ReverseMode/rematSimple.c +++ b/enzyme/test/Integration/ReverseMode/rematSimple.c @@ -3,10 +3,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - ; fi -// test.c -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rwrloop.c b/enzyme/test/Integration/ReverseMode/rwrloop.c index cdf9e3774553..74b9acb7b897 100644 --- a/enzyme/test/Integration/ReverseMode/rwrloop.c +++ b/enzyme/test/Integration/ReverseMode/rwrloop.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rwrmeta.c b/enzyme/test/Integration/ReverseMode/rwrmeta.c index 34a15d5c93d4..985ccdd44f1a 100644 --- a/enzyme/test/Integration/ReverseMode/rwrmeta.c +++ b/enzyme/test/Integration/ReverseMode/rwrmeta.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/smallrealloc.c b/enzyme/test/Integration/ReverseMode/smallrealloc.c index 51e29ccdfce5..1244a2f12cf1 100644 --- a/enzyme/test/Integration/ReverseMode/smallrealloc.c +++ b/enzyme/test/Integration/ReverseMode/smallrealloc.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/subdoublestore.c b/enzyme/test/Integration/ReverseMode/subdoublestore.c index f411d98f3e7f..130f79c99979 100644 --- a/enzyme/test/Integration/ReverseMode/subdoublestore.c +++ b/enzyme/test/Integration/ReverseMode/subdoublestore.c @@ -7,12 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/sumtil.c b/enzyme/test/Integration/ReverseMode/sumtil.c index 0a2b0502c2bc..f2a34b228afe 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil.c +++ b/enzyme/test/Integration/ReverseMode/sumtil.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/sumtil2.c b/enzyme/test/Integration/ReverseMode/sumtil2.c index aac316c7c4ea..85cea6c94095 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil2.c +++ b/enzyme/test/Integration/ReverseMode/sumtil2.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/taylorlog.c b/enzyme/test/Integration/ReverseMode/taylorlog.c index 649dd4fff243..fdecb8aac7af 100644 --- a/enzyme/test/Integration/ReverseMode/taylorlog.c +++ b/enzyme/test/Integration/ReverseMode/taylorlog.c @@ -7,8 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -//#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/blas_inline.h b/enzyme/test/Integration/blas_inline.h index 8ac0c16f51b3..b1b988190a5e 100644 --- a/enzyme/test/Integration/blas_inline.h +++ b/enzyme/test/Integration/blas_inline.h @@ -1,5 +1,8 @@ #include #include +#include +#include +#include typedef int32_t integer; typedef double doublereal; diff --git a/enzyme/test/Integration/test_utils.h b/enzyme/test/Integration/test_utils.h index afcf87d4471b..3226b61874ef 100644 --- a/enzyme/test/Integration/test_utils.h +++ b/enzyme/test/Integration/test_utils.h @@ -1,7 +1,22 @@ -#include -#include -#include -#include + + +#ifdef __cplusplus +extern "C" { +#endif +struct _IO_FILE; +extern struct _IO_FILE* stderr; +extern int fprintf(struct _IO_FILE *, const char*, ...); +extern int fflush(struct _IO_FILE *stream); +extern int printf(const char*, ...); +extern void abort(); +extern void free(void *); +extern void* malloc(unsigned long); +extern void *realloc( void *ptr, unsigned long new_size ); +extern void* memcpy( void* dest, const void* src, unsigned long count ); +extern void* memset( void* dest, int, unsigned long count ); +#ifdef __cplusplus +} +#endif extern #ifdef __cplusplus diff --git a/enzyme/test/MLIR/ForwardMode/inactive.mlir b/enzyme/test/MLIR/ForwardMode/inactive.mlir index 6ccb8d7033ee..10b7fcd61b27 100644 --- a/enzyme/test/MLIR/ForwardMode/inactive.mlir +++ b/enzyme/test/MLIR/ForwardMode/inactive.mlir @@ -1,11 +1,11 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme %s -allow-unregistered-dialect | FileCheck %s module { func.func @inactive(%x : f64) -> f64 { - // We don't have an interface implementation for "func", + // We don't have an interface implementation for "foo", // but we can see it's inactive from its lack of operands // and results. - func.func private @foo() + "test.foo"() : () -> () return %x : f64 } func.func @diff(%x : f64, %dx : f64) -> f64 { @@ -17,4 +17,4 @@ module { // Just check that we didn't trigger the error on there not being an interface // implementation. // CHECK-LABEL: func private @fwddiffeinactive -// CHECK: func private @foo +// CHECK: "test.foo"() From dc5eaa56b9fbb64aad911ffed27e7a59e88a9b32 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 13:12:23 +0000 Subject: [PATCH 041/131] Logabsgamma support (#1669) * logabsgamma * Fix type analysis * fixed * Add logabsgammaf --- enzyme/Enzyme/InstructionDerivatives.td | 30 +++++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 62 ++++++++++++++----- enzyme/test/Enzyme/ForwardMode/logabsgamma.ll | 28 +++++++++ enzyme/test/Enzyme/ReverseMode/logabsgamma.ll | 19 ++++-- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 42 +++++++++++++ 5 files changed, 159 insertions(+), 22 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/logabsgamma.ll diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 53fad8160993..00be8fe77e55 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -186,6 +186,12 @@ class PrependArgTypesFunc pretys_> { list pretys = pretys_; } +// Set return to arg[0] +// Same argument types +class ArgAsRetTypesFunc { + string name = name_; +} + // Specify that a given argument is inactive, aka not differentiable // By default this argument tells Enzyme that it must always be inactive // from the function semantics. @@ -461,6 +467,30 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["logabsgamma"], + [ + ( + ArrayRet (FMul (Call<(ArgAsRetTypesFunc<"digamma">), [ReadNone,NoUnwind]> $x), (DiffeRet) ), + (InactiveArg) + ) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["logabsgammaf"], + [ + ( + ArrayRet (FMul (Call<(ArgAsRetTypesFunc<"digammaf">), [ReadNone,NoUnwind]> $x), (DiffeRet) ), + (InactiveArg) + ) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : CallPattern<(Op $x), ["sinpi", "sinpif", "sinpil", "cospi", "cospif", "cospil"], [ diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index f2a3dc56cdf8..7a75daef2501 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -167,6 +167,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"yn", Intrinsic::not_intrinsic}, {"tgamma", Intrinsic::not_intrinsic}, {"lgamma", Intrinsic::not_intrinsic}, + {"logabsgamma", Intrinsic::not_intrinsic}, {"ceil", Intrinsic::ceil}, {"__nv_ceil", Intrinsic::ceil}, {"floor", Intrinsic::floor}, @@ -5305,23 +5306,52 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } else if (T->isVoidTy()) { } else if (auto ST = dyn_cast(T)) { assert(ST->getNumElements() >= 1); - for (size_t i = 1; i < ST->getNumElements(); ++i) { - assert(ST->getTypeAtIndex((unsigned)0) == ST->getTypeAtIndex(i)); - } - if (ST->getTypeAtIndex((unsigned)0)->isFloatingPointTy()) - updateAnalysis( - &call, - TypeTree(ConcreteType( - ST->getTypeAtIndex((unsigned)0)->getScalarType())) - .Only(-1, &call), - &call); - else if (ST->getTypeAtIndex((unsigned)0)->isIntegerTy()) { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), - &call); - } else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); + TypeTree TT; + auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); + for (size_t i = 0; i < ST->getNumElements(); ++i) { + auto T = ST->getTypeAtIndex(i); + ConcreteType CT(BaseType::Unknown); + + Value *vec[2] = { + ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), + ConstantInt::get(Type::getInt32Ty(call.getContext()), i)}; + auto ud = UndefValue::get(PointerType::getUnqual(ST)); + auto g2 = GetElementPtrInst::Create(ST, ud, vec); + APInt ai(DL.getIndexSizeInBits(0), 0); + g2->accumulateConstantOffset(DL, ai); + delete g2; + size_t Offset = ai.getZExtValue(); + + size_t nextOffset; + if (i + 1 == ST->getNumElements()) + nextOffset = (DL.getTypeSizeInBits(ST) + 7) / 8; + else { + Value *vec[2] = { + ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), + ConstantInt::get(Type::getInt32Ty(call.getContext()), i + 1)}; + auto ud = UndefValue::get(PointerType::getUnqual(ST)); + auto g2 = GetElementPtrInst::Create(ST, ud, vec); + APInt ai(DL.getIndexSizeInBits(0), 0); + g2->accumulateConstantOffset(DL, ai); + delete g2; + nextOffset = ai.getZExtValue(); + } + + if (T->isFloatingPointTy()) { + CT = T; + } else if (T->isIntegerTy()) { + CT = BaseType::Integer; + } + if (CT != BaseType::Unknown) { + TypeTree mid = TypeTree(CT).Only(-1, &call); + TT |= mid.ShiftIndices(DL, /*init offset*/ 0, + /*maxSize*/ nextOffset - Offset, + /*addOffset*/ Offset); + } } + auto Size = (DL.getTypeSizeInBits(ST) + 7) / 8; + TT.CanonicalizeInPlace(Size, DL); + updateAnalysis(&call, TT, &call); } else if (auto AT = dyn_cast(T)) { assert(AT->getNumElements() >= 1); if (AT->getElementType()->isFloatingPointTy()) diff --git a/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll b/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll new file mode 100644 index 000000000000..280d20db590e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define { double, i64 } @tester(double %x) { +entry: + %a = call { double, i64 } @logabsgamma(double %x) + ret { double, i64 } %a +} + +define { double, i64 } @test_derivative(double %x, double %dx) { +entry: + %0 = tail call { double, i64 } (...) @__enzyme_fwddiff({ double, i64 } (double)* nonnull @tester, double %x, double %dx) + ret { double, i64 } %0 +} + +declare { double, i64 } @logabsgamma(double) + +; Function Attrs: nounwind +declare { double, i64 } @__enzyme_fwddiff(...) + +; CHECK: define internal { double, i64 } @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @digamma(double %x) +; CHECK-NEXT: %1 = fmul fast double %0, %"x'" +; CHECK-NEXT: %2 = insertvalue { double, i64 } undef, double %1, 0 +; CHECK-NEXT: ret { double, i64 } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll index 02e0f40cd5bc..7682abf13dd9 100644 --- a/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll +++ b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll @@ -1,8 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s -; XFAIL: * - ; Function Attrs: nounwind readnone uwtable define double @tester(double %x) { entry: @@ -24,8 +22,17 @@ declare double @__enzyme_autodiff(double (double)*, ...) ; CHECK: define internal { double } @diffetester(double %x, double %differeturn) ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call fast double @cosh(double %x) -; CHECK-NEXT: %1 = fmul fast double %differeturn, %0 -; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0 -; CHECK-NEXT: ret { double } %2 +; CHECK-NEXT: %"a'de" = alloca { double, i64 }, align 8 +; CHECK-NEXT: store { double, i64 } zeroinitializer, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: %0 = getelementptr inbounds { double, i64 }, { double, i64 }* %"a'de", i32 0, i32 0 +; CHECK-NEXT: %1 = load double, double* %0, align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %differeturn +; CHECK-NEXT: store double %2, double* %0, align 8 +; CHECK-NEXT: %3 = load { double, i64 }, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: store { double, i64 } zeroinitializer, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: %4 = call fast double @digamma(double %x) +; CHECK-NEXT: %5 = extractvalue { double, i64 } %3, 0 +; CHECK-NEXT: %6 = fmul fast double %4, %5 +; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0 +; CHECK-NEXT: ret { double } %7 ; CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index c30f8870063e..bf6a00d07cd2 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -119,6 +119,21 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, << ")->getCallingConv();\n"; return; } + if (opName == "ArgAsRetTypesFunc" || + Def->isSubClassOf("ArgAsRetTypesFunc")) { + os << curIndent << "auto " << FT << "_old = cast(&" << origName + << ")->getFunctionType();\n"; + os << curIndent << "auto " << FT << " = FunctionType::get(" << FT + << "_old->params()[0], " << FT << "_old->params(), " << FT + << "_old->isVarArg());\n"; + os << curIndent << "auto " << callval + << " = gutils->oldFunc->getParent()->getOrInsertFunction("; + os << Def->getValueInit("name")->getAsString(); + os << ", " << FT << ", called->getAttributes()).getCallee();\n"; + os << curIndent << "auto " << cconv << " = cast(&" << origName + << ")->getCallingConv();\n"; + return; + } } assert(0 && "Unhandled function"); } @@ -954,6 +969,13 @@ void handleUse( foundDiffRet = true; return; } + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + return; + } + if (!Def->isSubClassOf("Operation")) { + errs() << *resultTree << "\n"; + errs() << opName << " " << *Def << "\n"; + } assert(Def->isSubClassOf("Operation")); bool usesPrimal = Def->getValueAsBit("usesPrimal"); bool usesShadow = Def->getValueAsBit("usesShadow"); @@ -1527,6 +1549,9 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } return; } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } os << curIndent << INDENT << "{\n"; if (intrinsic == MLIRDerivatives) os << curIndent << INDENT << INDENT << "mlir::Value itmp = "; @@ -1764,6 +1789,9 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } return; } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } const char *curIndent = " "; os << curIndent << "{\n"; if (intrinsic == MLIRDerivatives) @@ -1859,10 +1887,24 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } if (intrinsic != MLIRDerivatives) { + os << " auto found = gutils->invertedPointers.find(&(" << origName + << "));\n"; + os << " if (found != gutils->invertedPointers.end()) {\n"; + os << " PHINode* PN = cast(&*found->second);\n"; + os << " gutils->invertedPointers.erase(found);\n"; + os << " gutils->erase(PN);\n"; + os << " }\n"; os << " break;\n"; os << " }\n"; os << " case DerivativeMode::ReverseModePrimal:{\n"; + os << " auto found = gutils->invertedPointers.find(&(" << origName + << "));\n"; + os << " if (found != gutils->invertedPointers.end()) {\n"; + os << " PHINode* PN = cast(&*found->second);\n"; + os << " gutils->invertedPointers.erase(found);\n"; + os << " gutils->erase(PN);\n"; + os << " }\n"; // TODO os << " break;\n"; os << " }\n"; From 3c60e4317595d32b87353019b48f62a32657cd1e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 16:53:58 +0000 Subject: [PATCH 042/131] Fix symbol link collision (#1671) --- enzyme/Enzyme/ActivityAnalysis.cpp | 14 +++++++------- enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index fa4dd06bb68c..7b251e154f94 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -104,7 +104,7 @@ cl::opt EnzymeEnableRecursiveHypotheses( #include // clang-format off -const char *KnownInactiveFunctionsStartingWith[] = { +static const char *KnownInactiveFunctionsStartingWith[] = { "f90io", "$ss5print", "_ZTv0_n24_NSoD", //"1Ev, 0Ev @@ -113,11 +113,11 @@ const char *KnownInactiveFunctionsStartingWith[] = { "_ZNSaIcEC1Ev", }; -const char *KnownInactiveFunctionsContains[] = { +static const char *KnownInactiveFunctionsContains[] = { "__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"}; -const StringSet<> InactiveGlobals = { +static const StringSet<> InactiveGlobals = { "small_typeof", "ompi_request_null", "ompi_mpi_double", @@ -171,7 +171,7 @@ const llvm::StringMap MPIInactiveCommAllocators = { // Instructions which themselves are inactive // the returned value, however, may still be active -const StringSet<> KnownInactiveFunctionInsts = { +static const StringSet<> KnownInactiveFunctionInsts = { "__dynamic_cast", "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", @@ -180,7 +180,7 @@ const StringSet<> KnownInactiveFunctionInsts = { "jl_ptr_to_array", "jl_ptr_to_array_1d"}; -const StringSet<> KnownInactiveFunctions = { +static const StringSet<> KnownInactiveFunctions = { "mpfr_greater_p", "__nv_isnand", "__nv_isnanf", @@ -293,7 +293,7 @@ const StringSet<> KnownInactiveFunctions = { "floorl" }; -const std::set KnownInactiveIntrinsics = { +static const std::set KnownInactiveIntrinsics = { #if LLVM_VERSION_MAJOR >= 12 Intrinsic::experimental_noalias_scope_decl, #endif @@ -342,7 +342,7 @@ const std::set KnownInactiveIntrinsics = { Intrinsic::is_constant, Intrinsic::memset}; -const char *DemangledKnownInactiveFunctionsStartingWith[] = { +static const char *DemangledKnownInactiveFunctionsStartingWith[] = { // TODO this returns allocated memory and thus can be an active value // "std::allocator", "std::chrono::_V2::steady_clock::now", diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 5f3cd22db72d..349213ac2fbf 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -20,7 +20,7 @@ #include "Interfaces/AutoDiffOpInterface.h" -const char *KnownInactiveFunctionsStartingWith[] = { +static const char *KnownInactiveFunctionsStartingWith[] = { "f90io", "$ss5print", "_ZTv0_n24_NSoD", //"1Ev, 0Ev @@ -29,11 +29,11 @@ const char *KnownInactiveFunctionsStartingWith[] = { "_ZNSaIcEC1Ev", }; -const char *KnownInactiveFunctionsContains[] = { +static const char *KnownInactiveFunctionsContains[] = { "__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"}; -const std::set InactiveGlobals = { +static const std::set InactiveGlobals = { "ompi_request_null", "ompi_mpi_double", "ompi_mpi_comm_world", "stderr", "stdout", "stdin", "_ZSt3cin", "_ZSt4cout", "_ZSt5wcout", "_ZSt4cerr", "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", @@ -57,7 +57,7 @@ const std::set InactiveGlobals = { "_ZTVN10__cxxabiv117__class_type_infoE", "_ZTVN10__cxxabiv121__vmi_class_type_infoE"}; -const std::map MPIInactiveCommAllocators = { +static const std::map MPIInactiveCommAllocators = { {"MPI_Graph_create", 5}, {"MPI_Comm_split", 2}, {"MPI_Intercomm_create", 6}, @@ -75,7 +75,7 @@ const std::map MPIInactiveCommAllocators = { // Instructions which themselves are inactive // the returned value, however, may still be active -const std::set KnownInactiveFunctionInsts = { +static const std::set KnownInactiveFunctionInsts = { "__dynamic_cast", "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", @@ -84,7 +84,7 @@ const std::set KnownInactiveFunctionInsts = { "jl_ptr_to_array", "jl_ptr_to_array_1d"}; -const std::set KnownInactiveFunctions = { +static const std::set KnownInactiveFunctions = { "abort", "time", "memcmp", @@ -165,7 +165,7 @@ const std::set KnownInactiveFunctions = { "logbl", }; -const char *DemangledKnownInactiveFunctionsStartingWith[] = { +static const char *DemangledKnownInactiveFunctionsStartingWith[] = { // TODO this returns allocated memory and thus can be an active value // "std::allocator", "std::string", From b8c1936802aa8ec75115f37b60f50b4482fa5748 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 17:51:02 +0000 Subject: [PATCH 043/131] Fix integration tests (#1670) --- enzyme/test/Integration/ReverseMode/cmplx.cpp | 4 ---- enzyme/test/Integration/ReverseMode/eigentensor.cpp | 1 + enzyme/test/Integration/test_utils.h | 8 ++++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/enzyme/test/Integration/ReverseMode/cmplx.cpp b/enzyme/test/Integration/ReverseMode/cmplx.cpp index 73559beb94fb..bec8aa0d831b 100644 --- a/enzyme/test/Integration/ReverseMode/cmplx.cpp +++ b/enzyme/test/Integration/ReverseMode/cmplx.cpp @@ -11,10 +11,6 @@ #include "../test_utils.h" -#include -#include - -#include #include // std::complex, std::abs, std::arg void __enzyme_autodiff(...); diff --git a/enzyme/test/Integration/ReverseMode/eigentensor.cpp b/enzyme/test/Integration/ReverseMode/eigentensor.cpp index 28feb236f4db..6e2da6d65818 100644 --- a/enzyme/test/Integration/ReverseMode/eigentensor.cpp +++ b/enzyme/test/Integration/ReverseMode/eigentensor.cpp @@ -16,6 +16,7 @@ #include "../test_utils.h" +#include void memcpy(float* __restrict dst, float* __restrict src, size_t count) { for(size_t i=0; i +#include +#include +#else struct _IO_FILE; extern struct _IO_FILE* stderr; extern int fprintf(struct _IO_FILE *, const char*, ...); @@ -14,8 +16,6 @@ extern void* malloc(unsigned long); extern void *realloc( void *ptr, unsigned long new_size ); extern void* memcpy( void* dest, const void* src, unsigned long count ); extern void* memset( void* dest, int, unsigned long count ); -#ifdef __cplusplus -} #endif extern From 66a3bb6b394b8824e94913898368d0bb34550d75 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 18:12:07 +0000 Subject: [PATCH 044/131] Packaging fixes for julia on 16 (#1672) --- .packaging/build_tarballs.jl | 4 +- enzyme/BCLoad/CMakeLists.txt | 6 +-- enzyme/Enzyme/AdjointGenerator.h | 47 ++++----------------- enzyme/Enzyme/CallDerivatives.cpp | 6 +-- enzyme/Enzyme/DiffeGradientUtils.cpp | 25 +++-------- enzyme/Enzyme/Enzyme.cpp | 22 ++-------- enzyme/Enzyme/EnzymeLogic.cpp | 7 +-- enzyme/Enzyme/GradientUtils.cpp | 13 ++---- enzyme/Enzyme/LibraryFuncs.h | 6 +-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 11 +---- 10 files changed, 30 insertions(+), 117 deletions(-) diff --git a/.packaging/build_tarballs.jl b/.packaging/build_tarballs.jl index 67d214bf2c86..009b48e756e4 100644 --- a/.packaging/build_tarballs.jl +++ b/.packaging/build_tarballs.jl @@ -28,7 +28,7 @@ platforms = expand_cxxstring_abis(supported_platforms(; experimental=true)) script = raw""" cd Enzyme -if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15.asserts* ]]; then +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+17.asserts* ]]; then # LLVM 15 requires macOS SDK 10.14. pushd $WORKSPACE/srcdir/MacOSX10.*.sdk rm -rf /opt/${target}/${target}/sys-root/System @@ -117,7 +117,7 @@ for llvm_version in llvm_versions, llvm_assertions in (false, true) for platform in platforms augmented_platform = deepcopy(platform) augmented_platform[LLVM.platform_name] = LLVM.platform(llvm_version, llvm_assertions) - gcc_version = version > v"15" ? v"10" : v"8" + gcc_version = llvm_version > v"15" ? v"10" : v"8" should_build_platform(triplet(augmented_platform)) || continue push!(builds, (; dependencies, products, diff --git a/enzyme/BCLoad/CMakeLists.txt b/enzyme/BCLoad/CMakeLists.txt index f2b04238fb67..6516a46dfbd5 100644 --- a/enzyme/BCLoad/CMakeLists.txt +++ b/enzyme/BCLoad/CMakeLists.txt @@ -4,10 +4,10 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(BC_LOAD_FLAGS "" CACHE STRING "") set(BC_LOAD_HEADER "" CACHE STRING "") -if (${LLVM_VERSION_MAJOR} LESS 15) - set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") -else() +if (${LLVM_VERSION_MAJOR} EQUAL 16) set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS} -Xclang -no-opaque-pointers") +else() + set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") endif() if (APPLE) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 80d6c851e2a2..debac4dd496f 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4859,18 +4859,9 @@ class AdjointGenerator } } Value *tape = nullptr; -#if LLVM_VERSION_MAJOR >= 16 - if (tapeIdx.has_value()) -#else - if (tapeIdx.hasValue()) -#endif - { + if (tapeIdx) { -#if LLVM_VERSION_MAJOR >= 16 - auto idx = tapeIdx.value(); -#else - auto idx = tapeIdx.getValue(); -#endif + auto idx = *tapeIdx; FunctionType *FT = subdata->fn->getFunctionType(); tape = BuilderZ.CreatePHI( @@ -5416,17 +5407,8 @@ class AdjointGenerator if (!augmentcall->getType()->isVoidTy()) augmentcall->setName(call.getName() + "_augmented"); -#if LLVM_VERSION_MAJOR >= 16 - if (tapeIdx.has_value()) -#else - if (tapeIdx.hasValue()) -#endif - { -#if LLVM_VERSION_MAJOR >= 16 - auto tval = tapeIdx.value(); -#else - auto tval = tapeIdx.getValue(); -#endif + if (tapeIdx) { + auto tval = *tapeIdx; tape = (tval == -1) ? augmentcall : BuilderZ.CreateExtractValue( augmentcall, {(unsigned)tval}, "subcache"); @@ -5445,11 +5427,7 @@ class AdjointGenerator Value *dcall = nullptr; assert(returnIdx); assert(augmentcall); -#if LLVM_VERSION_MAJOR >= 16 - auto rval = returnIdx.value(); -#else - auto rval = returnIdx.getValue(); -#endif + auto rval = *returnIdx; dcall = (rval < 0) ? augmentcall : BuilderZ.CreateExtractValue(augmentcall, {(unsigned)rval}); @@ -5520,13 +5498,8 @@ class AdjointGenerator // assert(!tape); // assert(subdata); if (!tape) { -#if LLVM_VERSION_MAJOR >= 16 - assert(tapeIdx.has_value()); - auto tval = tapeIdx.value(); -#else - assert(tapeIdx.hasValue()); - auto tval = tapeIdx.getValue(); -#endif + assert(tapeIdx); + auto tval = *tapeIdx; tape = BuilderZ.CreatePHI( (tapeIdx == -1) ? FT->getReturnType() : cast(FT->getReturnType()) @@ -5585,11 +5558,7 @@ class AdjointGenerator if (Mode == DerivativeMode::ReverseModeCombined || Mode == DerivativeMode::ReverseModePrimal) { -#if LLVM_VERSION_MAJOR >= 16 - auto drval = differetIdx.value(); -#else - auto drval = differetIdx.getValue(); -#endif + auto drval = *differetIdx; newip = (drval < 0) ? augmentcall : BuilderZ.CreateExtractValue(augmentcall, diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index efd802db406b..02719a9a4660 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -2254,11 +2254,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( } if (auto blas = extractBLAS(funcName)) { -#if LLVM_VERSION_MAJOR >= 16 - if (handleBLAS(call, called, blas.value(), overwritten_args)) -#else - if (handleBLAS(call, called, blas.getValue(), overwritten_args)) -#endif + if (handleBLAS(call, called, *blas, overwritten_args)) return true; } diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 68a7e02d048d..f3ca78dacea1 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -966,14 +966,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, if (alignv) { if (start != 0) { // todo make better alignment calculation -#if LLVM_VERSION_MAJOR >= 16 - assert(alignv.value().value() != 0); - if (start % alignv.value().value() != 0) -#else - assert(alignv.getValue().value() != 0); - if (start % alignv.getValue().value() != 0) -#endif - { + assert((*alignv).value() != 0); + if (start % (*alignv).value() != 0) { alignv = Align(1); } } @@ -1007,13 +1001,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, if (alignv) { if (start != 0) { // todo make better alignment calculation -#if LLVM_VERSION_MAJOR >= 16 - assert(alignv.value().value() != 0); - if (start % alignv.value().value() != 0) { -#else - assert(alignv.getValue().value() != 0); - if (start % alignv.getValue().value() != 0) { -#endif + assert((*alignv).value() != 0); + if (start % (*alignv).value() != 0) { alignv = Align(1); } } @@ -1093,11 +1082,7 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc())); if (align) { -#if LLVM_VERSION_MAJOR >= 16 - auto alignv = align ? align.value().value() : 0; -#else - auto alignv = align ? align.getValue().value() : 0; -#endif + auto alignv = align ? (*align).value() : 0; if (alignv != 0) { if (start != 0) { // todo make better alignment calculation diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index e402c1656248..c9b73e598222 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1887,15 +1887,9 @@ class EnzymeBase { #endif } -#if LLVM_VERSION_MAJOR >= 16 return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.value(), sizeOnly, + byVal, constants, fn, mode, *options, sizeOnly, calls); -#else - return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.getValue(), - sizeOnly, calls); -#endif } bool HandleProbProg(CallInst *CI, ProbProgMode mode, @@ -2025,17 +2019,9 @@ class EnzymeBase { #endif } -#if LLVM_VERSION_MAJOR >= 16 - bool status = - HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, - constants, newFunc, DerivativeMode::ReverseModeCombined, - opt.value(), false, calls); -#else - bool status = - HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, - constants, newFunc, DerivativeMode::ReverseModeCombined, - opt.getValue(), false, calls); -#endif + bool status = HandleAutoDiff( + CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, + newFunc, DerivativeMode::ReverseModeCombined, *opt, false, calls); delete interface; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 9ad444e4bd3a..580b59f3c557 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4282,13 +4282,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } auto store = entryBuilder.CreateStore( Constant::getNullValue(g.getValueType()), &g); -#if LLVM_VERSION_MAJOR >= 16 - if (g.getAlign()) - store->setAlignment(g.getAlign().value()); -#else if (g.getAlign()) - store->setAlignment(g.getAlign().getValue()); -#endif + store->setAlignment(*g.getAlign()); } } if (sharedBlock) { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 6adfc1bd6b78..0b6947887f6f 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -534,17 +534,10 @@ DebugLoc GradientUtils::getNewFromOriginal(const DebugLoc L) const { return L; assert(originalToNewFn.hasMD()); auto opt = originalToNewFn.getMappedMD(L.getAsMDNode()); -#if LLVM_VERSION_MAJOR >= 16 - if (!opt.has_value()) - return L; - assert(opt.has_value()); - return DebugLoc(cast(opt.value())); -#else - if (!opt.hasValue()) + if (!opt) return L; - assert(opt.hasValue()); - return DebugLoc(cast(*opt.getPointer())); -#endif + assert(opt); + return DebugLoc(cast(*opt)); } Value *GradientUtils::getNewFromOriginal(const Value *originst) const { diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index 5fd3992cfad5..1a8ebce38eee 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -209,11 +209,7 @@ static inline void zeroKnownAllocation(llvm::IRBuilder<> &bb, } if (funcName == "enzyme_allocator") { auto index = getAllocationIndexFromCall(orig); -#if LLVM_VERSION_MAJOR >= 16 - allocSize = argValues[index.value()]; -#else - allocSize = argValues[index.getValue()]; -#endif + allocSize = argValues[*index]; } Value *dst_arg = toZero; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 7a75daef2501..a5db987b5a74 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4304,17 +4304,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { StringRef funcName = getFuncNameFromCall(&call); auto blasMetaData = extractBLAS(funcName); -#if LLVM_VERSION_MAJOR >= 16 - if (blasMetaData.has_value()) { - BlasInfo blas = blasMetaData.value(); + if (blasMetaData) { + BlasInfo blas = *blasMetaData; #include "BlasTA.inc" } -#else - if (blasMetaData.hasValue()) { - BlasInfo blas = blasMetaData.getValue(); -#include "BlasTA.inc" - } -#endif // When compiling Enzyme against standard LLVM, and not Intel's // modified version of LLVM, the intrinsic `llvm.intel.subscript` is From ed28cb68ccf47b5ff2594421ad62f878be562b03 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Feb 2024 17:32:21 +0000 Subject: [PATCH 045/131] [MLIR] make wrapping pass (#1673) --- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 107 +++++++++++++++++++ enzyme/Enzyme/MLIR/Passes/Passes.h | 6 ++ enzyme/Enzyme/MLIR/Passes/Passes.td | 58 ++++++++++ enzyme/test/MLIR/ForwardMode/wrap.mlir | 16 +++ 5 files changed, 188 insertions(+) create mode 100644 enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp create mode 100644 enzyme/test/MLIR/ForwardMode/wrap.mlir diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 03c91683adff..8a18dc33d195 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_doc(Passes EnzymePasses ./ -gen-pass-doc) add_mlir_dialect_library(MLIREnzymeTransforms EnzymeMLIRPass.cpp + EnzymeWrapPass.cpp PrintActivityAnalysis.cpp PrintAliasAnalysis.cpp EnzymeToMemRef.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp new file mode 100644 index 000000000000..235a061c6088 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -0,0 +1,107 @@ +//===- EnzymeWrapPass.cpp - Replace calls with their derivatives ------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to create wrapper functions which differentiate +// ops. +//===----------------------------------------------------------------------===// + +#include "Dialect/Ops.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "PassDetails.h" +#include "Passes/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace enzyme; + +namespace { +struct DifferentiateWrapperPass + : public DifferentiateWrapperPassBase { + + void runOnOperation() override { + MEnzymeLogic Logic; + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(getOperation()); + + Operation *symbolOp = nullptr; + if (infn != "") + symbolOp = symbolTable.lookupSymbolIn( + getOperation(), StringAttr::get(getOperation()->getContext(), infn)); + else { + for (auto &op : getOperation()->getRegion(0).front()) { + auto fn = dyn_cast(symbolOp); + if (!fn) + continue; + assert(symbolOp == nullptr); + symbolOp = &op; + } + } + auto fn = cast(symbolOp); + SmallVector split; + StringRef(argTys.getValue().data(), argTys.getValue().size()) + .split(split, ','); + std::vector constants; + for (auto &str : split) { + if (str == "enzyme_dup") + constants.push_back(DIFFE_TYPE::DUP_ARG); + else if (str == "enzyme_const") + constants.push_back(DIFFE_TYPE::CONSTANT); + else if (str == "enzyme_dupnoneed") + constants.push_back(DIFFE_TYPE::DUP_NONEED); + else if (str == "enzyme_out") + constants.push_back(DIFFE_TYPE::OUT_DIFF); + else + assert(0 && " unknown constant"); + } + + DIFFE_TYPE retType = retTy.getValue(); + MTypeAnalysis TA; + auto type_args = TA.getAnalyzedTypeInfo(fn); + + bool freeMemory = true; + size_t width = 1; + + std::vector volatile_args; + for (auto &a : fn.getFunctionBody().getArguments()) { + (void)a; + volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); + } + + FunctionOpInterface newFunc = Logic.CreateForwardDiff( + fn, retType, constants, TA, + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + if (outfn == "") { + fn->erase(); + } else { + SymbolTable::setSymbolName(cast(newFunc), + (std::string)outfn); + } + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createDifferentiateWrapperPass() { + new DifferentiateWrapperPass(); + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index d5e821bac643..7e64f7718988 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -8,9 +8,13 @@ #ifndef ENZYME_PASSES_H #define ENZYME_PASSES_H +#include "../../Utils.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Pass/Pass.h" #include + +#include "Enzyme/MLIR/Dialect/Dialect.h" + namespace mlir { class PatternRewriter; class RewritePatternSet; @@ -18,6 +22,8 @@ class DominanceInfo; namespace enzyme { std::unique_ptr createDifferentiatePass(); +std::unique_ptr createDifferentiateWrapperPass(); + std::unique_ptr createPrintActivityAnalysisPass(); std::unique_ptr createPrintAliasAnalysisPass(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 1289ddca9bf4..432b0938f763 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -19,6 +19,64 @@ def DifferentiatePass : Pass<"enzyme"> { let constructor = "mlir::enzyme::createDifferentiatePass()"; } +def DifferentiateWrapperPass : Pass<"enzyme-wrap"> { + let summary = "Add wrapper function to be differentiated"; + let dependentDialects = [ + "cf::ControlFlowDialect", + "enzyme::EnzymeDialect" + ]; + let constructor = "mlir::enzyme::createDifferentiateWrapperPass()"; + let options = [ + Option< + /*C++ variable name=*/"infn", + /*CLI argument=*/"infn", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Input function to differentiate" + >, + Option< + /*C++ variable name=*/"outfn", + /*CLI argument=*/"outfn", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Output function to differentiate" + >, + Option< + /*C++ variable name=*/"mode", + /*CLI argument=*/"mode", + /*type=*/"DerivativeMode", + /*default=*/"DerivativeMode::ForwardMode", + /*description=*/"mode to differentiate", +[{::llvm::cl::values( + clEnumValN(DerivativeMode::ForwardMode, "ForwardMode", "ForwardMode (default)"), + clEnumValN(DerivativeMode::ReverseModeCombined, "ReverseModeCombined", "Combined ReverseMode"), + clEnumValN(DerivativeMode::ReverseModePrimal, "ReverseModePrimal", "Forward Pass of ReverseMode"), + clEnumValN(DerivativeMode::ReverseModeGradient, "ReverseModeGradient", "Backward Pass of ReverseMode") + )}] + >, + Option< + /*C++ variable name=*/"retTy", + /*CLI argument=*/"retTy", + /*type=*/"DIFFE_TYPE", + /*default=*/"DIFFE_TYPE::DUP_ARG", + /*description=*/"activity of the return", +[{::llvm::cl::values( + clEnumValN(DIFFE_TYPE::DUP_ARG, "enzyme_dup", "Duplicated (default)"), + clEnumValN(DIFFE_TYPE::OUT_DIFF, "enzyme_out", "Active"), + clEnumValN(DIFFE_TYPE::CONSTANT, "enzyme_const", "Constant"), + clEnumValN(DIFFE_TYPE::DUP_NONEED, "enzyme_dupnoneed", "Duplicated noneed") + )}] + >, + Option< + /*C++ variable name=*/"argTys", + /*CLI argument=*/"argTys", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"The activity of the arguments" + >, + ]; +} + def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> { let summary = "Print the results of activity analysis"; let constructor = "mlir::enzyme::createPrintActivityAnalysisPass()"; diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir new file mode 100644 index 000000000000..7cb0eb82bf4b --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -0,0 +1,16 @@ +// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup, mode=ForwardMode" %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } +} + +// CHECK: func.func private @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: return %[[i2]] : f64 +// CHECK-NEXT: } From 3dd455b5c5009fda6782e0307b4c10f6af741227 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Feb 2024 19:31:04 +0000 Subject: [PATCH 046/131] Fix tablegen for ll16 without optional support (#1674) --- enzyme/tools/enzyme-tblgen/blasDeclUpdater.h | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h index aef73e479f2f..aebf6d7aec0d 100644 --- a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h @@ -195,17 +195,10 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) { os << " auto name = getFuncName(&F);\n"; os << " auto changed = false;\n"; os << " auto blasMetaData = extractBLAS(name);\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " if (F.empty() && blasMetaData.has_value()) {\n"; - os << " attributeBLAS(blasMetaData.value(), &F);\n"; - os << " changed = true;\n"; - os << " }\n"; - os << " #else\n"; - os << " if (F.empty() && blasMetaData.hasValue()) {\n"; - os << " attributeBLAS(blasMetaData.getValue(), &F);\n"; - os << " changed = true;\n"; - os << " }\n"; - os << " #endif\n"; + os << " if (F.empty() && blasMetaData) {\n"; + os << " attributeBLAS(*blasMetaData, &F);\n"; + os << " changed = true;\n"; + os << " }\n"; { const auto &patterns = RK.getAllDerivedDefinitions("CallPattern"); for (Record *pattern : patterns) { From e243d05c8a2f938c89b5e917a2a37f64aa3b8bfe Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Feb 2024 21:28:38 +0000 Subject: [PATCH 047/131] Restore mlir build (#1675) --- enzyme/Enzyme/MLIR/Passes/Passes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 7e64f7718988..25362d01294a 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -13,7 +13,7 @@ #include "mlir/Pass/Pass.h" #include -#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Dialect/Dialect.h" namespace mlir { class PatternRewriter; From ea07b775faeca2d63ab71df74d8269c6ef9f4ef6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 8 Feb 2024 10:11:52 -0500 Subject: [PATCH 048/131] More macos on llvm16 fix (#1676) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 6 +----- enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h | 12 ++---------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index a5db987b5a74..b9e670326fd4 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4937,11 +4937,7 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } if (auto opidx = getAllocationIndexFromCall(&call)) { auto ptr = TypeTree(BaseType::Pointer); -#if LLVM_VERSION_MAJOR >= 15 - unsigned index = (size_t)opidx.value(); -#else - unsigned index = (size_t)opidx.getValue(); -#endif + unsigned index = (size_t)*opidx; if (auto CI = dyn_cast(call.getOperand(index))) { auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); auto LoadSize = CI->getZExtValue(); diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index b3d64f2c7ab9..51e883e3b804 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -182,11 +182,7 @@ void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { emitDiffUse(RK, os, CallDerivatives); os << " auto blasMetaData = extractBLAS(funcName);\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " if (blasMetaData.has_value())\n"; - os << " #else\n"; - os << " if (blasMetaData.hasValue())\n"; - os << " #endif\n"; + os << " if (blasMetaData)\n"; os << " {\n"; os << " auto Mode = gutils->mode;\n"; os << " const bool cacheMode = (Mode != DerivativeMode::ForwardMode);\n"; @@ -198,11 +194,7 @@ void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { os << " assert(found != gutils->overwritten_args_map_ptr->end());\n"; os << " overwritten_args_ptr = &found->second;\n"; os << " }\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " BlasInfo blas = blasMetaData.value();\n"; - os << " #else\n"; - os << " BlasInfo blas = blasMetaData.getValue();\n"; - os << " #endif\n"; + os << " BlasInfo blas = *blasMetaData;\n"; for (auto &&newPattern : newBlasPatterns) { emit_BLASDiffUse(newPattern, os); } From 21053d4122e90cbac3a121286302864627050067 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 8 Feb 2024 10:40:35 -0500 Subject: [PATCH 049/131] bcload 15 or 16 fix (#1677) --- enzyme/BCLoad/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/BCLoad/CMakeLists.txt b/enzyme/BCLoad/CMakeLists.txt index 6516a46dfbd5..f88f515ec77f 100644 --- a/enzyme/BCLoad/CMakeLists.txt +++ b/enzyme/BCLoad/CMakeLists.txt @@ -4,7 +4,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(BC_LOAD_FLAGS "" CACHE STRING "") set(BC_LOAD_HEADER "" CACHE STRING "") -if (${LLVM_VERSION_MAJOR} EQUAL 16) +if (${LLVM_VERSION_MAJOR} EQUAL 15 OR ${LLVM_VERSION_MAJOR} EQUAL 16) set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS} -Xclang -no-opaque-pointers") else() set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") From bd36baefd93d1a42f31f0ae0f8d4df1165188289 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Fri, 9 Feb 2024 08:25:39 +0900 Subject: [PATCH 050/131] [Truncate] Handle libm calls (#1636) * Handle instrinsics * Add integration test * Add common handler for math intrinsics * Format EnzymeLogic.cpp --- enzyme/Enzyme/EnzymeLogic.cpp | 54 +++++---- enzyme/test/Integration/CMakeLists.txt | 1 + .../test/Integration/Truncate/CMakeLists.txt | 9 ++ enzyme/test/Integration/Truncate/simple.cpp | 111 ++++++++++++++++++ 4 files changed, 152 insertions(+), 23 deletions(-) create mode 100644 enzyme/test/Integration/Truncate/CMakeLists.txt create mode 100644 enzyme/test/Integration/Truncate/simple.cpp diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 580b59f3c557..dc5e511b73fa 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -29,6 +29,7 @@ //===----------------------------------------------------------------------===// #include "ActivityAnalysis.h" #include "AdjointGenerator.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/Intrinsics.h" #if LLVM_VERSION_MAJOR >= 16 @@ -5075,18 +5076,17 @@ class TruncateGenerator : public llvm::InstVisitor { return; } void visitFenceInst(llvm::FenceInst &FI) { return; } - void visitIntrinsicInst(llvm::IntrinsicInst &II) { - SmallVector orig_ops(II.arg_size()); - for (unsigned i = 0; i < II.arg_size(); ++i) - orig_ops[i] = II.getOperand(i); - if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops)) - return; + + bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) { + SmallVector orig_ops(CI.arg_size()); + for (unsigned i = 0; i < CI.arg_size(); ++i) + orig_ops[i] = CI.getOperand(i); bool hasFromType = false; - auto newI = cast(getNewFromOriginal(&II)); + auto newI = cast(getNewFromOriginal(&CI)); IRBuilder<> B(newI); - SmallVector new_ops(II.arg_size()); - for (unsigned i = 0; i < II.arg_size(); ++i) { + SmallVector new_ops(CI.arg_size()); + for (unsigned i = 0; i < CI.arg_size(); ++i) { if (orig_ops[i]->getType() == getFromType()) { new_ops[i] = truncate(B, getNewFromOriginal(orig_ops[i])); hasFromType = true; @@ -5094,27 +5094,29 @@ class TruncateGenerator : public llvm::InstVisitor { new_ops[i] = getNewFromOriginal(orig_ops[i]); } } - Type *retTy = II.getType(); - if (II.getType() == getFromType()) { + Type *retTy = CI.getType(); + if (CI.getType() == getFromType()) { hasFromType = true; retTy = getToType(); } if (!hasFromType) - return; + return false; // TODO check that the intrinsic is overloaded CallInst *intr; - Value *nres = intr = createIntrinsicCall(B, II.getIntrinsicID(), retTy, - new_ops, &II, II.getName()); - if (II.getType() == getFromType()) + Value *nres = intr = + createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); + if (CI.getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); newI->replaceAllUsesWith(nres); newI->eraseFromParent(); - - return; + return true; + } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { + handleIntrinsic(II, II.getIntrinsicID()); } void visitReturnInst(llvm::ReturnInst &I) { return; } @@ -5201,18 +5203,24 @@ class TruncateGenerator : public llvm::InstVisitor { return v; } // Return - void visitCallInst(llvm::CallInst &call) { + void visitCallInst(llvm::CallInst &CI) { + Intrinsic::ID ID; + StringRef funcName = getFuncNameFromCall(const_cast(&CI)); + if (isMemFreeLibMFunction(funcName, &ID)) + if (handleIntrinsic(CI, ID)) + return; + using namespace llvm; - CallInst *const newCall = cast(getNewFromOriginal(&call)); + CallInst *const newCall = cast(getNewFromOriginal(&CI)); IRBuilder<> BuilderZ(newCall); - if (auto called = call.getCalledFunction()) - if (handleKnownCalls(call, called, getFuncNameFromCall(&call), newCall)) + if (auto called = CI.getCalledFunction()) + if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall)) return; - RequestContext ctx(&call, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(call.getCalledOperand())); + RequestContext ctx(&CI, &BuilderZ); + auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); newCall->setCalledOperand(val); return; } diff --git a/enzyme/test/Integration/CMakeLists.txt b/enzyme/test/Integration/CMakeLists.txt index 7a14214b46cd..98171f188d86 100644 --- a/enzyme/test/Integration/CMakeLists.txt +++ b/enzyme/test/Integration/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(ForwardModeVector) add_subdirectory(ReverseMode) add_subdirectory(BatchMode) add_subdirectory(Sparse) +add_subdirectory(Truncate) # Run regression and unit tests add_lit_testsuite(check-enzyme-integration "Running enzyme integration tests" diff --git a/enzyme/test/Integration/Truncate/CMakeLists.txt b/enzyme/test/Integration/Truncate/CMakeLists.txt new file mode 100644 index 000000000000..bfd5e99064ac --- /dev/null +++ b/enzyme/test/Integration/Truncate/CMakeLists.txt @@ -0,0 +1,9 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme batch mode integration tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-integration-truncate PROPERTIES FOLDER "Tests") + diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp new file mode 100644 index 000000000000..749d8fded7c1 --- /dev/null +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -0,0 +1,111 @@ +// COM: %clang -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// COM: %clang -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// COM: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - + +#include + +#include "../test_utils.h" + +#define N 10 + +double simple_add(double a, double b) { + return a + b; +} +double intrinsics(double a, double b) { + return sqrt(a) * pow(b, 2); +} +// TODO +double constt(double a, double b) { + return 2; +} +double compute(double *A, double *B, double *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] * 2; + } + return C[0]; +} + +typedef double (*fty)(double *, double *, double *, int); + +typedef double (*fty2)(double, double); + +extern fty __enzyme_truncate_func_2(...); +extern fty2 __enzyme_truncate_func(...); +extern double __enzyme_truncate_value(...); +extern double __enzyme_expand_value(...); + +#define FROM 64 +#define TO 32 + +#define TEST(F) do { + + +int main() { + + { + double a = 1; + APPROX_EQ( + __enzyme_expand_value( + __enzyme_truncate_value(a, FROM, TO) , FROM, TO), + a, 1e-10); + } + + { + double a = 2; + double b = 3; + double truth = simple_add(a, b); + a = __enzyme_truncate_value(a, FROM, TO); + b = __enzyme_truncate_value(b, FROM, TO); + double trunc = __enzyme_expand_value(__enzyme_truncate_func(simple_add, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + { + double a = 2; + double b = 3; + double truth = intrinsics(a, b); + a = __enzyme_truncate_value(a, FROM, TO); + b = __enzyme_truncate_value(b, FROM, TO); + double trunc = __enzyme_expand_value(__enzyme_truncate_func(intrinsics, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + // { + // double a = 2; + // double b = 3; + // double truth = intrinsics(a, b); + // a = __enzyme_truncate_value(a, FROM, TO); + // b = __enzyme_truncate_value(b, FROM, TO); + // double trunc = __enzyme_expand_value(__enzyme_truncate_func(constt, FROM, TO)(a, b), FROM, TO); + // APPROX_EQ(trunc, truth, 1e-5); + // } + + // double A[N]; + // double B[N]; + // double C[N]; + // double D[N]; + + + // for (int i = 0; i < N; i++) { + // A[i] = 1 + i % 5; + // B[i] = 1 + i % 3; + // } + + // compute(A, B, D, N); + + // for (int i = 0; i < N; i++) { + // A[i] = __enzyme_truncate_value(A[i], 64, 32); + // B[i] = __enzyme_truncate_value(B[i], 64, 32); + // } + + // __enzyme_truncate_func_2(compute, 64, 32)(A, B, C, N); + + // for (int i = 0; i < N; i++) { + // C[i] = __enzyme_expand_value(C[i], 64, 32); + // } + + // for (int i = 0; i < N; i++) { + // printf("%d\n", i); + // APPROX_EQ(D[i], C[i], 1e-5); + // } + +} From 534f79fa18134e5f1882de4fe4e1782b17952250 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Fri, 9 Feb 2024 15:46:32 +0900 Subject: [PATCH 051/131] Add a float op only truncation mode (#1651) * add TODO FP trunc ext conversion insts * WIP value trunc mode * WIP * Fixes * Remove alloc from op trunc * Support vector types * test * Add constructor for FloatRepr * Update CMakeLists.txt * Update simple.cpp --- enzyme/Enzyme/Enzyme.cpp | 56 ++++-- enzyme/Enzyme/EnzymeLogic.cpp | 184 +++++++++++------- enzyme/Enzyme/EnzymeLogic.h | 76 +++++++- enzyme/test/Enzyme/Truncate/cmp.ll | 23 ++- enzyme/test/Enzyme/Truncate/intrinsic.ll | 29 ++- enzyme/test/Enzyme/Truncate/select.ll | 22 ++- enzyme/test/Enzyme/Truncate/simple.ll | 31 ++- enzyme/test/Enzyme/Truncate/value.ll | 8 +- .../test/Integration/Truncate/CMakeLists.txt | 3 +- enzyme/test/Integration/Truncate/simple.cpp | 90 +++++---- 10 files changed, 371 insertions(+), 151 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index c9b73e598222..99c429efa08e 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1314,7 +1314,20 @@ class EnzymeBase { return type_args; } - bool HandleTruncateFunc(CallInst *CI) { + static FloatRepresentation getDefaultFloatRepr(unsigned width) { + switch (width) { + case 16: + return FloatRepresentation(5, 10); + case 32: + return FloatRepresentation(8, 23); + case 64: + return FloatRepresentation(11, 52); + default: + llvm_unreachable("Invalid float width"); + } + }; + + bool HandleTruncateFunc(CallInst *CI, TruncateMode mode) { IRBuilder<> Builder(CI); Function *F = parseFunctionParameter(CI); if (!F) @@ -1331,8 +1344,9 @@ class EnzymeBase { assert(Cto); RequestContext context(CI, &Builder); llvm::Value *res = Logic.CreateTruncateFunc( - context, F, (unsigned)Cfrom->getValue().getZExtValue(), - (unsigned)Cto->getValue().getZExtValue()); + context, F, + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), mode); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); @@ -1356,8 +1370,10 @@ class EnzymeBase { auto Addr = CI->getArgOperand(0); RequestContext context(CI, &Builder); bool res = Logic.CreateTruncateValue( - context, Addr, (unsigned)Cfrom->getValue().getZExtValue(), - (unsigned)Cto->getValue().getZExtValue(), isTruncate); + context, Addr, + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), + isTruncate); if (!res) return false; return true; @@ -2096,7 +2112,8 @@ class EnzymeBase { MapVector toVirtual; MapVector toSize; SmallVector toBatch; - SmallVector toTruncateFunc; + SmallVector toTruncateFuncMem; + SmallVector toTruncateFuncOp; SmallVector toTruncateValue; SmallVector toExpandValue; MapVector toProbProg; @@ -2408,7 +2425,8 @@ class EnzymeBase { bool virtualCall = false; bool sizeOnly = false; bool batch = false; - bool truncateFunc = false; + bool truncateFuncOp = false; + bool truncateFuncMem = false; bool truncateValue = false; bool expandValue = false; bool probProg = false; @@ -2440,13 +2458,16 @@ class EnzymeBase { } else if (Fn->getName().contains("__enzyme_batch")) { enableEnzyme = true; batch = true; - } else if (Fn->getName().contains("__enzyme_truncate_func")) { + } else if (Fn->getName().contains("__enzyme_truncate_mem_func")) { enableEnzyme = true; - truncateFunc = true; - } else if (Fn->getName().contains("__enzyme_truncate_value")) { + truncateFuncMem = true; + } else if (Fn->getName().contains("__enzyme_truncate_op_func")) { + enableEnzyme = true; + truncateFuncOp = true; + } else if (Fn->getName().contains("__enzyme_truncate_mem_value")) { enableEnzyme = true; truncateValue = true; - } else if (Fn->getName().contains("__enzyme_expand_value")) { + } else if (Fn->getName().contains("__enzyme_expand_mem_value")) { enableEnzyme = true; expandValue = true; } else if (Fn->getName().contains("__enzyme_likelihood")) { @@ -2506,8 +2527,10 @@ class EnzymeBase { toSize[CI] = derivativeMode; else if (batch) toBatch.push_back(CI); - else if (truncateFunc) - toTruncateFunc.push_back(CI); + else if (truncateFuncOp) + toTruncateFuncOp.push_back(CI); + else if (truncateFuncMem) + toTruncateFuncMem.push_back(CI); else if (truncateValue) toTruncateValue.push_back(CI); else if (expandValue) @@ -2605,8 +2628,11 @@ class EnzymeBase { for (auto call : toBatch) { HandleBatch(call); } - for (auto call : toTruncateFunc) { - HandleTruncateFunc(call); + for (auto call : toTruncateFuncMem) { + HandleTruncateFunc(call, TruncMem); + } + for (auto call : toTruncateFuncOp) { + HandleTruncateFunc(call, TruncOp); } for (auto call : toTruncateValue) { HandleTruncateValue(call, true); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index dc5e511b73fa..468c4e5ccba2 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -27,8 +27,10 @@ // primal pass. // //===----------------------------------------------------------------------===// +#include "EnzymeLogic.h" #include "ActivityAnalysis.h" #include "AdjointGenerator.h" +#include "EnzymeLogic.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Intrinsics.h" @@ -4809,23 +4811,31 @@ Function *EnzymeLogic::CreateForwardDiff( return nf; } -static Type *getTypeForWidth(LLVMContext &ctx, unsigned width) { - switch (width) { - default: - return llvm::Type::getIntNTy(ctx, width); - case 64: - return llvm::Type::getDoubleTy(ctx); - case 32: - return llvm::Type::getFloatTy(ctx); - case 16: - return llvm::Type::getHalfTy(ctx); - } +static Value *floatValTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, + FloatRepresentation to) { + Type *toTy = to.getType(B.getContext()); + if (auto vty = dyn_cast(v->getType())) + toTy = VectorType::get(toTy, vty->getElementCount()); + return B.CreateFPTrunc(v, toTy, "enzyme_trunc"); +} + +static Value *floatValExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, FloatRepresentation to) { + Type *fromTy = from.getBuiltinType(B.getContext()); + if (auto vty = dyn_cast(v->getType())) + fromTy = VectorType::get(fromTy, vty->getElementCount()); + return B.CreateFPExt(v, fromTy, "enzyme_exp"); } -static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - unsigned fromwidth, unsigned towidth) { - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); - Type *toTy = getTypeForWidth(B.getContext(), towidth); +static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, + FloatRepresentation to) { + if (isa(v->getType())) + report_fatal_error("vector operations not allowed in mem trunc mode"); + + Type *fromTy = from.getBuiltinType(B.getContext()); + Type *toTy = to.getType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); B.CreateStore( @@ -4834,13 +4844,16 @@ static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, toTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(toTy))); } -static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - unsigned fromwidth, unsigned towidth) { - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); +static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, FloatRepresentation to) { + if (isa(v->getType())) + report_fatal_error("vector operations not allowed in mem trunc mode"); + + Type *fromTy = from.getBuiltinType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); - auto c0 = - Constant::getNullValue(llvm::Type::getIntNTy(B.getContext(), fromwidth)); + auto c0 = Constant::getNullValue( + llvm::Type::getIntNTy(B.getContext(), from.getTypeWidth())); B.CreateStore( c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); B.CreateStore( @@ -4852,27 +4865,38 @@ static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, class TruncateGenerator : public llvm::InstVisitor { private: ValueToValueMapTy &originalToNewFn; - unsigned fromwidth; - unsigned towidth; + FloatRepresentation from; + FloatRepresentation to; Type *fromType; Type *toType; Function *oldFunc; Function *newFunc; AllocaInst *tmpBlock; + TruncateMode mode; EnzymeLogic &Logic; public: - TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, - unsigned towidth, Function *oldFunc, Function *newFunc, + TruncateGenerator(ValueToValueMapTy &originalToNewFn, + FloatRepresentation from, FloatRepresentation to, + Function *oldFunc, Function *newFunc, TruncateMode mode, EnzymeLogic &Logic) - : originalToNewFn(originalToNewFn), fromwidth(fromwidth), - towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { + : originalToNewFn(originalToNewFn), from(from), to(to), oldFunc(oldFunc), + newFunc(newFunc), mode(mode), Logic(Logic) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - fromType = getTypeForWidth(B.getContext(), fromwidth); - toType = getTypeForWidth(B.getContext(), towidth); + fromType = from.getBuiltinType(B.getContext()); + toType = to.getType(B.getContext()); - tmpBlock = B.CreateAlloca(fromType); + if (mode == TruncMem) + tmpBlock = B.CreateAlloca(fromType); + else + tmpBlock = nullptr; + } + + void checkHandled(llvm::Instruction &inst) { + // if (all_of(inst.getOperandList(), + // [&](Use *use) { return use->get()->getType() == fromType; })) + // todo(inst); } void visitInstruction(llvm::Instruction &inst) { @@ -4887,7 +4911,7 @@ class TruncateGenerator : public llvm::InstVisitor { break; } - todo(inst); + checkHandled(inst); } Type *getFromType() { return fromType; } @@ -4895,11 +4919,25 @@ class TruncateGenerator : public llvm::InstVisitor { Type *getToType() { return toType; } Value *truncate(IRBuilder<> &B, Value *v) { - return floatTruncate(B, v, tmpBlock, fromwidth, towidth); + switch (mode) { + case TruncMem: + return floatMemTruncate(B, v, tmpBlock, from, to); + case TruncOp: + return floatValTruncate(B, v, tmpBlock, from, to); + default: + llvm_unreachable("Unknown trunc mode"); + } } Value *expand(IRBuilder<> &B, Value *v) { - return floatExpand(B, v, tmpBlock, fromwidth, towidth); + switch (mode) { + case TruncMem: + return floatMemExpand(B, v, tmpBlock, from, to); + case TruncOp: + return floatValExpand(B, v, tmpBlock, from, to); + default: + llvm_unreachable("Unknown trunc mode"); + } } void todo(llvm::Instruction &I) { @@ -4967,17 +5005,25 @@ class TruncateGenerator : public llvm::InstVisitor { return; } void visitSelectInst(llvm::SelectInst &SI) { - auto newI = getNewFromOriginal(&SI); - IRBuilder<> B(newI); - auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); - auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); - auto nres = cast( - B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - return; + switch (mode) { + case TruncMem: { + auto newI = getNewFromOriginal(&SI); + IRBuilder<> B(newI); + auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); + auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); + auto nres = cast( + B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + return; + } + case TruncOp: + return; + default: + llvm_unreachable(""); + } } void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } @@ -5005,7 +5051,7 @@ class TruncateGenerator : public llvm::InstVisitor { return; } - if (towidth == 32 || towidth == 16 || towidth == 64) { + if (to.getBuiltinType(BO.getContext())) { auto newI = getNewFromOriginal(&BO); IRBuilder<> B(newI); auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); @@ -5197,7 +5243,7 @@ class TruncateGenerator : public llvm::InstVisitor { Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) - return Logic.CreateTruncateFunc(ctx, F, fromwidth, towidth); + return Logic.CreateTruncateFunc(ctx, F, from, to, mode); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; @@ -5224,19 +5270,26 @@ class TruncateGenerator : public llvm::InstVisitor { newCall->setCalledOperand(val); return; } + void visitFPTruncInst(FPTruncInst &I) { return; } + void visitFPExtInst(FPExtInst &I) { return; } + void visitFPToUIInst(FPToUIInst &I) { return; } + void visitFPToSIInst(FPToSIInst &I) { return; } + void visitUIToFPInst(UIToFPInst &I) { return; } + void visitSIToFPInst(SIToFPInst &I) { return; } }; bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, - unsigned fromwidth, unsigned towidth, - bool isTruncate) { + FloatRepresentation from, + FloatRepresentation to, bool isTruncate) { assert(context.req && context.ip); - if (fromwidth == towidth) { + if (from == to) { + context.req->replaceAllUsesWith(context.req->getOperand(0)); context.req->eraseFromParent(); return true; } - if (fromwidth < towidth) { + if (from < to) { std::string s; llvm::raw_string_ostream ss(s); ss << "Cannot truncate into a large width\n"; @@ -5250,16 +5303,15 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, } IRBuilderBase &B = *context.ip; - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); - Type *toTy = getTypeForWidth(B.getContext(), towidth); + Type *fromTy = from.getBuiltinType(B.getContext()); + Type *toTy = to.getType(B.getContext()); Value *converted = nullptr; if (isTruncate) - converted = - floatExpand(B, B.CreateFPTrunc(v, toTy), nullptr, fromwidth, towidth); + converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, from, to); else converted = - B.CreateFPExt(floatTruncate(B, v, nullptr, fromwidth, towidth), fromTy); + B.CreateFPExt(floatMemTruncate(B, v, nullptr, from, to), fromTy); assert(converted); context.req->replaceAllUsesWith(converted); @@ -5270,12 +5322,13 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm::Function *totrunc, - unsigned fromwidth, - unsigned towidth) { - if (fromwidth == towidth) + FloatRepresentation from, + FloatRepresentation to, + TruncateMode mode) { + if (from == to) return totrunc; - TruncateCacheKey tup(totrunc, fromwidth, towidth); + TruncateCacheKey tup(totrunc, from, to, mode); if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { return TruncateCachedFunctions.find(tup)->second; } @@ -5290,11 +5343,12 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, Type *NewTy = totrunc->getReturnType(); FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); - Function *NewF = - Function::Create(FTy, totrunc->getLinkage(), - "trunc_" + std::to_string(fromwidth) + "_" + - std::to_string(towidth) + totrunc->getName(), - totrunc->getParent()); + std::string truncName = std::string("__enzyme_done_truncate_") + + (mode == TruncMem ? "mem" : "op") + "_func_" + + from.to_string() + "_" + to.to_string() + "_" + + totrunc->getName().str(); + Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, + totrunc->getParent()); NewF->setLinkage(Function::LinkageTypes::InternalLinkage); @@ -5327,7 +5381,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } - if (fromwidth < towidth) { + if (from < to) { std::string s; llvm::raw_string_ostream ss(s); ss << "Cannot truncate into a large width\n"; @@ -5375,7 +5429,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - TruncateGenerator handle(originalToNewFn, fromwidth, towidth, totrunc, NewF, + TruncateGenerator handle(originalToNewFn, from, to, totrunc, NewF, mode, *this); for (auto &BB : *totrunc) for (auto &I : BB) diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 2f1ac9fde496..6f311fcc85da 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -253,6 +253,74 @@ struct RequestContext { : req(req), ip(ip) {} }; +[[maybe_unused]] static llvm::Type * +getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { + switch (width) { + default: + if (builtinFloat) + llvm::report_fatal_error("Invalid float width requested"); + else + llvm::report_fatal_error( + "Truncation to non builtin float width unsupported"); + case 64: + return llvm::Type::getDoubleTy(ctx); + case 32: + return llvm::Type::getFloatTy(ctx); + case 16: + return llvm::Type::getHalfTy(ctx); + } +} + +enum TruncateMode { TruncMem, TruncOp }; + +struct FloatRepresentation { + // |_|__________|_________________| + // ^ ^ ^ + // sign bit exponent significand + // + // value = (sign) * significand * 2 ^ exponent + unsigned exponentWidth; + unsigned significandWidth; + + FloatRepresentation(unsigned e, unsigned s) + : exponentWidth(e), significandWidth(s) {} + + unsigned getTypeWidth() const { return 1 + exponentWidth + significandWidth; } + + bool canBeBuiltin() const { + unsigned w = getTypeWidth(); + return (w == 16 && significandWidth == 10) || + (w == 32 && significandWidth == 23) || + (w == 64 && significandWidth == 52); + } + + llvm::Type *getBuiltinType(llvm::LLVMContext &ctx) const { + if (!canBeBuiltin()) + return nullptr; + return getTypeForWidth(ctx, getTypeWidth(), /*builtinFloat=*/true); + } + + llvm::Type *getType(llvm::LLVMContext &ctx) const { + llvm::Type *builtinType = getBuiltinType(ctx); + if (builtinType) + return builtinType; + llvm_unreachable("TODO MPFR"); + } + + bool operator==(const FloatRepresentation &other) const { + return other.exponentWidth == exponentWidth && + other.significandWidth == significandWidth; + } + bool operator<(const FloatRepresentation &other) const { + return std::tuple(exponentWidth, significandWidth) < + std::tuple(other.exponentWidth, other.significandWidth); + } + std::string to_string() const { + return std::to_string(getTypeWidth()) + "_" + + std::to_string(significandWidth); + } +}; + class EnzymeLogic { public: PreProcessCache PPC; @@ -511,13 +579,15 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); - using TruncateCacheKey = std::tuple; + using TruncateCacheKey = std::tuple; std::map TruncateCachedFunctions; llvm::Function *CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, - unsigned fromwidth, unsigned towidth); + FloatRepresentation from, + FloatRepresentation to, TruncateMode mode); bool CreateTruncateValue(RequestContext context, llvm::Value *addr, - unsigned fromwidth, unsigned towidth, + FloatRepresentation from, FloatRepresentation to, bool isTruncate); /// Create a traced version of a function diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index e9c61ebd773e..c96efa70660a 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -6,22 +6,28 @@ define i1 @f(double %x, double %y) { ret i1 %res } -declare i1 (double, double)* @__enzyme_truncate_func(...) +declare i1 (double, double)* @__enzyme_truncate_mem_func(...) +declare i1 (double, double)* @__enzyme_truncate_op_func(...) define i1 @tester(double %x, double %y) { entry: - %ptr = call i1 (double, double)* (...) @__enzyme_truncate_func(i1 (double, double)* @f, i64 64, i64 32) + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_mem_func(i1 (double, double)* @f, i64 64, i64 32) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} +define i1 @tester_op(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_op_func(i1 (double, double)* @f, i64 64, i64 32) %res = call i1 %ptr(double %x, double %y) ret i1 %res } ; CHECK: define i1 @tester(double %x, double %y) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call i1 @trunc_64_32f(double %x, double %y) +; CHECK-NEXT: %res = call i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) ; CHECK-NEXT: ret i1 %res -; CHECK-NEXT: } -; CHECK: define internal i1 @trunc_64_32f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -31,4 +37,9 @@ entry: ; CHECK-DAG: %5 = load float, float* %4, align 4 ; CHECK-DAG: %res = fcmp olt float %3, %5 ; CHECK-DAG: ret i1 %res -; CHECK-NEXT:} + +; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %res = fcmp olt float %enzyme_trunc, %enzyme_trunc1 +; CHECK-DAG: ret i1 %res diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index da4457492ce2..99568539c3f3 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -13,16 +13,23 @@ define double @f(double %x, double %y) { ret double %res } -declare double (double, double)* @__enzyme_truncate_func(...) +declare double (double, double)* @__enzyme_truncate_mem_func(...) +declare double (double, double)* @__enzyme_truncate_op_func(...) define double @tester(double %x, double %y) { entry: - %ptr = call double (double, double)* (...) @__enzyme_truncate_func(double (double, double)* @f, i64 64, i64 32) + %ptr = call double (double, double)* (...) @__enzyme_truncate_mem_func(double (double, double)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y) + ret double %res +} +define double @tester2(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y) ret double %res } -; CHECK: define internal double @trunc_64_32f(double %x, double %y) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { ; CHECK-NEXT: %1 = alloca double, align 8 ; CHECK-NEXT: store double %x, double* %1, align 8 ; CHECK-NEXT: %2 = bitcast double* %1 to float* @@ -59,4 +66,18 @@ entry: ; CHECK-NEXT: %20 = load double, double* %1, align 8 ; CHECK-NEXT: call void @llvm.nvvm.barrier0() ; CHECK-NEXT: ret double %20 -; CHECK-NEXT: } + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) +; CHECK-DAG: %enzyme_exp = fpext float %res12 to double +; CHECK-DAG: %enzyme_trunc3 = fptrunc double %x to float +; CHECK-DAG: %res24 = call float @llvm.powi.f32.i16(float %enzyme_trunc3, i16 2) +; CHECK-DAG: %enzyme_exp5 = fpext float %res24 to double +; CHECK-DAG: %enzyme_trunc6 = fptrunc double %enzyme_exp to float +; CHECK-DAG: %enzyme_trunc7 = fptrunc double %enzyme_exp5 to float +; CHECK-DAG: %res = fadd float %enzyme_trunc6, %enzyme_trunc7 +; CHECK-DAG: %enzyme_exp8 = fpext float %res to double +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %enzyme_exp8 diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index 58b4a58ef91b..365d21ab5913 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -6,22 +6,29 @@ define double @f(double %x, double %y, i1 %cond) { ret double %res } -declare double (double, double, i1)* @__enzyme_truncate_func(...) +declare double (double, double, i1)* @__enzyme_truncate_mem_func(...) +declare double (double, double, i1)* @__enzyme_truncate_op_func(...) define double @tester(double %x, double %y, i1 %cond) { entry: - %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_func(double (double, double, i1)* @f, i64 64, i64 32) + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_mem_func(double (double, double, i1)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y, i1 %cond) + ret double %res +} + +define double @tester2(double %x, double %y, i1 %cond) { +entry: + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_op_func(double (double, double, i1)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y, i1 %cond) ret double %res } ; CHECK: define double @tester(double %x, double %y, i1 %cond) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call double @trunc_64_32f(double %x, double %y, i1 %cond) +; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) ; CHECK-NEXT: ret double %res -; CHECK-NEXT: } -; CHECK: define internal double @trunc_64_32f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -36,4 +43,7 @@ entry: ; CHECK-DAG: store float %res, float* %7, align 4 ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: ret double %8 -; CHECK-NEXT: } + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK-DAG: %res = select i1 %cond, double %x, double %y +; CHECK-DAG: ret double %res diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 0f346a26f0d2..19d6cf1f3a23 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -8,22 +8,34 @@ define void @f(double* %x) { ret void } -declare void (double*)* @__enzyme_truncate_func(...) +declare void (double*)* @__enzyme_truncate_mem_func(...) +declare void (double*)* @__enzyme_truncate_op_func(...) define void @tester(double* %data) { entry: - %ptr = call void (double*)* (...) @__enzyme_truncate_func(void (double*)* @f, i64 64, i64 32) + %ptr = call void (double*)* (...) @__enzyme_truncate_mem_func(void (double*)* @f, i64 64, i64 32) + call void %ptr(double* %data) + ret void +} + +define void @tester2(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 32) call void %ptr(double* %data) ret void } ; CHECK: define void @tester(double* %data) ; CHECK-NEXT: entry: -; CHECK-NEXT: call void @trunc_64_32f(double* %data) +; CHECK-NEXT: call void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %data) ; CHECK-NEXT: ret void -; CHECK-NEXT: } -; CHECK: define internal void @trunc_64_32f(double* %x) +; CHECK: define void @tester2(double* %data) { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %data) +; CHECK-NEXT: ret void + +; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %x) ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: store double %y, double* %1, align 8 @@ -40,3 +52,12 @@ entry: ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: store double %8, double* %x, align 8 ; CHECK-DAG: ret void + +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %x) { +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: %enzyme_trunc = fptrunc double %y to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %m = fmul float %enzyme_trunc, %enzyme_trunc1 +; CHECK-DAG: %enzyme_exp = fpext float %m to double +; CHECK-DAG: store double %enzyme_exp, double* %x, align 8 +; CHECK-DAG: ret void diff --git a/enzyme/test/Enzyme/Truncate/value.ll b/enzyme/test/Enzyme/Truncate/value.ll index 51f00401078d..9f87d00d2173 100644 --- a/enzyme/test/Enzyme/Truncate/value.ll +++ b/enzyme/test/Enzyme/Truncate/value.ll @@ -1,18 +1,18 @@ ; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi ; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi -declare double @__enzyme_truncate_value(double, i64, i64) -declare double @__enzyme_expand_value(double, i64, i64) +declare double @__enzyme_truncate_mem_value(double, i64, i64) +declare double @__enzyme_expand_mem_value(double, i64, i64) define double @expand_tester(double %a, double * %c) { entry: - %b = call double @__enzyme_expand_value(double %a, i64 64, i64 32) + %b = call double @__enzyme_expand_mem_value(double %a, i64 64, i64 32) ret double %b } define double @truncate_tester(double %a) { entry: - %b = call double @__enzyme_truncate_value(double %a, i64 64, i64 32) + %b = call double @__enzyme_truncate_mem_value(double %a, i64 64, i64 32) ret double %b } diff --git a/enzyme/test/Integration/Truncate/CMakeLists.txt b/enzyme/test/Integration/Truncate/CMakeLists.txt index bfd5e99064ac..65187869f1ae 100644 --- a/enzyme/test/Integration/Truncate/CMakeLists.txt +++ b/enzyme/test/Integration/Truncate/CMakeLists.txt @@ -1,5 +1,4 @@ -# Run regression and unit tests -add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme batch mode integration tests" +add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme fp truncation integration tests" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${ENZYME_TEST_DEPS} ARGS -v diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 749d8fded7c1..792366d687e0 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -1,7 +1,7 @@ -// COM: %clang -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// COM: %clang -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// COM: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_OP -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_OP -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - #include @@ -21,7 +21,8 @@ double constt(double a, double b) { } double compute(double *A, double *B, double *C, int n) { for (int i = 0; i < n; i++) { - C[i] = A[i] * 2; + C[i] = A[i] * 2; + // C[i] A=[i] * 2 + B[i] * sqrt(A[i]) ; } return C[0]; } @@ -30,10 +31,12 @@ typedef double (*fty)(double *, double *, double *, int); typedef double (*fty2)(double, double); -extern fty __enzyme_truncate_func_2(...); -extern fty2 __enzyme_truncate_func(...); -extern double __enzyme_truncate_value(...); -extern double __enzyme_expand_value(...); +extern fty __enzyme_truncate_mem_func_2(...); +extern fty2 __enzyme_truncate_mem_func(...); +extern fty __enzyme_truncate_op_func_2(...); +extern fty2 __enzyme_truncate_op_func(...); +extern double __enzyme_truncate_mem_value(...); +extern double __enzyme_expand_mem_value(...); #define FROM 64 #define TO 32 @@ -43,11 +46,12 @@ extern double __enzyme_expand_value(...); int main() { + #ifdef TRUNC_MEM { double a = 1; APPROX_EQ( - __enzyme_expand_value( - __enzyme_truncate_value(a, FROM, TO) , FROM, TO), + __enzyme_expand_mem_value( + __enzyme_truncate_mem_value(a, FROM, TO) , FROM, TO), a, 1e-10); } @@ -55,57 +59,61 @@ int main() { double a = 2; double b = 3; double truth = simple_add(a, b); - a = __enzyme_truncate_value(a, FROM, TO); - b = __enzyme_truncate_value(b, FROM, TO); - double trunc = __enzyme_expand_value(__enzyme_truncate_func(simple_add, FROM, TO)(a, b), FROM, TO); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(simple_add, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } { double a = 2; double b = 3; double truth = intrinsics(a, b); - a = __enzyme_truncate_value(a, FROM, TO); - b = __enzyme_truncate_value(b, FROM, TO); - double trunc = __enzyme_expand_value(__enzyme_truncate_func(intrinsics, FROM, TO)(a, b), FROM, TO); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } + #endif // { // double a = 2; // double b = 3; // double truth = intrinsics(a, b); - // a = __enzyme_truncate_value(a, FROM, TO); - // b = __enzyme_truncate_value(b, FROM, TO); - // double trunc = __enzyme_expand_value(__enzyme_truncate_func(constt, FROM, TO)(a, b), FROM, TO); + // a = __enzyme_truncate_mem_value(a, FROM, TO); + // b = __enzyme_truncate_mem_value(b, FROM, TO); + // double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); // APPROX_EQ(trunc, truth, 1e-5); // } - // double A[N]; - // double B[N]; - // double C[N]; - // double D[N]; + #ifdef TRUNC_OP + { + double A[N]; + double B[N]; + double C[N]; + double D[N]; - // for (int i = 0; i < N; i++) { - // A[i] = 1 + i % 5; - // B[i] = 1 + i % 3; - // } + for (int i = 0; i < N; i++) { + A[i] = 1 + i % 5; + B[i] = 1 + i % 3; + } - // compute(A, B, D, N); + compute(A, B, D, N); - // for (int i = 0; i < N; i++) { - // A[i] = __enzyme_truncate_value(A[i], 64, 32); - // B[i] = __enzyme_truncate_value(B[i], 64, 32); - // } + // for (int i = 0; i < N; i++) { + // A[i] = __enzyme_truncate_mem_value(A[i], 64, 32); + // B[i] = __enzyme_truncate_mem_value(B[i], 64, 32); + // } - // __enzyme_truncate_func_2(compute, 64, 32)(A, B, C, N); + __enzyme_truncate_op_func_2(compute, 64, 32)(A, B, C, N); - // for (int i = 0; i < N; i++) { - // C[i] = __enzyme_expand_value(C[i], 64, 32); - // } + // for (int i = 0; i < N; i++) { + // C[i] = __enzyme_expand_mem_value(C[i], 64, 32); + // } - // for (int i = 0; i < N; i++) { - // printf("%d\n", i); - // APPROX_EQ(D[i], C[i], 1e-5); - // } + for (int i = 0; i < N; i++) { + APPROX_EQ(D[i], C[i], 1e-5); + } + } + #endif } From 60fd0a1635ffde1a962d266a6d77e4fa50884754 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 09:29:57 -0500 Subject: [PATCH 052/131] MDTuple better error handler (#1684) --- enzyme/Enzyme/GradientUtils.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 0b6947887f6f..ec15cd2e76ff 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5435,10 +5435,21 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, if (!isa(md)) { llvm::errs() << *arg << "\n"; llvm::errs() << *md << "\n"; - assert(0 && "cannot compute with global variable that doesn't have " - "marked shadow global"); - report_fatal_error("cannot compute with global variable that doesn't " - "have marked shadow global (metadata incorrect type)"); + std::string s; + llvm::raw_string_ostream ss(s); + ss << "cannot compute with global variable that doesn't have marked " + "shadow global as mdtuple\n"; + ss << *arg << "\n"; + ss << " md: " << *md << "\n"; + if (CustomErrorHandler) { + return unwrap(CustomErrorHandler(ss.str().c_str(), wrap(arg), + ErrorType::NoShadow, this, nullptr, + wrap(&BuilderM))); + } else { + EmitFailure("InvertGlobal", BuilderM.getCurrentDebugLocation(), oldFunc, + ss.str()); + } + return UndefValue::get(getShadowType(arg->getType())); } auto md2 = cast(md); assert(md2->getNumOperands() == 1); From e3c62c80cd4932af73f3e4d876940a44c985da4c Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 9 Feb 2024 21:34:11 +0100 Subject: [PATCH 053/131] Remove unnecessary uses of templating (#1681) * remove unnecessary uses of templating * fix build --- .devcontainer/devcontainer.json | 2 +- enzyme/Enzyme/AdjointGenerator.h | 16 +++++------ enzyme/Enzyme/CallDerivatives.cpp | 25 +++--------------- enzyme/Enzyme/EnzymeLogic.cpp | 38 +++++++++++++-------------- enzyme/Enzyme/TypeAnalysis/BaseType.h | 3 ++- 5 files changed, 33 insertions(+), 51 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 368b7933e347..d201b73d532d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -2,7 +2,7 @@ // available llvm versions: [9, 10, 11, 12, 13, 14, 15] { "name": "Enzyme", - "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-20-llvm-12:latest", + "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-22-llvm-16:latest", "mounts": [ "source=enzyme-bashhistory,target=/commandhistory,type=volume", "source=enzyme-extensions,target=/home/vscode/.vscode-server/extensions,type=volume", diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index debac4dd496f..de9bbf8e38a6 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -47,9 +47,7 @@ #define DEBUG_TYPE "enzyme" // Helper instruction visitor that generates adjoints -template -class AdjointGenerator - : public llvm::InstVisitor> { +class AdjointGenerator : public llvm::InstVisitor { private: // Type of code being generated (forward, reverse, or both) const DerivativeMode Mode; @@ -63,7 +61,7 @@ class AdjointGenerator const std::map> overwritten_args_map; const llvm::SmallPtrSetImpl *returnuses; - AugmentedReturnType augmentedReturn; + const AugmentedReturn *augmentedReturn; const std::map *replacedReturns; const llvm::SmallPtrSetImpl &unnecessaryValues; @@ -83,7 +81,7 @@ class AdjointGenerator const std::map> overwritten_args_map, const llvm::SmallPtrSetImpl *returnuses, - AugmentedReturnType augmentedReturn, + const AugmentedReturn *augmentedReturn, const std::map *replacedReturns, const llvm::SmallPtrSetImpl &unnecessaryValues, const llvm::SmallPtrSetImpl @@ -3040,8 +3038,8 @@ class AdjointGenerator } if (!vd.isKnownPastPointer()) { if (looseTypeAnalysis) { - if (auto CI = dyn_cast(MS.getOperand(0))) { #if LLVM_VERSION_MAJOR < 17 + if (auto CI = dyn_cast(MS.getOperand(0))) { if (auto PT = dyn_cast(CI->getSrcTy())) { auto ET = PT->getPointerElementType(); while (1) { @@ -3070,8 +3068,8 @@ class AdjointGenerator goto known; } } -#endif } +#endif if (auto gep = dyn_cast(MS.getOperand(0))) { if (auto AT = dyn_cast(gep->getSourceElementType())) { if (AT->getElementType()->isIntegerTy()) { @@ -3312,8 +3310,8 @@ class AdjointGenerator if (!vd.isKnownPastPointer()) { if (looseTypeAnalysis) { for (auto val : {orig_dst, orig_src}) { - if (auto CI = dyn_cast(val)) { #if LLVM_VERSION_MAJOR < 17 + if (auto CI = dyn_cast(val)) { if (auto PT = dyn_cast(CI->getSrcTy())) { auto ET = PT->getPointerElementType(); while (1) { @@ -3342,8 +3340,8 @@ class AdjointGenerator goto known; } } -#endif } +#endif if (auto gep = dyn_cast(val)) { if (auto AT = dyn_cast(gep->getSourceElementType())) { if (AT->getElementType()->isIntegerTy()) { diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 02719a9a4660..c8b2cdaa4e4b 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -32,10 +32,8 @@ extern "C" { void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr; } -template -void AdjointGenerator::handleMPI(llvm::CallInst &call, - llvm::Function *called, - llvm::StringRef funcName) { +void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, + llvm::StringRef funcName) { using namespace llvm; assert(called); @@ -2214,8 +2212,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm_unreachable("Unhandled MPI FUNCTION"); } -template -bool AdjointGenerator::handleKnownCallDerivatives( +bool AdjointGenerator::handleKnownCallDerivatives( CallInst &call, Function *called, StringRef funcName, const std::vector &overwritten_args, CallInst *const newCall) { bool subretused = false; @@ -3231,7 +3228,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( llvm_unreachable("Unknown allocation to upgrade"); Size = gutils->getNewFromOriginal(Size); - if (auto CI = dyn_cast(Size)) { + if (isa(Size)) { B.SetInsertPoint(gutils->inversionAllocs); } Type *elTy = Type::getInt8Ty(call.getContext()); @@ -4155,17 +4152,3 @@ bool AdjointGenerator::handleKnownCallDerivatives( return false; } - -template bool AdjointGenerator::handleKnownCallDerivatives( - CallInst &call, Function *called, StringRef funcName, - const std::vector &overwritten_args, CallInst *const newCall); -template bool -AdjointGenerator::handleKnownCallDerivatives( - CallInst &call, Function *called, StringRef funcName, - const std::vector &overwritten_args, CallInst *const newCall); - -template void -AdjointGenerator::handleMPI(CallInst &call, Function *called, - StringRef funcName); -template void AdjointGenerator::handleMPI( - CallInst &call, Function *called, StringRef funcName); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 468c4e5ccba2..4d6b0fef39d9 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2416,12 +2416,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( } } - AdjointGenerator maker( - DerivativeMode::ReverseModePrimal, gutils, constant_args, retType, - getIndex, overwritten_args_map, &returnuses, - &AugmentedCachedFunctions.find(tup)->second, nullptr, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, - nullptr); + AdjointGenerator maker(DerivativeMode::ReverseModePrimal, gutils, + constant_args, retType, getIndex, overwritten_args_map, + &returnuses, + &AugmentedCachedFunctions.find(tup)->second, nullptr, + unnecessaryValues, unnecessaryInstructions, + unnecessaryStores, guaranteedUnreachable, nullptr); for (BasicBlock &oBB : *gutils->oldFunc) { auto term = oBB.getTerminator(); @@ -4174,12 +4174,12 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } } - AdjointGenerator maker( - key.mode, gutils, key.constant_args, key.retType, getIndex, - overwritten_args_map, - /*returnuses*/ nullptr, augmenteddata, &replacedReturns, - unnecessaryValues, unnecessaryInstructions, unnecessaryStores, - guaranteedUnreachable, dretAlloca); + AdjointGenerator maker(key.mode, gutils, key.constant_args, key.retType, + getIndex, overwritten_args_map, + /*returnuses*/ nullptr, augmenteddata, + &replacedReturns, unnecessaryValues, + unnecessaryInstructions, unnecessaryStores, + guaranteedUnreachable, dretAlloca); for (BasicBlock &oBB : *gutils->oldFunc) { // Don't create derivatives for code that results in termination @@ -4643,7 +4643,7 @@ Function *EnzymeLogic::CreateForwardDiff( SmallPtrSet unnecessaryInstructions; SmallPtrSet unnecessaryStores; - AdjointGenerator *maker; + AdjointGenerator *maker; std::unique_ptr> can_modref_map; if (mode == DerivativeMode::ForwardModeSplit) { @@ -4683,7 +4683,7 @@ Function *EnzymeLogic::CreateForwardDiff( calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, unnecessaryInstructions, gutils, TLI); - maker = new AdjointGenerator( + maker = new AdjointGenerator( mode, gutils, constant_args, retType, getIndex, overwritten_args_map, /*returnuses*/ nullptr, augmenteddata, nullptr, unnecessaryValues, unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, @@ -4736,11 +4736,11 @@ Function *EnzymeLogic::CreateForwardDiff( calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, unnecessaryInstructions, gutils, TLI); - maker = new AdjointGenerator( - mode, gutils, constant_args, retType, nullptr, {}, - /*returnuses*/ nullptr, nullptr, nullptr, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, - nullptr); + maker = + new AdjointGenerator(mode, gutils, constant_args, retType, nullptr, {}, + /*returnuses*/ nullptr, nullptr, nullptr, + unnecessaryValues, unnecessaryInstructions, + unnecessaryStores, guaranteedUnreachable, nullptr); } for (BasicBlock &oBB : *gutils->oldFunc) { diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index f73948a703ad..2f5275c4b6f8 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -25,6 +25,7 @@ #ifndef ENZYME_TYPE_ANALYSIS_BASE_TYPE_H #define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1 +#include "llvm/ADT/StringRef.h" #include /// Categories of potential types @@ -59,7 +60,7 @@ static inline std::string to_string(BaseType t) { } /// Convert string to BaseType -template static inline BaseType parseBaseType(T str) { +static inline BaseType parseBaseType(llvm::StringRef str) { if (str == "Integer") return BaseType::Integer; if (str == "Float") From abb4ee5fbf292d2e31fac2bc97bbf77aae3c0ba3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 20:04:37 -0500 Subject: [PATCH 054/131] [MLIR] fix mlir activity arg parsing (#1678) --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 5 ++++- enzyme/test/MLIR/ForwardMode/wrap.mlir | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 235a061c6088..d96c18a68c09 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -64,8 +64,11 @@ struct DifferentiateWrapperPass constants.push_back(DIFFE_TYPE::DUP_NONEED); else if (str == "enzyme_out") constants.push_back(DIFFE_TYPE::OUT_DIFF); - else + else { + llvm::errs() << "unknown argument activity to parse, found: '" << str + << "'\n"; assert(0 && " unknown constant"); + } } DIFFE_TYPE retType = retTy.getValue(); diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir index 7cb0eb82bf4b..7cef99c691f8 100644 --- a/enzyme/test/MLIR/ForwardMode/wrap.mlir +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup, mode=ForwardMode" %s | FileCheck %s +// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup mode=ForwardMode" %s | FileCheck %s module { func.func @square(%x : f64) -> f64{ From 49cc903a7c4d4f3a781a9307db4c63b68b3ddd97 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 9 Feb 2024 20:30:29 -0500 Subject: [PATCH 055/131] add a little bit more docs (#1664) * add a little bit more docs * update * adressing feedback * fmt * fmt --- enzyme/Enzyme/CApi.h | 11 +++++++---- enzyme/Enzyme/EnzymeLogic.h | 34 +++++++++++++++++++--------------- enzyme/Enzyme/Utils.h | 11 +++++++---- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index d8e04724631a..fe3760457926 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -113,11 +113,14 @@ struct CFnTypeInfo { }; typedef enum { - DFT_OUT_DIFF = 0, // add differential to an output struct - DFT_DUP_ARG = 1, // duplicate the argument and store differential inside - DFT_CONSTANT = 2, // no differential + DFT_OUT_DIFF = 0, // add differential to an output struct. Only for scalar + // values in ReverseMode variants. + DFT_DUP_ARG = 1, // duplicate the argument and store differential inside. + // For references, pointers, or integers in ReverseMode + // variants. For all types in ForwardMode variants. + DFT_CONSTANT = 2, // no differential. Usable everywhere. DFT_DUP_NONEED = 3 // duplicate this argument and store differential inside, - // but don't need the forward + // but don't need the forward. Same as DUP_ARG otherwise. } CDIFFE_TYPE; typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 6f311fcc85da..3923f205037e 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -132,6 +132,21 @@ class AugmentedReturn { isComplete(false) {} }; +/// \p todiff is the function to differentiate +/// \p retType is the activity info of the return. +/// Only allowed to be DUP_ARG or CONSTANT. DUP_NONEED is not allowed, +/// set returnValue to false instead. +/// \p constant_args is the activity info of the arguments +/// \p returnValue is whether the primal's return should also be returned. +/// \p dretUsed is whether the shadow return value should also be returned. +/// Only allowed to be true if retType is CDIFFE_TYPE::DUP_ARG. +/// \p additionalArg is the type (or null) of an additional type in the +/// signature to hold the tape. +/// \p typeInfo is the type info information about the calling context +/// \p _overwritten_args marks whether an argument may be overwritten +/// before loads in the generated function (and thus cannot be cached). +/// \p AtomicAdd is whether to perform all adjoint +/// updates to memory in an atomic way struct ReverseCacheKey { llvm::Function *todiff; DIFFE_TYPE retType; @@ -427,9 +442,10 @@ class EnzymeLogic { /// \p returnUsed is whether the primal's return should also be returned /// \p typeInfo is the type info information about the calling context /// \p _overwritten_args marks whether an argument may be rewritten before - /// loads in the generated function (and thus cannot be cached). \p - /// forceAnonymousTape forces the tape to be an i8* rather than the true tape - /// structure \p AtomicAdd is whether to perform all adjoint updates to + /// loads in the generated function (and thus cannot be cached). + /// \p forceAnonymousTape forces the tape to be an i8* rather than the true + /// tape structure + /// \p AtomicAdd is whether to perform all adjoint updates to /// memory in an atomic way const AugmentedReturn &CreateAugmentedPrimal( RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, @@ -521,20 +537,8 @@ class EnzymeLogic { /// Create the reverse pass, or combined forward+reverse derivative function. /// \p context the instruction which requested this derivative (or null). - /// \p todiff is the function to differentiate - /// \p retType is the activity info of the return - /// \p constant_args is the activity info of the arguments - /// \p returnValue is whether the primal's return should also be returned - /// \p dretUsed is whether the shadow return value should also be returned - /// \p additionalArg is the type (or null) of an additional type in the - /// signature to hold the tape. - /// \p typeInfo is the type info information about the calling context - /// \p _overwritten_args marks whether an argument may be rewritten - /// before loads in the generated function (and thus cannot be cached). /// \p augmented is the data structure created by prior call to an /// augmented forward pass - /// \p AtomicAdd is whether to perform all adjoint - /// updates to memory in an atomic way llvm::Function *CreatePrimalAndGradient(RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 7b46fd9a2e83..f4f72f8cf52e 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -318,11 +318,14 @@ enum class ReturnType { /// Potential differentiable argument classifications enum class DIFFE_TYPE { - OUT_DIFF = 0, // add differential to an output struct - DUP_ARG = 1, // duplicate the argument and store differential inside - CONSTANT = 2, // no differential + OUT_DIFF = 0, // add differential to an output struct. Only for scalar values + // in ReverseMode variants. + DUP_ARG = 1, // duplicate the argument and store differential inside. + // For references, pointers, or integers in ReverseMode variants. + // For all types in ForwardMode variants. + CONSTANT = 2, // no differential. Usable everywhere. DUP_NONEED = 3 // duplicate this argument and store differential inside, but - // don't need the forward + // don't need the forward. Same as DUP_ARG otherwise. }; enum class BATCH_TYPE { From b5b7d10a6852e9535be768b70d25349339e4cf4f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 20:49:45 -0500 Subject: [PATCH 056/131] Update build_tarballs.jl (#1688) --- .packaging/build_tarballs.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/.packaging/build_tarballs.jl b/.packaging/build_tarballs.jl index 009b48e756e4..0b41f026ba38 100644 --- a/.packaging/build_tarballs.jl +++ b/.packaging/build_tarballs.jl @@ -28,7 +28,7 @@ platforms = expand_cxxstring_abis(supported_platforms(; experimental=true)) script = raw""" cd Enzyme -if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+17.asserts* ]]; then +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16* ]]; then # LLVM 15 requires macOS SDK 10.14. pushd $WORKSPACE/srcdir/MacOSX10.*.sdk rm -rf /opt/${target}/${target}/sys-root/System @@ -54,11 +54,11 @@ cmake -B build-native -S enzyme -GNinja "${NATIVE_CMAKE_FLAGS[@]}" # Only build blasheaders and tblgen ninja -C build-native -j ${nproc} blasheaders enzyme-tblgen - # 2. Cross-compile CMAKE_FLAGS=() CMAKE_FLAGS+=(-DENZYME_EXTERNAL_SHARED_LIB=ON) CMAKE_FLAGS+=(-DBC_LOAD_HEADER=`pwd`/build-native/BCLoad/gsl/blas_headers.h) +CMAKE_FLAGS+=(-DEnzyme_TABLEGEN=`pwd`/build-native/tools/enzyme-tblgen/enzyme-tblgen) CMAKE_FLAGS+=(-DEnzyme_TABLEGEN_EXE=`pwd`/build-native/tools/enzyme-tblgen/enzyme-tblgen) CMAKE_FLAGS+=(-DENZYME_CLANG=OFF) # RelWithDebInfo for decent performance, with debugability @@ -66,7 +66,11 @@ CMAKE_FLAGS+=(-DCMAKE_BUILD_TYPE=RelWithDebInfo) # Install things into $prefix CMAKE_FLAGS+=(-DCMAKE_INSTALL_PREFIX=${prefix}) # Explicitly use our cmake toolchain file and tell CMake we're cross-compiling -CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN}) +if [[ "${target}" == *mingw* && "${bb_full_target}" == *llvm_version+16* ]]; then + CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN%.*}_clang.cmake) +else + CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN}) +fi CMAKE_FLAGS+=(-DCMAKE_CROSSCOMPILING:BOOL=ON) # Tell CMake where LLVM is CMAKE_FLAGS+=(-DLLVM_DIR="${prefix}/lib/cmake/llvm") @@ -74,10 +78,18 @@ CMAKE_FLAGS+=(-DLLVM_DIR="${prefix}/lib/cmake/llvm") CMAKE_FLAGS+=(-DLLVM_LINK_LLVM_DYLIB=ON) # Build the library CMAKE_FLAGS+=(-DBUILD_SHARED_LIBS=ON) + +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16* ]]; then +if [[ "${target}" == x86_64-apple* ]]; then + CMAKE_FLAGS+=(-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=10.14) +fi +else if [[ "${target}" == x86_64-apple* ]]; then CMAKE_FLAGS+=(-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=10.12) fi +fi +echo ${CMAKE_FLAGS[@]} cmake -B build -S enzyme -GNinja ${CMAKE_FLAGS[@]} ninja -C build -j ${nproc} install @@ -113,7 +125,7 @@ for llvm_version in llvm_versions, llvm_assertions in (false, true) # We don't build LLVM 15 for i686-linux-musl. filter!(p -> !(arch(p) == "i686" && libc(p) == "musl"), platforms) end - + for platform in platforms augmented_platform = deepcopy(platform) augmented_platform[LLVM.platform_name] = LLVM.platform(llvm_version, llvm_assertions) From 56bb0a62f49ee47e5ceb8bdd4785dd392201c91c Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 22:35:27 -0500 Subject: [PATCH 057/131] Demangle and improve nofree (#1687) * Demangle and improve nofree * f --- enzyme/Enzyme/ActivityAnalysis.cpp | 2 + enzyme/Enzyme/EnzymeLogic.cpp | 170 +++++++++++++++++------------ 2 files changed, 102 insertions(+), 70 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 7b251e154f94..7921d3ff65c0 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -395,6 +395,8 @@ static const char *DemangledKnownInactiveFunctionsStartingWith[] = { // libc++ + "std::__1::locale", + "std::__1::ios_base", "std::__1::basic_string", "std::__1::__do_string_hash", "std::__1::hash", diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 4d6b0fef39d9..d60ac934634b 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -60,6 +60,8 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Demangle/Demangle.h" + #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -5888,73 +5890,101 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (isAllocationFunction(F->getName(), TLI)) return F; + // clang-format off + StringSet<> NoFreeDemangles = { + "std::basic_ostream>& std::__ostream_insert >(std::basic_ostream >&)", + "std::basic_ostream>::put(char)", + + "std::basic_filebuf>::open(char const*, std::_Ios_Openmode)", + "std::basic_filebuf>::basic_filebuf()", + "std::basic_filebuf>::close()", + + "std::basic_ios>::clear(std::_Ios_Iostate)", + "std::__detail::_Prime_rehash_policy::_M_need_rehash(unsigned long, unsigned long, unsigned long) const", + + "std::basic_streambuf >::xsputn(char const*, long)", + + "std::basic_ios >::init(std::basic_streambuf >*)", + + "std::_Hash_bytes(void const*, unsigned long, unsigned long)", + "unsigned long std::__1::__do_string_hash(char const*, char const*)", + "std::__1::hash::operator()(char const*) const", + + "std::allocator::allocator()", + "std::allocator::~allocator()", + + + "std::__cxx11::basic_string, std::allocator>::basic_string(char const*, std::allocator const&)", + "std::__cxx11::basic_string, std::allocator>::basic_string(std::__cxx11::basic_string, std::allocator>&&)", + "std::__cxx11::basic_string, std::allocator>::_M_construct(unsigned long, char)", + "std::__cxx11::basic_string, std::allocator>::_M_append(char const*, unsigned long)", + "std::__cxx11::basic_string, std::allocator>::_M_assign(std::__cxx11::basic_string, std::allocator> const&)", + "std::__cxx11::basic_string, std::allocator>::_M_replace(unsigned long, unsigned long, char const*, unsigned long)", + "std::__cxx11::basic_string, std::allocator>::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)", + "std::__cxx11::basic_string, std::allocator>::length() const", + "std::__cxx11::basic_string, std::allocator>::data() const", + "std::__cxx11::basic_string, std::allocator>::size() const", + "std::__cxx11::basic_string, std::allocator>::~basic_string()", + "std::__cxx11::basic_string, std::allocator>::compare(char const*) const", + "std::__cxx11::basic_string, std::allocator>::compare(std::__cxx11::basic_string, std::allocator> const&) const", + "std::__cxx11::basic_string, std::allocator>::reserve(unsigned long)", + + "std::__cxx11::basic_string, std::allocator>::~basic_string()", + "std::__cxx11::basic_stringbuf, std::allocator>::overflow(int)", + "std::__cxx11::basic_stringbuf, std::allocator>::pbackfail(int)", + "std::__cxx11::basic_stringbuf, std::allocator>::underflow()", + "std::__cxx11::basic_stringbuf, std::allocator>::_M_sync(char*, unsigned long, unsigned long)", + + "std::__basic_file::~__basic_file()", + + "std::basic_ostream>::flush()", + "std::basic_streambuf>::xsgetn(char*, long)", + + "std::locale::~locale()", + "std::ios_base::ios_base()", + "std::basic_ostream>& " + "std::basic_ostream " + ">::_M_insert(double)", + + // libc++ + "std::__1::basic_string, std::__1::allocator>::basic_string(std::__1::basic_string, std::__1::allocator> const&)", + "std::__1::basic_string, std::__1::allocator>::~basic_string()", + "std::__1::basic_string, std::__1::allocator>::__init(char const*, unsigned long)", + "std::__1::basic_string, std::__1::allocator>::append(char const*, unsigned long)", + "std::__1::basic_string, std::__1::allocator>::data() const", + "std::__1::basic_ostream>::sentry::sentry(std::__1::basic_ostream>&)", + "std::__1::basic_ostream>::sentry::~sentry()", + "std::__1::ios_base::__set_badbit_and_consider_rethrow()", + "char* std::__1::addressof(char&)", + "char const* std::__1::addressof(char const&)", + "std::__1::random_device::operator()()", + + "std::__1::locale::~locale()", + "std::__1::locale::use_facet(std::__1::locale::id&) const", + "std::__1::ios_base::ios_base()", + "std::__1::ios_base::getloc() const", + "std::__1::ios_base::clear(unsigned int)", + }; + // clang-format on + StringSet<> NoFrees = { - "mpfr_greater_p", - "memchr", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEC1EPKcRKS3_", - "_ZSt16__ostream_insertIcSt11char_traitsIcEERSt13basic_ostreamIT_T0_ES6_" - "PKS3_l", - "_ZNSo3putEc", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE7_M_syncEPcmm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE10_M_" - "replaceEmmPKcm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_appendEPKcm", - "_ZNSt13basic_filebufIcSt11char_traitsIcEE4openEPKcSt13_Ios_Openmode", - "_ZNSt9basic_iosIcSt11char_traitsIcEE5clearESt12_Ios_Iostate", - "_ZNSt13basic_filebufIcSt11char_traitsIcEE5closeEv", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE14_M_replace_" - "auxEmmmc", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE12_M_constructEmc", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7reserveEm", - "time", - "strlen", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareERKS4_", - "_ZNKSt8__detail20_Prime_rehash_policy14_M_need_rehashEmmm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEC1EOS4_", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE6lengthEv", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE4dataEv", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE4sizeEv", - "_ZNKSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE4dataEv" - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEED1Ev", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEED1Ev", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE6__" - "initEPKcm", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEEC1ERKS5_", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_" - "9allocatorIcEEE6appendEPKcm", - "_ZNSt12__basic_fileIcED1Ev", - "__cxa_begin_catch", - "__cxa_end_catch", - "_ZNSo5flushEv", - "compress2", - "_ZNSt6localeD1Ev", - "_ZNSt8ios_baseC2Ev", - "_ZNSo9_M_insertIdEERSoT_", - "malloc_usable_size", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEED1Ev", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareEPKc", - "_ZNSt13basic_filebufIcSt11char_traitsIcEEC1Ev", - "_ZNSt15basic_streambufIcSt11char_traitsIcEE6xsputnEPKcl", - "_ZNSt9basic_iosIcSt11char_traitsIcEE4initEPSt15basic_streambufIcS1_E", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE8overflowEi", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9pbackfailEi", - "_ZNSt15basic_streambufIcSt11char_traitsIcEE6xsgetnEPcl", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9underflowEv", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_assignERKS4_", - "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev", - "_ZSt11_Hash_bytesPKvmm", - "_ZNSt3__116__do_string_hashIPKcEEmT_S3_", - "_ZNKSt3__14hashIPKcEclES2_", - "_ZNSt3__19addressofIcEEPT_RS1_", - "_ZNSt3__19addressofIKcEEPT_RS2_", - "_ZNSt3__113random_deviceclEv", + "mpfr_greater_p", "memchr", "time", "strlen", + "__cxa_begin_catch", "__cxa_end_catch", "compress2", "malloc_usable_size", "MPI_Allreduce", }; if (startsWith(F->getName(), "_ZNSolsE") || NoFrees.count(F->getName())) return F; + std::string demangledName = llvm::demangle(F->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); + } + if (NoFreeDemangles.count(demangledName)) + return F; + switch (F->getIntrinsicID()) { case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: @@ -5972,23 +6002,23 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (EnzymeEmptyFnInactive) { return F; } + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No create nofree of empty function (" << demangledName << ") " + << F->getName() << ")\n"; + if (context.req) { + ss << " at context: " << *context.req; + } else { + ss << *F << "\n"; + } if (CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of empty function " << F->getName() << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *F << "\n"; - } CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(F), wrap(context.ip)); return F; } if (context.req) { - EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, - "Cannot create nofree of empty function: ", *F); + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, s); return F; } llvm::errs() << " unhandled, create no free of empty function: " << *F From f1f269a567b09db936b3eac8fe885c7dee7e3003 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 10 Feb 2024 10:42:03 -0500 Subject: [PATCH 058/131] Fix cached gemv (#1689) * Fix cached gemv * fix --- enzyme/Enzyme/BlasDerivatives.td | 6 +++--- .../Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll | 2 +- .../ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll | 2 +- .../Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll | 2 +- enzyme/test/Integration/ReverseMode/blas.cpp | 6 +++--- enzyme/test/Integration/blasinfra.h | 2 +- enzyme/tools/enzyme-tblgen/caching.cpp | 2 +- 9 files changed, 15 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 70682e540ef5..01894e87e453 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -192,7 +192,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ ["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>], [ /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]> - (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), + (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $m), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), (b<"dot"> (Rows $transa, $m, $n), adj<"y">, use<"Ax">, ConstantInt<1>)), //if (is_normal $transa) { @@ -201,7 +201,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ // call sger(m, n, alpha, x, incx, ya, incy, Aa, lda) //} /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, (Concat adj<"y">, $x), (Concat $x, adj<"y">)), adj<"A">), - /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">), + /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $m), adj<"y">, Constant<"1.0">, adj<"x">), /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">), /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">) ] @@ -218,7 +218,7 @@ def ger : CallBlasPattern<(Op $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, >; //(ld $A, $transa, $lda, $m, $k) // if (cache_A) { -// ld_A = (arg_transa == 'N') ? arg_m : arg_k; +// ld_A = (arg_transa == 'N') ? arg_k : arg_m; // } else { // ld_A = arg_lda; // } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index 0918ca43f83f..8f1c4f3e566e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -109,7 +109,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i3]]) ; CHECK-NEXT: %[[i10:.+]] = bitcast i8* %m_p to i64* ; CHECK-NEXT: %[[i11:.+]] = load i64, i64* %[[i10]] ; CHECK-NEXT: %[[i12:.+]] = bitcast i8* %n_p to i64* @@ -119,7 +119,7 @@ entry: ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i64 %mallocsize1) ; CHECK-NEXT: %cache.C = bitcast i8* %malloccall2 to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage3 -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %n_p) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %m_p) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index f58a4b2e209a..64d862b240ff 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %21) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) ; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index 2a7bbed40c24..f7164d4926e8 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %21) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) ; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index 964b53b1b925..d0f9ed397b0d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -105,7 +105,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] -; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage]], i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage]], i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i3]]) ; CHECK-NEXT: %loaded.trans1 = load i8, i8* %transb ; CHECK-DAG: %[[i10:.+]] = icmp eq i8 %loaded.trans1, 78 ; CHECK-DAG: %[[i11:.+]] = icmp eq i8 %loaded.trans1, 110 @@ -121,7 +121,7 @@ entry: ; CHECK-NEXT: %[[malloccall2:.+]] = tail call noalias nonnull i8* @malloc(i64 %[[mallocsize1]]) ; CHECK-NEXT: %cache.B = bitcast i8* %[[malloccall2]] to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage4 -; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %[[i13]], i8* %[[i14]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[i14]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %[[i13]], i8* %[[i14]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[i13]]) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index a4c00d61bbc4..09189c1fe8d1 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -103,7 +103,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.B = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[z3]], i8* %[[z4]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[z4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[z3]], i8* %[[z4]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[z3]]) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %ptr = bitcast i8* %B to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8 diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 8c929ed054d0..1c01178b874f 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -226,8 +226,8 @@ static void gemvTests() { assert(foundCalls.size() > 2); auto A_cache = (double*)foundCalls[0].pout_arg1; - cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, N); - inputs[4] = BlasInfo(A_cache, layout, M, N, N); + cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, M); + inputs[4] = BlasInfo(A_cache, layout, M, N, M); auto B_cache = (double*)foundCalls[1].pout_arg1; cblas_dcopy(trans ? M : N, B, incB, B_cache, 1); inputs[5] = BlasInfo(B_cache, trans ? M : N, 1); @@ -244,7 +244,7 @@ static void gemvTests() { lda); // dB = alpha * trans(A) * dC + dB - cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, N, dC, incC, 1.0, dB, incB); + cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, M, dC, incC, 1.0, dB, incB); // dY = beta * dY cblas_dscal(trans ? N : M, beta, dC, incC); diff --git a/enzyme/test/Integration/blasinfra.h b/enzyme/test/Integration/blasinfra.h index 07d540a9b326..cc9d2cc01dd9 100644 --- a/enzyme/test/Integration/blasinfra.h +++ b/enzyme/test/Integration/blasinfra.h @@ -1,5 +1,5 @@ -#include +#include #include #include #include diff --git a/enzyme/tools/enzyme-tblgen/caching.cpp b/enzyme/tools/enzyme-tblgen/caching.cpp index cdb7b3a92910..7a98d944b8ab 100644 --- a/enzyme/tools/enzyme-tblgen/caching.cpp +++ b/enzyme/tools/enzyme-tblgen/caching.cpp @@ -293,7 +293,7 @@ os << " if (byRef) valueTypes[" << len_pos << "] = ValueType::Primal;\n"; os << " if (EnzymeLapackCopy) {\n" << " Value *uplo = llvm::ConstantInt::get(charTy, 0);\n" // garbage data, just should not match U or L << " uplo = to_blas_callconv(BuilderZ, uplo, byRef, cublas, nullptr, allocationBuilder, \"copy.garbage\");\n" -<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, N};\n" +<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" << " if (!byRef) {\n" << " args.insert(args.begin(), arg_layout); valueTypes.insert(valueTypes.begin(), ValueType::Primal); }\n" << " callMemcpyStridedLapack(BuilderZ, *gutils->oldFunc->getParent(), blas, args, gutils->getInvertedBundles(&call, valueTypes, BuilderZ, /*lookup*/false));\n" From c199e36cf3cb7db48a425fdfb2b351013b0db4af Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 10 Feb 2024 15:30:36 -0500 Subject: [PATCH 059/131] Turn mask assertion error into nice compile error (#1690) --- enzyme/Enzyme/DiffeGradientUtils.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index f3ca78dacea1..43a14051c4fd 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -928,11 +928,20 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, applyChainRule(PointerType::get(addingType, 1), BuilderM, rule, ptr); } - assert(!mask); if (mask) { - llvm::errs() << "unhandled masked atomic fadd on llvm version " << *ptr - << " " << *dif << " mask: " << *mask << "\n"; - llvm_unreachable("unhandled masked atomic fadd"); + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Unimplemented masked atomic fadd for ptr:" << *ptr + << " dif:" << *dif << " mask: " << *mask << " orig: " << *orig << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(orig), + ErrorType::NoDerivative, this, nullptr, + wrap(&BuilderM)); + return; + } else { + EmitFailure("NoDerivative", orig->getDebugLoc(), orig, ss.str()); + return; + } } /* From 90d057f8bc5412bdbd7693266de8db289073685e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 16:37:34 -0500 Subject: [PATCH 060/131] [ActivityAnalysis] simplify and cleanup (#1694) --- enzyme/Enzyme/ActivityAnalysis.cpp | 376 +++++++++++------------------ 1 file changed, 139 insertions(+), 237 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 7921d3ff65c0..e6b5a565c112 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -104,19 +104,6 @@ cl::opt EnzymeEnableRecursiveHypotheses( #include // clang-format off -static const char *KnownInactiveFunctionsStartingWith[] = { - "f90io", - "$ss5print", - "_ZTv0_n24_NSoD", //"1Ev, 0Ev - "_ZNSt16allocator_traitsISaIdEE10deallocate", - "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev", -}; - -static const char *KnownInactiveFunctionsContains[] = { - "__enzyme_float", "__enzyme_double", "__enzyme_integer", - "__enzyme_pointer"}; - static const StringSet<> InactiveGlobals = { "small_typeof", "ompi_request_null", @@ -168,19 +155,26 @@ const llvm::StringMap MPIInactiveCommAllocators = { {"MPI_Comm_idup", 1}, {"MPI_Comm_join", 1}, }; +// clang-format on -// Instructions which themselves are inactive -// the returned value, however, may still be active -static const StringSet<> KnownInactiveFunctionInsts = { - "__dynamic_cast", - "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base", - "jl_ptr_to_array", - "jl_ptr_to_array_1d"}; +/// Return whether the call is always inactive by definition. +bool isInactiveCall(CallBase &CI) { + + // clang-format off +const char *KnownInactiveFunctionsStartingWith[] = { + "f90io", + "$ss5print", + "_ZTv0_n24_NSoD", //"1Ev, 0Ev + "_ZNSt16allocator_traitsISaIdEE10deallocate", + "_ZNSaIcED1Ev", + "_ZNSaIcEC1Ev", +}; + +const char *KnownInactiveFunctionsContains[] = { + "__enzyme_float", "__enzyme_double", "__enzyme_integer", + "__enzyme_pointer"}; -static const StringSet<> KnownInactiveFunctions = { +const StringSet<> KnownInactiveFunctions = { "mpfr_greater_p", "__nv_isnand", "__nv_isnanf", @@ -293,7 +287,7 @@ static const StringSet<> KnownInactiveFunctions = { "floorl" }; -static const std::set KnownInactiveIntrinsics = { +const std::set KnownInactiveIntrinsics = { #if LLVM_VERSION_MAJOR >= 12 Intrinsic::experimental_noalias_scope_decl, #endif @@ -342,7 +336,7 @@ static const std::set KnownInactiveIntrinsics = { Intrinsic::is_constant, Intrinsic::memset}; -static const char *DemangledKnownInactiveFunctionsStartingWith[] = { +const char *DemangledKnownInactiveFunctionsStartingWith[] = { // TODO this returns allocated memory and thus can be an active value // "std::allocator", "std::chrono::_V2::steady_clock::now", @@ -417,13 +411,110 @@ static const char *DemangledKnownInactiveFunctionsStartingWith[] = { "std::__detail::_Prime_rehash_policy", "std::__detail::_Hash_code_base", }; -// clang-format on + // clang-format on + + if (CI.hasFnAttr("enzyme_inactive")) + return true; + + if (auto iasm = dyn_cast(CI.getCalledOperand())) { + if (StringRef(iasm->getAsmString()).contains("exit") || + StringRef(iasm->getAsmString()).contains("cpuid")) + return false; + } + + auto F = getFunctionFromCall(&CI); + + if (F == nullptr) + return false; + + if (F->hasFnAttribute("enzyme_inactive")) { + return true; + } + + auto Name = getFuncNameFromCall(&CI); + + std::string demangledName = llvm::demangle(Name.str()); + auto dName = StringRef(demangledName); + for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { + if (startsWith(dName, FuncName)) { + return true; + } + } + + for (auto FuncName : KnownInactiveFunctionsStartingWith) { + if (startsWith(Name, FuncName)) { + return true; + } + } + + for (auto FuncName : KnownInactiveFunctionsContains) { + if (Name.contains(FuncName)) { + return true; + } + } + if (KnownInactiveFunctions.count(Name)) { + return true; + } + + if (MPIInactiveCommAllocators.find(Name) != MPIInactiveCommAllocators.end()) { + return true; + } + Intrinsic::ID ID; + if (isMemFreeLibMFunction(Name, &ID)) + if (KnownInactiveIntrinsics.count(ID)) { + return true; + } + + if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { + return true; + } + + return false; +} + +bool isInactiveCallInst(CallBase &CB, llvm::TargetLibraryInfo &TLI) { + // clang-format off +// Instructions which themselves are inactive +// the returned value, however, may still be active +static const StringSet<> KnownInactiveFunctionInsts = { + "__dynamic_cast", + "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base", + "jl_ptr_to_array", + "jl_ptr_to_array_1d"}; + // clang-format on + if (isInactiveCall(CB)) + return true; + if (CB.hasFnAttr("enzyme_inactive_inst")) { + return true; + } + auto called = getFunctionFromCall(&CB); + + if (called) { + if (called->hasFnAttribute("enzyme_inactive_inst")) { + return true; + } + } + + auto funcName = getFuncNameFromCall(&CB); + if (KnownInactiveFunctionInsts.count(funcName)) + return true; + + if (isAllocationFunction(funcName, TLI) || + isDeallocationFunction(funcName, TLI)) { + return true; + } + + return false; +} /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { assert(directions & DOWN); - if (CI->hasFnAttr("enzyme_inactive")) + if (isInactiveCall(*CI)) return true; auto F = getFunctionFromCall(CI); @@ -453,10 +544,6 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (F == nullptr) return false; - if (F->hasFnAttribute("enzyme_inactive")) { - return true; - } - auto Name = getFuncNameFromCall(CI); // Only the 1-th arg impacts activity @@ -468,43 +555,6 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI)) return true; - std::string demangledName = llvm::demangle(Name.str()); - auto dName = StringRef(demangledName); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return true; - } - } - if (demangledName == Name.str()) { - // Either demangeling failed - // or they are equal but matching failed - // if (!startsWith(Name, "llvm.")) - // llvm::errs() << "matching failed: " << Name.str() << " " - // << demangledName << "\n"; - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(Name, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (Name.contains(FuncName)) { - return true; - } - } - if (KnownInactiveFunctions.count(Name)) { - return true; - } - - if (MPIInactiveCommAllocators.find(Name) != MPIInactiveCommAllocators.end()) { - return true; - } - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return true; - } - /// Only the first argument (magnitude) of copysign is active if (F->getIntrinsicID() == Intrinsic::copysign && CI->getArgOperand(0) != val) { @@ -551,6 +601,8 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { static inline void propagateArgumentInformation( TargetLibraryInfo &TLI, CallInst &CI, llvm::function_ref propagateFromOperand) { + if (isInactiveCall(CI)) + return; // These functions are known to only have the first argument impact // the activity of the call instruction @@ -598,12 +650,6 @@ static inline void propagateArgumentInformation( return; } - // Certain intrinsics are inactive by definition - // and have nothing to propagate. - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return; - } - if (F->getIntrinsicID() == Intrinsic::memcpy || F->getIntrinsicID() == Intrinsic::memmove) { propagateFromOperand(CI.getOperand(0)); @@ -720,13 +766,6 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, ActiveInstructions.insert(I); return false; } - if (CI->hasFnAttr("enzyme_inactive") || - CI->hasFnAttr("enzyme_inactive_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } auto called = getFunctionFromCall(CI); if (called) { @@ -737,27 +776,17 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, ActiveInstructions.insert(I); return false; } - if (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } } - if (KnownInactiveFunctionInsts.count(getFuncNameFromCall(CI))) { + if (isInactiveCallInst(*CI, TLI)) { + if (EnzymePrintActivity) + llvm::errs() << "known inactive instruction from call " << *I << "\n"; InsertConstantInstruction(TR, I); return true; } } if (auto II = dyn_cast(I)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - if (EnzymePrintActivity) - llvm::errs() << "known inactive intrinsic " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } else if (isIntelSubscriptIntrinsic(*II)) { + if (isIntelSubscriptIntrinsic(*II)) { // The intrinsic "llvm.intel.subscript" does not propogate deriviative // information directly. But its returned pointer may be active. InsertConstantInstruction(TR, I); @@ -1066,13 +1095,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - if (auto II = dyn_cast(Val)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - InsertConstantValue(TR, Val); - return true; - } - } - // All arguments must be marked constant/nonconstant ahead of time if (isa(Val) && !cast(Val)->hasByValAttr()) { llvm::errs() << *(cast(Val)->getParent()) << "\n"; @@ -1348,6 +1370,12 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } } + if (isInactiveCall(*CI)) { + if (EnzymePrintActivity) + llvm::errs() << "known inactive val from call" << *Val << "\n"; + InsertConstantValue(TR, Val); + return true; + } } if (auto BO = dyn_cast(Val)) { // x & 0b100000 is definitionally inactive @@ -1536,8 +1564,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } } } else if (auto op = dyn_cast(TmpOrig)) { - if (op->hasFnAttr("enzyme_inactive") || - op->hasFnAttr("enzyme_inactive_val") || + if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") || op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, "enzyme_inactive")) { InsertConstantValue(TR, Val); @@ -1549,8 +1576,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { StringRef funcName = getFuncNameFromCall(op); if (called && - (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_val") || + (called->hasFnAttribute("enzyme_inactive_val") || called->getAttributes().hasAttribute( llvm::AttributeList::ReturnIndex, "enzyme_inactive"))) { InsertConstantValue(TR, Val); @@ -1564,45 +1590,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - - if (called && called->getIntrinsicID() == Intrinsic::trap) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - // If requesting empty unknown functions to be considered inactive, // abide by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -1857,57 +1844,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { // If this is a malloc or free, this doesn't impact the activity if (auto CI = dyn_cast(I)) { - if (CI->hasFnAttr("enzyme_inactive") || - CI->hasFnAttr("enzyme_inactive_inst")) + if (isInactiveCallInst(*CI, TLI)) return false; - if (auto iasm = dyn_cast(CI->getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("exit") || - StringRef(iasm->getAsmString()).contains("cpuid")) - return false; - } - - auto F = getFunctionFromCall(CI); StringRef funcName = getFuncNameFromCall(CI); - - if (F && (F->hasFnAttribute("enzyme_inactive") || - F->hasFnAttribute("enzyme_inactive_inst"))) { - return false; - } - if (isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return false; - } - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - return false; - } - if (KnownInactiveFunctionInsts.count(funcName)) { - return false; - } if (isMemFreeLibMFunction(funcName)) { return false; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return false; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - return false; - } - } - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - return false; - } - } - if (funcName == "__cxa_guard_acquire" || funcName == "__cxa_guard_release" || funcName == "__cxa_guard_abort" || funcName == "posix_memalign" || @@ -1917,12 +1861,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { funcName == "cudaMallocFromPoolAsync") { return false; } - - if (F) { - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return false; - } - } } Value *memval = Val; @@ -2538,8 +2476,10 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, } if (auto op = dyn_cast(inst)) { - if (op->hasFnAttr("enzyme_inactive") || - op->hasFnAttr("enzyme_inactive_val")) { + if (isInactiveCall(*op)) + return true; + + if (op->hasFnAttr("enzyme_inactive_val")) { return true; } // Calls to print/assert/cxa guard are definitionally inactive @@ -2548,8 +2488,7 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, StringRef funcName = getFuncNameFromCall(op); auto called = getFunctionFromCall(op); - if (called && (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_val"))) { + if (called && (called->hasFnAttribute("enzyme_inactive_val"))) { return true; } if (funcName == "free" || funcName == "_ZdlPv" || funcName == "_ZdlPvm" || @@ -2557,37 +2496,6 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, return true; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - return true; - } - } - - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions - << ") up-knowninactivecall " << *inst << "\n"; - return true; - } - - if (called && called->getIntrinsicID() == Intrinsic::trap) - return true; - // If requesting empty unknown functions to be considered inactive, abide // by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -2609,12 +2517,6 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, } // Intrinsics known always to be inactive if (auto II = dyn_cast(inst)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - << *inst << "\n"; - return true; - } if (isIntelSubscriptIntrinsic(*II)) { // The only argument that can make an llvm.intel.subscript intrinsic // active is the pointer operand From 1d7347865e88c1e6a306bb665493d35948401a9a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 16:57:58 -0500 Subject: [PATCH 061/131] Mark nvfmodf and improve error message for wrong arg num (#1695) --- enzyme/Enzyme/EnzymeLogic.cpp | 49 +++++++++++++++++++++++-- enzyme/Enzyme/InstructionDerivatives.td | 2 +- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index d60ac934634b..17ab0c2f4eca 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -1961,11 +1961,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( assert(!todiff->getReturnType()->isEmptyTy() && !todiff->getReturnType()->isVoidTy()); - assert(_overwritten_args.size() == todiff->arg_size()); - FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(oldTypeInfo_, todiff); - - assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); AugmentedCacheKey tup = {todiff, retType, constant_args, _overwritten_args, returnUsed, shadowReturnUsed, @@ -1973,6 +1969,51 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( AtomicAdd, omp, width}; + if (_overwritten_args.size() != todiff->arg_size()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << " overwritten_args.size() [" << _overwritten_args.size() + << "] != todiff->arg_size()\n"; + ss << "todiff: " << *todiff << "\n"; + llvm::Value *toshow = todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *todiff << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(todiff), + wrap(context.ip)); + auto newFunc = todiff; + std::map returnMapping; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + auto newFunc = todiff; + std::map returnMapping; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; + } + llvm::errs() << "mod: " << *todiff->getParent() << "\n"; + llvm::errs() << *todiff << "\n"; + llvm_unreachable( + "attempting to differentiate function with wrong overwritten count"); + } + + assert(_overwritten_args.size() == todiff->arg_size()); + assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); + auto found = AugmentedCachedFunctions.find(tup); if (found != AugmentedCachedFunctions.end()) { return found->second; diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 00be8fe77e55..58731be4a532 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -434,7 +434,7 @@ def : CallPattern<(Op $x), >; def : CallPattern<(Op $x, $y), - ["fmod", "fmodf", "fmodl"], + ["fmod", "fmodf", "fmodl", "__nv_fmod", "__nv_fmodf", "__nv_fmodl"], [ (DiffeRet), (CheckedMul (DiffeRet), (FNeg (Intrinsic<"copysign"> (Intrinsic<"floor"> (Intrinsic<"fabs"> (FDiv $x, $y):$div)), $div))) From 6b8960d17a76e5f7bbbfdb44a57f153fee1c0681 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 18:06:26 -0500 Subject: [PATCH 062/131] Simplify and improve differential use analysis (#1693) * Simplify and improve differential use analysis * fix lambda --------- Co-authored-by: Tim Gymnich --- enzyme/Enzyme/AdjointGenerator.h | 16 +- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 78 ++-------- enzyme/Enzyme/DifferentialUseAnalysis.h | 147 +++++++++++++++---- enzyme/Enzyme/EnzymeLogic.cpp | 4 +- enzyme/Enzyme/GradientUtils.cpp | 52 ++++--- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 56 +++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h | 10 ++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 21 ++- 8 files changed, 252 insertions(+), 132 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index de9bbf8e38a6..ec9cd5c51c9a 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -756,7 +756,7 @@ class AdjointGenerator : public llvm::InstVisitor { auto alignment = LI.getAlign(); auto &DL = gutils->newFunc->getParent()->getDataLayout(); - bool constantval = parseTBAA(LI, DL, nullptr).Inner0().isIntegral(); + bool constantval = parseTBAA(LI, DL, nullptr)[{-1}].isIntegral(); visitLoadLike(LI, alignment, constantval); eraseIfUnused(LI); } @@ -992,7 +992,7 @@ class AdjointGenerator : public llvm::InstVisitor { NewI->setMetadata(LLVMContext::MD_noalias, noscope); bool constantval = gutils->isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); IRBuilder<> BuilderZ(NewI); BuilderZ.setFastMathFlags(getFast()); @@ -3272,7 +3272,7 @@ class AdjointGenerator : public llvm::InstVisitor { // copying into nullptr is invalid (not sure why it exists here), but we // shouldn't do it in reverse pass or shadow if (isa(orig_dst) || - TR.query(orig_dst).Inner0() == BaseType::Anything) { + TR.query(orig_dst)[{-1}] == BaseType::Anything) { eraseIfUnused(MTI); return; } @@ -3689,7 +3689,7 @@ class AdjointGenerator : public llvm::InstVisitor { auto align0 = cast(I.getOperand(1))->getZExtValue(); auto align = MaybeAlign(align0); auto &DL = gutils->newFunc->getParent()->getDataLayout(); - bool constantval = parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + bool constantval = parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); visitLoadLike(I, align, constantval, /*mask*/ gutils->getNewFromOriginal(I.getOperand(2)), /*orig_maskInit*/ I.getOperand(3)); @@ -4068,7 +4068,7 @@ class AdjointGenerator : public llvm::InstVisitor { assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); } else { - assert(TR.query(call.getArgOperand(i)).Inner0().isFloat()); + assert(TR.query(call.getArgOperand(i))[{-1}].isFloat()); OutTypes.push_back(call.getArgOperand(i)); OutFPTypes.push_back(argType); assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF || @@ -5437,8 +5437,7 @@ class AdjointGenerator : public llvm::InstVisitor { assert(dcall); if (!gutils->isConstantValue(&call)) { - if (!call.getType()->isFPOrFPVectorTy() && - TR.query(&call).Inner0().isPossiblePointer()) { + if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { } else if (Mode != DerivativeMode::ReverseModePrimal) { ((DiffeGradientUtils *)gutils)->differentials[dcall] = ((DiffeGradientUtils *)gutils)->differentials[newCall]; @@ -5898,8 +5897,7 @@ class AdjointGenerator : public llvm::InstVisitor { gutils->originalToNewFn[&call] = dcall; gutils->newToOriginalFn.erase(newCall); gutils->newToOriginalFn[dcall] = &call; - if (!call.getType()->isFPOrFPVectorTy() && - TR.query(&call).Inner0().isPossiblePointer()) { + if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { } else { ((DiffeGradientUtils *)gutils)->differentials[dcall] = ((DiffeGradientUtils *)gutils)->differentials[newCall]; diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 3be0cdfc2895..d20ec4eefbc4 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -293,47 +293,6 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( return false; } - if (shadow) { - if (auto IVI = dyn_cast(user)) - if (isa(IVI) || isa(IVI)) { - if (IVI->getOperand(1) == val) { - SmallVector todo; - todo.push_back(IVI); - SmallVector, 1> - done; - while (todo.size()) { - auto cur = todo.pop_back_val(); - for (auto u : cur->users()) { - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - done.emplace_back(cast(u), cur); - } - } - for (auto &pair : done) { - - bool direct = is_use_directly_needed_in_reverse( - gutils, pair.second, mode, pair.first, oldUnreachable, - QueryType::Shadow, recursiveUse); - if (direct) { - - if (EnzymePrintDiffUse) - llvm::errs() - << " Need (partial) direct " << to_string(qtype) << " of " - << *val << " in reverse from insertelem " << *user - << " via " << *pair.second << " in " << *pair.first << "\n"; - return true; - } - } - } - } - } - if (!shadow) if (auto IEI = dyn_cast(user)) { // Only need the index in the reverse, so if the value is not @@ -800,33 +759,26 @@ int DifferentialUseAnalysis::cmpLoopNest(Loop *prev, Loop *next) { return -1; } -void DifferentialUseAnalysis::minCut( - const DataLayout &DL, LoopInfo &OrigLI, - const SetVector &Recomputes, - const SetVector &Intermediates, SetVector &Required, - SetVector &MinReq, - const ValueMap - &rematerializableAllocations, - llvm::TargetLibraryInfo &TLI) { +void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, + const SetVector &Recomputes, + const SetVector &Intermediates, + SetVector &Required, + SetVector &MinReq, + const GradientUtils *gutils, + llvm::TargetLibraryInfo &TLI) { Graph G; for (auto V : Intermediates) { G[Node(V, false)].insert(Node(V, true)); - for (auto U : V->users()) { - if (auto I = dyn_cast(U)) { - for (auto pair : rematerializableAllocations) { - if (Intermediates.count(pair.first) && pair.second.stores.count(I)) { - if (V != pair.first) - G[Node(V, true)].insert(Node(pair.first, false)); + forEachDifferentialUser( + [&](Value *U) { + if (Intermediates.count(U)) { + if (V != U) + G[Node(V, true)].insert(Node(U, false)); } - } - } - if (Intermediates.count(U)) { - if (V != U) - G[Node(V, true)].insert(Node(U, false)); - } - } + }, + gutils, V); } - for (auto pair : rematerializableAllocations) { + for (auto pair : gutils->rematerializableAllocations) { if (Intermediates.count(pair.first)) { for (LoadInst *L : pair.second.loads) { if (Intermediates.count(L)) { diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index d9d689eeee64..565e18724869 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -166,6 +166,60 @@ inline bool is_value_needed_in_reverse( } } + if (!TR.anyFloat(const_cast(inst))) + if (auto IVI = dyn_cast(user)) { + bool inserted = false; + if (auto II = dyn_cast(IVI)) + inserted = II->getInsertedValueOperand() == inst || + II->getAggregateOperand() == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getAggregateOperand() == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(1) == inst || II->getOperand(0) == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(0) == inst; + if (inserted) { + SmallVector todo; + todo.push_back(IVI); + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + + bool partial = false; + if (!gutils->isConstantValue(const_cast(cur))) { + partial = is_value_needed_in_reverse( + gutils, user, mode, seen, oldUnreachable); + } + if (partial) { + + if (EnzymePrintDiffUse) + llvm::errs() + << " Need (partial) direct " << to_string(VT) << " of " + << *inst << " in reverse from insertelem " << *user + << " via " << *cur << " in " << *u << "\n"; + return seen[idx] = true; + } + } + } + } + } + if (VT != QueryType::Primal) continue; } @@ -332,36 +386,14 @@ inline bool is_value_needed_in_reverse( primalUsedInShadowPointer = false; } } - if (auto IVI = dyn_cast(user)) { - bool valueIsIndex = false; - for (unsigned i = 2; i < IVI->getNumOperands(); ++i) { - if (IVI->getOperand(i) == inst) { - if (inst == IVI->getInsertedValueOperand() && - TR.query( - const_cast(IVI->getInsertedValueOperand()))[{-1}] - .isFloat()) { - continue; - } - valueIsIndex = true; - } - } - primalUsedInShadowPointer = valueIsIndex; - } - if (auto EVI = dyn_cast(user)) { - bool valueIsIndex = false; - for (unsigned i = 1; i < EVI->getNumOperands(); ++i) { - if (EVI->getOperand(i) == inst) { - valueIsIndex = true; - } - } - primalUsedInShadowPointer = valueIsIndex; - } + // No need for insert/extractvalue since indices are unsigned + // not llvm runtime values + if (isa(user) || isa(user)) + primalUsedInShadowPointer = false; if (primalUsedInShadowPointer) if (!user->getType()->isVoidTy() && - TR.query(const_cast(user)) - .Inner0() - .isPossiblePointer()) { + TR.anyPointer(const_cast(user))) { if (is_value_needed_in_reverse( gutils, user, mode, seen, oldUnreachable)) { if (EnzymePrintDiffUse) @@ -433,11 +465,66 @@ void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, const llvm::SetVector &Recomputes, const llvm::SetVector &Intermediates, llvm::SetVector &Required, - llvm::SetVector &MinReq, - const llvm::ValueMap - &rematerializableAllocations, + llvm::SetVector &MinReq, const GradientUtils *gutils, llvm::TargetLibraryInfo &TLI); +__attribute__((always_inline)) static inline void +forEachDirectInsertUser(llvm::function_ref f, + const GradientUtils *gutils, llvm::Instruction *IVI, + llvm::Value *val, bool useCheck) { + using namespace llvm; + if (!gutils->isConstantValue(IVI)) + return; + bool inserted = false; + if (auto II = dyn_cast(IVI)) + inserted = II->getInsertedValueOperand() == val || + II->getAggregateOperand() == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getAggregateOperand() == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(1) == val || II->getOperand(0) == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(0) == val; + if (inserted) { + SmallVector todo; + todo.push_back(IVI); + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (isa(u) || isa(u) || + isa(u) || isa(u)) { + auto I2 = cast(u); + bool subCheck = useCheck; + if (!subCheck) { + subCheck = is_value_needed_in_reverse( + gutils, I2, gutils->mode, gutils->notForAnalysis); + } + if (subCheck) + f(I2); + todo.push_back(I2); + continue; + } + } + } + } +} + +__attribute__((always_inline)) static inline void +forEachDifferentialUser(llvm::function_ref f, + const GradientUtils *gutils, llvm::Value *V, + bool useCheck = false) { + for (auto V2 : V->users()) { + if (auto Inst = llvm::dyn_cast(V2)) { + for (const auto &pair : gutils->rematerializableAllocations) { + if (pair.second.stores.count(Inst)) { + f(llvm::cast(pair.first)); + } + } + f(Inst); + forEachDirectInsertUser(f, gutils, Inst, V, useCheck); + } + } +} }; // namespace DifferentialUseAnalysis #endif diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 17ab0c2f4eca..850a92543db3 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -1282,7 +1282,7 @@ bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils) { } if (!op->getType()->isFPOrFPVectorTy() && !gutils->isConstantValue(op) && - gutils->TR.query(op).Inner0().isPossiblePointer()) { + gutils->TR.anyPointer(op)) { modifyPrimal = true; #ifdef PRINT_AUGCALL @@ -1315,7 +1315,7 @@ bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils) { if (!argType->isFPOrFPVectorTy() && !gutils->isConstantValue(op->getArgOperand(i)) && - gutils->TR.query(op->getArgOperand(i)).Inner0().isPossiblePointer()) { + gutils->TR.anyPointer(op->getArgOperand(i))) { if (!isReadOnly(op, i)) { modifyPrimal = true; #ifdef PRINT_AUGCALL diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index ec15cd2e76ff..734d95989442 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3280,7 +3280,7 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto &DL = newFunc->getParent()->getDataLayout(); bool constantval = isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); // TODO allow recognition of other types that could contain // pointers [e.g. {void*, void*} or <2 x i64> ] @@ -4321,8 +4321,7 @@ DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::Value *orig, subretType = DIFFE_TYPE::DUP_ARG; shadowReturnUsed = true; } else { - if (!orig->getType()->isFPOrFPVectorTy() && - TR.query(orig).Inner0().isPossiblePointer()) { + if (!orig->getType()->isFPOrFPVectorTy() && TR.anyPointer(orig)) { if (DifferentialUseAnalysis::is_value_needed_in_reverse< QueryType::Shadow>(this, orig, cmode, notForAnalysis)) { subretType = DIFFE_TYPE::DUP_ARG; @@ -4359,8 +4358,7 @@ DIFFE_TYPE GradientUtils::getDiffeType(Value *v, bool foreignFunction) const { auto argType = v->getType(); - if (!argType->isFPOrFPVectorTy() && - (TR.query(v).Inner0().isPossiblePointer() || foreignFunction)) { + if (!argType->isFPOrFPVectorTy() && (TR.anyPointer(v) || foreignFunction)) { if (argType->isPointerTy()) { auto at = getBaseObject(v); if (auto arg = dyn_cast(at)) { @@ -5105,9 +5103,21 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, return applyChainRule(oval->getType(), BuilderM, rule); } - if (isConstantValue(oval) && !isa(oval) && - !isa(oval) && !isa(oval) && - !isa(oval)) { + bool shouldNullShadow = isConstantValue(oval); + if (shouldNullShadow) { + if (isa(oval) || isa(oval) || + isa(oval) || isa(oval)) { + shouldNullShadow = false; + auto orig = cast(oval); + if (knownRecomputeHeuristic.count(orig)) { + if (!knownRecomputeHeuristic[orig]) { + shouldNullShadow = true; + } + } + } + } + + if (shouldNullShadow) { // NOTE, this is legal and the correct resolution, however, our activity // analysis honeypot no longer exists @@ -8011,9 +8021,9 @@ void GradientUtils::computeMinCache() { todo.pop_front(); if (Intermediates.count(V)) continue; - if (!DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(this, V, minCutMode, FullSeen, - notForAnalysis)) { + bool multiLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< + QueryType::Primal>(this, V, minCutMode, FullSeen, notForAnalysis); + if (!multiLevel) { continue; } if (!Recomputes.count(V)) { @@ -8033,27 +8043,21 @@ void GradientUtils::computeMinCache() { } } Intermediates.insert(V); - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal, /*OneLevel*/ true>( - this, V, minCutMode, OneLevelSeen, notForAnalysis)) { + bool singleLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< + QueryType::Primal, /*OneLevel*/ true>(this, V, minCutMode, + OneLevelSeen, notForAnalysis); + if (singleLevel) { Required.insert(V); } else { - for (auto V2 : V->users()) { - if (auto Inst = dyn_cast(V2)) - for (auto pair : rematerializableAllocations) { - if (pair.second.stores.count(Inst)) { - todo.push_back(pair.first); - } - } - todo.push_back(V2); - } + DifferentialUseAnalysis::forEachDifferentialUser( + [&](Value *V2) { todo.push_back(V2); }, this, V); } } SetVector MinReq; DifferentialUseAnalysis::minCut(oldFunc->getParent()->getDataLayout(), OrigLI, Recomputes, Intermediates, Required, - MinReq, rematerializableAllocations, TLI); + MinReq, this, TLI); SmallPtrSet NeedGraph; for (Value *V : MinReq) NeedGraph.insert(V); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index b9e670326fd4..d5bf445ee429 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -5773,6 +5773,62 @@ TypeTree TypeResults::query(Value *val) const { return analyzer.getAnalysis(val); } +bool TypeResults::anyFloat(Value *val) const { + assert(val); + assert(val->getType()); + auto q = query(val); + auto dt = q[{-1}]; + if (dt != BaseType::Anything && dt != BaseType::Unknown) + return dt.isFloat(); + + size_t ObjSize = 1; + auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + if (val->getType()->isSized()) + ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; + + for (size_t i = 0; i < ObjSize;) { + dt = q[{(int)i}]; + if (dt == BaseType::Integer) { + i++; + continue; + } + if (dt == BaseType::Pointer) { + i += dl.getPointerSize(0); + continue; + } + return true; + } + return false; +} + +bool TypeResults::anyPointer(Value *val) const { + assert(val); + assert(val->getType()); + auto q = query(val); + auto dt = q[{-1}]; + if (dt != BaseType::Anything && dt != BaseType::Unknown) + return dt == BaseType::Pointer; + + size_t ObjSize = 1; + auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + if (val->getType()->isSized()) + ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; + + for (size_t i = 0; i < ObjSize;) { + dt = q[{(int)i}]; + if (dt == BaseType::Integer) { + i++; + continue; + } + if (auto FT = dt.isFloat()) { + i += (dl.getTypeSizeInBits(FT) + 7) / 8; + continue; + } + return true; + } + return false; +} + void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer.dump(ss); } ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h index fc7bc1d24dd7..6e036454143a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h @@ -174,6 +174,16 @@ class TypeResults { /// The TypeTree of a particular Value TypeTree query(llvm::Value *val) const; + /// Whether any part of the top level register can contain a float + /// e.g. { i64, float } can contain a float, but { i64, i8* } would not. + // Of course, here we compute with type analysis rather than llvm type + bool anyFloat(llvm::Value *val) const; + + /// Whether any part of the top level register can contain a pointer + /// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not. + // Of course, here we compute with type analysis rather than llvm type + bool anyPointer(llvm::Value *val) const; + /// The TypeInfo calling convention FnTypeInfo getAnalyzedTypeInfo() const; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index bf6a00d07cd2..7eb015121a3e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -510,6 +510,13 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, "{(llvm::Constant*)ConstantFP::get(ST->getElementType(0), \"" << rvalue->getValue() << "\"), (llvm::Constant*)ConstantFP::get(ST->getElementType(1), \"" + << ivalue->getValue() << "\")});\n" + << "} else if (auto AT = dyn_cast(ty)) {\n" + << curIndent << INDENT << INDENT + << "ret = ConstantArray::get(AT, " + "{(llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" + << rvalue->getValue() + << "\"), (llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" << ivalue->getValue() << "\")});\n"; os << curIndent << INDENT << "} else assert(0 && \"unhandled cfp\");\n"; os << curIndent << INDENT << "ret;\n"; @@ -2115,14 +2122,20 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, StringRef name = cast(lst->getValues()[0])->getValue(); if (lst->size() >= 2) { auto min = cast(lst->getValues()[1])->getValue(); - int min_int; - min.getAsInteger(10, min_int); + int min_int = 0; + if (min.size() != 0 && min.getAsInteger(10, min_int)) { + PrintFatalError(pattern->getLoc(), + "Could not parse min llvm version as int"); + } if (min.size() != 0 && LLVM_VERSION_MAJOR < min_int) continue; if (lst->size() >= 3) { auto max = cast(lst->getValues()[2])->getValue(); - int max_int; - max.getAsInteger(10, max_int); + int max_int = 0; + if (max.size() != 0 && max.getAsInteger(10, max_int)) { + PrintFatalError(pattern->getLoc(), + "Could not parse max llvm version as int"); + } if (max.size() != 0 && LLVM_VERSION_MAJOR > max_int) continue; } From 20221d266f6058fb24650c2c499f408d114b2766 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:28:17 -0500 Subject: [PATCH 063/131] Enable c++ warning for runtime activity (#1699) --- enzyme/Enzyme/AdjointGenerator.h | 22 ++++++++++------ enzyme/Enzyme/EnzymeLogic.cpp | 22 ++++++++++------ enzyme/Enzyme/GradientUtils.cpp | 44 ++++++++++++++++++++------------ 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ec9cd5c51c9a..ced5c498b6d0 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1051,7 +1051,7 @@ class AdjointGenerator : public llvm::InstVisitor { } Value *diff = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && constantval) { + if (!EnzymeRuntimeActivityCheck && constantval) { if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { if (!isa(orig_val) && !isa(orig_val)) { @@ -1059,9 +1059,12 @@ class AdjointGenerator : public llvm::InstVisitor { raw_string_ostream ss(str); ss << "Mismatched activity for: " << I << " const val: " << *orig_val; - diff = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils, - wrap(orig_val), wrap(&BuilderZ))); + if (CustomErrorHandler) + diff = unwrap(CustomErrorHandler( + str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils, + wrap(orig_val), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", I, ss.str()); } } } @@ -1280,7 +1283,7 @@ class AdjointGenerator : public llvm::InstVisitor { Value *valueop = nullptr; if (constantval) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler) { + if (!EnzymeRuntimeActivityCheck) { if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { if (!isa(orig_val) && !isa(orig_val)) { @@ -1288,9 +1291,12 @@ class AdjointGenerator : public llvm::InstVisitor { raw_string_ostream ss(str); ss << "Mismatched activity for: " << I << " const val: " << *orig_val; - valueop = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, - gutils, wrap(orig_val), wrap(&BuilderZ))); + if (CustomErrorHandler) + valueop = unwrap(CustomErrorHandler( + str.c_str(), wrap(&I), ErrorType::MixedActivityError, + gutils, wrap(orig_val), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", I, ss.str()); } } } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 850a92543db3..fbd078ce23c7 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2548,7 +2548,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( IRBuilder<> BuilderZ(newri); Value *invertri = nullptr; if (gutils->isConstantValue(orig_oldval)) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && gutils->TR.query(orig_oldval)[{-1}].isPossiblePointer()) { if (!isa(orig_oldval) && !isa(orig_oldval)) { @@ -2556,9 +2556,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( raw_string_ostream ss(str); ss << "Mismatched activity for: " << *ri << " const val: " << *orig_oldval; - invertri = unwrap(CustomErrorHandler( - str.c_str(), wrap(ri), ErrorType::MixedActivityError, - gutils, wrap(orig_oldval), wrap(&BuilderZ))); + if (CustomErrorHandler) + invertri = unwrap(CustomErrorHandler( + str.c_str(), wrap(ri), ErrorType::MixedActivityError, + gutils, wrap(orig_oldval), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", *ri, ss.str()); } } } @@ -3083,16 +3086,19 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB, if (!ret->getType()->isFPOrFPVectorTy() && TR.getReturnAnalysis().Inner0().isPossiblePointer()) { if (gutils->isConstantValue(ret)) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(ret)[{-1}].isPossiblePointer()) { if (!isa(ret) && !isa(ret)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *inst << " const val: " << *ret; - invertedPtr = unwrap(CustomErrorHandler( - str.c_str(), wrap(inst), ErrorType::MixedActivityError, gutils, - wrap(ret), wrap(&nBuilder))); + if (CustomErrorHandler) + invertedPtr = unwrap(CustomErrorHandler( + str.c_str(), wrap(inst), ErrorType::MixedActivityError, + gutils, wrap(ret), wrap(&nBuilder))); + else + EmitWarning("MixedActivityError", *inst, ss.str()); } } } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 734d95989442..0fd01c83a05e 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5695,15 +5695,18 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *itval = nullptr; { auto tval = arg->getTrueValue(); - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(arg)[{-1}].isPossiblePointer() && !isa(tval) && !isa(tval) && isConstantValue(tval)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *arg << " const val: " << *tval; - itval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(tval), wrap(&bb))); + if (CustomErrorHandler) + itval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), + ErrorType::MixedActivityError, this, + wrap(tval), wrap(&bb))); + else + EmitWarning("MixedActivityError", *arg, ss.str()); } if (!itval) { itval = invertPointerM(tval, bb, nullShadow); @@ -5712,15 +5715,18 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *ifval = nullptr; { auto fval = arg->getFalseValue(); - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(arg)[{-1}].isPossiblePointer() && !isa(fval) && !isa(fval) && isConstantValue(fval)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *arg << " const val: " << *fval; - ifval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(fval), wrap(&bb))); + if (CustomErrorHandler) + ifval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), + ErrorType::MixedActivityError, this, + wrap(fval), wrap(&bb))); + else + EmitWarning("MixedActivityError", *arg, ss.str()); } if (!ifval) { ifval = invertPointerM(fval, bb, nullShadow); @@ -6064,7 +6070,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *preval = phi->getIncomingValue(j); Value *val = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(phi)[{-1}].isPossiblePointer() && !isa(preval) && !isa(preval) && isConstantValue(preval)) { @@ -6072,9 +6078,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, raw_string_ostream ss(str); ss << "Mismatched activity for: " << *phi << " const val: " << *preval; - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, this, - wrap(preval), wrap(&pre))); + if (CustomErrorHandler) + val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), + ErrorType::MixedActivityError, + this, wrap(preval), wrap(&pre))); + else + EmitWarning("MixedActivityError", *phi, ss.str()); } if (!val) { val = invertPointerM(preval, pre, nullShadow); @@ -6124,7 +6133,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *preval = phi->getIncomingValue(i); Value *val = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(phi)[{-1}].isPossiblePointer() && !isa(preval) && !isa(preval) && isConstantValue(preval)) { @@ -6132,9 +6141,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, raw_string_ostream ss(str); ss << "Mismatched activity for: " << *phi << " const val: " << *preval; - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, this, - wrap(preval), wrap(&pre))); + if (CustomErrorHandler) + val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), + ErrorType::MixedActivityError, + this, wrap(preval), wrap(&pre))); + else + EmitWarning("MixedActivityError", *phi, ss.str()); } if (!val) { val = invertPointerM(preval, pre, nullShadow); From a309cc083f64fb6c17689a81feffa804bc7fcb3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:29:31 -0500 Subject: [PATCH 064/131] Mark memcpy of 1 byte as inactive (#1698) * Mark memcpy of 1 byte as inactive * Update ActivityAnalysis.cpp --- enzyme/Enzyme/ActivityAnalysis.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index e6b5a565c112..459da4928560 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -468,6 +468,13 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { return true; } + // Copies of size 1 are inactive [cannot move differentiable data in one byte] + if (auto MTI = dyn_cast(&CI)) { + if (auto sz = dyn_cast(MTI->getOperand(2))) { + if (sz->getValue() == 1) + return true; + } + } return false; } From c356fbe124a5f916ee70964d71b67036f248822d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:51:01 -0500 Subject: [PATCH 065/131] Improve nofree demangle support (#1696) * Improve nofree demangle support * fix * fix * fix --- enzyme/Enzyme/ActivityAnalysis.cpp | 15 ++- enzyme/Enzyme/AdjointGenerator.h | 4 +- enzyme/Enzyme/CallDerivatives.cpp | 3 +- enzyme/Enzyme/EnzymeLogic.cpp | 105 +++++++++++++++--- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 6 +- .../test/Integration/ReverseMode/blas_gemm2.c | 3 - .../Integration/ReverseMode/customlog1p.c | 4 - .../test/Integration/ReverseMode/inactivefn.c | 4 - enzyme/test/Integration/ReverseMode/invsqrt.c | 9 +- enzyme/test/Integration/ReverseMode/mycos.c | 18 ++- enzyme/test/Integration/ReverseMode/omp.c | 4 - enzyme/test/Integration/ReverseMode/omp2.c | 4 - enzyme/test/Integration/ReverseMode/omp3.c | 2 - enzyme/test/Integration/ReverseMode/omp6.c | 4 - enzyme/test/Integration/ReverseMode/omp_two.c | 4 - .../test/Integration/ReverseMode/ompbound.c | 4 - enzyme/test/Integration/test_utils.h | 4 +- 17 files changed, 121 insertions(+), 76 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 459da4928560..ef6c7a14f811 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -115,9 +115,11 @@ static const StringSet<> InactiveGlobals = { "_ZSt3cin", "_ZSt4cout", "_ZNSt3__14coutE", + "_ZNSt3__15wcoutE", "_ZNSt3__113basic_ostreamIcNS_11char_traitsIcEEE6sentryC1ERS3_", "_ZSt5wcout", "_ZSt4cerr", + "_ZNSt3__14cerrE", "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", "_ZTVSt15basic_streambufIcSt11char_traitsIcEE", "_ZTVSt9basic_iosIcSt11char_traitsIcEE", @@ -284,13 +286,17 @@ const StringSet<> KnownInactiveFunctions = { "cuDevicePrimaryCtxRetain", "floor", "floorf", - "floorl" + "floorl", + "\01_fopen", + "fopen", + "fclose", }; const std::set KnownInactiveIntrinsics = { #if LLVM_VERSION_MAJOR >= 12 Intrinsic::experimental_noalias_scope_decl, #endif + Intrinsic::objectsize, Intrinsic::floor, Intrinsic::ceil, Intrinsic::trunc, @@ -407,9 +413,16 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { "std::__1::shuffle_order_engine", "std::__1::basic_streambuf", "std::__1::basic_stringbuf", + "std::__1::basic_istream", + "std::__1::basic_filebuf", + "std::__1::basic_iostream", + "std::__1::basic_ios", + "virtual thunk to std::__1::basic_istream", + "virtual thunk to std::__1::basic_ostream", "std::__detail::_Prime_rehash_policy", "std::__detail::_Hash_code_base", + }; // clang-format on diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ced5c498b6d0..60902eafd2a1 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4400,7 +4400,6 @@ class AdjointGenerator : public llvm::InstVisitor { } } } - size_t freeCount = 0; for (auto LI : geps) { CallInst *freeCall = nullptr; for (auto LU : LI->users()) { @@ -4428,7 +4427,6 @@ class AdjointGenerator : public llvm::InstVisitor { } if (freeCall) { freeCall->eraseFromParent(); - freeCount++; } } } @@ -4599,7 +4597,7 @@ class AdjointGenerator : public llvm::InstVisitor { EmitFailure("CannotDeduceType", call.getDebugLoc(), &call, "failed to deduce type of copy ", call); } -#if LLVM_VERSION_MAJOR < 18 +#if LLVM_VERSION_MAJOR < 17 knownF: #endif unsigned start = 0; diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index c8b2cdaa4e4b..155967e24a33 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3572,7 +3572,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( ConstantInt::getFalse(call.getContext())); return true; } - if (funcName == "memset" || funcName == "memset_pattern16") { + if (funcName == "memset" || funcName == "memset_pattern16" || + funcName == "__memset_chk") { visitMemSetCommon(call); return true; } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index fbd078ce23c7..bc4856566909 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2336,7 +2336,18 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( if (todiff->empty()) { std::string s; llvm::raw_string_ostream ss(s); - ss << "No augmented forward pass found for " + todiff->getName() << "\n"; + ss << "No augmented forward pass found for " + todiff->getName(); + { + std::string demangledName = llvm::demangle(todiff->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); + } + if (demangledName != todiff->getName()) + ss << "(" << demangledName << ")"; + } + ss << "\n"; llvm::Value *toshow = todiff; if (context.req) { toshow = context.req; @@ -5889,37 +5900,47 @@ llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context, cast(CreateNoFree(context, castinst->getOperand(0)))}; return castinst->getWithOperands(reps); } - if (CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of unknown value\n"; - ss << *todiff << "\n"; - if (context.req) { - ss << " at context: " << *context.req; + if (EnzymeAssumeUnknownNoFree) { + return todiff; + } + + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No create nofree of unknown value\n"; + ss << *todiff << "\n"; + if (context.req) { + ss << " at context: " << *context.req; + } + if (auto I = dyn_cast(todiff)) { + auto fname = I->getParent()->getParent()->getName(); + if (startsWith(fname, "nofree_")) + fname = fname.substr(7); + std::string demangledName = llvm::demangle(fname.str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); } + ss << " within func " << fname << " (" << demangledName << ")\n"; + } + if (CustomErrorHandler) { CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); return todiff; } - if (EnzymeAssumeUnknownNoFree) { - return todiff; - } - if (context.req) { - EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, - "Cannot create nofree of instruction-created value: ", *todiff); + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, s); return todiff; } if (auto arg = dyn_cast(todiff)) { auto loc = arg->getDebugLoc(); - EmitFailure("IllegalNoFree", loc, arg, - "Cannot create nofree of instruction-created value: ", *todiff); + EmitFailure("IllegalNoFree", loc, arg, s); return todiff; } - llvm::errs() << " unhandled, create no free of: " << *todiff << "\n"; + llvm::errs() << s; llvm_unreachable("unhandled, create no free"); } @@ -6001,6 +6022,7 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { "std::__1::basic_string, std::__1::allocator>::data() const", "std::__1::basic_ostream>::sentry::sentry(std::__1::basic_ostream>&)", "std::__1::basic_ostream>::sentry::~sentry()", + "std::__1::basic_ostream>::flush()", "std::__1::ios_base::__set_badbit_and_consider_rethrow()", "char* std::__1::addressof(char&)", "char const* std::__1::addressof(char const&)", @@ -6011,6 +6033,40 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { "std::__1::ios_base::ios_base()", "std::__1::ios_base::getloc() const", "std::__1::ios_base::clear(unsigned int)", + "std::__1::basic_iostream>::~basic_iostream()", + "std::__1::basic_ios>::~basic_ios()", + "std::__1::basic_streambuf>::basic_streambuf()", + "std::__1::basic_streambuf>::~basic_streambuf()", + "std::__1::basic_streambuf>::imbue(std::__1::locale const&)", + "std::__1::basic_streambuf>::setbuf(char*, long)", + "std::__1::basic_streambuf>::sync()", + "std::__1::basic_streambuf>::showmanyc()", + "std::__1::basic_streambuf>::xsgetn(char*, long)", + "std::__1::basic_streambuf>::uflow()", + "std::__1::basic_filebuf>::basic_filebuf()", + "std::__1::basic_filebuf>::~basic_filebuf()", + "std::__1::basic_filebuf>::open(char const*, unsigned int)", + "std::__1::basic_filebuf>::close()", + "std::__1::basic_filebuf>::sync()", + "std::__1::basic_istream>::~basic_istream()", + "virtual thunk to std::__1::basic_istream>::~basic_istream()", + "virtual thunk to std::__1::basic_ostream>::~basic_ostream()", + "std::__1::basic_ifstream>::~basic_ifstream()", + "std::__1::ios_base::init(void*)", + "std::__1::basic_istream>::read(char*, long)", + "std::__1::basic_ostream>::~basic_ostream()", + "std::__1::basic_string, std::__1::allocator>::__init(unsigned long, char)", + "std::__1::basic_ostream>::write(char const*, long)", + }; + const char* NoFreeDemanglesStartsWith[] = { + "std::__1::basic_ostream>::operator<<", + "std::__1::ios_base::imbue", + "std::__1::basic_streambuf>::pubimbue", + "std::__1::basic_stringbuf, std::__1::allocator>::__init_buf_ptrs", + "std::__1::basic_stringbuf, std::__1::allocator>::basic_stringbuf", + "std::__1::basic_string, std::__1::allocator>::operator=", + "std::__1::ctype::widen", + "std::__1::basic_streambuf>::sputn", }; // clang-format on @@ -6032,6 +6088,10 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (NoFreeDemangles.count(demangledName)) return F; + for (auto Name : NoFreeDemanglesStartsWith) + if (startsWith(demangledName, Name)) + return F; + switch (F->getIntrinsicID()) { case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: @@ -6055,6 +6115,17 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { << F->getName() << ")\n"; if (context.req) { ss << " at context: " << *context.req; + if (auto CB = dyn_cast(context.req)) { + if (auto F = CB->getCalledFunction()) { + std::string demangleF = llvm::demangle(F->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangleF.find("> >", start)) != std::string::npos) { + demangleF.replace(start, 3, ">>"); + } + ss << " (" << demangleF << ")"; + } + } } else { ss << *F << "\n"; } diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index d5bf445ee429..c0b54a4eccbb 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4507,8 +4507,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { return; } - if (startsWith(funcName, "_ZNKSt3__112basic_stringIcNS_11char_traitsIcEENS_" - "9allocatorIcEEE13__get_pointer")) { + if (startsWith(funcName, "_ZNKSt3__112basic_string") || + startsWith(funcName, "_ZNSt3__112basic_string") || + startsWith(funcName, "_ZNSt3__112__hash_table") || + startsWith(funcName, "_ZNKSt3__115basic_stringbuf")) { return; } diff --git a/enzyme/test/Integration/ReverseMode/blas_gemm2.c b/enzyme/test/Integration/ReverseMode/blas_gemm2.c index a417e0418cd9..b06474282634 100644 --- a/enzyme/test/Integration/ReverseMode/blas_gemm2.c +++ b/enzyme/test/Integration/ReverseMode/blas_gemm2.c @@ -8,9 +8,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi -#include -#include -#include #include "../test_utils.h" #include "../blas_inline.h" diff --git a/enzyme/test/Integration/ReverseMode/customlog1p.c b/enzyme/test/Integration/ReverseMode/customlog1p.c index 95c0697fff77..6f563030f9b3 100644 --- a/enzyme/test/Integration/ReverseMode/customlog1p.c +++ b/enzyme/test/Integration/ReverseMode/customlog1p.c @@ -17,10 +17,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/inactivefn.c b/enzyme/test/Integration/ReverseMode/inactivefn.c index e48edd3486ed..bbd6696a04d9 100644 --- a/enzyme/test/Integration/ReverseMode/inactivefn.c +++ b/enzyme/test/Integration/ReverseMode/inactivefn.c @@ -17,10 +17,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/invsqrt.c b/enzyme/test/Integration/ReverseMode/invsqrt.c index c14ce79ab88d..09ada6757e45 100644 --- a/enzyme/test/Integration/ReverseMode/invsqrt.c +++ b/enzyme/test/Integration/ReverseMode/invsqrt.c @@ -17,13 +17,8 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include -#include -#include - #include "../test_utils.h" +#include // Fast inverse sqrt // Code taken from https://en.wikipedia.org/wiki/Fast_inverse_square_root @@ -74,10 +69,8 @@ int main(int argc, char *argv[]) { double *A = (double*)malloc(sizeof(double) * n); for(int i=0; i -#include -#include - #include "../test_utils.h" +double pow(double, double); + __attribute__((noinline)) -uint64_t factorial(uint64_t x) { +unsigned long long factorial(unsigned long long x) { if (x == 0) return 1; return x * factorial(x-1); } double my_sin(double x) { double result = 0; - uint64_t N = 12; - for(uint64_t i=0; i<=N; i++) { + unsigned long long N = 12; + for(unsigned long long i=0; i<=N; i++) { if (i % 2 == 0) continue; result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1); } @@ -38,14 +36,14 @@ double my_sin(double x) { } -uint64_t __enzyme_iter(uint64_t, uint64_t); +unsigned long long __enzyme_iter(unsigned long long, unsigned long long); double __enzyme_autodiff(void*, double); double my_sin2(double x) { double result = 0; - uint64_t N = __enzyme_iter(12, 1); - for(uint64_t i=0; i<=N; i++) { + unsigned long long N = __enzyme_iter(12, 1); + for(unsigned long long i=0; i<=N; i++) { if (i % 2 == 0) continue; result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1); } diff --git a/enzyme/test/Integration/ReverseMode/omp.c b/enzyme/test/Integration/ReverseMode/omp.c index 4d6f0bd164c5..5d4f8ae09970 100644 --- a/enzyme/test/Integration/ReverseMode/omp.c +++ b/enzyme/test/Integration/ReverseMode/omp.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp2.c b/enzyme/test/Integration/ReverseMode/omp2.c index 8cdd06545fd1..44f315fe28f9 100644 --- a/enzyme/test/Integration/ReverseMode/omp2.c +++ b/enzyme/test/Integration/ReverseMode/omp2.c @@ -8,10 +8,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp3.c b/enzyme/test/Integration/ReverseMode/omp3.c index 59bdca34cf9a..970348a7cedd 100644 --- a/enzyme/test/Integration/ReverseMode/omp3.c +++ b/enzyme/test/Integration/ReverseMode/omp3.c @@ -9,8 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -# include -# include #include "../test_utils.h" void msg(double* in, int *len, unsigned int slen) { diff --git a/enzyme/test/Integration/ReverseMode/omp6.c b/enzyme/test/Integration/ReverseMode/omp6.c index 8a6f34fe235f..abd170d07dfd 100644 --- a/enzyme/test/Integration/ReverseMode/omp6.c +++ b/enzyme/test/Integration/ReverseMode/omp6.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -# include -# include -#include - #include "../test_utils.h" __attribute__((noinline)) diff --git a/enzyme/test/Integration/ReverseMode/omp_two.c b/enzyme/test/Integration/ReverseMode/omp_two.c index 1f95e93f2520..8d3dac92ca75 100644 --- a/enzyme/test/Integration/ReverseMode/omp_two.c +++ b/enzyme/test/Integration/ReverseMode/omp_two.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/ompbound.c b/enzyme/test/Integration/ReverseMode/ompbound.c index 4d6f0bd164c5..5d4f8ae09970 100644 --- a/enzyme/test/Integration/ReverseMode/ompbound.c +++ b/enzyme/test/Integration/ReverseMode/ompbound.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/test_utils.h b/enzyme/test/Integration/test_utils.h index 881aa5f8f703..268922598f54 100644 --- a/enzyme/test/Integration/test_utils.h +++ b/enzyme/test/Integration/test_utils.h @@ -1,9 +1,11 @@ -#ifdef __cplusplus +#if defined(__cplusplus) || defined(__APPLE__) #include #include #include +#include +#include #else struct _IO_FILE; extern struct _IO_FILE* stderr; From 1285a39808c276ddedbe9a6fc491e7eeda64c488 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:51:25 -0500 Subject: [PATCH 066/131] Ensure custom derivative functions aren't deleted (#1697) --- enzyme/Enzyme/EnzymeLogic.cpp | 22 ++++++++++ enzyme/Enzyme/PreserveNVVM.cpp | 40 +++++++++++-------- .../test/Enzyme/ReverseMode/custom-sret3.ll | 2 +- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index bc4856566909..ebdbb937733d 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4490,7 +4490,29 @@ Function *EnzymeLogic::CreateForwardDiff( "unknown derivative for function -- metadata incorrect"); } auto md2 = cast(md); + assert(md2); assert(md2->getNumOperands() == 1); + if (!md2->getOperand(0)) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Failed to use custom forward mode derivative for " + << todiff->getName() << "\n"; + ss << " found metadata (but null op0) " << *md2 << "\n"; + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + return ForwardCachedFunctions[tup] = nullptr; + } + if (!isa(md2->getOperand(0))) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Failed to use custom forward mode derivative for " + << todiff->getName() << "\n"; + ss << " found metadata (but not constantasmetadata) " + << *md2->getOperand(0) << "\n"; + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + return ForwardCachedFunctions[tup] = nullptr; + } auto gvemd = cast(md2->getOperand(0)); auto foundcalled = cast(gvemd->getValue()); diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index 73d13a009abb..a7df39596f0a 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -56,6 +56,25 @@ using namespace llvm; #define addAttribute addAttributeAtIndex #endif +//! Returns whether changed. +bool preserveLinkage(bool Begin, Function &F, bool Inlining = true) { + if (Begin && !F.hasFnAttribute("prev_fixup")) { + F.addFnAttr("prev_fixup"); + if (Inlining) { + if (F.hasFnAttribute(Attribute::AlwaysInline)) + F.addFnAttr("prev_always_inline"); + F.removeFnAttr(Attribute::AlwaysInline); + if (F.hasFnAttribute(Attribute::NoInline)) + F.addFnAttr("prev_no_inline"); + F.addFnAttr(Attribute::NoInline); + } + F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); + F.setLinkage(Function::LinkageTypes::ExternalLinkage); + return true; + } + return false; +} + template static void handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, @@ -237,26 +256,31 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, Fs[fn] = NewF; } + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_augment", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); + preserveLinkage(true, *Fs[2], false); Fs[0]->setMetadata( "enzyme_gradient", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[2])})); } else if (Mode == DerivativeMode::ForwardMode) { assert(numargs == 2); + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_derivative", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); } else if (Mode == DerivativeMode::ForwardModeSplit) { assert(numargs == 3); + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_augment", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); + preserveLinkage(true, *Fs[2], false); Fs[0]->setMetadata( "enzyme_splitderivative", llvm::MDTuple::get(Fs[0]->getContext(), @@ -282,22 +306,6 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, } globalsToErase.push_back(&g); } -//! Returns whether changed. -bool preserveLinkage(bool Begin, Function &F) { - if (Begin && !F.hasFnAttribute("prev_fixup")) { - F.addFnAttr("prev_fixup"); - if (F.hasFnAttribute(Attribute::AlwaysInline)) - F.addFnAttr("prev_always_inline"); - if (F.hasFnAttribute(Attribute::NoInline)) - F.addFnAttr("prev_no_inline"); - F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); - F.setLinkage(Function::LinkageTypes::ExternalLinkage); - F.addFnAttr(Attribute::NoInline); - F.removeFnAttr(Attribute::AlwaysInline); - return true; - } - return false; -} bool preserveNVVM(bool Begin, Function &F) { bool changed = false; diff --git a/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll b/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll index 195470291e2a..da82145bdbd8 100644 --- a/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll +++ b/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll @@ -118,7 +118,7 @@ attributes #4 = { nounwind } !10 = !{!11, !11, i64 0} !11 = !{!"int", !5, i64 0} -; CHECK: define internal void @fixbyval_myblas_cdot_rev(%struct.complex* %arg0, %struct.complex* %arg1, %struct.complex* %arg2, %struct.complex* %arg3, i32 %arg4, i32 %arg5, %struct.complex %arg6, i8* %arg7) +; CHECK: define dso_local void @fixbyval_myblas_cdot_rev(%struct.complex* %arg0, %struct.complex* %arg1, %struct.complex* %arg2, %struct.complex* %arg3, i32 %arg4, i32 %arg5, %struct.complex %arg6, i8* %arg7) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = alloca %struct.complex ; CHECK-NEXT: store %struct.complex %arg6, %struct.complex* %0 From 5e453755a1affc6da3d40c4ece9dcebfe4a319ab Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 22:22:57 -0500 Subject: [PATCH 067/131] Fix memcpy of size 1 [part 2] (#1701) --- enzyme/Enzyme/ActivityAnalysis.cpp | 19 ++++++++----------- enzyme/Enzyme/AdjointGenerator.h | 8 ++++++++ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index ef6c7a14f811..79a56d885507 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -432,16 +432,16 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { if (auto iasm = dyn_cast(CI.getCalledOperand())) { if (StringRef(iasm->getAsmString()).contains("exit") || StringRef(iasm->getAsmString()).contains("cpuid")) - return false; + return true; } - auto F = getFunctionFromCall(&CI); - - if (F == nullptr) - return false; - - if (F->hasFnAttribute("enzyme_inactive")) { - return true; + if (auto F = getFunctionFromCall(&CI)) { + if (F->hasFnAttribute("enzyme_inactive")) { + return true; + } + if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { + return true; + } } auto Name = getFuncNameFromCall(&CI); @@ -478,9 +478,6 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { return true; } - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return true; - } // Copies of size 1 are inactive [cannot move differentiable data in one byte] if (auto MTI = dyn_cast(&CI)) { if (auto sz = dyn_cast(MTI->getOperand(2))) { diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 60902eafd2a1..c94f03b4bb0b 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -3275,6 +3275,14 @@ class AdjointGenerator : public llvm::InstVisitor { return; } + // memcpy of size 1 cannot move differentiable data [single byte copy] + if (auto ci = dyn_cast(new_size)) { + if (ci->getValue() == 1) { + eraseIfUnused(MTI); + return; + } + } + // copying into nullptr is invalid (not sure why it exists here), but we // shouldn't do it in reverse pass or shadow if (isa(orig_dst) || From 89d20faf5beac95fc319d82d75dc9d39ce22ac3f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 22:51:30 -0500 Subject: [PATCH 068/131] Fix asan error and add mlir alloc/free fns (#1700) * Fix asan error and add mlir alloc/free fns * fix --- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 28 ++--- enzyme/Enzyme/GradientUtils.cpp | 5 +- enzyme/Enzyme/LibraryFuncs.h | 4 + enzyme/test/Enzyme/ReverseMode/mlirmincut.ll | 107 +++++++++++++++++++ 4 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/mlirmincut.ll diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index d20ec4eefbc4..f4eb07a01602 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -849,7 +849,9 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, assert(pair.first.outgoing == 0 && N.outgoing == 1); assert(pair.first.V == N.V); MinReq.insert(N.V); - todo.push_back(N.V); + if (Orig.find(Node(N.V, true)) != Orig.end()) { + todo.push_back(N.V); + } } } } @@ -860,20 +862,20 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, auto V = todo.front(); todo.pop_front(); auto found = Orig.find(Node(V, true)); - if (found->second.size() == 1 && !Required.count(V)) { + assert(found != Orig.end()); + const auto &mp = found->second; + if (mp.size() == 1 && !Required.count(V)) { bool potentiallyRecursive = - isa((*found->second.begin()).V) && - OrigLI.isLoopHeader( - cast((*found->second.begin()).V)->getParent()); + isa((*mp.begin()).V) && + OrigLI.isLoopHeader(cast((*mp.begin()).V)->getParent()); int moreOuterLoop = cmpLoopNest( OrigLI.getLoopFor(cast(V)->getParent()), - OrigLI.getLoopFor( - cast(((*found->second.begin()).V))->getParent())); + OrigLI.getLoopFor(cast(((*mp.begin()).V))->getParent())); if (potentiallyRecursive) continue; if (moreOuterLoop == -1) continue; - if (auto ASC = dyn_cast((*found->second.begin()).V)) { + if (auto ASC = dyn_cast((*mp.begin()).V)) { if (ASC->getDestAddressSpace() == 11 || ASC->getDestAddressSpace() == 13) continue; @@ -882,7 +884,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, } // If an allocation call, we cannot cache any "capturing" users if (isAllocationCall(V, TLI)) { - auto next = (*found->second.begin()).V; + auto next = (*mp.begin()).V; bool noncapture = false; if (isa(next)) { noncapture = true; @@ -918,10 +920,12 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, if (moreOuterLoop == 1 || (moreOuterLoop == 0 && DL.getTypeSizeInBits(V->getType()) >= - DL.getTypeSizeInBits((*found->second.begin()).V->getType()))) { + DL.getTypeSizeInBits((*mp.begin()).V->getType()))) { MinReq.remove(V); - MinReq.insert((*found->second.begin()).V); - todo.push_back((*found->second.begin()).V); + auto nnode = (*mp.begin()).V; + MinReq.insert(nnode); + if (Orig.find(Node(nnode, true)) != Orig.end()) + todo.push_back(nnode); } } } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 0fd01c83a05e..b1d8161576f2 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9162,7 +9162,8 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, tofree = builder.CreateIntToPtr(tofree, getInt8PtrTy(tofree->getContext())); llvm::LibFunc libfunc; - if (allocationfn == "calloc" || allocationfn == "malloc") { + if (allocationfn == "calloc" || allocationfn == "malloc" || + allocationfn == "_mlir_memref_to_llvm_alloc") { libfunc = LibFunc_malloc; } else { bool res = TLI.getLibFunc(allocationfn, libfunc); @@ -9232,6 +9233,8 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, if (freename != "free") llvm_unreachable("illegal free"); } + if (allocationfn == "_mlir_memref_to_llvm_alloc") + freename = "_mlir_memref_to_llvm_free"; Type *VoidTy = Type::getVoidTy(tofree->getContext()); Type *IntPtrTy = getInt8PtrTy(tofree->getContext()); diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index 1a8ebce38eee..d18ecc346802 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -49,6 +49,8 @@ static inline bool isAllocationFunction(const llvm::StringRef name, return true; if (name == "calloc" || name == "malloc") return true; + if (name == "_mlir_memref_to_llvm_alloc") + return true; if (name == "swift_allocObject") return true; if (name == "__rust_alloc" || name == "__rust_alloc_zeroed") @@ -123,6 +125,8 @@ static inline bool isDeallocationFunction(const llvm::StringRef name, if (!TLI.getLibFunc(name, libfunc)) { if (name == "free") return true; + if (name == "_mlir_memref_to_llvm_free") + return true; if (name == "__rust_dealloc") return true; if (name == "swift_release") diff --git a/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll new file mode 100644 index 000000000000..6f214ddb2cdc --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll @@ -0,0 +1,107 @@ +; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi + +declare void @__enzyme_autodiff0(...) local_unnamed_addr + +declare void @_mlir_memref_to_llvm_free(ptr) + +declare ptr @_mlir_memref_to_llvm_alloc(i64) + +define void @jit_compiled(ptr %a) { + tail call void (...) @__enzyme_autodiff0(ptr nonnull @f, metadata !"enzyme_const", ptr %a, ptr %a, ptr %a, i64 0, metadata !"enzyme_const", ptr %a, metadata !"enzyme_const", ptr %a, i64 0, metadata !"enzyme_const", i64 1, metadata !"enzyme_const", ptr %a, metadata !"enzyme_dupnoneed", ptr %a, ptr %a, i64 0) + ret void +} + +define void @f(ptr %arg, ptr %arg1, i64 %arg2, ptr %arg3, ptr %arg4, i64 %arg5, i64 %arg6, ptr nocapture readnone %arg7, ptr nocapture writeonly %arg8, i64 %arg9) { + %.idx = shl i64 %arg6, 3 + %i = tail call ptr @_mlir_memref_to_llvm_alloc(i64 %.idx) + %i10 = load double, ptr %arg4, align 8 + %i11 = fcmp ogt double %i10, 1.500000e+00 + %i12 = load double, ptr %arg1, align 8 + %i13 = fmul double %i12, 2.000000e+00 + %storemerge = select i1 %i11, double %i13, double %i12 + store double %storemerge, ptr %i, align 8 + %i14 = alloca { ptr, ptr, i64 }, align 8 + store ptr %arg, ptr %i14, align 8 + %.repack2 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 1 + store ptr %arg1, ptr %.repack2, align 8 + %.repack4 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 2 + store i64 %arg2, ptr %.repack4, align 8 + %i15 = alloca { ptr, ptr, i64 }, align 8 + store ptr %arg3, ptr %i15, align 8 + %.repack6 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 1 + store ptr %arg4, ptr %.repack6, align 8 + %.repack8 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 2 + store i64 %arg5, ptr %.repack8, align 8 + %i16 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8 + store ptr %i, ptr %i16, align 8 + %.repack10 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 1 + store ptr %i, ptr %.repack10, align 8 + %.repack12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 2 + store i64 0, ptr %.repack12, align 8 + %.repack14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 3 + store i64 %arg6, ptr %.repack14, align 8 + %.repack16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 4 + store i64 1, ptr %.repack16, align 8 + %i17 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) + %i18 = alloca { ptr, ptr, i64 }, align 8 + store ptr %i17, ptr %i18, align 8 + %.repack18 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 1 + store ptr %i17, ptr %.repack18, align 8 + %.repack20 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 2 + store i64 0, ptr %.repack20, align 8 + %i19 = load double, ptr %i17, align 8 + store double %i19, ptr %arg8, align 8 + ret void +} + +; CHECK: define internal void @diffef(ptr %arg, ptr %arg1, ptr %"arg1'", i64 %arg2, ptr %arg3, ptr %arg4, i64 %arg5, i64 %arg6, ptr nocapture readnone %arg7, ptr nocapture writeonly %arg8, ptr nocapture %"arg8'", i64 %arg9) +; CHECK-NEXT: invert: +; CHECK-NEXT: %.idx = shl i64 %arg6, 3 +; CHECK-NEXT: %i = tail call ptr @_mlir_memref_to_llvm_alloc(i64 %.idx) +; CHECK-NEXT: %i10 = load double, ptr %arg4, align 8 +; CHECK-NEXT: %i11 = fcmp ogt double %i10, 1.500000e+00 +; CHECK-NEXT: %i12 = load double, ptr %arg1, align 8 +; CHECK-NEXT: %i13 = fmul double %i12, 2.000000e+00 +; CHECK-NEXT: %storemerge = select i1 %i11, double %i13, double %i12 +; CHECK-NEXT: store double %storemerge, ptr %i, align 8 +; CHECK-NEXT: %i14 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %arg, ptr %i14, align 8 +; CHECK-NEXT: %.repack2 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 1 +; CHECK-NEXT: store ptr %arg1, ptr %.repack2, align 8 +; CHECK-NEXT: %.repack4 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 2 +; CHECK-NEXT: store i64 %arg2, ptr %.repack4, align 8 +; CHECK-NEXT: %i15 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %arg3, ptr %i15, align 8 +; CHECK-NEXT: %.repack6 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 1 +; CHECK-NEXT: store ptr %arg4, ptr %.repack6, align 8 +; CHECK-NEXT: %.repack8 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 2 +; CHECK-NEXT: store i64 %arg5, ptr %.repack8, align 8 +; CHECK-NEXT: %i16 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8 +; CHECK-NEXT: store ptr %i, ptr %i16, align 8 +; CHECK-NEXT: %.repack10 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 1 +; CHECK-NEXT: store ptr %i, ptr %.repack10, align 8 +; CHECK-NEXT: %.repack12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 2 +; CHECK-NEXT: store i64 0, ptr %.repack12, align 8 +; CHECK-NEXT: %.repack14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 3 +; CHECK-NEXT: store i64 %arg6, ptr %.repack14, align 8 +; CHECK-NEXT: %.repack16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 4 +; CHECK-NEXT: store i64 1, ptr %.repack16, align 8 +; CHECK-NEXT: %"i17'mi" = tail call noalias nonnull ptr @_mlir_memref_to_llvm_alloc(i64 8) +; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr nonnull dereferenceable(8) dereferenceable_or_null(8) %"i17'mi", i8 0, i64 8, i1 false) +; CHECK-NEXT: %i17 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) +; CHECK-NEXT: %i18 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %i17, ptr %i18, align 8 +; CHECK-NEXT: %.repack18 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 1 +; CHECK-NEXT: store ptr %i17, ptr %.repack18, align 8 +; CHECK-NEXT: %.repack20 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 2 +; CHECK-NEXT: store i64 0, ptr %.repack20, align 8 +; CHECK-NEXT: %0 = load double, ptr %"arg8'", align 8 +; CHECK-NEXT: store double 0.000000e+00, ptr %"arg8'", align 8 +; CHECK-NEXT: %1 = load double, ptr %"i17'mi", align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %0 +; CHECK-NEXT: store double %2, ptr %"i17'mi", align 8 +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr nonnull %"i17'mi") +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr %i17) +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr %i) +; CHECK-NEXT: ret void +; CHECK-NEXT: } From 32c3786ceff49a285c6e1d7e548687b0203fd57c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 23:30:54 -0500 Subject: [PATCH 069/131] Erase atomic on err (#1702) --- enzyme/Enzyme/AdjointGenerator.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index c94f03b4bb0b..ae2e72aceb83 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -898,6 +898,9 @@ class AdjointGenerator : public llvm::InstVisitor { setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), BuilderZ); } + gutils->replaceAWithB(gutils->getNewFromOriginal(&I), + UndefValue::get(I.getType())); + eraseIfUnused(I, /*erase*/ true, /*check*/ false); return; } From 012da6b546597bec8f345129857939e156564c0e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 23:48:40 -0500 Subject: [PATCH 070/131] Preserve parmtype (#1703) --- enzyme/Enzyme/FunctionUtils.cpp | 92 +++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 71ceb9688eea..6bc443276b4b 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2194,10 +2194,14 @@ Function *PreProcessCache::CloneFunctionWithReturns( // Attribute::ElementType)); #endif } - for (auto ty : PrimalParamAttrsToPreserve) - if (F->getAttributes().hasParamAttr(ii, ty)) { - auto attr = F->getAttributes().getParamAttr(ii, ty); - NewF->addParamAttr(jj, attr); + for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (F->getAttributes().hasParamAttr(ii, attr)) { + NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr)); + for (auto ty : PrimalParamAttrsToPreserve) + if (F->getAttributes().hasParamAttr(ii, ty)) { + auto attr = F->getAttributes().getParamAttr(ii, ty); + NewF->addParamAttr(jj, attr); + } } if (constant_args[ii] == DIFFE_TYPE::CONSTANT) { if (!i->hasByValAttr()) @@ -2212,8 +2216,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( << " nonconstant arg " << *j << "\n"; } - // Always remove nonnull/noundef since the caller may choose to pass undef - // as an arg if provably it will not be used in the reverse pass + // Always remove nonnull/noundef since the caller may choose to pass + // undef as an arg if provably it will not be used in the reverse pass if (constant_args[ii] == DIFFE_TYPE::DUP_NONEED || mode == DerivativeMode::ReverseModeGradient) { if (F->hasParamAttribute(ii, Attribute::NonNull)) { @@ -2236,6 +2240,13 @@ Function *PreProcessCache::CloneFunctionWithReturns( NewF->addParamAttr(jj + 1, attr); } + for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (F->getAttributes().hasParamAttr(ii, attr)) { + if (width == 1) + NewF->addParamAttr(jj + 1, + F->getAttributes().getParamAttr(ii, attr)); + } + if (F->getAttributes().hasParamAttr(ii, "enzymejl_returnRoots")) { if (width == 1) { NewF->addParamAttr(jj + 1, F->getAttributes().getParamAttr( @@ -2247,7 +2258,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( #if LLVM_VERSION_MAJOR >= 13 // TODO // NewF->addParamAttr(jj + 1, - // F->getParamAttribute(ii, Attribute::ElementType)); + // F->getParamAttribute(ii, + // Attribute::ElementType)); #endif } @@ -2266,7 +2278,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // jj + 1, // Attribute::get(F->getContext(), // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, Attribute::StructRet) + // F->getParamAttribute(ii, + // Attribute::StructRet) // .getValueAsType())); #endif } else { @@ -2283,7 +2296,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // jj + 1, // Attribute::get(F->getContext(), // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, Attribute::StructRet) + // F->getParamAttribute(ii, + // Attribute::StructRet) // .getValueAsType())); #endif } @@ -3654,14 +3668,13 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } /* - // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 != - c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; j++) - if (auto c0 = dyn_cast(cur->getOperand(j))) - if (auto cmp0 = dyn_cast(c0->getOperand(0))) - if (auto c1 = dyn_cast(cur->getOperand(1-j))) - if (auto cmp1 = dyn_cast(c0->getOperand(0))) - if (cmp0->getPredicate() == ICmpInst::ICMP_EQ && - cmp1->getPredicate() == ICmpInst::ICMP_EQ) + // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 + != c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; + j++) if (auto c0 = dyn_cast(cur->getOperand(j))) if (auto cmp0 = + dyn_cast(c0->getOperand(0))) if (auto c1 = + dyn_cast(cur->getOperand(1-j))) if (auto cmp1 = + dyn_cast(c0->getOperand(0))) if (cmp0->getPredicate() == + ICmpInst::ICMP_EQ && cmp1->getPredicate() == ICmpInst::ICMP_EQ) { for (size_t i0 = 0; i0 < 2; i0++) for (size_t i1 = 0; i1 < 2; i1++) @@ -3694,7 +3707,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (auto C = dyn_cast(fcmp->getOperand(i))) { if (C->isZero()) { // (a1*a2*...an) == 0 -> (a1 == 0) || (a2 == 0) || ... (a2 == 0) - // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == 0) + // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == + // 0) // ] if (auto P = isProduct(fcmp->getOperand(1 - i))) { Value *res = nullptr; @@ -3885,8 +3899,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, cur->isExact()) if (auto C2 = dyn_cast(cur->getOperand(1))) if (auto mul = dyn_cast(cur->getOperand(0))) { - // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 - // divides C1 + // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if + // C2 divides C1 if (mul->getOpcode() == Instruction::Mul) for (int i0 = 0; i0 < 2; i0++) if (auto C1 = dyn_cast(mul->getOperand(i0))) { @@ -3913,7 +3927,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return "IMulDivConst"; } } - // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if C2 + // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if + // C2 if (mul->getOpcode() == Instruction::Add) for (int i0 = 0; i0 < 2; i0++) if (auto C1 = dyn_cast(mul->getOperand(i0))) { @@ -4123,8 +4138,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, // (a * b) != (c * b) -> (a != c) && b != 0 // auto S1 = SE.getSCEV(cur->getOperand(0)); // auto S2 = SE.getSCEV(cur->getOperand(1)); - // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << " S2: - // " << *S2 << " and " << *cur->getOperand(0) << " " << + // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << " + // S2: " << *S2 << " and " << *cur->getOperand(0) << " " << // *cur->getOperand(1) << "\n"; if (auto mul1 = dyn_cast(cur->getOperand(0))) if (auto mul2 = dyn_cast(cur->getOperand(1))) { @@ -4414,10 +4429,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (int j=0; j<2; j++) if (auto CI = dyn_cast(SI->getOperand(1+j))) if (CI->isZero()) { - auto tval = (j == 0) ? CI : pushcse(B.CreateMul(SI->getTrueValue(), - cur->getOperand(1-i), "tval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap())); - auto fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(), + auto tval = (j == 0) ? CI : + pushcse(B.CreateMul(SI->getTrueValue(), cur->getOperand(1-i), "tval." + + cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); auto + fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(), cur->getOperand(1-i), "fval." + cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); @@ -4488,9 +4503,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, (and1->getType()->isIntegerTy(1) && and2->getType()->isIntegerTy(1) && and1->getOpcode() == Instruction::And && and2->getOpcode() == Instruction::And) { bool done = false; for (int i1=0; i1<2; i1++) for (int - i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto c1 - = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x = - pushcse(B.CreateZExt(x, inst1->getType())); auto y = and2->getOperand(1-i2); + i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto + c1 = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x = + pushcse(B.CreateZExt(x, inst1->getType())); auto y = + and2->getOperand(1-i2); y = pushcse(B.CreateZExt(y, inst2->getType())); @@ -5181,8 +5197,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } - // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), (sitofp - // b) + // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), + // (sitofp b) if (cur->getOpcode() == Instruction::FMul && cur->isFast()) { for (int i = 0; i < 2; i++) @@ -5407,8 +5423,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } Value *sel = pushcse( - B.CreateSelect(condition, ConstantFP::get(cur->getType(), 0.0), - fmul, "mulcsi." + cur->getName())); + B.CreateSelect(condition, ConstantFP::get(cur->getType(), + 0.0), fmul, "mulcsi." + cur->getName())); replaceAndErase(cur, sel); return "FMulSIToFPProp"; @@ -6535,8 +6551,8 @@ return true; auto div = ctx.SE.getUDivExpr(MinusX, Y); auto div_e = ctx.SE.getUDivExactExpr(MinusX, Y); - // in case of inexact division, check that these exactly equal for - // replacement + // in case of inexact division, check that these exactly equal + // for replacement if (div == div_e) { if (isEqual) { @@ -6807,8 +6823,8 @@ return true; if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) { return rhs->andB(shared_from_this(), ctx); } - // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and - // (c or e)) + // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) + // and (c or e)) if (ty == Type::Union && rhs->ty == Type::Union) { if (*this == *rhs->notB(ctx)) { return Constraints::none(); From dd9f977bce5aef059a723d13f9bb2b22e363fd67 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 00:13:06 -0500 Subject: [PATCH 071/131] Handle constant of vector of i1 (#1704) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index c0b54a4eccbb..9ff5ba6384aa 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -747,6 +747,9 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA, delete g2; int Off = (int)ai.getLimitedValue(); + if (auto VT = dyn_cast(Val->getType())) + if (VT->getElementType()->isIntegerTy(1)) + Off = i / 8; getConstantAnalysis(Op, TA, analysis); auto mid = analysis[Op]; From 82a6393901cb1110cfbc8e1879c6a8caaaa8224a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 15:25:08 -0500 Subject: [PATCH 072/131] Ensure alloca works well with minCut aliasing (#1705) --- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index f4eb07a01602..8f2e27b3140e 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -883,7 +883,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, continue; } // If an allocation call, we cannot cache any "capturing" users - if (isAllocationCall(V, TLI)) { + if (isAllocationCall(V, TLI) || isa(V)) { auto next = (*mp.begin()).V; bool noncapture = false; if (isa(next)) { From a36ff5c39b131ed393fd737e7dc6897cf8948d4e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 16:57:21 -0500 Subject: [PATCH 073/131] Strengthen capturing alloc check on jlcall (#1706) --- enzyme/Enzyme/GradientUtils.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index b1d8161576f2..d1dedbc70353 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9304,6 +9304,16 @@ bool GradientUtils::needsCacheWholeAllocation( if (idx < CI->getNumArgOperands()) #endif { + + // Calling a non-empty function with a julia base object, this is fine. + // as GC will deal with any issues with. + if (auto PT = dyn_cast(CI->getArgOperand(idx)->getType())) + if (PT->getAddressSpace() == 10) + if (EnzymeJuliaAddrLoad) + if (auto F = getFunctionFromCall(CI)) + if (!F->empty()) + continue; + if (isNoCapture(CI, idx)) continue; From 437e0526ebe856c43d613d71c28480efdfc5fed6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 18:09:28 -0500 Subject: [PATCH 074/131] BLAS fix gemm overwrite (#1707) --- enzyme/Enzyme/BlasDerivatives.td | 10 +-- .../test/Enzyme/ReverseMode/blas/gemm_f_c.ll | 8 +- .../blas/gemm_f_c_lacpy_runtime_act.ll | 6 +- .../Enzyme/ReverseMode/blas/gemm_f_c_loop.ll | 4 +- .../Enzyme/ReverseMode/blas/gemm_f_c_split.ll | 4 +- .../ReverseMode/blas/gemm_f_c_split_lacpy.ll | 4 +- .../blas/gemm_f_c_split_transpose_lacpy.ll | 4 +- .../blas/gemm_f_c_transpose_lacpy.ll | 8 +- .../ReverseMode/blas/gemm_f_change_ld.ll | 4 +- enzyme/test/Integration/ReverseMode/blas.cpp | 77 +++++++++++++++++++ 10 files changed, 103 insertions(+), 26 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 01894e87e453..f031aa7f191c 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -229,15 +229,15 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A [ /* alpha */ (Seq<["AB", "product", "m", "n"]> - (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n + (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)), /* A */ (b<"gemm"> $layout, (Rows $transa, (Concat $transa, transpose<"transb">, $m, $k), (Concat $transb, $transa, $k, $m)), $n, $alpha, (Rows $transa, - (Concat adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n)), - (Concat $B, (ld $B, $transb, $ldb, $k, $n), adj<"C">)), + (Concat adj<"C">, $B, (ld $B, $transb, $ldb, $n, $k)), + (Concat $B, (ld $B, $transb, $ldb, $n, $k), adj<"C">)), Constant<"1.0">, adj<"A">), /* B */ (b<"gemm"> $layout, (Rows $transb, @@ -245,8 +245,8 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A (Concat $transb, $transa, $n, $k)), $m, $alpha, (Rows $transb, - (Concat $A, (ld $A, $transa, $lda, $m, $k), adj<"C">), - (Concat adj<"C">, $A, (ld $A, $transa, $lda, $m, $k))), + (Concat $A, (ld $A, $transa, $lda, $k, $m), adj<"C">), + (Concat adj<"C">, $A, (ld $A, $transa, $lda, $k, $m))), Constant<"1.0">, adj<"B">), /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">), /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, Alloca<1>) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll index e84c12e14701..633feb70203a 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll @@ -212,12 +212,12 @@ entry: ; CHECK-NEXT: %[[r65:.+]] = icmp eq i8 %loaded.trans4, 78 ; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans4, 110 ; CHECK-NEXT: %[[r67:.+]] = or i1 %[[r66]], %[[r65]] -; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r67]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r67]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r69:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[r71:.+]] = or i1 %[[r70]], %[[r69]] -; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r73:.+]] = icmp eq i8 %ld.row.trans6, 110 ; CHECK-NEXT: %[[r74:.+]] = icmp eq i8 %ld.row.trans6, 78 @@ -251,12 +251,12 @@ entry: ; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %[[loaded_trans10]], 78 ; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %[[loaded_trans10]], 110 ; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] -; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans11:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %[[loaded_trans11]], 78 ; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %[[loaded_trans11]], 110 ; CHECK-NEXT: %[[r93:.+]] = or i1 %[[r92]], %[[r91]] -; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans12:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r95:.+]] = icmp eq i8 %[[ld_row_trans12]], 110 ; CHECK-NEXT: %[[r96:.+]] = icmp eq i8 %[[ld_row_trans12]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index 8f1c4f3e566e..78fc4959f21e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -152,7 +152,7 @@ entry: ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans7, 78 ; CHECK-DAG: %[[i42:.+]] = icmp eq i8 %loaded.trans7, 110 ; CHECK-NEXT: %[[i43:.+]] = or i1 %[[i42]], %[[i41]] -; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) @@ -268,12 +268,12 @@ entry: ; CHECK-NEXT: %[[r83:.+]] = icmp eq i8 %[[loaded_trans14]], 78 ; CHECK-NEXT: %[[r84:.+]] = icmp eq i8 %[[loaded_trans14]], 110 ; CHECK-NEXT: %[[r85:.+]] = or i1 %[[r84]], %[[r83]] -; CHECK-NEXT: %[[r86:.+]] = select i1 %[[r85]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r86:.+]] = select i1 %[[r85]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans15:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %[[loaded_trans15]], 78 ; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %[[loaded_trans15]], 110 ; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] -; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans16:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %[[ld_row_trans16]], 110 ; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %[[ld_row_trans16]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll index c66d4ff88d63..606ecb1e1c5b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll @@ -271,12 +271,12 @@ entry: ; CHECK-NEXT: %[[r62:.+]] = icmp eq i8 %loaded.trans30, 78 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %loaded.trans30, 110 ; CHECK-NEXT: %[[r64:.+]] = or i1 %[[r63]], %[[r62]] -; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %cast.k, i8* %[[r37]] +; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %[[r37]], i8* %cast.k ; CHECK-NEXT: %loaded.trans31 = load i8, i8* %byref.transa, align 1 ; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans31, 78 ; CHECK-NEXT: %[[r67:.+]] = icmp eq i8 %loaded.trans31, 110 ; CHECK-NEXT: %[[r68:.+]] = or i1 %[[r67]], %[[r66]] -; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r68]], i8* %cast.k, i8* %[[r37]] +; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r68]], i8* %[[r37]], i8* %cast.k ; CHECK-NEXT: %ld.row.trans32 = load i8, i8* %byref.transb, align 1 ; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %ld.row.trans32, 110 ; CHECK-NEXT: %[[r71:.+]] = icmp eq i8 %ld.row.trans32, 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll index 0e30a9fdc7da..48589c0dd732 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll @@ -265,12 +265,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[loaded_trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[loaded_trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[ld_row_trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[ld_row_trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index 64d862b240ff..132e5c14ba4d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -239,12 +239,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index f7164d4926e8..c3a424a956a3 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -237,12 +237,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[loaded_trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[loaded_trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[ld_row_trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[ld_row_trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index d0f9ed397b0d..ffc9f795f395 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -156,12 +156,12 @@ entry: ; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[i42:.+]] = or i1 %[[i41]], %[[i40]] -; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans6 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a49:.+]] = icmp eq i8 %loaded.trans6, 78 ; CHECK-NEXT: %[[a50:.+]] = icmp eq i8 %loaded.trans6, 110 ; CHECK-NEXT: %[[a51:.+]] = or i1 %[[a50]], %[[a49]] -; CHECK-NEXT: %[[a52:.+]] = select i1 %[[a51]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[a52:.+]] = select i1 %[[a51]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans7 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a53:.+]] = icmp eq i8 %ld.row.trans7, 110 ; CHECK-NEXT: %[[a54:.+]] = icmp eq i8 %ld.row.trans7, 78 @@ -195,12 +195,12 @@ entry: ; CHECK-DAG: %[[i54:.+]] = icmp eq i8 %[[cachedtrans2]], 78 ; CHECK-DAG: %[[i55:.+]] = icmp eq i8 %[[cachedtrans2]], 110 ; CHECK-NEXT: %[[i56:.+]] = or i1 %[[i55]], %[[i54]] -; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans12:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a71:.+]] = icmp eq i8 %[[loaded_trans12]], 78 ; CHECK-NEXT: %[[a72:.+]] = icmp eq i8 %[[loaded_trans12]], 110 ; CHECK-NEXT: %[[a73:.+]] = or i1 %[[a72]], %[[a71]] -; CHECK-NEXT: %[[a74:.+]] = select i1 %[[a73]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[a74:.+]] = select i1 %[[a73]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans13:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a75:.+]] = icmp eq i8 %[[ld_row_trans13]], 110 ; CHECK-NEXT: %[[a76:.+]] = icmp eq i8 %[[ld_row_trans13]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index 09189c1fe8d1..0c50fa797b3b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -135,12 +135,12 @@ entry: ; CHECK-DAG: %[[r16:.+]] = icmp eq i8 %loaded.trans1, 78 ; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans1, 110 ; CHECK-NEXT: %[[r18:.+]] = or i1 %[[r17]], %[[r16]] -; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans2 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a38:.+]] = icmp eq i8 %loaded.trans2, 78 ; CHECK-NEXT: %[[a39:.+]] = icmp eq i8 %loaded.trans2, 110 ; CHECK-NEXT: %[[a40:.+]] = or i1 %[[a39]], %[[a38]] -; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a42:.+]] = icmp eq i8 %ld.row.trans3, 110 ; CHECK-NEXT: %[[a43:.+]] = icmp eq i8 %ld.row.trans3, 78 diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 1c01178b874f..7cc7438cc932 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -42,6 +42,11 @@ void my_dgemm(char layout, char transA, char transB, int M, int N, int K, double inDerivative = true; } +void ow_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { + cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + inDerivative = true; +} + static void dotTests() { @@ -374,6 +379,78 @@ static void gemmTests() { // should be the same). checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + Test = "GEMM overwrite"; + + init(); + __enzyme_autodiff((void*) ow_dgemm, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, transB, + enzyme_const, M, + enzyme_const, N, + enzyme_const, K, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + assert(foundCalls.size() > 2); + auto A_cache = (double*)foundCalls[0].pout_arg1; + cblas_dlacpy(layout, '\0', (!transA_bool) ? M : K, (!transA_bool) ? K : M, A, lda, A_cache, (!transA_bool) ? M : K); + inputs[4] = BlasInfo(A_cache, layout, (!transA_bool) ? M : K, (!transA_bool) ? K : M, (!transA_bool) ? M : K); + auto B_cache = (double*)foundCalls[1].pout_arg1; + cblas_dlacpy(layout, '\0', (!transB_bool) ? K : N, (!transB_bool) ? N : K, B, incB, B_cache, (!transB_bool) ? K : N); + inputs[5] = BlasInfo(B_cache, layout, (!transB_bool) ? K : N, (!transB_bool) ? N : K, (!transB_bool) ? K : N); + + ow_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + // dA = + my_dgemm(layout, + transA_bool ? (char)transB : (char)CBLAS_TRANSPOSE::CblasNoTrans, + transA_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transB), + transA_bool ? K : M, + transA_bool ? M : K, + N, + alpha, + transA_bool ? B_cache : dC, + transA_bool ? ( (!transB_bool) ? K : N ) : incC, + transA_bool ? dC : B_cache, + transA_bool ? incC : ( (!transB_bool) ? K : N), + 1.0, dA, lda); + + // dB = + my_dgemm(layout, + transB_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transA), + transB_bool ? (char)transA : (char)CBLAS_TRANSPOSE::CblasNoTrans, //transB, + transB_bool ? N : K, + transB_bool ? K : N, + M, + alpha, + transB_bool ? dC : A_cache, + transB_bool ? incC : ( (!transA_bool) ? M : K), + transB_bool ? A_cache : dC, + transB_bool ? ( (!transA_bool) ? M : K) : incC, + 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); } From e014a85f415189d097a4a5b2d8bbd49dcd5ffb8f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 20:07:53 -0500 Subject: [PATCH 075/131] BLAS fix vector mode (#1708) --- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index c0e725eaf948..e8535370f5b6 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -855,7 +855,7 @@ void emit_fwd_rewrite_rules(const TGPattern &pattern, raw_ostream &os) { if (ty == ArgType::fp) { const auto name = nameVec[inputType.first]; os << " Value *d_" << name - << " = llvm::ConstantFP::get(fpType, 0.0);\n"; + << " = Constant::getNullValue(gutils->getShadowType(fpType));\n"; } } From 82f8cdf207cdb7b847ef445a1061fb44b5362f26 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 21:32:52 -0500 Subject: [PATCH 076/131] BCLoad: Still ignore if cblas lowering required (#1709) * BCLoad: Still ignore if cblas lowering required * inform which were replaced --- enzyme/BCLoad/BCLoader.cpp | 46 ++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index 6aad07eb6739..02a8abb9223f 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -24,28 +24,30 @@ static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) { #endif // LLVM_VERSION_MAJOR } -bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { +bool provideDefinitions(Module &M, std::set ignoreFunctions, + std::vector &replaced) { std::vector todo; bool seen32 = false; bool seen64 = false; for (auto &F : M) { if (!F.empty()) continue; + if (ignoreFunctions.count(F.getName().str())) + continue; int index = 0; for (auto postfix : {"", "_", "_64_"}) { std::string str; if (strlen(postfix) == 0) { str = F.getName().str(); - if (ignoreFunctions.count(str)) continue; } else if (endsWith(F.getName(), postfix)) { auto blasName = F.getName().substr(0, F.getName().size() - strlen(postfix)).str(); - if (ignoreFunctions.count(blasName)) continue; str = "cblas_" + blasName; } auto found = EnzymeBlasBC.find(str); if (found != EnzymeBlasBC.end()) { + replaced.push_back(F.getName().str()); todo.push_back(found->second); if (index == 1) seen32 = true; @@ -81,13 +83,23 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { }); #endif - if (!BC) + if (!BC) { Err.print("bcloader", llvm::errs()); + continue; + } assert(BC); SmallVector toReplace; for (auto &F : *BC) { if (F.empty()) continue; + if (ignoreFunctions.count(F.getName().str())) { +#if LLVM_VERSION_MAJOR >= 16 + F.erase(F.begin(), F.end()); +#else + F.getBasicBlockList().erase(F.begin(), F.end()); +#endif + continue; + } toReplace.push_back(F.getName().str()); } BC->setTargetTriple(""); @@ -106,12 +118,29 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { extern "C" { uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M, char **FncsNamesToIgnore, - size_t numFncNames) { + size_t numFncNames, const char ***foundP, + size_t *foundLen) { std::set ignoreFunctions = {}; for (size_t i = 0; i < numFncNames; i++) { ignoreFunctions.insert(std::string(FncsNamesToIgnore[i])); } - return provideDefinitions(*unwrap(M), ignoreFunctions); + std::vector replaced; + auto res = provideDefinitions(*unwrap(M), ignoreFunctions, replaced); + + const char **found = nullptr; + if (replaced.size()) { + found = (const char **)malloc(replaced.size() * sizeof(const char **)); + for (size_t i = 0; i < replaced.size(); i++) { + char *data = (char *)malloc(replaced[i].size() + 1); + memcpy(data, replaced[i].data(), replaced[i].size()); + data[replaced[i].size()] = 0; + found[i] = data; + } + } + *foundP = found; + *foundLen = replaced.size(); + + return res; } } @@ -121,7 +150,10 @@ class BCLoader final : public ModulePass { static char ID; BCLoader() : ModulePass(ID) {} - bool runOnModule(Module &M) override { return provideDefinitions(M, {}); } + bool runOnModule(Module &M) override { + std::vector replaced; + return provideDefinitions(M, {}, replaced); + } }; } // namespace From 87a7b5dfef3543c0dbd2dbeb8e39f4f19357e291 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 12:20:10 -0500 Subject: [PATCH 077/131] BLAS fix erasure (#1711) --- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 30 ++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index e8535370f5b6..39cb55384ad9 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -146,14 +146,22 @@ void emit_handleBLAS(ArrayRef blasPatterns, raw_ostream &os) { << " llvm::errs() << \" fallback?\\n\"; \n" << " return false; \n" << " } \n" - << " } \n" - << " \n" - << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" - << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" << " } else { \n" - << " eraseIfUnused(call); \n" - << " } \n" - << " \n" + << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" + << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" + << " } else { \n" + << " eraseIfUnused(call); \n" + << " } \n" + << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n" + << " gutils->knownRecomputeHeuristic.end()) {\n" + << " if (!gutils->knownRecomputeHeuristic[&call]) {\n" + << " auto newCall = gutils->getNewFromOriginal(&call);\n" + << " llvm::IRBuilder<> BuilderZ(newCall);\n" + << " gutils->cacheForReverse(BuilderZ, newCall,\n" + << " getIndex(&call, CacheType::Self, BuilderZ));\n" + << " }\n" + << " }\n" + << " }\n" << " return result; \n" << "} \n"; } @@ -229,10 +237,16 @@ void emit_free_and_ending(const TGPattern &pattern, raw_ostream &os) { os << " }\n" << " }\n" + << " \n" + << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" + << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" + << " } else { \n" + << " eraseIfUnused(call); \n" + << " } \n" << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n" << " gutils->knownRecomputeHeuristic.end()) {\n" << " if (!gutils->knownRecomputeHeuristic[&call]) {\n" - << " gutils->cacheForReverse(BuilderZ, newCall,\n" + << " auto cv = gutils->cacheForReverse(BuilderZ, newCall,\n" << " getIndex(&call, CacheType::Self, BuilderZ));\n" << " }\n" << " }\n" From 6ef7d1e134c708b8453f61a03ccd69a934b0f09e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 12:56:06 -0500 Subject: [PATCH 078/131] Drop all fn references on bc (#1710) --- enzyme/BCLoad/BCLoader.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index 02a8abb9223f..286bfc2e421e 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -93,6 +93,7 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions, if (F.empty()) continue; if (ignoreFunctions.count(F.getName().str())) { + F.dropAllReferences(); #if LLVM_VERSION_MAJOR >= 16 F.erase(F.begin(), F.end()); #else From 4623cb1e530f27b20c87cffae8efcc66de4c2dd3 Mon Sep 17 00:00:00 2001 From: Max Aehle Date: Tue, 13 Feb 2024 18:57:05 +0100 Subject: [PATCH 079/131] Support more C math.h functions (#1605) * Add testcases for coshf, coshl in the forward and reverse mode. - The coshf tests pass. - In the coshl tests, the opt call crashes. * Define derivatives of sinhl, coshl, tanhl * Fix frexpl and add tests * Define derivatives of modf, modff, modfl * Add tests for modf, modff, modfl * Activity and type analysis for modf, modff, modfl * clang-format * Define derivatives of fdim, fdimf, fdiml * Add tests for fdim, fdimf, fdiml * Define derivatives for inverse hyperbolic functions asinh, asinhf, asinhl, acosh, acoshf, acoshl, atanh, atanhf, atanhl * Fix nearbyint, nearbyintf, nearbyintl * Define derivatives for erff, erfl, erfcf, erfcl * Add tests for inverse hyperbolic functions --- enzyme/Enzyme/ActivityAnalysis.cpp | 3 +- enzyme/Enzyme/InstructionDerivatives.td | 75 +++++++- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 3 + enzyme/test/Enzyme/ForwardMode/acosh.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/asinh.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/asinhf.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/asinhl.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/atanh.ll | 30 ++++ enzyme/test/Enzyme/ForwardMode/coshf.ll | 28 +++ enzyme/test/Enzyme/ForwardMode/coshl.ll | 28 +++ enzyme/test/Enzyme/ForwardMode/fdim.ll | 32 ++++ enzyme/test/Enzyme/ForwardMode/frexp.ll | 27 ++- enzyme/test/Enzyme/ForwardMode/modf.ll | 120 +++++++++++++ enzyme/test/Enzyme/ReverseMode/acosh.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/asinh.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/asinhf.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/asinhl.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/atanh.ll | 45 +++++ enzyme/test/Enzyme/ReverseMode/coshf.ll | 29 +++ enzyme/test/Enzyme/ReverseMode/coshl.ll | 29 +++ enzyme/test/Enzyme/ReverseMode/fdim.ll | 54 ++++++ enzyme/test/Enzyme/ReverseMode/frexp.ll | 28 ++- enzyme/test/Enzyme/ReverseMode/modf.ll | 189 ++++++++++++++++++++ 23 files changed, 1019 insertions(+), 9 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/acosh.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/asinh.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/asinhf.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/asinhl.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/atanh.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/coshf.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/coshl.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/fdim.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/modf.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/acosh.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/asinh.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/asinhf.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/asinhl.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/atanh.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/coshf.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/coshl.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/fdim.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/modf.ll diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 79a56d885507..3175e277a397 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -900,7 +900,8 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, if (isMemFreeLibMFunction(funcName)) { noActiveWrite = true; } else if (funcName == "frexp" || funcName == "frexpf" || - funcName == "frexpl") { + funcName == "frexpl" || funcName == "modf" || + funcName == "modff" || funcName == "modfl") { noActiveWrite = true; } } diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 58731be4a532..77a67fafe325 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -210,7 +210,9 @@ def AssertingInactiveArg : InactiveArgSpec { class GlobalExpr : Operation{ string value = val; } -def MantissaMaskOfReturn : GlobalExprisX86_FP80Ty()) {\n" " tsize = 80;\n" " high = tsize - 1;\n" -" low = high - 15;\n" +" low = high - 16;\n" + // x86_fp80 has only 15 exponent bits, but we must also + // retain the most-significant bit of the mantissa as + // there is no implicit leading 1. " } else if (ty->isFP128Ty()) {\n" " tsize = 128;\n" " high = tsize - 1;\n" @@ -317,13 +322,18 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; - def : CallPattern<(Op $x), ["tanhf"], [(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"coshf">), [ReadNone,NoUnwind]> $x):$c, $c))], (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["tanhl"], + [(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"coshl">), [ReadNone,NoUnwind]> $x):$c, $c))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["cosh"], @@ -337,6 +347,12 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["coshl"], + [(FMul (DiffeRet), (Call<(SameTypesFunc<"sinhl">), [ReadNone,NoUnwind]> $x))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["sinh"], @@ -350,6 +366,33 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["sinhl"], + [(FMul (DiffeRet), (Call<(SameTypesFunc<"coshl">), [ReadNone,NoUnwind]> $x))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["asinh", "asinhf", "asinhl", "__nv_asinh", "__nv_asinhf"], + [(FDiv (DiffeRet), (Intrinsic<"sqrt"> (FAdd (FMul $x, $x), (ConstantFP<"1.0"> $x))) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["acosh", "acoshf", "acoshl", "__nv_acosh", "__nv_acoshf"], + [(FDiv (DiffeRet), (Intrinsic<"sqrt"> (FSub (FMul $x, $x), (ConstantFP<"1.0"> $x))) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["atanh", "atanhf", "atanhl", "__nv_atanh", "__nv_atanhf"], + [(FDiv (DiffeRet), (FSub (ConstantFP<"1.0"> $x), (FMul $x, $x)) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["exp10"], @@ -443,6 +486,16 @@ def : CallPattern<(Op $x, $y), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x, $integral_part_ptr), + ["modf", "modff", "modfl"], + [ + (DiffeRet), + (InactiveArg) + ], + (ForwardFromSummedReverse), + [ReadOnly, NoUnwind] + >; + def : CallPattern<(Op $x), ["__fd_sincos_1", "__fd_sincos_1f", "__fd_sincos_1l"], [ @@ -564,7 +617,7 @@ def : CallPattern<(Op $n, $x), >; def : CallPattern<(Op $x), - ["erf"], + ["erf","erff","erfl"], [ (FMul (DiffeRet), (FMul (ConstantFP<"1.1283791670955125738961589031215451716881012586580"> $x), (Intrinsic<"exp"> (FNeg (FMul $x, $x))))) ], @@ -580,7 +633,7 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x), - ["erfc"], + ["erfc","erfcf","erfcl"], [ (FMul (DiffeRet), (FMul (ConstantFP<"-1.1283791670955125738961589031215451716881012586580"> $x), (Intrinsic<"exp"> (FNeg (FMul $x, $x))))) ], @@ -722,7 +775,7 @@ def : CallPattern<(Op $x, $expout), (DiffeRet), (FMul (BitCast - (And (MantissaMaskOfReturn):$mask, (BitCast $x, (TypeOf $mask)) ), + (And (MantissaMaskOfReturnForFrexp):$mask, (BitCast $x, (TypeOf $mask)) ), (TypeOf $x) ), (ConstantFP<"2"> $x) @@ -839,6 +892,16 @@ def : IntrPattern<(Op $x, $y), (Select (FCmpOLT $x, $y), (SelectIfActive $y, (Shadow $y), (Zero $y)), (SelectIfActive $x, (Shadow $x), (Zero $x))) >; +def : CallPattern<(Op $x, $y), + ["fdim", "fdimf", "fdiml"], + [ + (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $x), (DiffeRet)), + (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $y), (FNeg (DiffeRet))) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : IntrPattern<(Op $x), [["fabs"]], [ diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 9ff5ba6384aa..b7b900783f48 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -175,6 +175,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"trunc", Intrinsic::trunc}, {"round", Intrinsic::round}, {"rint", Intrinsic::rint}, + {"nearbyint", Intrinsic::nearbyint}, {"remainder", Intrinsic::not_intrinsic}, {"copysign", Intrinsic::copysign}, {"nextafter", Intrinsic::not_intrinsic}, @@ -5226,6 +5227,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { CONSIDER(frexpl) CONSIDER2(ldexp, double, double, int) CONSIDER2(modf, double, double, double *) + CONSIDER(modff) + CONSIDER(modfl) CONSIDER2(remquo, double, double, double, int *) CONSIDER(remquof) diff --git a/enzyme/test/Enzyme/ForwardMode/acosh.ll b/enzyme/test/Enzyme/ForwardMode/acosh.ll new file mode 100644 index 000000000000..90108c74de17 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/acosh.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @acosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @acosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fsub fast double %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %1) +; CHECK-NEXT: %3 = fdiv fast double %"x'", %2 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinh.ll b/enzyme/test/Enzyme/ForwardMode/asinh.ll new file mode 100644 index 000000000000..73377c6edfd5 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinh.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @asinh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @asinh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fadd fast double %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %1) +; CHECK-NEXT: %3 = fdiv fast double %"x'", %2 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinhf.ll b/enzyme/test/Enzyme/ForwardMode/asinhf.ll new file mode 100644 index 000000000000..edd9959120d9 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinhf.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @asinhf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_fwddiff(float (float)* nonnull @tester, float %x, float 1.0) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @asinhf(float) + +; Function Attrs: nounwind +declare float @__enzyme_fwddiff(float (float)*, ...) + +; CHECK: define internal float @fwddiffetester(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast float %x, %x +; CHECK-NEXT: %1 = fadd fast float %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast float @llvm.sqrt.f32(float %1) +; CHECK-NEXT: %3 = fdiv fast float %"x'", %2 +; CHECK-NEXT: ret float %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinhl.ll b/enzyme/test/Enzyme/ForwardMode/asinhl.ll new file mode 100644 index 000000000000..80929fd665c2 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinhl.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @asinhl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_fwddiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x, x86_fp80 0xK3FFF8000000000000000) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @asinhl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_fwddiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal x86_fp80 @fwddiffetester(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast x86_fp80 %x, %x +; CHECK-NEXT: %1 = fadd fast x86_fp80 %0, 0xK3FFF8000000000000000 +; CHECK-NEXT: %2 = call fast x86_fp80 @llvm.sqrt.f80(x86_fp80 %1) +; CHECK-NEXT: %3 = fdiv fast x86_fp80 %"x'", %2 +; CHECK-NEXT: ret x86_fp80 %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/atanh.ll b/enzyme/test/Enzyme/ForwardMode/atanh.ll new file mode 100644 index 000000000000..358ab107e140 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/atanh.ll @@ -0,0 +1,30 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @atanh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @atanh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fsub fast double 1.000000e+00, %0 +; CHECK-NEXT: %2 = fdiv fast double %"x'", %1 +; CHECK-NEXT: ret double %2 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/coshf.ll b/enzyme/test/Enzyme/ForwardMode/coshf.ll new file mode 100644 index 000000000000..b94270380e52 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/coshf.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @coshf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_fwddiff(float (float)* nonnull @tester, float %x, float 1.0) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @coshf(float) + +; Function Attrs: nounwind +declare float @__enzyme_fwddiff(float (float)*, ...) + +; CHECK: define internal float @fwddiffetester(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast float @sinhf(float %x) +; CHECK-NEXT: %1 = fmul fast float %"x'", %0 +; CHECK-NEXT: ret float %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/coshl.ll b/enzyme/test/Enzyme/ForwardMode/coshl.ll new file mode 100644 index 000000000000..4a6e3054f0fc --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/coshl.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @coshl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_fwddiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x, x86_fp80 0xK3FFF8000000000000000) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @coshl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_fwddiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal x86_fp80 @fwddiffetester(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast x86_fp80 @sinhl(x86_fp80 %x) +; CHECK-NEXT: %1 = fmul fast x86_fp80 %"x'", %0 +; CHECK-NEXT: ret x86_fp80 %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/fdim.ll b/enzyme/test/Enzyme/ForwardMode/fdim.ll new file mode 100644 index 000000000000..7d333686f7f4 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/fdim.ll @@ -0,0 +1,32 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s ; fi + +declare double @fdim(double, double) + +define double @tester(double %x, double %y) { +entry: + %0 = call double @fdim(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 10.0, double %y, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double, double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fcmp fast olt double %x, %y +; CHECK-NEXT: %1 = select fast i1 %0, double 0.000000e+00, double %"x'" +; CHECK-NEXT: %2 = fcmp fast olt double %x, %y +; CHECK-NEXT: %3 = fneg fast double %"y'" +; CHECK-NEXT: %4 = select fast i1 %2, double 0.000000e+00, double %3 +; CHECK-NEXT: %5 = fadd fast double %1, %4 +; CHECK-NEXT: ret double %5 +; CHECK-NEXT: } + + diff --git a/enzyme/test/Enzyme/ForwardMode/frexp.ll b/enzyme/test/Enzyme/ForwardMode/frexp.ll index 9e421ec21e3a..f9f0916a09b9 100644 --- a/enzyme/test/Enzyme/ForwardMode/frexp.ll +++ b/enzyme/test/Enzyme/ForwardMode/frexp.ll @@ -1,10 +1,11 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s -declare double @frexp(double, i32*) declare double @__enzyme_fwddiff(i8*, ...) declare float @__enzyme_fwddifff(i8*, ...) +declare x86_fp80 @__enzyme_fwddiffl(i8*, ...) +declare double @frexp(double, i32*) define double @test(double %x) { entry: %exp = alloca i32, align 4 @@ -32,6 +33,20 @@ entry: ret float %call } +declare x86_fp80 @frexpl(x86_fp80, i32*) +define x86_fp80 @testl(x86_fp80 %x) { +entry: + %exp = alloca i32, align 4 + %call = call x86_fp80 @frexpl(x86_fp80 %x, i32* %exp) + ret x86_fp80 %call +} + +define x86_fp80 @dtestl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} + ; CHECK: define internal double @fwddiffetest(double %x, double %"x'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = bitcast double %x to i64 @@ -51,3 +66,13 @@ entry: ; CHECK-NEXT: %4 = fdiv fast float %"x'", %3 ; CHECK-NEXT: ret float %4 ; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast x86_fp80 %x to i80 +; CHECK-NEXT: %1 = and i80 604453686435277732577280, %0 +; CHECK-NEXT: %2 = bitcast i80 %1 to x86_fp80 +; CHECK-NEXT: %3 = fmul fast x86_fp80 %2, 0xK40008000000000000000 +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %"x'", %3 +; CHECK-NEXT: ret x86_fp80 %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/modf.ll b/enzyme/test/Enzyme/ForwardMode/modf.ll new file mode 100644 index 000000000000..d36c4d659315 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/modf.ll @@ -0,0 +1,120 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +declare double @__enzyme_fwddiff(i8*, ...) +declare float @__enzyme_fwddifff(i8*, ...) +declare x86_fp80 @__enzyme_fwddiffl(i8*, ...) + +; double +declare double @modf(double, double*) +define double @testint(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + %ret = load double, double* %integral_part, align 8 + ret double %ret +} +define double @testfrac(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + ret double %fractional_part +} + +define double @dtestint(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @testint to i8*), double %x, double %dx) + ret double %call +} +define double @dtestfrac(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @testfrac to i8*), double %x, double %dx) + ret double %call +} + +; float +declare float @modff(float, float*) +define float @testintf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + %ret = load float, float* %integral_part, align 4 + ret float %ret +} +define float @testfracf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + ret float %fractional_part +} + +define float @dtestintf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_fwddifff(i8* bitcast (float (float)* @testintf to i8*), float %x, float %dx) + ret float %call +} +define float @dtestfracf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_fwddifff(i8* bitcast (float (float)* @testfracf to i8*), float %x, float %dx) + ret float %call +} + +; x86_fp80 +declare x86_fp80 @modfl(x86_fp80, x86_fp80*) +define x86_fp80 @testintl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + %ret = load x86_fp80, x86_fp80* %integral_part, align 8 + ret x86_fp80 %ret +} +define x86_fp80 @testfracl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + ret x86_fp80 %fractional_part +} + +define x86_fp80 @dtestintl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testintl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} +define x86_fp80 @dtestfracl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testfracl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} + +; tests + +; CHECK: define internal double @fwddiffetestint(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double 0.000000e+00 +; CHECK-NEXT: } + +; CHECK: define internal double @fwddiffetestfrac(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } + +; CHECK: define internal float @fwddiffetestintf(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float 0.000000e+00 +; CHECK-NEXT: } + +; CHECK: define internal float @fwddiffetestfracf(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float %"x'" +; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestintl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret x86_fp80 0xK00000000000000000000 +; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestfracl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret x86_fp80 %"x'" +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/acosh.ll b/enzyme/test/Enzyme/ReverseMode/acosh.ll new file mode 100644 index 000000000000..b6ff83ea4f5f --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/acosh.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @acosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @acosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fsub fast double %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast double @llvm.sqrt.f64(double %2) +; CHECK-NEXT: %4 = fdiv fast double %0, %3 +; CHECK-NEXT: %5 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fadd fast double %5, %4 +; CHECK-NEXT: store double %6, double* %"x'de", align 8 +; CHECK-NEXT: %7 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %8 = insertvalue { double } undef, double %7, 0 +; CHECK-NEXT: ret { double } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinh.ll b/enzyme/test/Enzyme/ReverseMode/asinh.ll new file mode 100644 index 000000000000..3c34fabc25bf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinh.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @asinh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @asinh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fadd fast double %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast double @llvm.sqrt.f64(double %2) +; CHECK-NEXT: %4 = fdiv fast double %0, %3 +; CHECK-NEXT: %5 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fadd fast double %5, %4 +; CHECK-NEXT: store double %6, double* %"x'de", align 8 +; CHECK-NEXT: %7 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %8 = insertvalue { double } undef, double %7, 0 +; CHECK-NEXT: ret { double } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinhf.ll b/enzyme/test/Enzyme/ReverseMode/asinhf.ll new file mode 100644 index 000000000000..c113c111abc1 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinhf.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @asinhf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* nonnull @tester, float %x) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @asinhf(float) + +; Function Attrs: nounwind +declare float @__enzyme_autodiff(float (float)*, ...) + +; CHECK: define internal { float } @diffetester(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"'de", align 4 +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store float %differeturn, float* %"'de", align 4 +; CHECK-NEXT: %0 = load float, float* %"'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"'de", align 4 +; CHECK-NEXT: %1 = fmul fast float %x, %x +; CHECK-NEXT: %2 = fadd fast float %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast float @llvm.sqrt.f32(float %2) +; CHECK-NEXT: %4 = fdiv fast float %0, %3 +; CHECK-NEXT: %5 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %6 = fadd fast float %5, %4 +; CHECK-NEXT: store float %6, float* %"x'de", align 4 +; CHECK-NEXT: %7 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %8 = insertvalue { float } undef, float %7, 0 +; CHECK-NEXT: ret { float } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinhl.ll b/enzyme/test/Enzyme/ReverseMode/asinhl.ll new file mode 100644 index 000000000000..759d05bf3236 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinhl.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @asinhl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_autodiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @asinhl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_autodiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal { x86_fp80 } @diffetester(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store x86_fp80 %differeturn, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"'de", align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %1 = fmul fast x86_fp80 %x, %x +; CHECK-NEXT: %2 = fadd fast x86_fp80 %1, 0xK3FFF8000000000000000 +; CHECK-NEXT: %3 = call fast x86_fp80 @llvm.sqrt.f80(x86_fp80 %2) +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %0, %3 +; CHECK-NEXT: %5 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %6 = fadd fast x86_fp80 %5, %4 +; CHECK-NEXT: store x86_fp80 %6, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %7 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %8 = insertvalue { x86_fp80 } undef, x86_fp80 %7, 0 +; CHECK-NEXT: ret { x86_fp80 } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/atanh.ll b/enzyme/test/Enzyme/ReverseMode/atanh.ll new file mode 100644 index 000000000000..708c4da03bf3 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/atanh.ll @@ -0,0 +1,45 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @atanh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @atanh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fsub fast double 1.000000e+00, %1 +; CHECK-NEXT: %3 = fdiv fast double %0, %2 +; CHECK-NEXT: %4 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %5 = fadd fast double %4, %3 +; CHECK-NEXT: store double %5, double* %"x'de", align 8 +; CHECK-NEXT: %6 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0 +; CHECK-NEXT: ret { double } %7 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/coshf.ll b/enzyme/test/Enzyme/ReverseMode/coshf.ll new file mode 100644 index 000000000000..b028239a4ecf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/coshf.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @coshf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* nonnull @tester, float %x) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @coshf(float) + +; Function Attrs: nounwind +declare float @__enzyme_autodiff(float (float)*, ...) + +; CHECK: define internal { float } @diffetester(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast float @sinhf(float %x) +; CHECK-NEXT: %1 = fmul fast float %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { float } undef, float %1, 0 +; CHECK-NEXT: ret { float } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/coshl.ll b/enzyme/test/Enzyme/ReverseMode/coshl.ll new file mode 100644 index 000000000000..dd8af143ff48 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/coshl.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @coshl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_autodiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @coshl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_autodiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal { x86_fp80 } @diffetester(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast x86_fp80 @sinhl(x86_fp80 %x) +; CHECK-NEXT: %1 = fmul fast x86_fp80 %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { x86_fp80 } undef, x86_fp80 %1, 0 +; CHECK-NEXT: ret { x86_fp80 } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/fdim.ll b/enzyme/test/Enzyme/ReverseMode/fdim.ll new file mode 100644 index 000000000000..113532464547 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/fdim.ll @@ -0,0 +1,54 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s ; fi + +declare double @fdim(double, double) + +define double @tester(double %x, double %y) { +entry: + %0 = call double @fdim(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, double %x, double %y) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double, double)*, ...) + +; CHECK: define internal { double, double } @diffetester(double %x, double %y, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: %"y'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"y'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fcmp fast olt double %x, %y +; CHECK-NEXT: %2 = select fast i1 %1, double 0.000000e+00, double %0 +; CHECK-NEXT: %3 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %4 = fadd fast double %3, %0 +; CHECK-NEXT: %5 = select fast i1 %1, double %3, double %4 +; CHECK-NEXT: store double %5, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fcmp fast olt double %x, %y +; CHECK-NEXT: %7 = fneg fast double %0 +; CHECK-NEXT: %8 = select fast i1 %6, double 0.000000e+00, double %7 +; CHECK-NEXT: %9 = load double, double* %"y'de", align 8 +; CHECK-NEXT: %10 = fadd fast double %9, %7 +; CHECK-NEXT: %11 = select fast i1 %6, double %9, double %10 +; CHECK-NEXT: store double %11, double* %"y'de", align 8 +; CHECK-NEXT: %12 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %13 = load double, double* %"y'de", align 8 +; CHECK-NEXT: %14 = insertvalue { double, double } undef, double %12, 0 +; CHECK-NEXT: %15 = insertvalue { double, double } %14, double %13, 1 +; CHECK-NEXT: ret { double, double } %15 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/frexp.ll b/enzyme/test/Enzyme/ReverseMode/frexp.ll index 2e74d6914105..fe134bdab00b 100644 --- a/enzyme/test/Enzyme/ReverseMode/frexp.ll +++ b/enzyme/test/Enzyme/ReverseMode/frexp.ll @@ -1,10 +1,11 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -adce -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify,adce)" -S | FileCheck %s -declare double @frexp(double, i32*) declare double @__enzyme_autodiff(i8*, ...) declare float @__enzyme_autodifff(i8*, ...) +declare x86_fp80 @__enzyme_autodiffl(i8*, ...) +declare double @frexp(double, i32*) define double @test(double %x) { entry: %exp = alloca i32, align 4 @@ -32,6 +33,20 @@ entry: ret float %call } +declare x86_fp80 @frexpl(x86_fp80, i32*) +define x86_fp80 @testl(x86_fp80 %x) { +entry: + %exp = alloca i32, align 4 + %call = call x86_fp80 @frexpl(x86_fp80 %x, i32* %exp) + ret x86_fp80 %call +} + +define x86_fp80 @dtestl(x86_fp80 %x) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} + ; CHECK: define internal { double } @diffetest(double %x, double %differeturn) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = bitcast double %x to i64 @@ -53,3 +68,14 @@ entry: ; CHECK-NEXT: %5 = insertvalue { float } {{(undef|poison)}}, float %4, 0 ; CHECK-NEXT: ret { float } %5 ; CHECK-NEXT: } + +; CHECK: define internal { x86_fp80 } @diffetestl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast x86_fp80 %x to i80 +; CHECK-NEXT: %1 = and i80 604453686435277732577280, %0 +; CHECK-NEXT: %2 = bitcast i80 %1 to x86_fp80 +; CHECK-NEXT: %3 = fmul fast x86_fp80 %2, 0xK40008000000000000000 +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %differeturn, %3 +; CHECK-NEXT: %5 = insertvalue { x86_fp80 } {{(undef|poison)}}, x86_fp80 %4, 0 +; CHECK-NEXT: ret { x86_fp80 } %5 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/modf.ll b/enzyme/test/Enzyme/ReverseMode/modf.ll new file mode 100644 index 000000000000..5c601a39e340 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/modf.ll @@ -0,0 +1,189 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +declare double @__enzyme_autodiff(i8*, ...) +declare float @__enzyme_autodifff(i8*, ...) +declare x86_fp80 @__enzyme_autodiffl(i8*, ...) + +; double +declare double @modf(double, double*) +define double @testint(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + %ret = load double, double* %integral_part, align 8 + ret double %ret +} +define double @testfrac(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + ret double %fractional_part +} + +define double @dtestint(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double)* @testint to i8*), double %x) + ret double %call +} +define double @dtestfrac(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double)* @testfrac to i8*), double %x) + ret double %call +} + +; float +declare float @modff(float, float*) +define float @testintf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + %ret = load float, float* %integral_part, align 4 + ret float %ret +} +define float @testfracf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + ret float %fractional_part +} + +define float @dtestintf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_autodifff(i8* bitcast (float (float)* @testintf to i8*), float %x) + ret float %call +} +define float @dtestfracf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_autodifff(i8* bitcast (float (float)* @testfracf to i8*), float %x) + ret float %call +} + +; x86_fp80 +declare x86_fp80 @modfl(x86_fp80, x86_fp80*) +define x86_fp80 @testintl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + %ret = load x86_fp80, x86_fp80* %integral_part, align 8 + ret x86_fp80 %ret +} +define x86_fp80 @testfracl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + ret x86_fp80 %fractional_part +} + +define x86_fp80 @dtestintl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testintl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} +define x86_fp80 @dtestfracl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testfracl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} + +; double tests + +; CHECK: define internal { double } @diffetestint(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %1 = insertvalue { double } undef, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffetestfrac(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"fractional_part'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %1 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %0 +; CHECK-NEXT: store double %2, double* %"x'de", align 8 +; CHECK-NEXT: %3 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %4 = insertvalue { double } undef, double %3, 0 +; CHECK-NEXT: ret { double } %4 +; CHECK-NEXT: } + +; float tests + +; CHECK: define internal { float } @diffetestintf(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %1 = insertvalue { float } undef, float %0, 0 +; CHECK-NEXT: ret { float } %1 +; CHECK-NEXT: } + +; CHECK: define internal { float } @diffetestfracf(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store float %differeturn, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %0 = load float, float* %"fractional_part'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %1 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %2 = fadd fast float %1, %0 +; CHECK-NEXT: store float %2, float* %"x'de", align 4 +; CHECK-NEXT: %3 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %4 = insertvalue { float } undef, float %3, 0 +; CHECK-NEXT: ret { float } %4 +; CHECK-NEXT: } + +; x86_fp80 tests + +; CHECK: define internal { x86_fp80 } @diffetestintl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %1 = insertvalue { x86_fp80 } undef, x86_fp80 %0, 0 +; CHECK-NEXT: ret { x86_fp80 } %1 +; CHECK-NEXT: } + +; CHECK: define internal { x86_fp80 } @diffetestfracl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store x86_fp80 %differeturn, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %1 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %2 = fadd fast x86_fp80 %1, %0 +; CHECK-NEXT: store x86_fp80 %2, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %3 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %4 = insertvalue { x86_fp80 } undef, x86_fp80 %3, 0 +; CHECK-NEXT: ret { x86_fp80 } %4 +; CHECK-NEXT: } From 55d67c3fad170fabef9d83d10ba534a89fd69d3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 13:36:28 -0500 Subject: [PATCH 080/131] Act conservative if not affine (#1712) --- enzyme/Enzyme/FunctionUtils.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 6bc443276b4b..160ccc740f8a 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -7197,11 +7197,13 @@ getSparseConditions(bool &legal, Value *val, } } if (scope) - EmitFailure("NoSparsification", I->getDebugLoc(), I, - "F: ", *I->getParent()->getParent(), "\n", + EmitWarning("NoSparsification", *I, " No sparsification: not sparse solvable(icmp): ", *I, " via ", *sub1); - legal = false; + if (SparseDebug) { + llvm::errs() << " getSparse(icmp_dflt, " << *I + << ") = " << *defaultFloat << "\n"; + } return defaultFloat; } From c0c1070533767e11c3e97e3337724f309c60331f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 14:01:19 -0500 Subject: [PATCH 081/131] Change compile time or type analysis err into runtime (#1713) --- enzyme/Enzyme/InstructionDerivatives.td | 9 +++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 28 ++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 77a67fafe325..32b30bd8f9b4 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -823,6 +823,15 @@ def : CallPattern<(Op (Op $x, $y):$z), [ReadNone, NoUnwind] >; +def : CallPattern<(Op (Op $x, $y):$z), + ["cmplx_inv"], + [ + (CFDiv (CFNeg (DiffeRet)), (CFMul $z, $z)), + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : IntrPattern<(Op $x), [["sin"]], [(FMul (DiffeRet), (Intrinsic<"cos"> $x))] , diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index b7b900783f48..851db780723d 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -157,6 +157,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"__fd_sincos_1", Intrinsic::not_intrinsic}, {"sincospi", Intrinsic::not_intrinsic}, + {"cmplx_inv", Intrinsic::not_intrinsic}, // bessel functions {"j0", Intrinsic::not_intrinsic}, @@ -2937,7 +2938,32 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, // If ^ against 0b10000000000, the result is a float bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl); if (validXor) { - ((i == 0) ? RHS : LHS) |= TypeTree(FT).Only(-1, nullptr); + bool Legal = true; + ((i == 0) ? RHS : LHS) + .checkedOrIn(TypeTree(FT).Only(-1, nullptr), + /*pointerintsame*/ false, Legal); + + if (!Legal) { + std::string str; + raw_string_ostream ss(str); + if (!CustomErrorHandler) { + llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; + llvm::errs() << *fntypeinfo.Function << "\n"; + dump(ss); + } + ss << "Illegal updateBinop (xor up) Analysis " << *origin << "\n"; + ss << " (i=" << i << ") " << (i == 0 ? "RHS" : "LHS") << " " + << ((i == 0) ? RHS : LHS).str() << " FT from ret: " << *FT + << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(str.c_str(), wrap(origin), + ErrorType::IllegalTypeAnalysis, (void *)this, + wrap(origin), nullptr); + } + EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(), + origin, ss.str()); + report_fatal_error("Performed illegal updateAnalysis"); + } } } break; From adfc693a187438f58c010ef9c12c04909f1e475e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 21:00:48 -0500 Subject: [PATCH 082/131] MLIR Make core mlir registration fn (#1715) --- enzyme/Enzyme/ActivityAnalysis.cpp | 6 ++++++ .../CoreDialectsAutoDiffImplementations.cpp | 14 ++++++++++++++ .../CoreDialectsAutoDiffImplementations.h | 2 ++ enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp | 2 ++ enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 11 +---------- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 3175e277a397..243e71dec04c 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1432,6 +1432,12 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { if (containsPointer && !isValuePotentiallyUsedAsPointer(Val)) { containsPointer = false; + if (auto Arg = dyn_cast(Val)) { + assert(Arg->hasByValAttr()); + (void)Arg; + InsertConstantValue(TR, Val); + return true; + } } // We do this pointer dance here to ensure that any derived pointers from diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 857784472211..3bf48fc16b81 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -326,3 +326,17 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( return success(); } + +void mlir::enzyme::registerCoreDialectAutodiffInterfaces( + DialectRegistry ®istry) { + enzyme::registerAffineDialectAutoDiffInterface(registry); + enzyme::registerArithDialectAutoDiffInterface(registry); + enzyme::registerBuiltinDialectAutoDiffInterface(registry); + enzyme::registerLLVMDialectAutoDiffInterface(registry); + enzyme::registerNVVMDialectAutoDiffInterface(registry); + enzyme::registerMathDialectAutoDiffInterface(registry); + enzyme::registerMemRefDialectAutoDiffInterface(registry); + enzyme::registerSCFDialectAutoDiffInterface(registry); + enzyme::registerCFDialectAutoDiffInterface(registry); + enzyme::registerLinalgDialectAutoDiffInterface(registry); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 5a1fbc136708..974b888599d5 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -187,5 +187,7 @@ void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); + +void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 2b2760dfa8ff..ba953dbca5df 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -238,6 +238,8 @@ FunctionOpInterface CloneFunctionWithReturns( { auto &blk = NewF.getFunctionBody().front(); + assert(F.getFunctionBody().front().getNumArguments() == + constant_args.size()); for (ssize_t i = constant_args.size() - 1; i >= 0; i--) { mlir::Value oval = F.getFunctionBody().front().getArgument(i); if (constant_args[i] == DIFFE_TYPE::CONSTANT) diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 90038c1e3236..4a7b9231d1ee 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -97,16 +97,7 @@ int main(int argc, char **argv) { }); // Register the autodiff interface implementations for upstream dialects. - enzyme::registerAffineDialectAutoDiffInterface(registry); - enzyme::registerArithDialectAutoDiffInterface(registry); - enzyme::registerBuiltinDialectAutoDiffInterface(registry); - enzyme::registerLLVMDialectAutoDiffInterface(registry); - enzyme::registerNVVMDialectAutoDiffInterface(registry); - enzyme::registerMathDialectAutoDiffInterface(registry); - enzyme::registerMemRefDialectAutoDiffInterface(registry); - enzyme::registerSCFDialectAutoDiffInterface(registry); - enzyme::registerCFDialectAutoDiffInterface(registry); - enzyme::registerLinalgDialectAutoDiffInterface(registry); + enzyme::registerCoreDialectAutodiffInterfaces(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Enzyme modular optimizer driver", registry)); From 9384fe20caec02bd30f302e32f4f1c1f7ccb7d9d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 23:08:18 -0500 Subject: [PATCH 083/131] Fix enzymemlir visibility and return (#1716) --- enzyme/BUILD | 3 ++- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index 0ad691e5aae0..9bf6fadabe9a 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -7,7 +7,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//visibility:public"], ) cc_library( @@ -29,6 +29,7 @@ cc_binary( "@llvm-project//llvm:TableGen", "@llvm-project//llvm:config", ], + visibility = ["//visibility:public"], ) gentbl( diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index d96c18a68c09..a5c4d45ced95 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -86,11 +86,14 @@ struct DifferentiateWrapperPass FunctionOpInterface newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, - /*should return*/ false, mode, freeMemory, width, + /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, + width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); if (outfn == "") { fn->erase(); + SymbolTable::setSymbolName(cast(newFunc), + (std::string)infn); } else { SymbolTable::setSymbolName(cast(newFunc), (std::string)outfn); From 3dcc55a4ca93a95e269433d0ada1829c4cb55ce7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 14 Feb 2024 20:27:28 -0500 Subject: [PATCH 084/131] MLIR fix broken test (#1718) --- enzyme/test/MLIR/ForwardMode/wrap.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir index 7cef99c691f8..5ff0e1540d13 100644 --- a/enzyme/test/MLIR/ForwardMode/wrap.mlir +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -7,10 +7,10 @@ module { } } -// CHECK: func.func private @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK: func.func private @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 -// CHECK-NEXT: return %[[i2]] : f64 +// CHECK-NEXT: return %[[i3]], %[[i2]] : f64, f64 // CHECK-NEXT: } From eaa72381f5134f4f0f43cbc75fdbb2abfe4f9603 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 14 Feb 2024 20:36:16 -0500 Subject: [PATCH 085/131] MLIR change visibility for wrapper pass (#1719) --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index a5c4d45ced95..327c3d254911 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -92,6 +92,7 @@ struct DifferentiateWrapperPass /*augmented*/ nullptr); if (outfn == "") { fn->erase(); + SymbolTable::setSymbolVisibility(newFunc, SymbolTable::Visibility::Private); SymbolTable::setSymbolName(cast(newFunc), (std::string)infn); } else { From b74b7e9e8f67198c2b96724a1d042714ff6b2277 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 14 Feb 2024 20:45:28 -0500 Subject: [PATCH 086/131] MLIR embarassing bugfix (#1720) --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 327c3d254911..e50306b4a32b 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -92,7 +92,8 @@ struct DifferentiateWrapperPass /*augmented*/ nullptr); if (outfn == "") { fn->erase(); - SymbolTable::setSymbolVisibility(newFunc, SymbolTable::Visibility::Private); + SymbolTable::setSymbolVisibility(newFunc, + SymbolTable::Visibility::Public); SymbolTable::setSymbolName(cast(newFunc), (std::string)infn); } else { From 742c8d035c7920f0c0ca6f9a72146b8f886a8228 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Wed, 14 Feb 2024 22:26:54 -0800 Subject: [PATCH 087/131] Add full module truncation mode (#1717) * Add full module truncation mode * Fix optional header * Forgot test * Apply suggestions from code review Co-authored-by: Tim Gymnich * Change to command line option * Fix comp --------- Co-authored-by: Tim Gymnich --- enzyme/Enzyme/Enzyme.cpp | 86 ++++++++++++++++++- enzyme/Enzyme/EnzymeLogic.cpp | 28 +++--- enzyme/Enzyme/EnzymeLogic.h | 2 +- enzyme/test/Integration/Truncate/simple.cpp | 5 +- .../Integration/Truncate/truncate-all.cpp | 45 ++++++++++ 5 files changed, 147 insertions(+), 19 deletions(-) create mode 100644 enzyme/test/Integration/Truncate/truncate-all.cpp diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 99c429efa08e..055b6f394842 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -37,10 +37,9 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" +#include #if LLVM_VERSION_MAJOR <= 16 #include "llvm/ADT/Optional.h" -#else -#include #endif #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -113,6 +112,12 @@ llvm::cl::opt EnzymeAttributor("enzyme-attributor", cl::init(false), llvm::cl::opt EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, cl::desc("Whether to enable openmp opt")); +llvm::cl::opt EnzymeTruncateAll( + "enzyme-truncate-all", cl::init(""), cl::Hidden, + cl::desc( + "Truncate all floating point operations. " + "E.g. \"64to32\" or \"64to-\".")); + #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex #define getAttribute getAttributeAtIndex @@ -2044,6 +2049,76 @@ class EnzymeBase { return status; } + bool handleFullModuleTrunc(Function &F) { + typedef std::vector> + TruncationsTy; + static TruncationsTy FullModuleTruncs = []() -> TruncationsTy { + StringRef ConfigStr(EnzymeTruncateAll); + auto Invalid = [=]() { + // TODO emit better diagnostic + llvm::errs() << "error: invalid format for truncation config\n"; + abort(); + }; + + // "64" or "11-52" + auto parseFloatRepr = [&]() -> std::optional { + unsigned Tmp = 0; + if (ConfigStr.consumeInteger(10, Tmp)) + return {}; + if (ConfigStr.consume_front("-")) { + unsigned Tmp2 = 0; + if (ConfigStr.consumeInteger(10, Tmp2)) + Invalid(); + return FloatRepresentation(Tmp, Tmp2); + } + return getDefaultFloatRepr(Tmp); + }; + + // Parse "64to32;32to16;5-10to4-9" + TruncationsTy Tmp; + while (true) { + auto From = parseFloatRepr(); + if (!From && !ConfigStr.empty()) + Invalid(); + if (!From) + break; + if (!ConfigStr.consume_front("to")) + Invalid(); + auto To = parseFloatRepr(); + if (!To) + Invalid(); + Tmp.push_back({*From, *To}); + ConfigStr.consume_front(";"); + } + return Tmp; + }(); + + if (FullModuleTruncs.empty()) + return false; + + // TODO sort truncations (64to32, then 32to16 will make everything 16) + for (auto Truncation : FullModuleTruncs) { + IRBuilder<> Builder(F.getContext()); + RequestContext context(&*F.getEntryBlock().begin(), &Builder); + Function *TruncatedFunc = + Logic.CreateTruncateFunc(context, &F, Truncation.first, + Truncation.second, TruncOpFullModuleMode); + + ValueToValueMapTy Mapping; + for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args())) + Mapping[&TArg] = &Arg; + + // Move the truncated body into the original function + F.deleteBody(); + F.getBasicBlockList().splice(F.begin(), + TruncatedFunc->getBasicBlockList()); + RemapFunction(F, Mapping, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + TruncatedFunc->deleteBody(); + } + return true; + } + bool lowerEnzymeCalls(Function &F, std::set &done) { if (done.count(&F)) return false; @@ -2052,6 +2127,9 @@ class EnzymeBase { if (F.empty()) return false; + if (handleFullModuleTrunc(F)) + return true; + bool Changed = false; for (BasicBlock &BB : F) @@ -2629,10 +2707,10 @@ class EnzymeBase { HandleBatch(call); } for (auto call : toTruncateFuncMem) { - HandleTruncateFunc(call, TruncMem); + HandleTruncateFunc(call, TruncMemMode); } for (auto call : toTruncateFuncOp) { - HandleTruncateFunc(call, TruncOp); + HandleTruncateFunc(call, TruncOpMode); } for (auto call : toTruncateValue) { HandleTruncateValue(call, true); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index ebdbb937733d..731826bb793f 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4969,7 +4969,7 @@ class TruncateGenerator : public llvm::InstVisitor { fromType = from.getBuiltinType(B.getContext()); toType = to.getType(B.getContext()); - if (mode == TruncMem) + if (mode == TruncMemMode) tmpBlock = B.CreateAlloca(fromType); else tmpBlock = nullptr; @@ -5002,9 +5002,10 @@ class TruncateGenerator : public llvm::InstVisitor { Value *truncate(IRBuilder<> &B, Value *v) { switch (mode) { - case TruncMem: + case TruncMemMode: return floatMemTruncate(B, v, tmpBlock, from, to); - case TruncOp: + case TruncOpMode: + case TruncOpFullModuleMode: return floatValTruncate(B, v, tmpBlock, from, to); default: llvm_unreachable("Unknown trunc mode"); @@ -5013,9 +5014,10 @@ class TruncateGenerator : public llvm::InstVisitor { Value *expand(IRBuilder<> &B, Value *v) { switch (mode) { - case TruncMem: + case TruncMemMode: return floatMemExpand(B, v, tmpBlock, from, to); - case TruncOp: + case TruncOpMode: + case TruncOpFullModuleMode: return floatValExpand(B, v, tmpBlock, from, to); default: llvm_unreachable("Unknown trunc mode"); @@ -5088,7 +5090,7 @@ class TruncateGenerator : public llvm::InstVisitor { } void visitSelectInst(llvm::SelectInst &SI) { switch (mode) { - case TruncMem: { + case TruncMemMode: { auto newI = getNewFromOriginal(&SI); IRBuilder<> B(newI); auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); @@ -5101,7 +5103,8 @@ class TruncateGenerator : public llvm::InstVisitor { newI->eraseFromParent(); return; } - case TruncOp: + case TruncOpMode: + case TruncOpFullModuleMode: return; default: llvm_unreachable(""); @@ -5347,9 +5350,11 @@ class TruncateGenerator : public llvm::InstVisitor { if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall)) return; - RequestContext ctx(&CI, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); - newCall->setCalledOperand(val); + if (mode != TruncOpFullModuleMode) { + RequestContext ctx(&CI, &BuilderZ); + auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); + newCall->setCalledOperand(val); + } return; } void visitFPTruncInst(FPTruncInst &I) { return; } @@ -5426,7 +5431,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); std::string truncName = std::string("__enzyme_done_truncate_") + - (mode == TruncMem ? "mem" : "op") + "_func_" + + (mode == TruncMemMode ? "mem" : "op") + "_func_" + from.to_string() + "_" + to.to_string() + "_" + totrunc->getName().str(); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, @@ -5463,6 +5468,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } + // TODO This is overloaded an doesnt do what it should do here if (from < to) { std::string s; llvm::raw_string_ostream ss(s); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 3923f205037e..1e9bf216b6e9 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -286,7 +286,7 @@ getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { } } -enum TruncateMode { TruncMem, TruncOp }; +enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; struct FloatRepresentation { // |_|__________|_________________| diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 792366d687e0..53cf859c84da 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -15,14 +15,13 @@ double simple_add(double a, double b) { double intrinsics(double a, double b) { return sqrt(a) * pow(b, 2); } -// TODO +// TODO trunc mem mode double constt(double a, double b) { return 2; } double compute(double *A, double *B, double *C, int n) { for (int i = 0; i < n; i++) { - C[i] = A[i] * 2; - // C[i] A=[i] * 2 + B[i] * sqrt(A[i]) ; + C[i] = A[i] * 2 + B[i] * sqrt(A[i]); } return C[0]; } diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp new file mode 100644 index 000000000000..ad5df438842f --- /dev/null +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -0,0 +1,45 @@ +// Baseline +// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli -)" == "900000000.560000" ] ; fi + +// Truncated +// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli -)" == "900000000.000000" ] ; fi +// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli -)" == "900000000.000000" ] ; fi + +#include + +#include "../test_utils.h" + +#define N 10 + +#define floatty double + + +__attribute__((noinline)) +floatty simple_add(floatty a, floatty b) { + return a + b; +} +__attribute__((noinline)) +floatty intrinsics(floatty a, floatty b) { + return sqrt(a) * pow(b, 2); +} +__attribute__((noinline)) +floatty compute(floatty *A, floatty *B, floatty *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] / 2 + intrinsics(A[i], simple_add(B[i] * 10000, 0.000001)); + } + return C[0]; +} + +int main() { + floatty A[N]; + floatty B[N]; + floatty C[N]; + + for (int i = 0; i < N; i++) { + A[i] = 1 + i % 5; + B[i] = 1 + i % 3; + } + + compute(A, B, C, N); + printf("%f\n", C[5]); +} From 1b0a46dd751f204e1bbef8cc0641e3a1ae27c74f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 12:44:58 -0500 Subject: [PATCH 088/131] MLIR correctly preserve attributes for shadows (#1721) * MLIR correctly preserve attributes for shadows * fix --- enzyme/Enzyme/MLIR/Implementations/Common.td | 24 ++-- .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 105 +++++++++++++++++- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 4 +- 3 files changed, 118 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 1451d99f17ca..70f4c39f99d6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -62,17 +62,17 @@ def Op { class ArithInst : Inst; class MathInst : Inst; -def AddF : ArithInst<"arith::AddFOp">; -def SubF : ArithInst<"arith::SubFOp">; -def NegF : ArithInst<"arith::NegFOp">; -def MulF : ArithInst<"arith::MulFOp">; -def DivF : ArithInst<"arith::DivFOp">; -def RemF : ArithInst<"arith::RemFOp">; +def AddF : ArithInst<"AddFOp">; +def SubF : ArithInst<"SubFOp">; +def NegF : ArithInst<"NegFOp">; +def MulF : ArithInst<"MulFOp">; +def DivF : ArithInst<"DivFOp">; +def RemF : ArithInst<"RemFOp">; -def CheckedMulF : ArithInst<"arith::MulFOp">; -def CheckedDivF : ArithInst<"arith::DivFOp">; +def CheckedMulF : ArithInst<"MulFOp">; +def CheckedDivF : ArithInst<"DivFOp">; -def CosF : MathInst<"math::CosOp">; -def SinF : MathInst<"math::SinOp">; -def ExpF : MathInst<"math::ExpOp">; -def SqrtF : MathInst<"math::SqrtOp">; +def CosF : MathInst<"CosOp">; +def SinF : MathInst<"SinOp">; +def ExpF : MathInst<"ExpOp">; +def SqrtF : MathInst<"SqrtOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index ba953dbca5df..7e802b6a952e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -1,3 +1,5 @@ +#include "llvm/ADT/APSInt.h" + #include "CloneFunction.h" using namespace mlir; @@ -203,7 +205,7 @@ FunctionOpInterface CloneFunctionWithReturns( SmallPtrSetImpl &constants, SmallPtrSetImpl &nonconstants, SmallPtrSetImpl &returnvals, ReturnType returnValue, - DIFFE_TYPE ReturnType, Twine name, IRMapping &VMap, + DIFFE_TYPE DReturnType, Twine name, IRMapping &VMap, std::map &OpMap, bool diffeReturnArg, mlir::Type additionalArg) { assert(!F.getFunctionBody().empty()); @@ -211,7 +213,7 @@ FunctionOpInterface CloneFunctionWithReturns( // llvm::ValueToValueMapTy VMap; auto FTy = getFunctionTypeForClone( F.getFunctionType().cast(), mode, width, - additionalArg, constant_args, diffeReturnArg, returnValue, ReturnType); + additionalArg, constant_args, diffeReturnArg, returnValue, DReturnType); /* for (Block &BB : F.getFunctionBody().getBlocks()) { @@ -267,5 +269,104 @@ FunctionOpInterface CloneFunctionWithReturns( } } + std::string ToClone[] = { + "bufferization.writable", + "mhlo.sharding", + "mhlo.layout_mode", + "xla_framework.input_mapping", + "xla_framework.result_mapping", + }; + size_t newxlacnt = 0; + { + size_t oldi = 0; + size_t newi = 0; + while (oldi < F.getNumResults()) { + bool primalReturn = returnValue == ReturnType::ArgsWithReturn || + returnValue == ReturnType::ArgsWithTwoReturns || + (returnValue == ReturnType::TapeAndReturn && + DReturnType == DIFFE_TYPE::CONSTANT) || + returnValue == ReturnType::TapeAndTwoReturns || + returnValue == ReturnType::TwoReturns || + (returnValue == ReturnType::Return && + DReturnType == DIFFE_TYPE::CONSTANT); + if (primalReturn) { + for (auto attrName : ToClone) { + auto attrNameS = StringAttr::get(F->getContext(), attrName); + NewF.removeResultAttr(newi, attrNameS); + if (auto attr = F.getResultAttr(oldi, attrName)) { + if (attrName == "xla_framework.result_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setResultAttr(newi, attrNameS, attr); + } + } + newi++; + } + if (DReturnType == DIFFE_TYPE::DUP_ARG || + DReturnType == DIFFE_TYPE::DUP_NONEED) { + for (auto attrName : ToClone) { + auto attrNameS = StringAttr::get(F->getContext(), attrName); + NewF.removeResultAttr(newi, attrNameS); + if (auto attr = F.getResultAttr(oldi, attrName)) { + if (attrName == "xla_framework.result_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setResultAttr(newi, attrNameS, attr); + } + } + newi++; + } + oldi++; + } + } + { + size_t oldi = 0; + size_t newi = 0; + while (oldi < F.getNumArguments()) { + for (auto attrName : ToClone) { + NewF.removeArgAttr(newi, attrName); + if (auto attr = F.getArgAttr(oldi, attrName)) { + if (attrName == "xla_framework.input_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setArgAttr(newi, attrName, attr); + } + } + + newi++; + if (constant_args[oldi] == DIFFE_TYPE::DUP_ARG || + constant_args[oldi] == DIFFE_TYPE::DUP_NONEED) { + + for (auto attrName : ToClone) { + NewF.removeArgAttr(newi, attrName); + if (auto attr = F.getArgAttr(oldi, attrName)) { + if (attrName == "xla_framework.input_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setArgAttr(newi, attrName, attr); + } + } + newi++; + } + oldi++; + } + } + return NewF; } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 7eb015121a3e..b61cb6fa856a 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -855,7 +855,9 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, } else if (opName == "CheckedDiv") { os << "checkedDiv(" << builder << ", "; } else if (intrinsic == MLIRDerivatives) { - os << builder << ".create<" << opName << ">(op.getLoc(), "; + auto dialect = Def->getValueAsString("dialect"); + os << builder << ".create<" << dialect << "::" << opName + << ">(op.getLoc(), "; } else { os << builder << ".Create" << opName << "("; } From f808be8d07a6d35f877a08dedd8eb4ef4caf87ce Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 20:51:26 -0500 Subject: [PATCH 089/131] MLIR: handle multi-result functions (#1723) --- .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 30 ++++---- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 68 +++++++++++-------- enzyme/test/MLIR/ForwardMode/multiout.mlir | 24 +++++++ 3 files changed, 79 insertions(+), 43 deletions(-) create mode 100644 enzyme/test/MLIR/ForwardMode/multiout.mlir diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 7e802b6a952e..9b5c007c62ee 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -19,22 +19,26 @@ mlir::FunctionType getFunctionTypeForClone( SmallVector RetTypes; if (returnValue == ReturnType::ArgsWithReturn || returnValue == ReturnType::Return) { - assert(FTy.getNumResults() == 1); - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(0), width)); - } else { - RetTypes.push_back(FTy.getResult(0)); + assert(FTy.getNumResults() >= 1); + for (size_t i = 0; i < FTy.getNumResults(); i++) { + if (ReturnType != DIFFE_TYPE::CONSTANT && + ReturnType != DIFFE_TYPE::OUT_DIFF) { + RetTypes.push_back(getShadowType(FTy.getResult(i), width)); + } else { + RetTypes.push_back(FTy.getResult(i)); + } } } else if (returnValue == ReturnType::ArgsWithTwoReturns || returnValue == ReturnType::TwoReturns) { - assert(FTy.getNumResults() == 1); - RetTypes.push_back(FTy.getResult(0)); - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(0), width)); - } else { - RetTypes.push_back(FTy.getResult(0)); + assert(FTy.getNumResults() >= 1); + for (size_t i = 0; i < FTy.getNumResults(); i++) { + RetTypes.push_back(FTy.getResult(i)); + if (ReturnType != DIFFE_TYPE::CONSTANT && + ReturnType != DIFFE_TYPE::OUT_DIFF) { + RetTypes.push_back(getShadowType(FTy.getResult(i), width)); + } else { + RetTypes.push_back(FTy.getResult(i)); + } } } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 1da6904893f8..064b6f181c95 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -46,44 +46,52 @@ void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, switch (retVal) { case ReturnType::Return: { - auto ret = inst->getOperand(0); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); - } else { - Type retTy = ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue(nBuilder, - ret.getLoc()); + for (size_t i = 0; i < inst->getNumOperands(); i++) { + auto ret = inst->getOperand(i); + + mlir::Value toret; + if (retType == DIFFE_TYPE::CONSTANT) { + toret = gutils->getNewFromOriginal(ret); + } else if (!isa(ret.getType()) && + true /*type analysis*/) { + toret = gutils->invertPointerM(ret, nBuilder); + } else if (!gutils->isConstantValue(ret)) { + toret = gutils->invertPointerM(ret, nBuilder); + } else { + Type retTy = + ret.getType().cast().getShadowType(); + toret = retTy.cast().createNullValue( + nBuilder, ret.getLoc()); + } + retargs.push_back(toret); } - retargs.push_back(toret); break; } case ReturnType::TwoReturns: { if (retType == DIFFE_TYPE::CONSTANT) assert(false && "Invalid return type"); - auto ret = inst->getOperand(0); - - retargs.push_back(gutils->getNewFromOriginal(ret)); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); - } else { - Type retTy = ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue(nBuilder, - ret.getLoc()); + for (size_t i = 0; i < inst->getNumOperands(); i++) { + auto ret = inst->getOperand(i); + + retargs.push_back(gutils->getNewFromOriginal(ret)); + + mlir::Value toret; + if (retType == DIFFE_TYPE::CONSTANT) { + toret = gutils->getNewFromOriginal(ret); + } else if (!isa(ret.getType()) && + true /*type analysis*/) { + toret = gutils->invertPointerM(ret, nBuilder); + } else if (!gutils->isConstantValue(ret)) { + toret = gutils->invertPointerM(ret, nBuilder); + } else { + Type retTy = + ret.getType().cast().getShadowType(); + toret = retTy.cast().createNullValue( + nBuilder, ret.getLoc()); + } + retargs.push_back(toret); } - retargs.push_back(toret); break; } case ReturnType::Void: { diff --git a/enzyme/test/MLIR/ForwardMode/multiout.mlir b/enzyme/test/MLIR/ForwardMode/multiout.mlir new file mode 100644 index 000000000000..e42322a7f74f --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/multiout.mlir @@ -0,0 +1,24 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> (f64, f64) { + %y = arith.mulf %x, %x : f64 + return %y, %x : f64, f64 + } + func.func @dsq(%x : f64, %dx : f64) -> (f64, f64) { + %r:2 = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64, f64) + return %r#0, %r#1 : f64, f64 + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { +// CHECK-NEXT: %[[i0:[0-9]+]]:2 = call @fwddiffesquare(%[[arg0]], %[[arg1]]) : (f64, f64) -> (f64, f64) +// CHECK-NEXT: return %[[i0]]#0, %[[i0]]#1 : f64, f64 +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: return %[[i2]], %[[arg1]] : f64, f64 +// CHECK-NEXT: } From f463384c7f3ae601db25acdb9213e03f7a4daaba Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 21:17:15 -0500 Subject: [PATCH 090/131] MLIR better error for incorrect arg activity count (#1724) --- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 3 +++ enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 064b6f181c95..edd21fbbf1ea 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -126,6 +126,9 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); } + assert(fn.getFunctionBody().front().getNumArguments() == constants.size()); + assert(fn.getFunctionBody().front().getNumArguments() == + volatile_args.size()); MForwardCacheKey tup = { fn, retType, constants, diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index e50306b4a32b..1b37e03383a8 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -71,6 +71,13 @@ struct DifferentiateWrapperPass } } + if (constants.size() != fn.getFunctionBody().front().getNumArguments()) { + fn->emitError() + << "Incorrect number of arg activity states for function, found " + << split; + return; + } + DIFFE_TYPE retType = retTy.getValue(); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); @@ -108,7 +115,6 @@ struct DifferentiateWrapperPass namespace mlir { namespace enzyme { std::unique_ptr createDifferentiateWrapperPass() { - new DifferentiateWrapperPass(); return std::make_unique(); } } // namespace enzyme From 2c753c97fcb41623e9aca972edfc08202b23e04f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 22:56:23 -0500 Subject: [PATCH 091/131] MLIR fix activity analysis assertion error (#1725) --- enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 349213ac2fbf..06b3257cd3a8 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -2598,12 +2598,13 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( // return true; // } Value callVal = call.getCallableForCallee().dyn_cast(); - if (isConstantValue(TR, callVal)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-constfn " - // << *inst << " - " << *callVal << "\n"; - return true; - } + if (callVal) + if (isConstantValue(TR, callVal)) { + // if (EnzymePrintActivity) + // llvm::errs() << "constant(" << (int)directions << ") up-constfn " + // << *inst << " - " << *callVal << "\n"; + return true; + } } if (auto gep = dyn_cast(op)) { From bf0261421bdbb78d61c0117aef4ca7dbd1f30575 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 23:58:48 -0500 Subject: [PATCH 092/131] MLIR constantfp (#1726) --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 29 ++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index b61cb6fa856a..fd2ab67ee725 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -449,17 +449,30 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") + resultTree->getAsString()); - os << "ConstantFP::get("; - if (resultRoot->getArgName(0)) { + if (intrinsic == MLIRDerivatives) { + os << builder << ".create<" + << cast(Def->getValueInit("dialect"))->getValue() + << "::" << cast(Def->getValueInit("opName"))->getValue() + << ">(op.getLoc(), "; auto name = resultRoot->getArgName(0)->getAsUnquotedString(); auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); assert(!isVec); - os << ord; - } else - PrintFatalError(pattern->getLoc(), - Twine("unknown named operand in constantfp") + - resultTree->getAsString()); - os << "->getType(), \"" << value->getValue() << "\")"; + os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), "; + os << "\"" << value->getValue() << "\"))"; + } else { + + os << "ConstantFP::get("; + if (resultRoot->getArgName(0)) { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); + assert(!isVec); + os << ord; + } else + PrintFatalError(pattern->getLoc(), + Twine("unknown named operand in constantfp") + + resultTree->getAsString()); + os << "->getType(), \"" << value->getValue() << "\")"; + } return false; } else if (opName == "Zero" || Def->isSubClassOf("Zero")) { if (resultRoot->getNumArgs() != 1) From 4ccab29dc691cb43d250a7c5ca612c3ff9cd23e3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 16 Feb 2024 01:04:31 -0500 Subject: [PATCH 093/131] MLIR improve error handling --- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 7 +++++- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 3 ++- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 24 +++++++++++++++---- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index edd21fbbf1ea..3025b5963fcf 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -183,6 +183,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( unnecessaryInstructions, gutils, TLI); */ + bool valid = true; for (Block &oBB : gutils->oldFunc.getFunctionBody().getBlocks()) { // Don't create derivatives for code that results in termination if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { @@ -205,7 +206,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end()); for (auto it = first; it != last; ++it) { // TODO: propagate errors. - (void)gutils->visitChild(&*it); + auto res = gutils->visitChild(&*it); + valid &= res.succeeded(); } createTerminator(gutils, &oBB, retType, returnValue); @@ -232,6 +234,9 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto nf = gutils->newFunc; delete gutils; + if (!valid) + return nullptr; + // if (PostOpt) // PPC.optimizeIntermediate(nf); // if (EnzymePrint) { diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 467a8f59ec6e..45473c75921f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -318,5 +318,6 @@ LogicalResult MGradientUtils::visitChild(Operation *op) { return iface.createForwardModeTangent(builder, this); } } - return op->emitError() << "could not compute the adjoint for this operation"; + return op->emitError() << "could not compute the adjoint for this operation " + << *op; } diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index efe4b7d53f76..1e2c55640280 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -35,7 +35,7 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; template - void HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { + LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { std::vector constants; SmallVector args; @@ -83,16 +83,20 @@ struct DifferentiatePass : public DifferentiatePassBase { /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); + if (!newFunc) + return failure(); OpBuilder builder(CI); auto dCI = builder.create(CI.getLoc(), newFunc.getName(), newFunc.getResultTypes(), args); CI.replaceAllUsesWith(dCI); CI->erase(); + return success(); } template - void HandleAutoDiffReverse(SymbolTableCollection &symbolTable, T CI) { + LogicalResult HandleAutoDiffReverse(SymbolTableCollection &symbolTable, + T CI) { std::vector constants; SmallVector args; @@ -144,12 +148,15 @@ struct DifferentiatePass : public DifferentiatePassBase { /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr, symbolTable); + if (!newFunc) + return failure(); OpBuilder builder(CI); auto dCI = builder.create(CI.getLoc(), newFunc.getName(), newFunc.getResultTypes(), args); CI.replaceAllUsesWith(dCI); CI->erase(); + return success(); } void lowerEnzymeCalls(SymbolTableCollection &symbolTable, @@ -167,7 +174,11 @@ struct DifferentiatePass : public DifferentiatePassBase { for (auto T : toLower) { if (auto F = dyn_cast(T)) { - HandleAutoDiff(symbolTable, F); + auto res = HandleAutoDiff(symbolTable, F); + if (!res.succeeded()) { + signalPassFailure(); + return; + } } else { llvm_unreachable("Illegal type"); } @@ -187,7 +198,11 @@ struct DifferentiatePass : public DifferentiatePassBase { for (auto T : toLower) { if (auto F = dyn_cast(T)) { - HandleAutoDiffReverse(symbolTable, F); + auto res = HandleAutoDiffReverse(symbolTable, F); + if (!res.succeeded()) { + signalPassFailure(); + return; + } } else { llvm_unreachable("Illegal type"); } @@ -201,7 +216,6 @@ struct DifferentiatePass : public DifferentiatePassBase { namespace mlir { namespace enzyme { std::unique_ptr createDifferentiatePass() { - new DifferentiatePass(); return std::make_unique(); } } // namespace enzyme From dbfa740fa8c170f828a8396267eafff169779b42 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 16 Feb 2024 11:28:38 -0500 Subject: [PATCH 094/131] MLIR improve tblgen --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 4 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 18 ++++++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 1b37e03383a8..b8d91f2c82e8 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -97,6 +97,10 @@ struct DifferentiateWrapperPass width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); + if (!newFunc) { + signalPassFailure(); + return; + } if (outfn == "") { fn->erase(); SymbolTable::setSymbolVisibility(newFunc, diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index fd2ab67ee725..1e1ac3697095 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -454,9 +454,15 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, << cast(Def->getValueInit("dialect"))->getValue() << "::" << cast(Def->getValueInit("opName"))->getValue() << ">(op.getLoc(), "; - auto name = resultRoot->getArgName(0)->getAsUnquotedString(); - auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); - assert(!isVec); + std::string ord; + if (resultRoot->getNumArgs() == 0) { + ord = "op->getResult(0)"; + } else { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord1, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); + assert(!isVec); + ord = ord1; + } os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), "; os << "\"" << value->getValue() << "\"))"; } else { @@ -1636,7 +1642,11 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } } else { - os << " Value *res = "; + if (intrinsic == MLIRDerivatives) { + os << " mlir::Value res = nullptr;\n"; + } else { + os << " Value *res = "; + } ArrayRef retidx{}; bool vectorValued = handle(" ", "fwdnsrarg", os, pattern, duals, "Builder2", From d158f8350741d7a07e26512ecb319c6986e21e1e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 16 Feb 2024 19:52:54 +0000 Subject: [PATCH 095/131] [bazel] export common derivative definitions --- enzyme/BUILD | 7 +++++++ enzyme/Enzyme/MLIR/Implementations/Common.td | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/enzyme/BUILD b/enzyme/BUILD index 9bf6fadabe9a..7c78c75a4051 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -389,6 +389,13 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) +td_library( + name = "ImplementationsCommonTdFiles", + srcs = [ + "Enzyme/MLIR/Implementations/Common.td", + ], +) + gentbl( name = "affine-derivatives", tbl_outs = [( diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 70f4c39f99d6..f558d1225a76 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -1,3 +1,6 @@ +#ifndef ENZYME_MLIR_IMPLEMENTATIONS_COMMON +#define ENZYME_MLIR_IMPLEMENTATIONS_COMMON + class InactiveOp { string dialect = dialect_; string opName = opName_; @@ -76,3 +79,5 @@ def CosF : MathInst<"CosOp">; def SinF : MathInst<"SinOp">; def ExpF : MathInst<"ExpOp">; def SqrtF : MathInst<"SqrtOp">; + +#endif // ENZYME_MLIR_IMPLEMENTATIONS_COMMON From 2eb83870cbd8b26fc969ea99e691609becb1fb91 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 16 Feb 2024 14:59:15 -0500 Subject: [PATCH 096/131] MLIR more error messages for interface internals --- .../CoreDialectsAutoDiffImplementations.cpp | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 3bf48fc16b81..c18a7bd0e921 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -136,7 +136,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( } } } - orig->emitWarning() + orig->emitError() << "Unsupported constant arg to memory identity forward " "handler(opidx=" << operand.getOperandNumber() << ", op=" << operand.get() << ")\n"; @@ -175,8 +175,9 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler( if (auto iface = dyn_cast(shadowRes.getType())) { return iface.zeroInPlace(builder, orig->getLoc(), shadowRes); } else { - orig->emitWarning() << "memref.alloc element type does not implement " - "AutoDiffTypeInterface"; + orig->emitError() << "Type " << shadowRes.getType() + << " does not implement " + "AutoDiffTypeInterface"; return failure(); } } @@ -235,16 +236,22 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( if (gutils->isConstantValue(result)) continue; auto typeIface = dyn_cast(result.getType()); - if (!typeIface) + if (!typeIface) { + op->emitError() << " AutoDiffTypeInterface not implemented for " + << result.getType() << "\n"; return failure(); + } newOpResultTypes.push_back(typeIface.getShadowType()); } // For all operands that are forwarded to the body, if they are active, also // add the shadow as operand. auto regionBranchOp = dyn_cast(op); - if (!regionBranchOp) + if (!regionBranchOp) { + op->emitError() << " RegionBranchOpInterface not implemented for " << *op + << "\n"; return failure(); + } SmallVector successors; // TODO: we may need to record, for every successor, which of its inputs @@ -283,8 +290,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( // yielded by terminators, and only those values. auto iface = dyn_cast(op); - if (!iface) + if (!iface) { + op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for " + << *op << "\n"; return failure(); + } Operation *replacement = iface.createWithShadows( builder, gutils, op, newOperands, newOpResultTypes); for (auto &&[region, replacementRegion] : @@ -314,8 +324,9 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( for (auto &origRegion : op->getRegions()) { for (auto &origBlock : origRegion) { for (Operation &o : origBlock) { - if (failed(gutils->visitChild(&o))) + if (failed(gutils->visitChild(&o))) { return failure(); + } } } } From 26ef37ea96d4f94a24409d5e59329f1255c0ac95 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 16 Feb 2024 17:57:24 -0500 Subject: [PATCH 097/131] Add smax, smin, umin, umax to type analysis (#1634) * Add s/u/min/max * Add opaque pointer in test * Update test and llvm version * Add s/u/min/max * Add opaque pointer in test * Update test * Pointer type case * Update test * Update test * Lost line * Break missing * Update logic * Update function * Update * Update test * Update test * gt llvm 11 * Add version check in test * Update test --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 29 ++++++++++++++++++++- enzyme/test/TypeAnalysis/smax.ll | 26 ++++++++++++++++++ enzyme/test/TypeAnalysis/smax0.ll | 23 ++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/TypeAnalysis/smax.ll create mode 100644 enzyme/test/TypeAnalysis/smax0.ll diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 851db780723d..04d1822d4e23 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -3892,7 +3892,34 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { .Only(-1, &I), &I); return; - +#if LLVM_VERSION_MAJOR >= 12 + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: + if (direction & UP) { + auto returnType = getAnalysis(&I)[{-1}]; + if (returnType == BaseType::Integer || returnType == BaseType::Pointer) { + updateAnalysis(I.getOperand(0), TypeTree(returnType).Only(-1, &I), &I); + updateAnalysis(I.getOperand(1), TypeTree(returnType).Only(-1, &I), &I); + } + } + if (direction & DOWN) { + auto opType0 = getAnalysis(I.getOperand(0))[{-1}]; + auto opType1 = getAnalysis(I.getOperand(1))[{-1}]; + if (opType0 == opType1 && + (opType0 == BaseType::Integer || opType0 == BaseType::Pointer)) { + updateAnalysis(&I, TypeTree(opType0).Only(-1, &I), &I); + } else if (opType0 == BaseType::Integer && + opType1 == BaseType::Anything) { + updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); + } else if (opType1 == BaseType::Integer && + opType0 == BaseType::Anything) { + updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); + } + } + return; +#endif case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: case Intrinsic::ssub_with_overflow: diff --git a/enzyme/test/TypeAnalysis/smax.ll b/enzyme/test/TypeAnalysis/smax.ll new file mode 100644 index 000000000000..b050ccc9ef18 --- /dev/null +++ b/enzyme/test/TypeAnalysis/smax.ll @@ -0,0 +1,26 @@ +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=smax -o /dev/null | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=smax -S -o /dev/null | FileCheck %s; fi + +define i32 @smax(i32 %a, i32 %b) { +entry: + %0 = call i32 @llvm.smax.i32(i32 %a, i32 %b) + %1 = call i32 @getint() + %2 = call i32 @getint() + %3 = call i32 @llvm.smax.i32(i32 %1, i32 %2) + ret i32 %3 +} + +declare i32 @llvm.smax.i32(i32, i32) + +declare i32 @getint() + + +; CHECK: smax - {[-1]:Integer} |{[-1]:Integer}:{} {[-1]:Integer}:{} +; CHECK-NEXT: i32 %a: {[-1]:Integer} +; CHECK-NEXT: i32 %b: {[-1]:Integer} +; CHECK-NEXT: entry +; CHECK-NEXT: %0 = call i32 @llvm.smax.i32(i32 %a, i32 %b): {[-1]:Integer} +; CHECK-NEXT: %1 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %2 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %3 = call i32 @llvm.smax.i32(i32 %1, i32 %2): {[-1]:Integer} +; CHECK-NEXT: ret i32 %3: {} diff --git a/enzyme/test/TypeAnalysis/smax0.ll b/enzyme/test/TypeAnalysis/smax0.ll new file mode 100644 index 000000000000..de79bcde70bb --- /dev/null +++ b/enzyme/test/TypeAnalysis/smax0.ll @@ -0,0 +1,23 @@ +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=smax0 -o /dev/null | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=smax0 -S -o /dev/null | FileCheck %s; fi + +define i32 @smax0(i32 %a, i32 %b) { +entry: + %0 = call i32 @llvm.smax.i32(i32 %a, i32 0) + %1 = call i32 @getint() + %2 = call i32 @llvm.smax.i32(i32 %1, i32 0) + ret i32 %2 +} + +declare i32 @llvm.smax.i32(i32, i32) + +declare i32 @getint() + +; CHECK: smax0 - {[-1]:Integer} |{[-1]:Integer}:{} {[-1]:Integer}:{} +; CHECK-NEXT: i32 %a: {[-1]:Integer} +; CHECK-NEXT: i32 %b: {[-1]:Integer} +; CHECK-NEXT: entry +; CHECK-NEXT: %0 = call i32 @llvm.smax.i32(i32 %a, i32 0): {[-1]:Integer} +; CHECK-NEXT: %1 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %2 = call i32 @llvm.smax.i32(i32 %1, i32 0): {[-1]:Integer} +; CHECK-NEXT: ret i32 %2: {} From 12030cd074a6a167807de8eff733ee9b67e2a42f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 16 Feb 2024 18:43:48 -0500 Subject: [PATCH 098/131] Update opaque pointer test (#1732) --- enzyme/test/Enzyme/ReverseMode/mlirmincut.ll | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll index 6f214ddb2cdc..dc0e9599fe22 100644 --- a/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll +++ b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll @@ -1,4 +1,4 @@ -; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %OPnewLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi declare void @__enzyme_autodiff0(...) local_unnamed_addr From 616a88cfa03c36f61afa325f03582935b271d307 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 17 Feb 2024 21:32:51 -0500 Subject: [PATCH 099/131] MLIR add more general control flow handler interface (#1731) * MLIR add more general control flow handler interface * fixup act * more activ improvements --- .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 1029 +++++++++-------- .../Enzyme/MLIR/Analysis/ActivityAnalysis.h | 10 +- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 8 +- .../CoreDialectsAutoDiffImplementations.cpp | 80 +- .../CoreDialectsAutoDiffImplementations.h | 9 + .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 6 + .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 11 +- enzyme/Enzyme/MLIR/Passes/Passes.td | 7 + .../MLIR/Passes/PrintActivityAnalysis.cpp | 85 +- enzyme/test/MLIR/ActivityAnalysis/region.mlir | 27 + 10 files changed, 720 insertions(+), 552 deletions(-) create mode 100644 enzyme/test/MLIR/ActivityAnalysis/region.mlir diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 06b3257cd3a8..30fbfc8ef345 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -248,6 +248,8 @@ static Operation *getFunctionFromCall(CallOpInterface iface) { return SymbolTable::lookupNearestSymbolFrom(iface.getOperation(), symbol); } +constexpr bool EnzymePrintActivity = true; + /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode bool mlir::enzyme::ActivityAnalyzer::isFunctionArgumentConstant( @@ -465,7 +467,7 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // The return instruction doesn't impact activity (handled specifically // during adjoint generation) - if (isa(I)) + if (I->hasTrait()) return true; if (auto ifaceOp = dyn_cast(I)) { @@ -485,9 +487,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, } if (notForAnalysis.count(I->getBlock())) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as dominates unreachable " << *I - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as dominates unreachable " << *I + << "\n"; InsertConstantOperation(TR, I); return true; } @@ -495,14 +497,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (auto CI = dyn_cast(I)) { // TODO(PR #904): This needs to be put into the enzyme dialect if (CI->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active " << *I << "\n"; ActiveOperations.insert(I); return false; } if (CI->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -510,14 +512,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (called) { if (called->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active " << *I << "\n"; ActiveOperations.insert(I); return false; } if (called->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -683,11 +685,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // If all returned values constant otherwise, the operation is inactive if (llvm::all_of(I->getResults(), [&](Value v) { return isConstantValue(TR, v); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from known constant - // non-writing " - // "instruction " - // << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction from known constant non-writing " + "instruction " + << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -710,9 +711,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (llvm::all_of(I->getResults(), [&](Value val) { return isValueInactiveFromUsers(TR, val, UseActivity::None); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction[" << (int)directions - // << "] from users instruction " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction[" << (int)directions + << "] from users instruction " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -724,9 +725,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, return DownHypothesis->isValueInactiveFromUsers( TR, val, UseActivity::None); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction[" << (int)directions - // << "] from users instruction " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction[" << (int)directions + << "] from users instruction " << *I << "\n"; InsertConstantOperation(TR, I); insertConstantsFrom(TR, *DownHypothesis); return true; @@ -748,65 +749,42 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, new mlir::enzyme::ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantOperations.insert(I); assert(directions & UP); - if (UpHypothesis->isOperationInactiveFromOrigin(TR, I)) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from origin " - // "instruction " - // << *I << "\n"; + SmallPtrSet toredo; + if (UpHypothesis->isOperationInactiveFromOrigin(TR, I, std::nullopt, + &toredo)) { + if (EnzymePrintActivity) + llvm::errs() << " constant instruction from origin " + "instruction " + << *I << "\n"; InsertConstantOperation(TR, I); insertConstantsFrom(TR, *UpHypothesis); if (DownHypothesis) insertConstantsFrom(TR, *DownHypothesis); return true; } else if (directions == (UP | DOWN)) { - // TODO: what does this mean for interfaces? - if (isa< - // clang-format off - LLVM::LoadOp, - LLVM::StoreOp, - // Integer binary ops. - LLVM::AddOp, - LLVM::SubOp, - LLVM::MulOp, - LLVM::UDivOp, - LLVM::SDivOp, - LLVM::URemOp, - LLVM::SRemOp, - LLVM::AndOp, - LLVM::OrOp, - LLVM::XOrOp, - LLVM::ShlOp, - LLVM::LShrOp, - LLVM::AShrOp, - // Float binary ops. - LLVM::FAddOp, - LLVM::FSubOp, - LLVM::FMulOp, - LLVM::FDivOp, - LLVM::FRemOp, - LLVM::FNegOp - // clang-format on - >(I)) { - for (Value operand : I->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - ReEvaluateOpIfInactiveValue[operand].insert(I); - } - } + for (Value operand : toredo) { + ReEvaluateOpIfInactiveValue[operand].insert(I); } } } // Otherwise we must fall back and assume this instruction to be active. ActiveOperations.insert(I); - // if (EnzymePrintActivity) - // llvm::errs() << "couldnt decide fallback as nonconstant instruction(" - // << (int)directions << "):" << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "couldnt decide fallback as nonconstant instruction(" + << (int)directions << "):" << *I << "\n"; if (noActiveWrite && (directions == (UP | DOWN))) for (Value result : I->getResults()) ReEvaluateOpIfInactiveValue[result].insert(I); return false; } +static bool isFunctionReturn(Operation *op) { + if (!op->hasTrait()) + return false; + return dyn_cast(op->getParentOp()); +} + static bool isValuePotentiallyUsedAsPointer(Value val) { std::deque todo = {val}; SmallPtrSet seen; @@ -817,7 +795,39 @@ static bool isValuePotentiallyUsedAsPointer(Value val) { continue; seen.insert(cur); for (Operation *user : cur.getUsers()) { - if (isa(user)) + if (isa(user->getParentOp())) + if (auto termIface = + dyn_cast(user)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + auto parentOp = termIface->getParentOp(); + for (auto &successor : successors) { + OperandRange operandRange = + termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) { + if (prev == cur) { + todo.push_back(post); + } + } + } + continue; + } + if (auto iface = dyn_cast(user)) { + for (auto &op : user->getOpOperands()) + if (op.get() == cur) + if (auto blk = + iface.getSuccessorBlockArgument(op.getOperandNumber())) + todo.push_back(*blk); + continue; + } + if (isFunctionReturn(user)) return true; // The operation is known not to read or write memory. if (isa(user) && @@ -828,10 +838,9 @@ static bool isValuePotentiallyUsedAsPointer(Value val) { } continue; } - // if (EnzymePrintActivity) - // llvm::errs() << " VALUE potentially used as pointer " << *val << " by - // " - // << *u << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " VALUE potentially used as pointer " << val << " by " + << *user << "\n"; return true; } } @@ -889,7 +898,134 @@ static FunctionOpInterface getFunctionIfArgument(Value value) { return dyn_cast(block->getParentOp()); } -// TODO: move the extraction based on dataflow here. +// For a given instruction, determine whether it is a terminator which +// controls dataflow out, and if so return all users either in results +// or blockarguments +static std::optional> +getPotentialTerminatorUsers(Operation *op, Value parent) { + auto block = op->getBlock(); + + if (block->getTerminator() != op) + return {}; + if (isFunctionReturn(op)) + return {}; + + SmallVector results; + + if (isa(op->getParentOp())) + if (auto termIface = dyn_cast(op)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + auto parentOp = termIface->getParentOp(); + SmallVector results; + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) { + if (prev == parent) { + results.push_back(post); + } + } + } + return std::move(results); + } + if (auto iface = dyn_cast(op)) { + for (auto &operand : op->getOpOperands()) + if (operand.get() == parent) + if (auto blk = + iface.getSuccessorBlockArgument(operand.getOperandNumber())) { + results.push_back(*blk); + return std::move(results); + } + } + + // assume all terminator operands potentially flow into all op results + for (auto res : op->getParentOp()->getResults()) + results.push_back(res); + + // assume all terminator operands potentially flow into all blockArgs in + // region + for (auto &blk : *block->getParent()) + for (auto arg : blk.getArguments()) + results.push_back(arg); + + // assume all terminator operands potentially flow into all other region + // entries + for (auto ® : op->getParentOp()->getRegions()) + for (auto arg : reg.front().getArguments()) + results.push_back(arg); + + return std::move(results); +} + +// For a result of an op, find all values which could flow into this result +static SmallVector getPotentialIncomingValues(OpResult res) { + Operation *owner = res.getOwner(); + SmallVector potentialSources; + + auto resultNo = res.getResultNumber(); + + if (auto iface = dyn_cast(owner)) { + SmallVector successors; + iface.getSuccessorRegions(RegionBranchPoint::parent(), successors); + for (auto &succ : successors) { + if (!succ.isParent()) + continue; + auto successorOperands = + llvm::to_vector(iface.getEntrySuccessorOperands(succ)); + + if (successorOperands.size() != owner->getNumResults()) { + llvm::errs() << *owner << "\n"; + } + assert(successorOperands.size() == owner->getNumResults() && + "expected all results to be populated with incoming operands"); + + potentialSources.push_back(successorOperands[resultNo]); + } + } else { + // assume all inputs potentially flow into all op results + for (auto operand : owner->getOperands()) { + potentialSources.push_back(operand); + } + } + + for (Region ®ion : owner->getRegions()) { + for (Block &block : region) { + // TODO: MLIR blocks without terminator? + if (auto iface = dyn_cast( + block.getTerminator())) { + // TODO: the interface may also tell us which regions are allowed to + // yield parent op results, and which only branch to other regions. + auto successorOperands = llvm::to_vector( + iface.getSuccessorOperands(RegionBranchPoint::parent())); + // TODO: understand/document the assumption of how operands flow. + + if (successorOperands.size() != owner->getNumResults()) { + llvm::errs() << *owner << "\n"; + } + assert(successorOperands.size() == owner->getNumResults() && + "expected all results to be populated with yielded " + "terminator operands"); + potentialSources.push_back(successorOperands[resultNo]); + } else { + // assume all terminator operands potentially flow into op results + for (Value v : block.getTerminator()->getOperands()) + potentialSources.push_back(v); + } + } + } + + return potentialSources; +} + +// For a blockargument, find all non-operand values which could flow into +// this result static SmallVector getPotentialIncomingValues(BlockArgument arg) { SetVector potentialSources; @@ -986,6 +1122,10 @@ static SmallVector getPotentialIncomingValues(BlockArgument arg) { potentialSources.insert(v); } } + + // and also any operand to the parent + for (auto op : parent->getOperands()) + potentialSources.insert(op); } return potentialSources.takeVector(); @@ -1144,8 +1284,16 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } if (auto arg = Val.dyn_cast()) { - auto funcIface = dyn_cast_or_null( - arg.getParentBlock()->getParentOp()); + // All arguments must be marked constant/nonconstant ahead of time + if (auto funcIface = dyn_cast_or_null( + arg.getParentBlock()->getParentOp())) + if (funcIface && arg.getOwner()->isEntryBlock() && + !funcIface.getArgAttr(arg.getArgNumber(), + LLVM::LLVMDialect::getByValAttrName())) { + llvm::errs() << funcIface << "\n"; + llvm::errs() << Val << "\n"; + assert(0 && "must've put arguments in constant/nonconstant"); + } // if (!funcIface || !arg.getOwner()->isEntryBlock()) { // TODO: we want a more advanced analysis based on MLIR interfaces here // For now, conservatively assume all block arguments are active @@ -1174,25 +1322,15 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // } // } // } - - // All arguments must be marked constant/nonconstant ahead of time - if (funcIface && arg.getOwner()->isEntryBlock() && - !funcIface.getArgAttr(arg.getArgNumber(), - LLVM::LLVMDialect::getByValAttrName())) { - llvm::errs() << funcIface << "\n"; - llvm::errs() << Val << "\n"; - assert(0 && "must've put arguments in constant/nonconstant"); - } } // This value is certainly an integer (and only and integer, not a pointer or // float). Therefore its value is constant if (TR.intType(1, Val, /*errIfNotFound*/ false).isIntegral()) { - // if (EnzymePrintActivity) - // llvm::errs() << " Value const as integral " << (int)directions << " " - // << *Val << " " - // << TR.intType(1, Val, /*errIfNotFound*/ false).str() << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Value const as integral " << (int)directions << " " + << Val << " " + << TR.intType(1, Val, /*errIfNotFound*/ false).str() << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1368,28 +1506,28 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (auto CI = Val.getDefiningOp()) { if (CI->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active val " << Val << "\n"; ActiveValues.insert(Val); return false; } if (CI->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive val " << Val << "\n"; InsertConstantValue(TR, Val); return true; } Operation *called = getFunctionFromCall(CI); if (called) { if (called->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active val " << Val << "\n"; ActiveValues.insert(Val); return false; } if (called->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive val " << Val << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1440,14 +1578,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, LLVM::LLVMDialect::getByValAttrName())) { bool res = isConstantValue(TR, TmpOrig); if (res) { - // if (EnzymePrintActivity) - // llvm::errs() << " arg const from orig val=" << *Val - // << " orig=" << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " arg const from orig val=" << Val + << " orig=" << TmpOrig << "\n"; InsertConstantValue(TR, Val); } else { - // if (EnzymePrintActivity) - // llvm::errs() << " arg active from orig val=" << *Val - // << " orig=" << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " arg active from orig val=" << Val + << " orig=" << TmpOrig << "\n"; ActiveValues.insert(Val); } return res; @@ -1722,10 +1860,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // argument if (TmpOrig != Val) { if (isConstantValue(TR, TmpOrig)) { - // if (EnzymePrintActivity) - // llvm::errs() << " Potential Pointer(" << (int)directions << ") " - // << *Val << " inactive from inactive origin " - // << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Potential Pointer(" << (int)directions << ") " + << Val << " inactive from inactive origin " << TmpOrig + << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1737,11 +1875,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (!op || (!mayReadFromMemory(op) && !mayAllocateMemory(op))) { if (directions == UP && !Val.isa()) { if (isValueInactiveFromOrigin(TR, Val)) { + if (EnzymePrintActivity) + llvm::errs() << " Non-function value inactive from origin(" + << (int)directions << ") " << Val << "\n"; InsertConstantValue(TR, Val); return true; } } else { if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { + if (EnzymePrintActivity) + llvm::errs() << " Non-function value_v2 inactive from origin(" + << (int)directions << ") " << Val << "\n"; InsertConstantValue(TR, Val); insertConstantsFrom(TR, *UpHypothesis); return true; @@ -1755,9 +1899,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // can be loaded/stored cannot be assesed and therefore we default to assume // it to be active if (directions != (UP | DOWN)) { - // if (EnzymePrintActivity) - // llvm::errs() << " " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << Val << "\n"; ActiveValues.insert(Val); return false; } @@ -1785,13 +1929,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { Hypothesis->DeducingPointers.insert(Val); - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction hypothesis: " << *VI << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction hypothesis: " << Val << "\n"; } else { - // if (EnzymePrintActivity) - // llvm::errs() << " cannot show constant instruction hypothesis: " - // << *VI << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " cannot show constant instruction hypothesis: " + << Val << "\n"; } } @@ -1920,8 +2063,8 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If we haven't already shown a potentially active load // check if this loads the given value and is active if (!potentiallyActiveLoad && isRefSet(modRef)) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active load: " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active load: " << *op << "\n"; if (isa(op)) { // TODO: this assumption should be built into the MLIR interface // verifier, or alternatively we should relax it. @@ -1945,12 +2088,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, for (Operation *user : V.getUsers()) { if (mayWriteToMemory(user)) { if (!Hypothesis->isConstantOperation(TR, user)) { - // if (EnzymePrintActivity) - // llvm::errs() - // << "potential active store via " - // "pointer in load: " - // << *I << " of " << *Val << " via " << *U << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store via " + "pointer in load: " + << *op << " of " << Val << " via " + << *user << "\n"; potentiallyActiveStore = true; return true; } @@ -2031,10 +2173,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, !Hypothesis->isConstantValue(TR, V) && TR.query(V)[{-1}].isPossiblePointer(); })) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active store via pointer in " - // "unknown inst: " - // << *I << " of " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store via pointer in " + "unknown inst: " + << *op << " of " << Val << "\n"; potentiallyActiveStore = true; } } @@ -2042,25 +2184,25 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } } if ((!potentiallyActiveStore || !potentialStore) && isModSet(modRef)) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active store: " << *I << " Val=" << *Val - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store: " << *op << " Val=" << Val + << "\n"; if (auto SI = dyn_cast(op)) { bool cop = !Hypothesis->isConstantValue(TR, SI.getValue()); - // if (EnzymePrintActivity) - // llvm::errs() << " -- store potential activity: " << (int)cop - // << " - " << *SI << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- store potential activity: " << (int)cop + << " - " << *SI << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; } else if (auto SI = dyn_cast(op)) { // FIXME: this is a copy-pasta form above to work with MLIR memrefs. bool cop = !Hypothesis->isConstantValue(TR, SI.getValueToStore()); - // if (EnzymePrintActivity) - // llvm::errs() << " -- store potential activity: " << (int)cop - // << " - " << *SI << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- store potential activity: " << (int)cop + << " - " << *SI << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; @@ -2077,11 +2219,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // TODO: note that this can be optimized (especially for function // calls) auto cop = !Hypothesis->isConstantOperation(TR, op); - // if (EnzymePrintActivity) - // llvm::errs() << " -- unknown store potential activity: " << - // (int)cop - // << " - " << *I << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- unknown store potential activity: " << (int)cop + << " - " << *op << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; @@ -2146,11 +2287,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } activeLoadAndStore:; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *Val - // << " potentiallyActiveLoad=" << potentiallyActiveLoad - // << " potentiallyActiveStore=" << potentiallyActiveStore - // << " potentialStore=" << potentialStore << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << Val + << " potentiallyActiveLoad=" << potentiallyActiveLoad + << " potentiallyActiveStore=" << potentiallyActiveStore + << " potentialStore=" << potentialStore << "\n"; if (potentiallyActiveLoad && potentiallyActiveStore) { insertAllFrom(TR, *Hypothesis, Val); // TODO have insertall dependence on this @@ -2212,15 +2353,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, new mlir::enzyme::ActivityAnalyzer(*DownHypothesis, DOWN)); DownHypothesis2->ConstantValues.insert(TmpOrig); if (DownHypothesis2->isValueActivelyStoredOrReturned(TR, TmpOrig)) { - // if (EnzymePrintActivity) - // llvm::errs() << " active from ivasor: " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " active from ivasor: " << TmpOrig << "\n"; ActiveDown = true; } } else { // unknown origin that could've been stored/returned/etc - // if (EnzymePrintActivity) - // llvm::errs() << " active from unknown origin: " << *TmpOrig << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " active from unknown origin: " << TmpOrig << "\n"; ActiveDown = true; } } @@ -2257,13 +2397,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If we go to an active return and only load it, however, that doesnt // transfer derivatives and we can say this memory is inactive - // if (EnzymePrintActivity) - // llvm::errs() << " @@MEMSEARCH" << (int)directions << ">" << *Val - // << " potentiallyActiveLoad=" << potentiallyActiveLoad - // << " potentialStore=" << potentialStore - // << " ActiveUp=" << ActiveUp << " ActiveDown=" << - // ActiveDown - // << " ActiveMemory=" << ActiveMemory << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " @@MEMSEARCH" << (int)directions << ">" << Val + << " potentiallyActiveLoad=" << potentiallyActiveLoad + << " potentialStore=" << potentialStore + << " ActiveUp=" << ActiveUp << " ActiveDown=" << ActiveDown + << " ActiveMemory=" << ActiveMemory << "\n"; if (ActiveMemory) { ActiveValues.insert(Val); @@ -2293,39 +2432,21 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // this value is inactive, we are inactive Since we won't look at uses to // prove, we can inductively assume this is inactive if (directions & UP) { - if (directions == UP && !Val.isa()) { - if (isValueInactiveFromOrigin(TR, Val)) { - InsertConstantValue(TR, Val); - return true; - } else if (Operation *op = Val.getDefiningOp()) { - if (directions == (UP | DOWN)) { - for (Value operand : op->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - for (Value result : op->getResults()) { - ReEvaluateValueIfInactiveValue[operand].insert(result); - } - } - } - } - } + UpHypothesis = std::shared_ptr( + new mlir::enzyme::ActivityAnalyzer(*this, UP)); + UpHypothesis->ConstantValues.insert(Val); + SmallPtrSet toredo; + if (UpHypothesis->isValueInactiveFromOrigin(TR, Val, &toredo)) { + insertConstantsFrom(TR, *UpHypothesis); + InsertConstantValue(TR, Val); + if (EnzymePrintActivity) + llvm::errs() << " Value constant from origin [" << (int)directions + << "]" << Val << "\n"; + return true; } else { - UpHypothesis = std::shared_ptr( - new mlir::enzyme::ActivityAnalyzer(*this, UP)); - UpHypothesis->ConstantValues.insert(Val); - if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } else if (Operation *op = Val.getDefiningOp()) { - if (directions == (UP | DOWN)) { - for (Value operand : op->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - for (Value result : op->getResults()) { - ReEvaluateValueIfInactiveValue[operand].insert(result); - } - } - } - } + for (Value result : toredo) { + if (result != Val) + ReEvaluateValueIfInactiveValue[result].insert(Val); } } } @@ -2335,164 +2456,54 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If all users are inactive, this is therefore inactive. // Since we won't look at origins to prove, we can inductively assume this // is inactive - - // As an optimization if we are going down already - // and we won't use ourselves (done by PHI's), we - // dont need to inductively assume we're true - // and can instead use this object! - if (directions == DOWN && !Val.isa()) { - if (isValueInactiveFromUsers(TR, Val, UseActivity::None)) { - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } - } else { - auto DownHypothesis = std::shared_ptr( - new mlir::enzyme::ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(Val); - if (DownHypothesis->isValueInactiveFromUsers(TR, Val, - UseActivity::None)) { - insertConstantsFrom(TR, *DownHypothesis); - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } + auto DownHypothesis = std::shared_ptr( + new mlir::enzyme::ActivityAnalyzer(*this, DOWN)); + DownHypothesis->ConstantValues.insert(Val); + if (DownHypothesis->isValueInactiveFromUsers(TR, Val, UseActivity::None)) { + insertConstantsFrom(TR, *DownHypothesis); + if (UpHypothesis) + insertConstantsFrom(TR, *UpHypothesis); + InsertConstantValue(TR, Val); + return true; } } - // if (EnzymePrintActivity) - // llvm::errs() << " Value nonconstant (couldn't disprove)[" << - // (int)directions - // << "]" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Value nonconstant (couldn't disprove)[" << (int)directions + << "]" << Val << "\n"; ActiveValues.insert(Val); return false; } /// Is the value guaranteed to be inactive because of how it's produced. bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromOrigin( - MTypeResults const &TR, Value val) { + MTypeResults const &TR, Value val, SmallPtrSetImpl *inactArg) { // Must be an analyzer only searching up assert(directions == UP); - // TODO: use getPotentialIncomingValues here to avoid duplciation. - if (auto arg = val.dyn_cast()) { - if (arg.getOwner()->isEntryBlock()) { - Operation *parent = arg.getOwner()->getParentOp(); - Region *parentRegion = arg.getOwner()->getParent(); - SetVector potentialSources; - // Use region interface to find the values flowing into the entry block. - if (auto iface = dyn_cast(parent)) { - auto isRegionSucessorOf = [arg](RegionBranchOpInterface iface, - Region *region, - RegionBranchPoint predecessor, - SetVector &potentialSources) { - SmallVector successors; - iface.getSuccessorRegions(predecessor, successors); - for (const RegionSuccessor &successor : successors) { - if (successor.getSuccessor() != region) - continue; - - unsigned operandOffset = static_cast(-1); - for (const auto &en : - llvm::enumerate(successor.getSuccessorInputs())) { - if (en.value() != arg) - continue; - operandOffset = en.index(); - } - assert(operandOffset != static_cast(-1) && - "could not locate the position of the argument in the " - "successor input list"); - - // Find the values that are forwarded to entry block arguments of - // the current region. - if (predecessor.isParent()) { - // XXX: this assumes a contiguous slice of operands is mapped 1-1 - // without swaps to a contiguous slice of entry block arguments. - assert(iface.getEntrySuccessorOperands(region).size() == - successor.getSuccessorInputs().size()); - potentialSources.insert( - iface.getEntrySuccessorOperands(region)[operandOffset]); - } else { - // Find all block terminators in the predecessor region that - // may be branching to this region, and get the operands they - // forward. - for (Block &block : *predecessor.getRegionOrNull()) { - // TODO: MLIR block without terminator - if (auto terminator = - dyn_cast( - block.getTerminator())) { - // XXX: this assumes a contiguous slice of operands is mapped - // 1-1 without swaps to a contiguous slice of entry block - // arguments. - assert(terminator.getSuccessorOperands(region).size() == - successor.getSuccessorInputs().size()); - potentialSources.insert( - terminator.getSuccessorOperands(region)[operandOffset]); - } else { - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); - } - } - } - } - }; - - // Find all possible source regions for the current region. - isRegionSucessorOf(iface, parentRegion, RegionBranchPoint::parent(), - potentialSources); - for (Region ®ion : parent->getRegions()) - isRegionSucessorOf(iface, parentRegion, region, potentialSources); - - } else { - // Conservatively assume any op operand and any terminator operand of - // any region can flow into any block argument. - for (Region ®ion : parent->getRegions()) { - for (Block &block : region) { - // TODO: MLIR blocks without terminator? - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); - } - } - } - - return llvm::all_of(potentialSources, [&](Value value) { - return isConstantValue(TR, value); - }); - } - - // Look at values flowing into block arguments. - for (Block *predecessor : arg.getOwner()->getPredecessors()) { - Operation *terminator = predecessor->getTerminator(); - if (auto iface = dyn_cast(terminator)) { - for (const auto &en : llvm::enumerate(predecessor->getSuccessors())) { - if (en.value() != arg.getOwner()) - continue; - - Value inflow = iface.getSuccessorOperands(en.index()) - .getForwardedOperands()[arg.getArgNumber()]; - if (!isConstantValue(TR, inflow)) - return false; - } - } else { - for (Value operand : terminator->getOperands()) { - if (!isConstantValue(TR, operand)) - return false; + for (auto v : getPotentialIncomingValues(arg)) { + if (!isConstantValue(TR, v)) { + if (EnzymePrintActivity) { + llvm::errs() << " blockarg: " << arg + << " may be active due to inflow from " << v << "\n"; } + if (inactArg) + inactArg->insert(v); + return false; } } - return true; } return isOperationInactiveFromOrigin(TR, val.getDefiningOp(), - val.cast().getResultNumber()); + val.cast().getResultNumber(), + inactArg); } bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( - MTypeResults const &TR, Operation *op, std::optional resultNo) { + MTypeResults const &TR, Operation *op, std::optional resultNo, + SmallPtrSetImpl *inactArg) { // Must be an analyzer only searching up assert(directions == UP); @@ -2510,19 +2521,22 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( } } - // if (EnzymePrintActivity) - // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *op << "\n"; if (auto store = dyn_cast(op)) { if (isConstantValue(TR, store.getValue()) || isConstantValue(TR, store.getAddr())) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as store operand is inactive - // " - // << *inst << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as store operand is inactive" + << *op << "\n"; return true; } + if (inactArg) { + inactArg->insert(store.getValue()); + inactArg->insert(store.getAddr()); + } + return false; } if (isa(op)) { @@ -2530,11 +2544,15 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( // values and thus the store is inactive if (isConstantValue(TR, op->getOperand(0)) || isConstantValue(TR, op->getOperand(1))) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as memtransfer " << *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as memtransfer " << *op << "\n"; return true; } + if (inactArg) { + inactArg->insert(op->getOperand(0)); + inactArg->insert(op->getOperand(1)); + } + return false; } if (auto call = dyn_cast(op)) { @@ -2576,9 +2594,9 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (KnownInactiveFunctions.count(funcName.str()) || MPIInactiveCommAllocators.find(funcName.str()) != MPIInactiveCommAllocators.end()) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions - // << ") up-knowninactivecall " << *inst << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions + << ") up-knowninactivecall " << *op << "\n"; return true; } @@ -2600,9 +2618,9 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( Value callVal = call.getCallableForCallee().dyn_cast(); if (callVal) if (isConstantValue(TR, callVal)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-constfn " - // << *inst << " - " << *callVal << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-constfn " + << *op << " - " << callVal << "\n"; return true; } } @@ -2610,12 +2628,14 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (auto gep = dyn_cast(op)) { // A gep's only args that could make it active is the pointer operand if (isConstantValue(TR, gep.getBase())) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-gep " << - // *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-gep " << *op + << "\n"; return true; } + if (inactArg) { + inactArg->insert(gep.getBase()); + } return false; } @@ -2676,97 +2696,79 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (isConstantValue(TR, si.getTrueValue()) && isConstantValue(TR, si.getFalseValue())) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-sel:" << - // *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-sel:" << *op + << "\n"; return true; } + if (inactArg) { + inactArg->insert(si.getTrueValue()); + inactArg->insert(si.getFalseValue()); + } return false; } - { - bool seenuse = false; - //! TODO does not consider reading from global memory that is active and not - //! an argument + if (!resultNo) { for (Value a : op->getOperands()) { bool hypval = isConstantValue(TR, a); if (!hypval) { - // if (EnzymePrintActivity) - // llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " - // << *inst << " op " << *a << "\n"; - seenuse = true; - break; - } - } - if (!resultNo) { - // Conservatively check all top-level operations nested in the region, - // there is recursion there. - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // XXX: We think that we don't need to check block arguments here - // because values flow into them either from operands of the parent op - // or from the op itself. - if (llvm::any_of(block, [&](Operation &nested) { - // No need to check the results, even if they may be active, - // because in absence of resultNo, we are checking for the - // entire op being inactive not individual values. - // - // // The loop _operation_ is inactive, but the result is, just - // // like the GEP inside it. - // %r = scf.for %i.. { - // // The GEP operation is not active, but the result is. - // %active_r = llvm.gep ... %active_operand - // scf.yield %active_r - // } - return !isConstantOperation(TR, &nested); - })) { - seenuse = true; - break; - } + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " + << *op << " op " << a << "\n"; + if (inactArg) { + inactArg->insert(a); } - if (seenuse) - break; + return false; } - } else { - SetVector potentialSources; - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // TODO: MLIR blocks without terminator? - if (auto iface = dyn_cast( - block.getTerminator())) { - // TODO: the interface may also tell us which regions are allowed to - // yield parent op results, and which only branch to other regions. - auto successorOperands = llvm::to_vector( - iface.getSuccessorOperands(RegionBranchPoint::parent())); - // TODO: understand/document the assumption of how operands flow. - assert(successorOperands.size() == op->getNumResults() && - "expected all results to be populated with yielded " - "terminator operands"); - potentialSources.insert(successorOperands[*resultNo]); - } else { - // assume all terminator operands potentially flow into op results - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); + } + // Conservatively check all top-level operations nested in the region, + // there is recursion there. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + // XXX: We think that we don't need to check block arguments here + // because values flow into them either from operands of the parent op + // or from the op itself. + for (Operation &nested : block) { + // No need to check the results, even if they may be active, + // because in absence of resultNo, we are checking for the + // entire op being inactive not individual values. + // + // // The loop _operation_ is inactive, but the result is, just + // // like the GEP inside it. + // %r = scf.for %i.. { + // // The GEP operation is not active, but the result is. + // %active_r = llvm.gep ... %active_operand + // scf.yield %active_r + // } + if (!isConstantOperation(TR, &nested)) { + // TODO set inactArg here, except with constant operand. + // assert(!inactArg); + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions + << ") up-inst-op " << *op << " sub-op " << nested + << "\n"; + return false; } } } - if (llvm::any_of(potentialSources, [&](Value value) { - return !isConstantValue(TR, value); - })) { - seenuse = true; - } } - - if (!seenuse) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-inst:" << - // *inst - // << "\n"; - return true; + } else { + for (auto value : getPotentialIncomingValues(op->getResult(*resultNo))) { + if (!isConstantValue(TR, value)) { + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " + << *op << " value " << value << "\n"; + if (inactArg) + inactArg->insert(value); + return false; + } } - return false; } + + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-inst:" << *op + << "\n"; + return true; } /// Is the value free of any active uses @@ -2778,9 +2780,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // To ensure we can call down - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << " UA=" << (int)PUA << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << " UA=" << (int)PUA << "\n"; bool seenuse = false; // user, predecessor @@ -2834,9 +2836,8 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( } } - // if (EnzymePrintActivity) - // llvm::errs() << " considering use of " << *val << " - " << *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " considering use of " << val << " - " << *a << "\n"; // Only ignore stores to the operand, not storing the operand // somewhere @@ -2915,25 +2916,23 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // vtodo.push_back(TmpOrig_2); // continue; // } - // if (EnzymePrintActivity) - // llvm::errs() << " -- cannot continuing indirect store from - // " - // << *val << " due to " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- cannot continuing indirect store from" + << val << " due to " << TmpOrig << "\n"; shouldContinue = false; break; } if (shouldContinue) { - // if (EnzymePrintActivity) - // llvm::errs() << " -- continuing indirect store from " << - // *val - // << " into:\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- continuing indirect store from " << val + << " into:\n"; done.insert(std::make_tuple(SI.getOperation(), SI.getValue(), UA)); for (Value TmpOrig : newAllocaSet) { for (Operation *a : TmpOrig.getUsers()) { todo.push_back(std::make_tuple(a, TmpOrig, UA)); - // if (EnzymePrintActivity) - // llvm::errs() << " ** " << *a << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " ** " << *a << "\n"; } AllocaSet.insert(TmpOrig); shouldContinue = true; @@ -2993,10 +2992,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( break; } if (shouldContinue) { - // if (EnzymePrintActivity) - // llvm::errs() << " -- continuing indirect store2 from " << - // *val - // << " via " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- continuing indirect store2 from " << val + << " via " << TmpOrig << "\n"; continue; } } @@ -3034,10 +3032,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // } if (isa(a)) { - // if (EnzymePrintActivity) - // llvm::errs() << "found constant(" << (int)directions - // << ") allocainst use:" << *val << " user " << *a << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << "found constant(" << (int)directions + << ") allocainst use:" << val << " user " << *a << "\n"; continue; } @@ -3066,12 +3063,21 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // continue; // This use is only active if specified - if (isa(a)) { - if (ActiveReturns == DIFFE_TYPE::CONSTANT && - UA != UseActivity::AllStores) { + if (UA != UseActivity::AllStores) { + if (auto termUsers = getPotentialTerminatorUsers(a, parent)) { + for (auto post : *termUsers) { + for (Operation *postUser : post.getUsers()) { + todo.push_back(std::make_tuple(postUser, post, UA)); + } + } continue; - } else { - return false; + } + if (isFunctionReturn(a)) { + if (ActiveReturns == DIFFE_TYPE::CONSTANT) { + continue; + } else { + return false; + } } } @@ -3166,10 +3172,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (auto call = dyn_cast(a)) { bool ConstantArg = isFunctionArgumentConstant(call, parent); if (ConstantArg && UA != UseActivity::AllStores) { - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant callinst use:" << *val - // << " user " << *call << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant callinst use:" << val + << " user " << *call << "\n"; + } continue; } @@ -3190,10 +3196,8 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (operand.getDefiningOp()) { bool legal = true; - for (unsigned i = 0; i < call.getArgOperands().size() + 1; ++i) { - // FIXME: this is based on an assumption that the callee operand - // precedes arg operands. - Value a = call->getOperand(i); + for (unsigned i = 0; i < call.getArgOperands().size(); ++i) { + Value a = call.getArgOperands()[i]; // FIXME: yet another ingrained assumption that integers cannot be // active. @@ -3267,11 +3271,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (Operation *I = a) { if (notForAnalysis.count(I->getBlock())) { // TODO(PR #904): replace the "EnzymePrintActivity" flag with LLVM_DEBUG - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant unreachable inst use:" << - // *val - // << " user " << *I << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant unreachable inst use:" << val + << " user " << *I << "\n"; + } continue; } if (UA != UseActivity::AllStores && ConstantOperations.count(I)) { @@ -3280,11 +3283,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( .isa() || ConstantValues.count(val); })) { - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant inst use:" << *val << " - // user " - // << *I << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant inst use:" << val << " user " + << *I << "\n"; + } continue; } UseActivity NU = UA; @@ -3359,17 +3361,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( *FoundInst = I; } - // if (EnzymePrintActivity) - // llvm::errs() << "Value nonconstant inst (uses):" << *val << " user " << - // *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "Value nonconstant inst (uses):" << val << " user " << *a + << "\n"; seenuse = true; break; } - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << "\n"; return !seenuse; } @@ -3386,10 +3387,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( return StoredOrReturnedCache[key]; } - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << "\n"; StoredOrReturnedCache[key] = false; @@ -3415,14 +3416,21 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( continue; } - if (isa(a)) { + if (auto termUsers = getPotentialTerminatorUsers(a, val)) { + for (auto post : *termUsers) + if (isValueActivelyStoredOrReturned(TR, post, outside)) { + return StoredOrReturnedCache[key] = true; + } + return false; + } + if (isFunctionReturn(a)) { if (ActiveReturns == DIFFE_TYPE::CONSTANT) continue; - // if (EnzymePrintActivity) - // llvm::errs() << " " - // << " active from-ret>" << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " + << " active from-ret>" << val << "\n"; StoredOrReturnedCache[key] = true; return true; } @@ -3445,11 +3453,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // Storing into active value, return true if (!isConstantValue(TR, SI.getValue())) { StoredOrReturnedCache[key] = true; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << " store into=" << *SI << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << " store into=" << *SI << "\n"; return true; } } @@ -3458,11 +3466,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // Storing into active memory, return true if (!isConstantValue(TR, SI.getAddr())) { StoredOrReturnedCache[key] = true; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << " store=" << *SI - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << " store=" << *SI + << "\n"; return true; } continue; @@ -3513,18 +3521,17 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // it is written to active memory // TODO handle more memory instructions above to be less conservative - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << " - use=" << *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << " - use=" << *a << "\n"; return StoredOrReturnedCache[key] = true; } - // if (EnzymePrintActivity) - // llvm::errs() << " " - // << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " + << val << "\n"; return false; } @@ -3540,9 +3547,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantOperation( if (!ActiveValues.count(toeval)) continue; ActiveValues.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of val " << *toeval - // << " due to inst " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of val " << toeval + << " due to inst " << *I << "\n"; isConstantValue(TR, toeval); } } @@ -3558,9 +3565,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantValue(MTypeResults const &TR, if (!ActiveValues.count(toeval)) continue; ActiveValues.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of val " << *toeval - // << " due to value " << *V << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of val " << toeval + << " due to value " << V << "\n"; isConstantValue(TR, toeval); } } @@ -3572,9 +3579,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantValue(MTypeResults const &TR, if (!ActiveOperations.count(toeval)) continue; ActiveOperations.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of inst " << *toeval - // << " due to value " << *V << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of inst " << *toeval + << " due to value " << V << "\n"; isConstantOperation(TR, toeval); } } diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h index bff41155fc3b..c73df2025883 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h @@ -143,12 +143,18 @@ class ActivityAnalyzer { bool isFunctionArgumentConstant(mlir::CallOpInterface CI, Value val); /// Is the value guaranteed to be inactive because of how it's produced. - bool isValueInactiveFromOrigin(MTypeResults const &TR, Value val); + /// If active and inactArg is non-null, store any values which may allow this + /// to succeed in the future + bool isValueInactiveFromOrigin( + MTypeResults const &TR, Value val, + llvm::SmallPtrSetImpl *inactArg = nullptr); + /// Is the operation guaranteed to be inactive because of how its operands are /// produced. bool isOperationInactiveFromOrigin( MTypeResults const &TR, Operation *op, - std::optional resultNo = std::nullopt); + std::optional resultNo = std::nullopt, + llvm::SmallPtrSetImpl *inactArg = nullptr); public: enum class UseActivity { diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 613e0ef2ad7e..a15b3283e10d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -57,7 +57,13 @@ class TensorTypeInterface public: Value createNullValue(Type self, OpBuilder &builder, Location loc) const { auto tenType = self.cast(); - auto attr = DenseElementsAttr::get(tenType, 0); + auto ET = tenType.getElementType(); + size_t num = 1; + for (auto sz : tenType.getShape()) + num *= sz; + APFloat apvalue(ET.cast().getFloatSemantics(), 0); + SmallVector supportedValues(num, apvalue); + auto attr = DenseElementsAttr::get(tenType, supportedValues); return builder.create(loc, tenType, attr); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index c18a7bd0e921..43b4757bdf53 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -217,32 +217,11 @@ void mlir::enzyme::detail::regionTerminatorForwardHandler( // Assuming shadows following the originals are fine. // TODO: consider extending to have a ShadowableTerminatorOpInterface Operation *replTerminator = gutils->getNewFromOriginal(origTerminator); - Operation *newTerminator = builder.clone(*replTerminator); - newTerminator->setOperands(newOperands); - gutils->erase(replTerminator); + replTerminator->setOperands(newOperands); } LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( Operation *op, OpBuilder &builder, MGradientUtils *gutils) { - // For all active results, add shadow types. - // For now, assuming all results are relevant. - Operation *newOp = gutils->getNewFromOriginal(op); - SmallVector newOpResultTypes; - newOpResultTypes.reserve(op->getNumResults() * 2); - for (Value result : op->getResults()) { - // TODO only if used (can we DCE the primal after having done the - // derivative). - newOpResultTypes.push_back(result.getType()); - if (gutils->isConstantValue(result)) - continue; - auto typeIface = dyn_cast(result.getType()); - if (!typeIface) { - op->emitError() << " AutoDiffTypeInterface not implemented for " - << result.getType() << "\n"; - return failure(); - } - newOpResultTypes.push_back(typeIface.getShadowType()); - } // For all operands that are forwarded to the body, if they are active, also // add the shadow as operand. @@ -253,28 +232,72 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( return failure(); } - SmallVector successors; // TODO: we may need to record, for every successor, which of its inputs // need a shadow to recreate the body correctly. llvm::SmallDenseSet operandPositionsToShadow; + llvm::SmallDenseSet resultPositionsToShadow; + + SmallVector entrySuccessors; regionBranchOp.getEntrySuccessorRegions( - SmallVector(op->getNumOperands(), Attribute()), successors); - for (const RegionSuccessor &successor : successors) { - if (!successor.isParent() && successor.getSuccessor()->empty()) - continue; + SmallVector(op->getNumOperands(), Attribute()), + entrySuccessors); + + for (const RegionSuccessor &successor : entrySuccessors) { OperandRange operandRange = regionBranchOp.getEntrySuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? op->getResults() + : successor.getSuccessorInputs(); + // Need to know which of the arguments are being forwarded to from // operands. for (auto &&[i, regionValue, operand] : - llvm::enumerate(successor.getSuccessorInputs(), operandRange)) { + llvm::enumerate(targetValues, operandRange)) { if (gutils->isConstantValue(regionValue)) continue; operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); + if (successor.isParent()) + resultPositionsToShadow.insert(i); } } + + for (auto res : op->getResults()) + if (!gutils->isConstantValue(res)) + resultPositionsToShadow.insert(res.getResultNumber()); + + return controlFlowForwardHandler( + op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow); +} + +LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils, + const llvm::SmallDenseSet &operandPositionsToShadow, + const llvm::SmallDenseSet &resultPositionsToShadow) { + // For all active results, add shadow types. + // For now, assuming all results are relevant. + Operation *newOp = gutils->getNewFromOriginal(op); + SmallVector newOpResultTypes; + newOpResultTypes.reserve(op->getNumResults() * 2); + for (auto result : op->getResults()) { + // TODO only if used (can we DCE the primal after having done the + // derivative). + newOpResultTypes.push_back(result.getType()); + if (!gutils->isConstantValue(result)) { + assert(resultPositionsToShadow.count(result.getResultNumber())); + } + if (!resultPositionsToShadow.count(result.getResultNumber())) + continue; + auto typeIface = dyn_cast(result.getType()); + if (!typeIface) { + op->emitError() << " AutoDiffTypeInterface not implemented for " + << result.getType() << "\n"; + return failure(); + } + newOpResultTypes.push_back(typeIface.getShadowType()); + } + SmallVector newOperands; newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size()); for (OpOperand &operand : op->getOpOperands()) { @@ -297,6 +320,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( } Operation *replacement = iface.createWithShadows( builder, gutils, op, newOperands, newOpResultTypes); + assert(replacement->getNumResults() == newOpResultTypes.size()); for (auto &&[region, replacementRegion] : llvm::zip(newOp->getRegions(), replacement->getRegions())) { replacementRegion.takeBody(region); diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 974b888599d5..a7ec6b179986 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -15,10 +15,13 @@ #include "Interfaces/AutoDiffOpInterface.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" + namespace mlir { class DialectRegistry; class Operation; class OpBuilder; +class RegionSuccessor; namespace enzyme { class MGradientUtils; @@ -27,9 +30,15 @@ class MGradientUtilsReverse; namespace detail { // Non-template implementation of // AutoDiffUsingControlFlow::createForwardModeTangent. + LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); +LogicalResult controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils, + const llvm::SmallDenseSet &operandPositionsToShadow, + const llvm::SmallDenseSet &resultPositionsToShadow); + // Implements forward-mode differentiation of branching operations. // Assumes that successive shadows are legal void branchingForwardHandler(Operation *op, OpBuilder &builder, diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index b9cbfe3e6913..39de939bc27a 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -165,6 +165,11 @@ Create reverse mode adjoint for an operation. */ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) { + if (llvm::all_of(op->getResults(), + [gutils](Value v) { return gutils->isConstantValue(v); }) && + gutils->isConstantInstruction(op)) { + return; + } if (auto ifaceOp = dyn_cast(op)) { SmallVector caches = ifaceOp.cacheValues(gutils); ifaceOp.createReverseModeAdjoint(builder, gutils, caches); @@ -175,6 +180,7 @@ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, gutils->clearValue(result, builder); } } + op->emitError() << "could not compute the adjoint for this operation " << *op; } void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 45473c75921f..d83e2da59a17 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -115,14 +115,14 @@ mlir::enzyme::MGradientUtils::getNewFromOriginal(mlir::Block *originst) const { Operation * mlir::enzyme::MGradientUtils::getNewFromOriginal(Operation *originst) const { + assert(originst); auto found = originalToNewFnOps.find(originst); if (found == originalToNewFnOps.end()) { llvm::errs() << oldFunc << "\n"; llvm::errs() << newFunc << "\n"; for (auto &pair : originalToNewFnOps) { llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n"; - // llvm::errs() << " map[" << pair.first << "] = " << pair.second << " - // -- " << *pair.first << " " << *pair.second << "\n"; + llvm::errs() << " map[" << *pair.first << "] = " << *pair.second << "\n"; } llvm::errs() << originst << " - " << *originst << "\n"; llvm_unreachable("Could not get new op from original"); @@ -154,7 +154,12 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, if (isConstantValue(v)) { if (auto iface = v.getType().dyn_cast()) { OpBuilder::InsertionGuard guard(Builder2); - Builder2.setInsertionPoint(getNewFromOriginal(v.getDefiningOp())); + if (auto op = v.getDefiningOp()) + Builder2.setInsertionPoint(getNewFromOriginal(op)); + else { + auto ba = cast(v); + Builder2.setInsertionPointToStart(getNewFromOriginal(ba.getOwner())); + } Value dv = iface.createNullValue(Builder2, v.getLoc()); invertedPointers.map(v, dv); return dv; diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 432b0938f763..d2a04fc37412 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -101,6 +101,13 @@ def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> { /*default=*/"false", /*description=*/"Annotate every operation and value with its activity" >, + Option< + /*C++ variable name=*/"dataflow", + /*CLI argument=*/"dataflow", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Whether to use the new Dataflow activity analysis" + >, Option< /*C++ variable name=*/"inactiveArgs", /*CLI argument=*/"inactive-args", diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index fb3ef000a0c1..ac88d0fc86fa 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -10,8 +10,10 @@ // analysis. // //===----------------------------------------------------------------------===// +#include "Analysis/ActivityAnalysis.h" #include "Analysis/DataFlowActivityAnalysis.h" #include "Dialect/Ops.h" +#include "Interfaces/EnzymeLogic.h" #include "Passes/PassDetails.h" #include "Passes/Passes.h" @@ -116,10 +118,79 @@ struct PrintActivityAnalysisPass } } + void runActivityAnalysis(bool dataflow, FunctionOpInterface callee, + ArrayRef argActivities, + ArrayRef resultActivities, + bool print, bool verbose, bool annotate) { + if (dataflow) { + enzyme::runDataFlowActivityAnalysis(callee, argActivities, + /*print=*/true, verbose, annotate); + } else { + + SmallPtrSet blocksNotForAnalysis; + + mlir::enzyme::MTypeResults TR; // TODO + SmallPtrSet constant_values; + SmallPtrSet activevals_; + for (auto &&[arg, act] : + llvm::zip(callee.getFunctionBody().getArguments(), argActivities)) { + if (act == enzyme::Activity::enzyme_const) + constant_values.insert(arg); + else + activevals_.insert(arg); + } + auto ReturnActivity = DIFFE_TYPE::CONSTANT; + for (auto act : resultActivities) + if (act != enzyme::Activity::enzyme_const) + ReturnActivity = DIFFE_TYPE::DUP_ARG; + + enzyme::ActivityAnalyzer activityAnalyzer( + blocksNotForAnalysis, constant_values, activevals_, ReturnActivity); + + callee.walk([&](Operation *op) { + + }); + MLIRContext *ctx = callee.getContext(); + callee.walk([&](Operation *op) { + if (print) + llvm::outs() << " Operation: " << *op << "\n"; + for (auto ® : op->getRegions()) { + for (auto &blk : reg.getBlocks()) { + for (auto &arg : blk.getArguments()) { + bool icv = activityAnalyzer.isConstantValue(TR, arg); + if (annotate) + op->setAttr("enzyme.arg_icv" + + std::to_string(arg.getArgNumber()), + BoolAttr::get(ctx, icv)); + if (print) + llvm::outs() << " arg: " << arg << " icv=" << icv << "\n"; + } + } + } + + bool ici = activityAnalyzer.isConstantOperation(TR, op); + if (annotate) + op->setAttr("enzyme.ici", BoolAttr::get(ctx, ici)); + if (print) + llvm::outs() << " op ici=" << ici << "\n"; + + for (auto res : op->getResults()) { + bool icv = activityAnalyzer.isConstantValue(TR, res); + if (annotate) + op->setAttr("enzyme.res_icv" + + std::to_string(res.getResultNumber()), + BoolAttr::get(ctx, icv)); + if (print) + llvm::outs() << " res: " << res << " icv=" << icv << "\n"; + } + }); + } + } + void runOnOperation() override { auto moduleOp = cast(getOperation()); - if (annotate) { + if (annotate && dataflow) { // Infer the activity attributes from the __enzyme_autodiff call Operation *autodiff_decl = moduleOp.lookupSymbol("__enzyme_autodiff"); if (!autodiff_decl) @@ -148,8 +219,8 @@ struct PrintActivityAnalysisPass // supplied annotation. First argument is the callee inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); } return; } @@ -163,8 +234,8 @@ struct PrintActivityAnalysisPass resultActivities{callee.getNumResults()}; initializeArgAndResActivities(callee, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); }); return; } @@ -186,8 +257,8 @@ struct PrintActivityAnalysisPass resultActivities{callee.getNumResults()}; initializeArgAndResActivities(callee, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); } } }; diff --git a/enzyme/test/MLIR/ActivityAnalysis/region.mlir b/enzyme/test/MLIR/ActivityAnalysis/region.mlir new file mode 100644 index 000000000000..4526e1325661 --- /dev/null +++ b/enzyme/test/MLIR/ActivityAnalysis/region.mlir @@ -0,0 +1,27 @@ +// RUN: %eopt --pass-pipeline="builtin.module(print-activity-analysis{dataflow=false annotate=true})" %s --split-input-file 2>&1 | FileCheck %s + +// A function that contains active and inactive region dataflow + +func.func @region(%x: f64) -> (f64, f64) { + %f0 = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + %r0:2 = scf.for %arg12 = %c0 to %c10 step %c1 iter_args(%arg13 = %f0, %arg14 = %f0) -> (f64, f64) { + %m = arith.addf %arg13, %x : f64 + scf.yield %m, %arg14 : f64, f64 + } + return %r0#0, %r0#1 : f64, f64 +} + +// CHECK: func.func @region(%arg0: f64) -> (f64, f64) attributes {enzyme.arg_icv0 = false, enzyme.ici = false} { +// CHECK-NEXT: %cst = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 0.000000e+00 : f64 +// CHECK-NEXT: %c0 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 0 : index +// CHECK-NEXT: %c10 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 10 : index +// CHECK-NEXT: %c1 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 1 : index +// CHECK-NEXT: %0:2 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst, %arg3 = %cst) -> (f64, f64) { +// CHECK-NEXT: %1 = arith.addf %arg2, %arg0 {enzyme.ici = false, enzyme.res_icv0 = false} : f64 +// CHECK-NEXT: scf.yield {enzyme.ici = true} %1, %arg3 : f64, f64 +// CHECK-NEXT: } {enzyme.arg_icv0 = true, enzyme.arg_icv1 = false, enzyme.arg_icv2 = true, enzyme.ici = false, enzyme.res_icv0 = false, enzyme.res_icv1 = true} +// CHECK-NEXT: return {enzyme.ici = true} %0#0, %0#1 : f64, f64 +// CHECK-NEXT: } From 94aa6edcd5f1a828f770a2ad1c1c50be60d8f8a1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 17 Feb 2024 21:36:56 -0500 Subject: [PATCH 100/131] Disable MLIR print activity --- enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 30fbfc8ef345..1415559b3f6b 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -248,7 +248,7 @@ static Operation *getFunctionFromCall(CallOpInterface iface) { return SymbolTable::lookupNearestSymbolFrom(iface.getOperation(), symbol); } -constexpr bool EnzymePrintActivity = true; +constexpr bool EnzymePrintActivity = false; /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode From c1f59a9838a73fb15f278fdcefa2f78ac7112a6b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 17 Feb 2024 22:41:19 -0500 Subject: [PATCH 101/131] MLIR fix forward terminator interface (#1733) --- .../CoreDialectsAutoDiffImplementations.cpp | 46 +++++++++++-------- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 6 +-- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 43b4757bdf53..025946ef9651 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -186,29 +186,39 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler( void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { - auto termIface = cast(origTerminator); - auto parentOp = termIface->getParentOp(); - - SmallVector successors; - termIface.getSuccessorRegions( - SmallVector(termIface->getNumOperands(), Attribute()), - successors); + auto parentOp = origTerminator->getParentOp(); llvm::SmallDenseSet operandsToShadow; - for (auto &successor : successors) { - OperandRange operandRange = termIface.getSuccessorOperands(successor); - ValueRange targetValues = successor.isParent() - ? parentOp->getResults() - : successor.getSuccessorInputs(); - assert(operandRange.size() == targetValues.size()); - for (auto &&[i, target] : llvm::enumerate(targetValues)) { - if (!gutils->isConstantValue(target)) - operandsToShadow.insert(operandRange.getBeginOperandIndex() + i); + if (auto termIface = + dyn_cast(origTerminator)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(origTerminator->getNumOperands(), Attribute()), + successors); + + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[i, target] : llvm::enumerate(targetValues)) { + if (!gutils->isConstantValue(target)) + operandsToShadow.insert(operandRange.getBeginOperandIndex() + i); + } + } + } else { + assert(parentOp->getNumResults() == origTerminator->getNumOperands()); + for (auto res : parentOp->getResults()) { + if (!gutils->isConstantValue(res)) + operandsToShadow.insert(res.getResultNumber()); } } + SmallVector newOperands; - newOperands.reserve(termIface->getNumOperands() + operandsToShadow.size()); - for (OpOperand &operand : termIface->getOpOperands()) { + newOperands.reserve(origTerminator->getNumOperands() + + operandsToShadow.size()); + for (OpOperand &operand : origTerminator->getOpOperands()) { newOperands.push_back(gutils->getNewFromOriginal(operand.get())); if (operandsToShadow.contains(operand.getOperandNumber())) newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index d83e2da59a17..4598394b9e44 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -306,11 +306,7 @@ void mlir::enzyme::MGradientUtils::forceAugmentedReturns() { LogicalResult MGradientUtils::visitChild(Operation *op) { if (mode == DerivativeMode::ForwardMode) { - // In absence of a proper activity analysis, approximate it by treating any - // side effect-free operation producing constants as inactive. - // if (auto iface = dyn_cast(op)) { - if (!isa(op) && - !isa(op) && + if ((op->getBlock()->getTerminator() != op) && llvm::all_of(op->getResults(), [this](Value v) { return isConstantValue(v); }) && /*iface.hasNoEffect()*/ activityAnalyzer->isConstantOperation(TR, op)) { From b675db757a1df1646b9518042859cef41e477602 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 17 Feb 2024 22:43:36 -0500 Subject: [PATCH 102/131] MLIR add constantfp to common.td (#1734) --- enzyme/Enzyme/MLIR/Implementations/Common.td | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index f558d1225a76..10f9c1532014 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -62,6 +62,12 @@ class Inst : Operation : Operation { + string value = val; + string dialect = dialect_; + string opName = op_; +} + class ArithInst : Inst; class MathInst : Inst; From d7c94ddd6ac7bb026a741aed030734ba8299bbca Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 16:08:08 -0500 Subject: [PATCH 103/131] MLIR handle dual tablegen (#1735) * MLIR handle dual tablegen * fix * extend constantfp --- .../MLIR/Implementations/ArithDerivatives.td | 4 +-- enzyme/Enzyme/MLIR/Implementations/Common.td | 19 ++++++++-- .../CoreDialectsAutoDiffImplementations.cpp | 17 +++++++++ .../CoreDialectsAutoDiffImplementations.h | 7 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 35 ++++++++++++------- 5 files changed, 65 insertions(+), 17 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index 3ed038fa4cb6..3d53793be3af 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -28,6 +28,6 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y), [ (CheckedDivF (DiffeRet), $y), (NegF (MulF (CheckedDivF (DiffeRet), $y), (DivF $x, $y))) - ] - // (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y)) + ], + (CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y)) >; diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 10f9c1532014..fe33089f15e6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -36,11 +36,18 @@ class RegionTerminatorOp { string opName = opName_; } -class MLIRDerivative resultOps> { +class ForwardFromSummedReverseInternal { + int unused = unused_; +} +def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>; + + +class MLIRDerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> { string dialect = dialect_; string opName = opName_; dag PatternToMatch = patternToMatch; list ArgDerivatives = resultOps; + dag ArgDuals = forwardOps; } class Operation { @@ -54,6 +61,9 @@ class DiffeRetIndex indices_> { } def DiffeRet : DiffeRetIndex<[-1]>; +def Shadow : Operation { +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; @@ -62,10 +72,15 @@ class Inst : Operation : Operation { +def SelectIfActive : Operation { + +} + +class ConstantFP : Operation { string value = val; string dialect = dialect_; string opName = op_; + string type = type_; } class ArithInst : Inst; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 025946ef9651..f9b2d61b0dce 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -19,6 +19,23 @@ using namespace mlir; using namespace mlir::enzyme; +mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, + llvm::StringRef value) { + using namespace mlir; + if (auto T = dyn_cast(type)) { + size_t num = 1; + for (auto sz : T.getShape()) + num *= sz; + APFloat apvalue(T.getElementType().cast().getFloatSemantics(), + value); + SmallVector supportedValues(num, apvalue); + return DenseFPElementsAttr::get(type.cast(), supportedValues); + } + auto T = cast(type); + APFloat apvalue(T.getFloatSemantics(), value); + return FloatAttr::get(T, apvalue); +} + void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, OpBuilder &builder, MGradientUtils *gutils) { diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index a7ec6b179986..7ee0be2adb70 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -12,6 +12,9 @@ // //===----------------------------------------------------------------------===// +#ifndef ENZYMEMLIR_CORE_IMPL_H_ +#define ENZYMEMLIR_CORE_IMPL_H_ + #include "Interfaces/AutoDiffOpInterface.h" #include "mlir/Support/LogicalResult.h" @@ -198,5 +201,9 @@ void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); + +mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value); } // namespace enzyme } // namespace mlir + +#endif diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 1e1ac3697095..35c937f79807 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -385,7 +385,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << "({\n"; os << curIndent << INDENT << "// Computing SelectIfActive\n"; - os << curIndent << INDENT << "Value *imVal = nullptr;\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value imVal = nullptr;\n"; + else + os << curIndent << INDENT << "llvm::Value *imVal = nullptr;\n"; os << curIndent << INDENT << "if (!gutils->isConstantValue("; @@ -415,7 +418,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, retidx, origName, newFromOriginal, intrinsic); os << ";\n"; - if (!vector) { + if (!vector && intrinsic != MLIRDerivatives) { os << curIndent << INDENT << INDENT << "llvm::Value* vec_imVal = gutils->getWidth() == 1 ? imVal : " "UndefValue::get(gutils->getShadowType(imVal" @@ -440,16 +443,15 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << curIndent << "})"; return true; } else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) { - if (resultRoot->getNumArgs() != 1) - PrintFatalError(pattern->getLoc(), - "only single op constantfp supported"); - auto value = dyn_cast(Def->getValueInit("value")); if (!value) PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") + resultTree->getAsString()); if (intrinsic == MLIRDerivatives) { + if (resultRoot->getNumArgs() > 1) + PrintFatalError(pattern->getLoc(), + "only zero or single op constantfp supported"); os << builder << ".create<" << cast(Def->getValueInit("dialect"))->getValue() << "::" << cast(Def->getValueInit("opName"))->getValue() @@ -463,9 +465,17 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, assert(!isVec); ord = ord1; } - os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), "; + os << ord << ".getType(), "; + auto typeCast = + dyn_cast(Def->getValueInit("type"))->getValue(); + if (typeCast != "") + os << "(" << typeCast << ")"; + os << "mlir::enzyme::getConstantAttr(" << ord << ".getType(), "; os << "\"" << value->getValue() << "\"))"; } else { + if (resultRoot->getNumArgs() != 1) + PrintFatalError(pattern->getLoc(), + "only single op constantfp supported"); os << "ConstantFP::get("; if (resultRoot->getArgName(0)) { @@ -1269,9 +1279,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, for (Record *pattern : patterns) { DagInit *tree = pattern->getValueAsDag("PatternToMatch"); - DagInit *duals = nullptr; - if (intrinsic != MLIRDerivatives) - duals = pattern->getValueAsDag("ArgDuals"); + DagInit *duals = pattern->getValueAsDag("ArgDuals"); + assert(duals); // Emit RewritePattern for Pattern. ListInit *argOps = pattern->getValueAsListInit("ArgDerivatives"); @@ -1514,8 +1523,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } // TODO - if (!duals || - duals->getOperator()->getAsString() == + if (duals->getOperator()->getAsString() == "ForwardFromSummedReverseInternal" || cast(duals->getOperator()) ->getDef() @@ -1649,7 +1657,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } ArrayRef retidx{}; bool vectorValued = - handle(" ", "fwdnsrarg", os, pattern, duals, "Builder2", + handle(" ", "fwdnsrarg", os, pattern, duals, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", nameToOrdinal, /*lookup*/ false, retidx, origName, /*newFromOriginal*/ true, intrinsic); (void)vectorValued; From 1e1c0eb1c9b4ae3fa6b0acc2394e305b3fc4e042 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 18:00:16 -0500 Subject: [PATCH 104/131] MLIR fix custom fwd tblgen (#1736) --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 35c937f79807..84dce5f2a08c 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1651,7 +1651,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } else { if (intrinsic == MLIRDerivatives) { - os << " mlir::Value res = nullptr;\n"; + os << " mlir::Value res = "; } else { os << " Value *res = "; } From 2dcf5c5bfbc93805964a4a519665591bdee38046 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 18 Feb 2024 19:00:36 -0500 Subject: [PATCH 105/131] MLIR support globalexpr --- enzyme/Enzyme/MLIR/Implementations/Common.td | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index fe33089f15e6..83b89f7cf0a4 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -64,6 +64,10 @@ def DiffeRet : DiffeRetIndex<[-1]>; def Shadow : Operation { } +class GlobalExpr : Operation{ + string value = val; +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; From f2a9361619ae27c744e32734f0eae89a72a69b4a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 19:57:23 -0500 Subject: [PATCH 106/131] Handle new debug format conversion (#1737) --- enzyme/Enzyme/FunctionUtils.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 160ccc740f8a..fdba018bd600 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -754,6 +754,16 @@ void PreProcessCache::AlwaysInline(Function *NewF) { for (auto CI : ToInline) { InlineFunctionInfo IFI; +#if LLVM_VERSION_MAJOR >= 18 + auto F = CI->getCalledFunction(); + if (CI->getParent()->IsNewDbgInfoFormat != F->IsNewDbgInfoFormat) { + if (CI->getParent()->IsNewDbgInfoFormat) { + F->convertToNewDbgValues(); + } else { + F->convertFromNewDbgValues(); + } + } +#endif InlineFunction(*CI, IFI); } } From f3fca8d1443fb688087f2b09f18299082d668ae5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 22:18:45 -0500 Subject: [PATCH 107/131] TypeAnalysis cleanup null result (#1738) --- enzyme/Enzyme/AdjointGenerator.h | 12 ++--- enzyme/Enzyme/CApi.cpp | 2 +- enzyme/Enzyme/PreserveNVVM.cpp | 8 +-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 57 +++++++++------------ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h | 5 +- 5 files changed, 39 insertions(+), 45 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ae2e72aceb83..36b337329a2e 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -100,7 +100,7 @@ class AdjointGenerator : public llvm::InstVisitor { using namespace llvm; assert(TR.getFunction() == gutils->oldFunc); - for (auto &pair : TR.analyzer.analysis) { + for (auto &pair : TR.analyzer->analysis) { if (auto in = dyn_cast(pair.first)) { if (in->getParent()->getParent() != gutils->oldFunc) { llvm::errs() << "inf: " << *in->getParent()->getParent() << "\n"; @@ -4152,7 +4152,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (called) { subdata = &gutils->Logic.CreateAugmentedPrimal( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*return is used*/ false, /*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false, gutils->getWidth(), @@ -4368,7 +4368,7 @@ class AdjointGenerator : public llvm::InstVisitor { : nullptr, .forceAnonymousTape = false, .typeInfo = nextTypeInfo}, - TR.analyzer.interprocedural, subdata, + TR.analyzer->interprocedural, subdata, /*omp*/ true); if (subdata->returns.find(AugmentedStruct::Tape) != @@ -4896,7 +4896,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (called) { newcalled = gutils->Logic.CreateForwardDiff( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*returnValue*/ subretused, Mode, ((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(), tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args, @@ -5311,7 +5311,7 @@ class AdjointGenerator : public llvm::InstVisitor { Mode == DerivativeMode::ReverseModeCombined) { subdata = &gutils->Logic.CreateAugmentedPrimal( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*return is used*/ subretused, shadowReturnUsed, nextTypeInfo, overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd); if (Mode == DerivativeMode::ReverseModePrimal) { @@ -5689,7 +5689,7 @@ class AdjointGenerator : public llvm::InstVisitor { .additionalType = tape ? tape->getType() : nullptr, .forceAnonymousTape = false, .typeInfo = nextTypeInfo}, - TR.analyzer.interprocedural, subdata); + TR.analyzer->interprocedural, subdata); if (!newcalled) return; FT = cast(newcalled)->getFunctionType(); diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 15bc6c795aa2..006965496f43 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -297,7 +297,7 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) { void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, LLVMValueRef F) { FnTypeInfo FTI(eunwrap(CTI, cast(unwrap(F)))); - return (void *)&((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; + return (void *)((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; } void *EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G) { diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index a7df39596f0a..84b8b5b9540a 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -60,12 +60,12 @@ using namespace llvm; bool preserveLinkage(bool Begin, Function &F, bool Inlining = true) { if (Begin && !F.hasFnAttribute("prev_fixup")) { F.addFnAttr("prev_fixup"); + if (F.hasFnAttribute(Attribute::AlwaysInline)) + F.addFnAttr("prev_always_inline"); + if (F.hasFnAttribute(Attribute::NoInline)) + F.addFnAttr("prev_no_inline"); if (Inlining) { - if (F.hasFnAttribute(Attribute::AlwaysInline)) - F.addFnAttr("prev_always_inline"); F.removeFnAttr(Attribute::AlwaysInline); - if (F.hasFnAttribute(Attribute::NoInline)) - F.addFnAttr("prev_no_inline"); F.addFnAttr(Attribute::NoInline); } F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 04d1822d4e23..3204b5667086 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -5745,20 +5745,10 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { return TypeResults(analysis); } -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wnull-dereference" -#else -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wnull-dereference" -#endif + if (fn.Function->empty()) - return TypeResults(*(TypeAnalyzer *)nullptr); -#ifdef __clang__ -#pragma clang diagnostic pop -#else -#pragma GCC diagnostic pop -#endif + return TypeResults(nullptr); + auto res = analyzedFunctions.emplace(fn, new TypeAnalyzer(fn, *this)); auto &analysis = *res.first->second; @@ -5808,30 +5798,31 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { return TypeResults(analysis); } -TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(analyzer) {} +TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(&analyzer) {} +TypeResults::TypeResults(std::nullptr_t) : analyzer(nullptr) {} FnTypeInfo TypeResults::getAnalyzedTypeInfo() const { - FnTypeInfo res(analyzer.fntypeinfo.Function); - for (auto &arg : analyzer.fntypeinfo.Function->args()) { + FnTypeInfo res(analyzer->fntypeinfo.Function); + for (auto &arg : analyzer->fntypeinfo.Function->args()) { res.Arguments.insert(std::pair(&arg, query(&arg))); } res.Return = getReturnAnalysis(); - res.KnownValues = analyzer.fntypeinfo.KnownValues; + res.KnownValues = analyzer->fntypeinfo.KnownValues; return res; } FnTypeInfo TypeResults::getCallInfo(CallBase &CI, Function &fn) const { - return analyzer.getCallInfo(CI, fn); + return analyzer->getCallInfo(CI, fn); } TypeTree TypeResults::query(Value *val) const { if (auto inst = dyn_cast(val)) { - assert(inst->getParent()->getParent() == analyzer.fntypeinfo.Function); + assert(inst->getParent()->getParent() == analyzer->fntypeinfo.Function); } if (auto arg = dyn_cast(val)) { - assert(arg->getParent() == analyzer.fntypeinfo.Function); + assert(arg->getParent() == analyzer->fntypeinfo.Function); } - return analyzer.getAnalysis(val); + return analyzer->getAnalysis(val); } bool TypeResults::anyFloat(Value *val) const { @@ -5843,7 +5834,7 @@ bool TypeResults::anyFloat(Value *val) const { return dt.isFloat(); size_t ObjSize = 1; - auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); if (val->getType()->isSized()) ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; @@ -5871,7 +5862,7 @@ bool TypeResults::anyPointer(Value *val) const { return dt == BaseType::Pointer; size_t ObjSize = 1; - auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); if (val->getType()->isSized()) ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; @@ -5890,7 +5881,7 @@ bool TypeResults::anyPointer(Value *val) const { return false; } -void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer.dump(ss); } +void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer->dump(ss); } ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, bool pointerIntSame) const { @@ -5913,7 +5904,7 @@ ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; llvm::errs() << *inst->getParent()->getParent() << "\n"; - for (auto &pair : analyzer.analysis) { + for (auto &pair : analyzer->analysis) { llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() << "\n"; } @@ -5948,7 +5939,7 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, assert(val->getType()); auto q = query(val).Data0(); if (!(val->getType()->isPointerTy() || q[{}] == BaseType::Pointer)) { - llvm::errs() << *analyzer.fntypeinfo.Function << "\n"; + llvm::errs() << *analyzer->fntypeinfo.Function << "\n"; dump(); llvm::errs() << "val: " << *val << "\n"; } @@ -5973,7 +5964,7 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, } if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) { - auto &res = analyzer; + auto &res = *analyzer; if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; llvm::errs() << *inst->getParent()->getParent() << "\n"; @@ -6006,15 +5997,15 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, << "\n"; } } - llvm::errs() << "fn: " << *analyzer.fntypeinfo.Function << "\n"; + llvm::errs() << "fn: " << *analyzer->fntypeinfo.Function << "\n"; dump(); llvm::errs() << "could not deduce type of integer " << *val << " num:" << num << " q:" << q.str() << " \n"; llvm::DiagnosticLocation loc = - analyzer.fntypeinfo.Function->getSubprogram(); + analyzer->fntypeinfo.Function->getSubprogram(); Instruction *codeLoc = - &*analyzer.fntypeinfo.Function->getEntryBlock().begin(); + &*analyzer->fntypeinfo.Function->getEntryBlock().begin(); if (auto inst = dyn_cast(val)) { loc = inst->getDebugLoc(); codeLoc = inst; @@ -6117,15 +6108,15 @@ TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, } Function *TypeResults::getFunction() const { - return analyzer.fntypeinfo.Function; + return analyzer->fntypeinfo.Function; } TypeTree TypeResults::getReturnAnalysis() const { - return analyzer.getReturnAnalysis(); + return analyzer->getReturnAnalysis(); } std::set TypeResults::knownIntegralValues(Value *val) const { - return analyzer.knownIntegralValues(val); + return analyzer->knownIntegralValues(val); } std::set TypeAnalyzer::knownIntegralValues(Value *val) { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h index 6e036454143a..19c8e2d4d19f 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h @@ -156,9 +156,10 @@ class TypeAnalysis; /// on a given function class TypeResults { public: - TypeAnalyzer &analyzer; + TypeAnalyzer *analyzer; public: + TypeResults(std::nullptr_t); TypeResults(TypeAnalyzer &analyzer); ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound = true, bool pointerIntSame = false) const; @@ -258,6 +259,8 @@ class TypeAnalyzer : public llvm::InstVisitor { FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn); + TypeAnalyzer(TypeAnalysis &TA); + TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA, uint8_t direction = BOTH); From 026f8a4d26dad002f65e0d92dc9d71103a269290 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 23:04:34 -0500 Subject: [PATCH 108/131] Cleanup gutils to avoid null reference (#1739) --- enzyme/Enzyme/AdjointGenerator.h | 4 +- enzyme/Enzyme/CallDerivatives.cpp | 2 +- enzyme/Enzyme/EnzymeLogic.cpp | 26 ++--- enzyme/Enzyme/GradientUtils.cpp | 154 +++++++++++++++--------------- enzyme/Enzyme/GradientUtils.h | 10 +- 5 files changed, 98 insertions(+), 98 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 36b337329a2e..6a0af0d2b78b 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1542,7 +1542,7 @@ class AdjointGenerator : public llvm::InstVisitor { lc) && gutils->getNewFromOriginal(P0->getParent()) == lc.header) { SmallVector Latches; - gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (&SI != P0->getIncomingValueForBlock(Latch)) { @@ -2206,7 +2206,7 @@ class AdjointGenerator : public llvm::InstVisitor { lc) && gutils->getNewFromOriginal(P0->getParent()) == lc.header) { SmallVector Latches; - gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (&BO != P0->getIncomingValueForBlock(Latch)) { diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 155967e24a33..8c77a51f954d 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3323,7 +3323,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( // rematerialization is loop level. This is because one can have a // loop level cache, but a function level allocation (e.g. for stack // allocas). If we deleted it here, we would have no allocation! - auto AllocationLoop = gutils->OrigLI.getLoopFor(call.getParent()); + auto AllocationLoop = gutils->OrigLI->getLoopFor(call.getParent()); // An allocation within a loop, must definitionally be a loop level // allocation (but not always the other way around. if (AllocationLoop) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 731826bb793f..120722dc7fb6 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -981,7 +981,7 @@ void calculateUnusedValuesInFunction( if (newMemory) { bool foundStore = false; allInstructionsBetween( - gutils->OrigLI, cast(at), + *gutils->OrigLI, cast(at), const_cast(mti), [&](Instruction *I) -> bool { if (!I->mayWriteToMemory()) @@ -994,7 +994,7 @@ void calculateUnusedValuesInFunction( } if (writesToMemoryReadBy( - gutils->OrigAA, TLI, + *gutils->OrigAA, TLI, /*maybeReader*/ const_cast(mti), /*maybeWriter*/ I)) { foundStore = true; @@ -1143,7 +1143,7 @@ void calculateUnusedStoresInFunction( if (newMemory) { bool foundStore = false; allInstructionsBetween( - gutils->OrigLI, cast(at), + *gutils->OrigLI, cast(at), const_cast(mti), [&](Instruction *I) -> bool { if (!I->mayWriteToMemory()) return /*earlyBreak*/ false; @@ -1152,7 +1152,7 @@ void calculateUnusedStoresInFunction( // if (I == &MTI) return; if (writesToMemoryReadBy( - gutils->OrigAA, TLI, + *gutils->OrigAA, TLI, /*maybeReader*/ const_cast(mti), /*maybeWriter*/ I)) { foundStore = true; @@ -1552,7 +1552,7 @@ bool legalCombinedForwardReverse( auto consider = [&](Instruction *user) { if (!user->mayReadFromMemory()) return false; - if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI, + if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI, /*maybeReader*/ user, /*maybeWriter*/ inst)) { @@ -1585,7 +1585,7 @@ bool legalCombinedForwardReverse( if (!post->mayWriteToMemory()) return false; - if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI, + if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI, /*maybeReader*/ inst, /*maybeWriter*/ post)) { if (EnzymePrintPerf) { @@ -2398,9 +2398,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, gutils->rematerializableAllocations, gutils->TR, - gutils->OrigAA, gutils->oldFunc, + *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, DerivativeMode::ReverseModePrimal, omp); const std::map> overwritten_args_map = CA.compute_overwritten_args_for_callsites(); @@ -3346,7 +3346,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, gutils->getNewFromOriginal(orig->getParent()) == loopContext.header && loopContext.exitBlocks.size() == 1) { SmallVector Latches; - gutils->OrigLI.getLoopFor(orig->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(orig->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (activeUses[0] != orig->getIncomingValueForBlock(Latch)) { @@ -4080,9 +4080,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient( gutils->computeGuaranteedFrees(); CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, gutils->rematerializableAllocations, gutils->TR, - gutils->OrigAA, gutils->oldFunc, + *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, key.mode, omp); const std::map> overwritten_args_map = (augmenteddata) ? augmenteddata->overwritten_args_map @@ -4734,10 +4734,10 @@ Function *EnzymeLogic::CreateForwardDiff( gutils->computeGuaranteedFrees(); CacheAnalysis CA( gutils->allocationsWithGuaranteedFree, - gutils->rematerializableAllocations, gutils->TR, gutils->OrigAA, + gutils->rematerializableAllocations, gutils->TR, *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, mode, omp); const std::map> overwritten_args_map = CA.compute_overwritten_args_for_callsites(); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index d1dedbc70353..9025b890a2f3 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -169,19 +169,19 @@ GradientUtils::GradientUtils( : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(), OrigDT(oldFunc_->empty() - ? *((DominatorTree *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((DominatorTree *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), OrigPDT(oldFunc_->empty() - ? *((PostDominatorTree *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((PostDominatorTree *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), OrigLI(oldFunc_->empty() - ? *((LoopInfo *)nullptr) - : Logic.PPC.FAM.getResult(*oldFunc_)), + ? ((LoopInfo *)nullptr) + : &Logic.PPC.FAM.getResult(*oldFunc_)), OrigSE(oldFunc_->empty() - ? *((ScalarEvolution *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((ScalarEvolution *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), notForAnalysis(getGuaranteedUnreachable(oldFunc_)), ATA(oldFunc_->empty() @@ -191,8 +191,8 @@ GradientUtils::GradientUtils( notForAnalysis, TLI_, constantvalues_, activevals_, ReturnActivity)), overwritten_args_map_ptr(nullptr), tid(nullptr), numThreads(nullptr), - OrigAA(oldFunc_->empty() ? *((AAResults *)nullptr) - : Logic.PPC.getAAResultsFromFunction(oldFunc_)), + OrigAA(oldFunc_->empty() ? ((AAResults *)nullptr) + : &Logic.PPC.getAAResultsFromFunction(oldFunc_)), TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) { if (oldFunc_->empty()) return; @@ -242,7 +242,7 @@ GradientUtils::GradientUtils( for (BasicBlock &BB : *oldFunc) { bool legal = true; for (auto BRet : ReturningBlocks) { - if (!(BRet == &BB || OrigDT.dominates(&BB, BRet))) { + if (!(BRet == &BB || OrigDT->dominates(&BB, BRet))) { legal = false; break; } @@ -508,7 +508,7 @@ Value *GradientUtils::getOrInsertConditionalIndex(Value *val, LoopContext &lc, bool GradientUtils::assumeDynamicLoopOfSizeOne(Loop *L) const { if (!EnzymeInactiveDynamic) return false; - auto OL = OrigLI.getLoopFor(isOriginal(L->getHeader())); + auto OL = OrigLI->getLoopFor(isOriginal(L->getHeader())); assert(OL); for (auto OB : OL->getBlocks()) { for (auto &OI : *OB) { @@ -3788,7 +3788,7 @@ bool GradientUtils::legalRecompute(const Value *val, struct { Function *func; const LoopInfo &FLI; - } options[2] = {{newFunc, LI}, {oldFunc, OrigLI}}; + } options[2] = {{newFunc, LI}, {oldFunc, *OrigLI}}; for (const auto &tup : options) { if (parent->getParent() == tup.func) { for (auto &val : phi->incoming_values()) { @@ -3928,7 +3928,7 @@ bool GradientUtils::legalRecompute(const Value *val, const_cast(orig), [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && writesToMemoryReadBy( - OrigAA, TLI, + *OrigAA, TLI, /*maybeReader*/ const_cast(orig), /*maybeWriter*/ I)) { failed = true; @@ -3951,7 +3951,7 @@ bool GradientUtils::legalRecompute(const Value *val, } origStart = origStart->getNextNode(); } while (true); - if (OrigDT.dominates(origStart, const_cast(orig))) { + if (OrigDT->dominates(origStart, const_cast(orig))) { bool failed = false; allInstructionsBetween( @@ -3959,7 +3959,7 @@ bool GradientUtils::legalRecompute(const Value *val, const_cast(orig), [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && writesToMemoryReadBy( - OrigAA, TLI, + *OrigAA, TLI, /*maybeReader*/ const_cast(orig), /*maybeWriter*/ I)) { failed = true; @@ -5290,7 +5290,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, if (F && isMemFreeLibMFunction(F->getName())) { continue; } - if (llvm::isModOrRefSet(OrigAA.getModRefInfo(CI, Loc))) { + if (llvm::isModOrRefSet(OrigAA->getModRefInfo(CI, Loc))) { seen = true; llvm::errs() << " cannot shadow-inline global " << *oval << " due to " << *CI << "\n"; @@ -6336,9 +6336,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // + or because the loop nests share no ancestry bool loopLegal = true; - for (Loop *idx = OrigLI.getLoopFor(orig); idx != nullptr; + for (Loop *idx = OrigLI->getLoopFor(orig); idx != nullptr; idx = idx->getParentLoop()) { - for (Loop *fdx = OrigLI.getLoopFor(forwardBlock); fdx != nullptr; + for (Loop *fdx = OrigLI->getLoopFor(forwardBlock); fdx != nullptr; fdx = fdx->getParentLoop()) { if (idx == fdx) { loopLegal = false; @@ -6542,9 +6542,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // << "\n"; allInstructionsBetween( - OrigLI, orig2, origInst, [&](Instruction *I) -> bool { + *OrigLI, orig2, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -6558,12 +6558,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (auto ar1 = dyn_cast(scev1)) { if (auto ar2 = dyn_cast(scev2)) { - if (ar1->getStart() != OrigSE.getCouldNotCompute() && + if (ar1->getStart() != OrigSE->getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar1->getStepRecurrence(OrigSE) != - OrigSE.getCouldNotCompute() && - ar1->getStepRecurrence(OrigSE) == - ar2->getStepRecurrence(OrigSE)) { + ar1->getStepRecurrence(*OrigSE) != + OrigSE->getCouldNotCompute() && + ar1->getStepRecurrence(*OrigSE) == + ar2->getStepRecurrence(*OrigSE)) { LoopContext l1; getContext(ar1->getLoop()->getHeader(), l1); @@ -6591,7 +6591,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); + auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); auto Arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); @@ -6599,7 +6599,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Arch == Triple::amdgcn ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local : 3; - if (EnzymeSharedForward && scev1 != OrigSE.getCouldNotCompute() && + if (EnzymeSharedForward && scev1 != OrigSE->getCouldNotCompute() && cast(orig_liobj->getType())->getAddressSpace() == SharedAddrSpace) { Value *resultValue = nullptr; @@ -6608,7 +6608,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, assert(pair.first->getType() == pair.second->getType()); newavail[pair.first] = pair.second; } - allDomPredecessorsOf(origInst, OrigDT, [&](Instruction *pred) { + allDomPredecessorsOf(origInst, *OrigDT, [&](Instruction *pred) { if (auto SI = dyn_cast(pred)) { // auto NewSI = cast(getNewFromOriginal(SI)); auto si2obj = getBaseObject(SI->getPointerOperand()); @@ -6619,10 +6619,10 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool lastStore = true; bool interveningSync = false; allInstructionsBetween( - OrigLI, SI, origInst, [&](Instruction *potentialAlias) { + *OrigLI, SI, origInst, [&](Instruction *potentialAlias) { if (!potentialAlias->mayWriteToMemory()) return false; - if (!writesToMemoryReadBy(OrigAA, TLI, origInst, + if (!writesToMemoryReadBy(*OrigAA, TLI, origInst, potentialAlias)) return false; @@ -6640,7 +6640,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (mid == SI) return false; - if (!writesToMemoryReadBy(OrigAA, TLI, origInst, + if (!writesToMemoryReadBy(*OrigAA, TLI, origInst, mid)) { return false; } @@ -6667,16 +6667,16 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (!lastStore) return false; - auto scev2 = OrigSE.getSCEV(SI->getPointerOperand()); + auto scev2 = OrigSE->getSCEV(SI->getPointerOperand()); bool legal = scev1 == scev2; if (auto ar2 = dyn_cast(scev2)) { if (auto ar1 = dyn_cast(scev1)) { - if (ar2->getStart() != OrigSE.getCouldNotCompute() && + if (ar2->getStart() != OrigSE->getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar2->getStepRecurrence(OrigSE) != - OrigSE.getCouldNotCompute() && - ar1->getStepRecurrence(OrigSE) == - ar2->getStepRecurrence(OrigSE)) { + ar2->getStepRecurrence(*OrigSE) != + OrigSE->getCouldNotCompute() && + ar1->getStepRecurrence(*OrigSE) == + ar2->getStepRecurrence(*OrigSE)) { LoopContext l1; getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), @@ -6729,15 +6729,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, ValueToValueMapTy ThreadLookup; bool legal = true; for (size_t i = 0; i < svals.size(); i++) { - auto ss = OrigSE.getSCEV(svals[i]); - auto ls = OrigSE.getSCEV(lvals[i]); + auto ss = OrigSE->getSCEV(svals[i]); + auto ls = OrigSE->getSCEV(lvals[i]); if (cast(ss->getType())->getBitWidth() > cast(ls->getType())->getBitWidth()) { - ls = OrigSE.getZeroExtendExpr(ls, ss->getType()); + ls = OrigSE->getZeroExtendExpr(ls, ss->getType()); } if (cast(ss->getType())->getBitWidth() < cast(ls->getType())->getBitWidth()) { - ls = OrigSE.getTruncateExpr(ls, ss->getType()); + ls = OrigSE->getTruncateExpr(ls, ss->getType()); } if (ls != ss) { if (auto II = dyn_cast(svals[i])) { @@ -6824,7 +6824,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto origPH = cast_or_null(isOriginal(ctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { goto noSpeedCache; } @@ -6837,10 +6837,10 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; allInstructionsBetween( - OrigLI, &*origTerm, origInst, + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ tmpload, /*maybeWriter*/ I)) { failed = true; @@ -6858,15 +6858,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; auto origPH = cast_or_null(isOriginal(nctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { break; } Instruction *origTerm = origPH->getTerminator(); allInstructionsBetween( - OrigLI, &*origTerm, origInst, + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ tmpload, /*maybeWriter*/ I)) { failed = true; @@ -6958,7 +6958,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); + auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); // Store in memcpy opt Value *lim = nullptr; BasicBlock *ctx = nullptr; @@ -6966,7 +6966,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Value *offset = nullptr; if (auto ar1 = dyn_cast(scev1)) { if (auto step = - dyn_cast(ar1->getStepRecurrence(OrigSE))) { + dyn_cast(ar1->getStepRecurrence(*OrigSE))) { if (step->getAPInt() != loadSize) goto noSpeedCache; @@ -6983,7 +6983,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto origPH = cast_or_null(isOriginal(ctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { goto noSpeedCache; } @@ -7002,7 +7002,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, SmallVector InsertedInstructions; { SCEVExpander OrigExp( - OrigSE, ctx->getParent()->getParent()->getDataLayout(), + *OrigSE, ctx->getParent()->getParent()->getDataLayout(), "enzyme"); OrigExp.setInsertPoint( @@ -7023,7 +7023,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // instructions. llvm::stable_sort(InsertedInstructions, [this](Instruction *A, Instruction *B) { - return OrigDT.dominates(A, B); + return OrigDT->dominates(A, B); }); for (auto a : InsertedInstructions) { assert(!isa(a)); @@ -7054,7 +7054,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, available.clear(); for (auto I : llvm::reverse(InsertedInstructions)) { assert(I->getNumUses() == 0); - OrigSE.forgetValue(I); + OrigSE->forgetValue(I); I->eraseFromParent(); } #endif @@ -7067,9 +7067,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; allInstructionsBetween( - OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -7091,14 +7091,14 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; auto origPH = cast_or_null(isOriginal(nctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { break; } Instruction *origTerm = origPH->getTerminator(); allInstructionsBetween( - OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -7891,7 +7891,7 @@ void GradientUtils::computeMinCache() { for (BasicBlock &BB : *oldFunc) { if (notForAnalysis.count(&BB)) continue; - auto L = OrigLI.getLoopFor(&BB); + auto L = OrigLI->getLoopFor(&BB); auto invariant = [&](Value *V) { if (isa(V)) @@ -7899,20 +7899,20 @@ void GradientUtils::computeMinCache() { if (isa(V)) return true; if (auto I = dyn_cast(V)) { - if (!L->contains(OrigLI.getLoopFor(I->getParent()))) + if (!L->contains(OrigLI->getLoopFor(I->getParent()))) return true; } return false; }; for (Instruction &I : BB) { if (auto PN = dyn_cast(&I)) { - if (!OrigLI.isLoopHeader(&BB)) + if (!OrigLI->isLoopHeader(&BB)) continue; if (PN->getType()->isIntegerTy()) { bool legal = true; SmallPtrSet Increment; for (auto B : PN->blocks()) { - if (OrigLI.getLoopFor(B) == L) { + if (OrigLI->getLoopFor(B) == L) { if (auto BO = dyn_cast( PN->getIncomingValueForBlock(B))) { if (BO->getOpcode() == BinaryOperator::Add) { @@ -8000,7 +8000,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(&BB); L != nullptr; + for (Loop *L = OrigLI->getLoopFor(&BB); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8042,7 +8042,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(cast(V)->getParent()); + for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8068,8 +8068,8 @@ void GradientUtils::computeMinCache() { SetVector MinReq; DifferentialUseAnalysis::minCut(oldFunc->getParent()->getDataLayout(), - OrigLI, Recomputes, Intermediates, Required, - MinReq, this, TLI); + *OrigLI, Recomputes, Intermediates, + Required, MinReq, this, TLI); SmallPtrSet NeedGraph; for (Value *V : MinReq) NeedGraph.insert(V); @@ -8098,7 +8098,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(cast(V)->getParent()); + for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8746,13 +8746,13 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { } // Find the outermost loop of all stores, and the allocation/lifetime - Loop *outer = OrigLI.getLoopFor(V->getParent()); + Loop *outer = OrigLI->getLoopFor(V->getParent()); if (LifetimeStarts.size() == 1) { - outer = OrigLI.getLoopFor((*LifetimeStarts.begin())->getParent()); + outer = OrigLI->getLoopFor((*LifetimeStarts.begin())->getParent()); } for (auto S : stores) { - outer = getAncestor(outer, OrigLI.getLoopFor(S->getParent())); + outer = getAncestor(outer, OrigLI->getLoopFor(S->getParent())); } // May now read pointers for storing into other pointers. Therefore we @@ -8766,8 +8766,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, - outer)) { + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI, + res, outer)) { EmitWarning("NotPromotable", *LI, " Could not promote shadow allocation ", *V, " due to pointer load ", *LI, @@ -8821,7 +8821,7 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI, res, outer)) { EmitWarning("NotPromotable", *LI, " Could not promote allocation ", *V, " due to load ", *LI, @@ -8837,8 +8837,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI.loadCall, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI.loadCall, - res, outer)) { + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, + LI.loadCall, res, outer)) { EmitWarning("NotPromotable", *LI.loadCall, " Could not promote allocation ", *V, " due to load-like call ", *LI.loadCall, @@ -9049,7 +9049,7 @@ void GradientUtils::computeGuaranteedFrees() { bool hasPDFree = false; if (dc->getParent() == CI->getParent() || - OrigPDT.dominates(CI->getParent(), dc->getParent())) { + OrigPDT->dominates(CI->getParent(), dc->getParent())) { hasPDFree = true; } diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 6a8c9fd61b9e..7552bec38dab 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -128,10 +128,10 @@ class GradientUtils : public CacheUtility { DerivativeMode mode; llvm::Function *oldFunc; llvm::ValueMap invertedPointers; - llvm::DominatorTree &OrigDT; - llvm::PostDominatorTree &OrigPDT; - llvm::LoopInfo &OrigLI; - llvm::ScalarEvolution &OrigSE; + llvm::DominatorTree *OrigDT; + llvm::PostDominatorTree *OrigPDT; + llvm::LoopInfo *OrigLI; + llvm::ScalarEvolution *OrigSE; /// (Original) Blocks which dominate all returns llvm::SmallPtrSet BlocksDominatingAllReturns; @@ -353,7 +353,7 @@ class GradientUtils : public CacheUtility { } public: - llvm::AAResults &OrigAA; + llvm::AAResults *OrigAA; TypeAnalysis &TA; TypeResults TR; bool omp; From cc7ef6f2c885301d49e3ad452d1be018e9974be5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 23:13:29 -0500 Subject: [PATCH 109/131] Fix integration test if not built with compiler-rt (#1740) --- enzyme/test/Integration/ReverseMode/dbginfo.c | 15 +++++++++++++++ enzyme/test/Integration/ReverseMode/taylorlog.c | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/enzyme/test/Integration/ReverseMode/dbginfo.c b/enzyme/test/Integration/ReverseMode/dbginfo.c index 31d9bafe4439..b094dffb5c6d 100644 --- a/enzyme/test/Integration/ReverseMode/dbginfo.c +++ b/enzyme/test/Integration/ReverseMode/dbginfo.c @@ -13,6 +13,21 @@ double __enzyme_autodiff(void*, double, unsigned); +// May be needed if not built with compiler-rt +double __powidf2(double a, int b) { + const int recip = b < 0; + double r = 1; + while (1) { + if (b & 1) + r *= a; + b /= 2; + if (b == 0) + break; + a *= a; + } + return recip ? 1 / r : r; +} + static double taylorlog(double x, unsigned SINCOSN) { double sum = 0; for(int i=1; i<=SINCOSN; i++) { diff --git a/enzyme/test/Integration/ReverseMode/taylorlog.c b/enzyme/test/Integration/ReverseMode/taylorlog.c index fdecb8aac7af..522928e3f60f 100644 --- a/enzyme/test/Integration/ReverseMode/taylorlog.c +++ b/enzyme/test/Integration/ReverseMode/taylorlog.c @@ -13,6 +13,21 @@ double __enzyme_autodiff(void*, double, unsigned); +// May be needed if not built with compiler-rt +double __powidf2(double a, int b) { + const int recip = b < 0; + double r = 1; + while (1) { + if (b & 1) + r *= a; + b /= 2; + if (b == 0) + break; + a *= a; + } + return recip ? 1 / r : r; +} + static double taylorlog(double x, unsigned SINCOSN) { double sum = 0; for(int i=1; i<=SINCOSN; i++) { From b0cee973864ca6a6bbf6e6a7fe96cab00a7f7550 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 19 Feb 2024 18:55:08 -0500 Subject: [PATCH 110/131] Test assert-less build (#1742) --- enzyme/CMakeLists.txt | 2 +- enzyme/Enzyme/ActivityAnalysis.cpp | 7 ++++-- enzyme/Enzyme/AdjointGenerator.h | 4 +++ enzyme/Enzyme/CApi.cpp | 2 ++ enzyme/Enzyme/CacheUtility.cpp | 3 +++ enzyme/Enzyme/CallDerivatives.cpp | 1 + enzyme/Enzyme/Clang/EnzymeClang.cpp | 4 +-- enzyme/Enzyme/DiffeGradientUtils.cpp | 9 +++++++ enzyme/Enzyme/DifferentialUseAnalysis.cpp | 7 ++++-- enzyme/Enzyme/EnzymeLogic.cpp | 2 ++ enzyme/Enzyme/FunctionUtils.cpp | 2 ++ enzyme/Enzyme/GradientUtils.cpp | 25 ++++++++++++++++--- enzyme/Enzyme/GradientUtils.h | 2 ++ enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 3 +-- enzyme/Enzyme/TypeAnalysis/BaseType.h | 2 ++ enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp | 5 ++-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 5 ++++ enzyme/Enzyme/TypeAnalysis/TypeTree.h | 2 ++ enzyme/Enzyme/Utils.h | 9 +++++-- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 2 ++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 4 +-- 21 files changed, 84 insertions(+), 18 deletions(-) diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 13a9b4973c54..32d09c19e302 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -14,7 +14,7 @@ add_definitions(-DENZYME_VERSION_MINOR=${ENZYME_MINOR_VERSION}) add_definitions(-DENZYME_VERSION_PATCH=${ENZYME_PATCH_VERSION}) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else") +SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull") SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "-O2") diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 243e71dec04c..274c2fcea183 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1049,9 +1049,11 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } assert(TR.getFunction() == I->getParent()->getParent()); } +#ifndef NDEBUG if (auto Arg = dyn_cast(Val)) { assert(TR.getFunction() == Arg->getParent()); } +#endif // Void values are definitionally inactive if (Val->getType()->isVoidTy()) @@ -2305,6 +2307,9 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { // this value is inactive, we are inactive Since we won't look at uses to // prove, we can inductively assume this is inactive if (directions & UP) { + if (!UpHypothesis) + UpHypothesis = + std::shared_ptr(new ActivityAnalyzer(*this, UP)); if (directions == UP && !isa(Val)) { if (isInstructionInactiveFromOrigin(TR, Val, true)) { InsertConstantValue(TR, Val); @@ -2320,8 +2325,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } } } else { - UpHypothesis = - std::shared_ptr(new ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantValues.insert(Val); if (UpHypothesis->isInstructionInactiveFromOrigin(TR, Val, true)) { insertConstantsFrom(TR, *UpHypothesis); diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 6a0af0d2b78b..dbcd0947d342 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -408,6 +408,7 @@ class AdjointGenerator : public llvm::InstVisitor { constantval |= gutils->isConstantValue(&I); Type *type = gutils->getShadowType(I.getType()); + (void)type; auto *newi = dyn_cast(gutils->getNewFromOriginal(&I)); @@ -621,6 +622,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (primalNeededInReverse) { inst = gutils->cacheForReverse(BuilderZ, newi, getIndex(&I, CacheType::Self, BuilderZ)); + (void)inst; assert(inst->getType() == type); if (Mode == DerivativeMode::ReverseModeGradient || @@ -3777,6 +3779,7 @@ class AdjointGenerator : public llvm::InstVisitor { setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), Builder2); } + (void)vdiff; switch (ID) { @@ -5201,6 +5204,7 @@ class AdjointGenerator : public llvm::InstVisitor { // Note sometimes whattype mistakenly says something should be // constant [because composed of integer pointers alone] + (void)argType; assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); } else { diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 006965496f43..8c1c295b54f6 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -1184,6 +1184,7 @@ LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, for (auto s : sub) { uint64_t ival; bool b = s.getAsInteger(10, ival); + (void)b; assert(!b); previdx.push_back(ival); } @@ -1241,6 +1242,7 @@ LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, APInt Offset(width, 0); bool success = collectOffset(cast(gep), DL, width, VariableOffsets, Offset); + (void)success; assert(success); Value *start = ConstantInt::get(T, Offset); for (auto &pair : VariableOffsets) diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index 68c6e46784cf..cba31f314000 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -200,6 +200,7 @@ std::pair FindCanonicalIV(Loop *L, Type *Ty) { } llvm::errs() << *Header << "\n"; assert(0 && "Could not find canonical IV"); + return std::pair(nullptr, nullptr); } // Attempt to rewrite all phinode's in the loop in terms of the @@ -1330,8 +1331,10 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx, IRBuilder<> &BuilderM, Value *val, AllocaInst *cache, MDNode *TBAA) { assert(BuilderM.GetInsertBlock()->getParent() == newFunc); +#ifndef NDEBUG if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == newFunc); +#endif IRBuilder<> v(BuilderM.GetInsertBlock()); v.SetInsertPoint(BuilderM.GetInsertBlock(), BuilderM.GetInsertPoint()); v.setFastMathFlags(getFast()); diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 8c77a51f954d..6280dc9f1b44 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -4047,6 +4047,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( return true; } assert(!unnecessaryValues.count(rmat.first)); + (void)primalNeededInReverse; assert(primalNeededInReverse); } } diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index ed01f1bf5739..a34a6429dcf7 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -141,10 +141,10 @@ class EnzymePlugin final : public clang::ASTConsumer { using namespace clang; DeclGroupRef::iterator it; - Visitor v(CI); + // Visitor v(CI); // Forcibly require emission of all libdevice for (it = dg.begin(); it != dg.end(); ++it) { - v.TraverseDecl(*it); + // v.TraverseDecl(*it); if (auto FD = dyn_cast(*it)) { if (!FD->hasAttr()) continue; diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 43a14051c4fd..552d2894d759 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -164,10 +164,12 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( AllocaInst *DiffeGradientUtils::getDifferential(Value *val) { assert(val); +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif assert(inversionAllocs); Type *type = getShadowType(val->getType()); @@ -195,10 +197,12 @@ AllocaInst *DiffeGradientUtils::getDifferential(Value *val) { } Value *DiffeGradientUtils::diffe(Value *val, IRBuilder<> &BuilderM) { +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif if (isConstantValue(val)) { llvm::errs() << *newFunc << "\n"; @@ -336,6 +340,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, llvm::errs() << "} start=" << start << " size=" << size << " storeSize=" << storeSize << " val=" << *val << "\n"; assert(0 && "unhandled accumulate with partial sizes"); + return {}; } SmallVector @@ -345,10 +350,12 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, assert(mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ReverseModeCombined); +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif SmallVector addedSelects; @@ -659,6 +666,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, void DiffeGradientUtils::setDiffe(Value *val, Value *toset, IRBuilder<> &BuilderM) { +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) @@ -668,6 +676,7 @@ void DiffeGradientUtils::setDiffe(Value *val, Value *toset, llvm::errs() << *val << "\n"; } assert(!isConstantValue(val)); +#endif toset = SanitizeDerivatives(val, toset, BuilderM); if (mode == DerivativeMode::ForwardMode || mode == DerivativeMode::ForwardModeSplit) { diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 8f2e27b3140e..b0b8c48ab5c9 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -54,9 +54,11 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( const SmallPtrSetImpl &oldUnreachable, QueryType qtype, bool *recursiveUse) { TypeResults const &TR = gutils->TR; +#ifndef NDEBUG if (auto ainst = dyn_cast(val)) { assert(ainst->getParent()->getParent() == gutils->oldFunc); } +#endif bool shadow = qtype == QueryType::Shadow || qtype == QueryType::ShadowByConstPrimal; @@ -79,8 +81,7 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( if (!user) { if (EnzymePrintDiffUse) - llvm::errs() << " Need: of " << *val << " in reverse as unknown user " - << *user << "\n"; + llvm::errs() << " Need: of " << *val << " in reverse as nullptr user\n"; return true; } @@ -794,12 +795,14 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, } } } +#ifndef NDEBUG for (auto R : Required) { assert(Intermediates.count(R)); } for (auto R : Recomputes) { assert(Intermediates.count(R)); } +#endif Graph Orig = G; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 120722dc7fb6..0767b3e52529 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3632,6 +3632,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( if (hasMetadata(key.todiff, "enzyme_gradient")) { std::set seen; +#ifndef NDEBUG DIFFE_TYPE subretType = whatType(key.todiff->getReturnType(), DerivativeMode::ReverseModeGradient, /*intAreConstant*/ false, seen); @@ -3639,6 +3640,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( key.todiff->getReturnType()->isEmptyTy()) subretType = DIFFE_TYPE::CONSTANT; assert(subretType == key.retType); +#endif if (key.mode == DerivativeMode::ReverseModeCombined) { auto res = getDefaultFunctionTypeForGradient( diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index fdba018bd600..218ca9549dfc 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -6295,6 +6295,7 @@ class Constraints : public std::enable_shared_from_this { assert(t != Type::None); assert(c.size() != 0); assert(c.size() != 1); +#ifndef NDEBUG SmallVector tmp(c.begin(), c.end()); for (unsigned i = 0; i < tmp.size(); i++) for (unsigned j = 0; j < i; j++) @@ -6317,6 +6318,7 @@ class Constraints : public std::enable_shared_from_this { if (auto s = dyn_cast(tmp[j]->node)) assert(s->getLoop() != tmp[i]->Loop); } +#endif } bool operator==(const Constraints &rhs) const { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 9025b890a2f3..57f21d8d9785 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -607,12 +607,14 @@ BasicBlock *GradientUtils::getOriginalFromNew(const BasicBlock *newinst) const { Value *GradientUtils::isOriginal(const Value *newinst) const { if (isa(newinst) || isa(newinst)) return const_cast(newinst); +#ifndef NDEBUG if (auto arg = dyn_cast(newinst)) { assert(arg->getParent() == newFunc); } if (auto inst = dyn_cast(newinst)) { assert(inst->getParent()->getParent() == newFunc); } +#endif auto found = newToOriginalFn.find(newinst); if (found == newToOriginalFn.end()) return nullptr; @@ -2519,11 +2521,13 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, return malloc; } +#ifndef NDEBUG if (auto CI = dyn_cast(malloc)) { if (auto F = CI->getCalledFunction()) { assert(F->getName() != "omp_get_thread_num"); } } +#endif if (malloc->getType()->isTokenTy()) { llvm::errs() << " oldFunc: " << *oldFunc << "\n"; @@ -3425,8 +3429,9 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { lctx, placeholder->getType(), placeholder->getName(), /*shouldFree*/ true); assert(cache); + Value *placeholder_tmp = placeholder; found = insert_or_assign( - scopeMap, (Value *&)placeholder, + scopeMap, placeholder_tmp, std::pair, LimitContext>(cache, lctx)); } auto cache = found->second.first; @@ -4754,12 +4759,14 @@ void GradientUtils::setPtrDiffe(Instruction *orig, Value *ptr, Value *newval, SyncScope::ID syncScope, Value *mask, ArrayRef noAlias, ArrayRef scopes) { +#ifndef NDEBUG if (auto inst = dyn_cast(ptr)) { assert(inst->getParent()->getParent() == oldFunc); } if (auto arg = dyn_cast(ptr)) { assert(arg->getParent() == oldFunc); } +#endif Value *origptr = ptr; @@ -5034,12 +5041,14 @@ llvm::Value *GradientUtils::recursiveFAdd(llvm::IRBuilder<> &B, Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, bool nullShadow) { assert(oval); +#ifndef NDEBUG if (auto inst = dyn_cast(oval)) { assert(inst->getParent()->getParent() == oldFunc); } if (auto arg = dyn_cast(oval)) { assert(arg->getParent() == oldFunc); } +#endif if (isa(oval)) { return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); @@ -6830,7 +6839,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Instruction *origTerm = origPH->getTerminator(); if (!origTerm) - llvm::errs() << *origTerm << "\n"; + llvm::errs() << *origPH << "\n"; assert(origTerm); IRBuilder<> OB(origTerm); LoadInst *tmpload = OB.CreateLoad(AT, orig_liobj, "'tmpload"); @@ -7031,6 +7040,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap, /*scope*/ nullptr, /*cache*/ false)); assert(uw->getType() == a->getType()); +#ifndef NDEBUG for (size_t i = 0; i < uw->getNumOperands(); i++) { auto op = uw->getOperand(i); if (auto arg = dyn_cast(op)) @@ -7038,6 +7048,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, else if (auto inst = dyn_cast(op)) assert(inst->getParent()->getParent() == newFunc); } +#endif available[a] = uw; unwrappedLoads.erase(cast(uw)); } @@ -7259,7 +7270,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, AllocaInst *cache = createCacheForScope( lctx, inst->getType(), inst->getName(), /*shouldFree*/ true); assert(cache); - insert_or_assign(scopeMap, (Value *&)inst, + Value *inst_tmp = inst; + insert_or_assign(scopeMap, inst_tmp, std::pair, LimitContext>( cache, lctx)); } @@ -7333,9 +7345,11 @@ void GradientUtils::branchToCorrespondingTarget( if (replacePHIs->size() == 0) return; +#ifndef NDEBUG for (auto x : *replacePHIs) { assert(targetToPreds.find(x.first) != targetToPreds.end()); } +#endif } if (targetToPreds.size() == 1) { @@ -8177,11 +8191,13 @@ void GradientUtils::forceActiveDetection() { bool GradientUtils::isConstantValue(Value *val) const { if (auto inst = dyn_cast(val)) { + (void)inst; assert(inst->getParent()->getParent() == oldFunc); return ATA->isConstantValue(TR, val); } if (auto arg = dyn_cast(val)) { + (void)arg; assert(arg->getParent() == oldFunc); return ATA->isConstantValue(TR, val); } @@ -8892,6 +8908,7 @@ void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) { // Check that the replacement doesn't already exist in the mapping // thereby resulting in a conflict. +#ifndef NDEBUG { auto found = newToOriginalFn.find(A); if (found != newToOriginalFn.end()) { @@ -8899,6 +8916,7 @@ void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) { assert(foundB == newToOriginalFn.end()); } } +#endif CacheUtility::replaceAWithB(A, B, storeInCache); } @@ -9167,6 +9185,7 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, libfunc = LibFunc_malloc; } else { bool res = TLI.getLibFunc(allocationfn, libfunc); + (void)res; assert(res && "ought find known allocation fn"); } diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 7552bec38dab..98cae4145639 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -601,11 +601,13 @@ class GradientUtils : public CacheUtility { llvm::ArrayRef diffs, llvm::IRBuilder<> &Builder, Func rule) { if (width > 1) { +#ifndef NDEBUG for (auto diff : diffs) { assert(diff); assert(llvm::cast(diff->getType())->getNumElements() == width); } +#endif llvm::Type *wrappedType = llvm::ArrayType::get(diffType, width); llvm::Value *res = llvm::UndefValue::get(wrappedType); for (unsigned int i = 0; i < getWidth(); ++i) { diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 3025b5963fcf..7d4ccead7764 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -199,8 +199,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( continue; } - auto term = oBB.getTerminator(); - assert(term); + assert(oBB.getTerminator()); auto first = oBB.begin(); auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end()); diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index 2f5275c4b6f8..71d6e0910408 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -57,6 +57,7 @@ static inline std::string to_string(BaseType t) { return "Unknown"; } assert(0 && "unknown inttype"); + return ""; } /// Convert string to BaseType @@ -72,5 +73,6 @@ static inline BaseType parseBaseType(llvm::StringRef str) { if (str == "Unknown") return BaseType::Unknown; assert(0 && "Unknown BaseType string"); + return BaseType::Unknown; } #endif diff --git a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp index 899ce46d9f9e..2376c9b23353 100644 --- a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp +++ b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp @@ -83,7 +83,6 @@ TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) { assert(0 && "There shouldn't be non-constant-size arrays in Rust"); } } - return Result; } else if (Type.getTag() == dwarf::DW_TAG_structure_type || Type.getTag() == dwarf::DW_TAG_union_type) { DINodeArray Elements = Type.getElements(); @@ -108,11 +107,11 @@ TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) { firstSubTT = !firstSubTT; } } - return Result; } else { assert(0 && "Composite types other than arrays, structs and unions are not " "supported by Rust debug info parser"); } + return Result; } TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) { @@ -134,6 +133,7 @@ TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) { assert(0 && "Derived types other than pointers and members are not " "supported by Rust debug info parser"); } + return {}; } TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) { @@ -151,6 +151,7 @@ TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) { assert(0 && "Types other than floating-points, integers, arrays, pointers, " "slices, and structs are not supported by debug info parser"); } + return {}; } bool isU8PointerType(DIType &type) { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 3204b5667086..86a5243c438a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -1860,6 +1860,7 @@ void TypeAnalyzer::visitGEPOperator(GEPOperator &gep) { MapVector VariableOffsets; bool legalOffset = collectOffset(&gep, DL, BitWidth, VariableOffsets, constOffset); + (void)legalOffset; assert(legalOffset); SmallVector, 4> idnext; @@ -5816,12 +5817,14 @@ FnTypeInfo TypeResults::getCallInfo(CallBase &CI, Function &fn) const { } TypeTree TypeResults::query(Value *val) const { +#ifndef NDEBUG if (auto inst = dyn_cast(val)) { assert(inst->getParent()->getParent() == analyzer->fntypeinfo.Function); } if (auto arg = dyn_cast(val)) { assert(arg->getParent() == analyzer->fntypeinfo.Function); } +#endif return analyzer->getAnalysis(val); } @@ -5989,8 +5992,10 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, if (auto arg = dyn_cast(val)) { llvm::errs() << *arg->getParent() << "\n"; for (auto &pair : res.analysis) { +#ifndef NDEBUG if (auto in = dyn_cast(pair.first)) assert(in->getParent()->getParent() == arg->getParent()); +#endif llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() << " int: " + to_string(res.knownIntegralValues(pair.first)) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index b69b83e5b355..3843623ca32a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -343,11 +343,13 @@ class TypeTree : public std::enable_shared_from_this { /// Whether this TypeTree contains any information bool isKnown() const { +#ifndef NDEBUG for (const auto &pair : mapping) { // we should assert here as we shouldn't keep any unknown maps for // efficiency assert(pair.second.isKnown()); } +#endif return mapping.size() != 0; } diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index f4f72f8cf52e..37edfc1a985d 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1158,6 +1158,7 @@ static inline llvm::Optional getAllocationIndexFromCall(T *op) bool b = AttrList.getAttribute("enzyme_allocator") .getValueAsString() .getAsInteger(10, res); + (void)b; assert(!b); #if LLVM_VERSION_MAJOR >= 16 return std::optional(res); @@ -1172,6 +1173,7 @@ static inline llvm::Optional getAllocationIndexFromCall(T *op) bool b = called->getFnAttribute("enzyme_allocator") .getValueAsString() .getAsInteger(10, res); + (void)b; assert(!b); #if LLVM_VERSION_MAJOR >= 16 return std::optional(res); @@ -1228,6 +1230,7 @@ static inline std::vector getDeallocationIndicesFromCall(T *op) { for (auto ind : inds) { ssize_t Result; bool b = ind.getAsInteger(10, Result); + (void)b; assert(!b); vinds.push_back(Result); } @@ -1355,10 +1358,11 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) { auto AttrList = Call->getAttributes().getAttributes( llvm::AttributeList::FunctionIndex); if (AttrList.hasAttribute("enzyme_pointermath")) { - size_t res; + size_t res = 0; bool failed = AttrList.getAttribute("enzyme_pointermath") .getValueAsString() .getAsInteger(10, res); + (void)failed; assert(!failed); V = Call->getArgOperand(res); continue; @@ -1386,10 +1390,11 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) { auto AttrList = fn->getAttributes().getAttributes( llvm::AttributeList::FunctionIndex); if (AttrList.hasAttribute("enzyme_pointermath")) { - size_t res; + size_t res = 0; bool failed = AttrList.getAttribute("enzyme_pointermath") .getValueAsString() .getAsInteger(10, res); + (void)failed; assert(!failed); V = Call->getArgOperand(res); continue; diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 39cb55384ad9..3689c77d4d74 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -444,6 +444,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) { break; } } + (void)hasInt; assert(hasInt); os << " Type* cublas_retty = nullptr;\n" @@ -482,6 +483,7 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) { break; } } + (void)foundInt; assert(foundInt && "no int type found in blas call"); os << " // fpType already given by blas type (s, d, c, z) \n" diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 84dce5f2a08c..04f63e1ad3ac 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1340,13 +1340,13 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, StringRef name = cast(lst->getValues()[0])->getValue(); if (lst->size() >= 2) { auto min = cast(lst->getValues()[1])->getValue(); - int min_int; + int min_int = 100000; min.getAsInteger(10, min_int); if (min.size() != 0 && LLVM_VERSION_MAJOR < min_int) continue; if (lst->size() >= 3) { auto max = cast(lst->getValues()[2])->getValue(); - int max_int; + int max_int = 0; max.getAsInteger(10, max_int); if (max.size() != 0 && LLVM_VERSION_MAJOR > max_int) continue; From d04f10346fc75a58eb96604db62de53d07fab6fe Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 19 Feb 2024 20:14:07 -0500 Subject: [PATCH 111/131] Add ResultTypes tablegen --- enzyme/Enzyme/MLIR/Implementations/Common.td | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 83b89f7cf0a4..33d9f12c2f37 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -87,6 +87,8 @@ class ConstantFP : Ope string type = type_; } +def ResultTypes : GlobalExprgetResultTypes()">; + class ArithInst : Inst; class MathInst : Inst; From 1ea2e0be2db660ee54bde212194f7446b5479e8d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 20 Feb 2024 09:31:20 -0500 Subject: [PATCH 112/131] TypeTree speedup shiftindicies (#1744) --- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 173 ++++++++++++++++++++++---- 1 file changed, 152 insertions(+), 21 deletions(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index 3843623ca32a..2cb4a7506c94 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -724,17 +724,51 @@ class TypeTree : public std::enable_shared_from_this { } /// Replace mappings in the range in [offset, offset+maxSize] with those in - // [addOffset, addOffset + maxSize]. In other worse, select all mappings in + // [addOffset, addOffset + maxSize]. In other words, select all mappings in // [offset, offset+maxSize] then add `addOffset` TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset, const int maxSize, size_t addOffset = 0) const { + + // If we have no terms 1+ layer deep return the current result as a shift + // won't change anything. This also makes the latercode simpler as it + // can assume at least a first index exists. + if (minIndices.size() == 0) + return *this; + + // If we have no size in return, simply return an empty type tree. Again + // this simplifies later code which can assume that a minus one expantion + // will always result in an added variable (which would not be the case + // on a size == 0). + if (maxSize == 0) + return TypeTree(); + TypeTree Result; + // The normal orIn / insert methods do collision checking, which is slow + // (and presently O(n)). This is because an expansion of a -1 which could + // conflict with a fixed value. Consider calling this + // ShiftIndicies(offset=0, maxSize=2, addOffset=0, tt={[-1]:Integer, + // [1]:Anything}) the -1 would expand to [0]:Int, [1]:Int, which would need + // to be merged with [1]:Anything + // + // The only possible values which can cause a conflict are minus -1's. + // As a result, we start with a fast insertion (aka without check) of + // non-expanded values, since they just do a literal shift which needs no + // extra checking, besides bounds checks. + // + // Since we're doing things manually, we also need to manually preserve TT + // invariants. Specifically, TT limits all values to have offsets < + // MAX_OFFSET, unless it is the smallest offset at that depth. (e.g. so we + // can still hava typetree {[123456]:Int}, even if limit is 100). + // + // First compute the minimum 0th index to be kept. + Result.minIndices.resize(minIndices.size(), INT_MAX); + for (const auto &pair : mapping) { if (pair.first.size() == 0) { if (pair.second == BaseType::Pointer || pair.second == BaseType::Anything) { - Result.insert(pair.first, pair.second); + Result.mapping.emplace(pair.first, pair.second); continue; } @@ -743,55 +777,152 @@ class TypeTree : public std::enable_shared_from_this { llvm_unreachable("ShiftIndices called on a nonpointer/anything"); } - std::vector next(pair.first); + int next0 = pair.first[0]; + + if (next0 == -1) { + if (maxSize == -1) { + // Max size does not clip the next index + + // If we have a follow up offset add, we lose the -1 since we only + // represent [0, inf) with -1 not the [addOffset, inf) required here + if (addOffset != 0) { + next0 = addOffset; + } + + } else { + // We're going to insert addOffset + 0...maxSize so the new minIndex + // is addOffset + Result.minIndices[0] = addOffset; + for (size_t i = 1, sz = pair.first.size(); i < sz; i++) + if (pair.first[i] < Result.minIndices[i]) + Result.minIndices[i] = pair.first[i]; + continue; + } + } else { + // Too small for range + if (next0 < offset) { + continue; + } + next0 -= offset; + + if (maxSize != -1) { + if (next0 >= maxSize) + continue; + } + + next0 += addOffset; + } + if (next0 < Result.minIndices[0]) + Result.minIndices[0] = next0; + for (size_t i = 1, sz = pair.first.size(); i < sz; i++) + if (pair.first[i] < Result.minIndices[i]) + Result.minIndices[i] = pair.first[i]; + } + + // Max depth of actual inserted values + size_t maxInsertedDepth = 0; + + // Insert all + for (const auto &pair : mapping) { + if (pair.first.size() == 0) + continue; - if (next[0] == -1) { + int next0 = pair.first[0]; + + if (next0 == -1) { if (maxSize == -1) { // Max size does not clip the next index // If we have a follow up offset add, we lose the -1 since we only // represent [0, inf) with -1 not the [addOffset, inf) required here if (addOffset != 0) { - next[0] = addOffset; + next0 = addOffset; } } else { - // This needs to become 0...maxSize as seen below + // This needs to become 0...maxSize handled separately as it is the + // only insertion that could have collisions + continue; } } else { // Too small for range - if (next[0] < offset) { + if (next0 < offset) { continue; } - next[0] -= offset; + next0 -= offset; if (maxSize != -1) { - if (next[0] >= maxSize) + if (next0 >= maxSize) continue; } - next[0] += addOffset; + next0 += addOffset; } - size_t chunk = 1; - auto op = operator[]({pair.first[0]}); - if (auto flt = op.isFloat()) { - chunk = dl.getTypeSizeInBits(flt) / 8; - } else if (op == BaseType::Pointer) { - chunk = dl.getPointerSizeInBits() / 8; + // If after moving this would not merit being kept for being a min index + // or being within the max type offset, skip it. + if (next0 > MaxTypeOffset) { + bool minIndex = next0 == Result.minIndices[0]; + if (!minIndex) + for (size_t i = 1; i < pair.first.size(); i++) { + if (pair.first[i] == Result.minIndices[i]) { + minIndex = true; + break; + } + } + if (!minIndex) + continue; } - if (next[0] == -1 && maxSize != -1) { + std::vector next(pair.first); + next[0] = next0; + Result.mapping.emplace(next, pair.second); + if (next.size() > maxInsertedDepth) + maxInsertedDepth = next.size(); + } + + // Insert and expand the minus one, if needed + if (maxSize != -1) + for (const auto &pair : mapping) { + if (pair.first.size() == 0) + continue; + if (pair.first[0] != -1) + continue; + + size_t chunk = 1; + std::vector next(pair.first); + auto op = operator[]({next[0]}); + if (auto flt = op.isFloat()) { + chunk = dl.getTypeSizeInBits(flt) / 8; + } else if (op == BaseType::Pointer) { + chunk = dl.getPointerSizeInBits() / 8; + } auto offincr = (chunk - offset % chunk) % chunk; + bool inserted = false; for (int i = offincr; i < maxSize; i += chunk) { next[0] = i + addOffset; - Result.orIn(next, pair.second); + ConcreteType prev(pair.second); + // We can use faster checks here, since we know there can be no + // -1's that we would conflict with, only conflicts from previous + // fixed value insertions. + auto found = Result.mapping.find(next); + if (found != Result.mapping.end()) { + // orIn returns if changed, update the value in the map if so + // with the new value. + if (prev.orIn(found->second, /*pointerIntSame*/ false)) + found->second = prev; + } else { + Result.mapping.emplace(next, pair.second); + } + inserted = true; } - } else { - Result.orIn(next, pair.second); + if (inserted && next.size() > maxInsertedDepth) + maxInsertedDepth = next.size(); } - } + // Resize minIndices down if we dropped any higher-depth indices for being + // out of scope. + Result.minIndices.resize(maxInsertedDepth); return Result; } From e8ca2b1de3b770c767145d027b357bed97178bb0 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 20 Feb 2024 16:03:49 +0100 Subject: [PATCH 113/131] Split tests into two packages in Bazel build. (#1741) This separates MLIR tests and regular Enzyme integration tests into independent packages, making sure only a subset of targets is rebuilt and retested when requested to decrease overhead. --- enzyme/BUILD | 53 ++--------------------------------- enzyme/test/BUILD | 32 +++++++++++++++++++++ enzyme/test/Integration/BUILD | 29 +++++++++++++++++++ enzyme/test/MLIR/BUILD | 25 +++++++++++++++++ 4 files changed, 88 insertions(+), 51 deletions(-) create mode 100644 enzyme/test/BUILD create mode 100644 enzyme/test/Integration/BUILD create mode 100644 enzyme/test/MLIR/BUILD diff --git a/enzyme/BUILD b/enzyme/BUILD index 7c78c75a4051..6f4459f0ebd2 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -1,7 +1,5 @@ -load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@llvm-project//llvm:tblgen.bzl", "gentbl") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") licenses(["notice"]) @@ -640,52 +638,5 @@ cc_binary( ], ) -# Generates lit config input file by applying path placeholder substitutions -# similar to the configure_lit_site_cfg CMake macro. -expand_template( - name = "lit_site_cfg_py", - testonly = True, - out = "test/lit.site.cfg.py", - substitutions = { - "@LLVM_VERSION_MAJOR@": "18", - "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", - "@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@ENZYME_SOURCE_DIR@": "", - "@ENZYME_BINARY_DIR@": "", - "@TARGET_TRIPLE@": "", - "@TARGETS_TO_BUILD@": "ALL", - "@LLVM_SHLIBEXT@": ".so", - }, - template = "test/lit.site.cfg.py.in", - visibility = ["//visibility:private"], -) - -[ - lit_test( - name = "%s.test" % src, - srcs = [src], - data = [ - ":enzyme-clang", - ":enzyme-clang++", - ":enzyme-opt", - ":enzymemlir-opt", - ":test/lit.cfg.py", - ":test/lit.site.cfg.py", - "@llvm-project//clang:builtin_headers_gen", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:count", - "@llvm-project//llvm:lli", - "@llvm-project//llvm:not", - ] + glob(["test/**/*.h"]), - ) - for src in glob( - [ - "test/**/*.mlir", - "test/Integration/**/*.c", - "test/Integration/**/.cpp", - ], - exclude = ["test/**/*omp*.c"], - ) -] +exports_files(["run_lit.sh"]) + diff --git a/enzyme/test/BUILD b/enzyme/test/BUILD new file mode 100644 index 000000000000..47143966810d --- /dev/null +++ b/enzyme/test/BUILD @@ -0,0 +1,32 @@ +# Enzyme tests. + +load("@llvm-project//llvm:lit_test.bzl", "package_path") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +# Generates lit config input file by applying path placeholder substitutions +# similar to the configure_lit_site_cfg CMake macro. +expand_template( + name = "lit_site_cfg_py", + testonly = True, + out = "lit.site.cfg.py", + substitutions = { + "@LLVM_VERSION_MAJOR@": "18", + "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", + "@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@ENZYME_SOURCE_DIR@": "", + "@ENZYME_BINARY_DIR@": "", + "@TARGET_TRIPLE@": "", + "@TARGETS_TO_BUILD@": "ALL", + "@LLVM_SHLIBEXT@": ".so", + }, + template = "lit.site.cfg.py.in", + visibility = [":__subpackages__"], +) + +exports_files( + ["lit.cfg.py"], + visibility = [":__subpackages__"], +) + diff --git a/enzyme/test/Integration/BUILD b/enzyme/test/Integration/BUILD new file mode 100644 index 000000000000..ab14f8fca82e --- /dev/null +++ b/enzyme/test/Integration/BUILD @@ -0,0 +1,29 @@ +# Enzyme integration tests. + +load("@llvm-project//llvm:lit_test.bzl", "lit_test") + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + "//:enzyme-clang", + "//:enzyme-clang++", + "//:enzyme-opt", + "//test:lit.cfg.py", + "//test:lit.site.cfg.py", + "@llvm-project//clang:builtin_headers_gen", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["**/*.h"]), + ) + for src in glob( + [ + "**/*.c", + "**/.cpp", + ], + exclude = ["**/*omp*.c"], + ) +] diff --git a/enzyme/test/MLIR/BUILD b/enzyme/test/MLIR/BUILD new file mode 100644 index 000000000000..0af1496b3a5c --- /dev/null +++ b/enzyme/test/MLIR/BUILD @@ -0,0 +1,25 @@ +# MLIR-specific tests for Enzyme. + +load("@llvm-project//llvm:lit_test.bzl", "lit_test") + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + "//:enzymemlir-opt", + "//test:lit.cfg.py", + "//test:lit.site.cfg.py", + "@llvm-project//clang:builtin_headers_gen", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["**/*.h"]), + ) + for src in glob( + [ + "**/*.mlir", + ], + ) +] From f600e2e6dfdb4312a36e7f9043f2d4231b4f1795 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 20 Feb 2024 17:20:02 +0100 Subject: [PATCH 114/131] silence warnings in bazel (#1745) * Split tests into two packages in Bazel build. This separates MLIR tests and regular Enzyme integration tests into independent packages, making sure only a subset of targets is rebuilt and retested when requested to decrease overhead. * silence warnings in bazel There are a bunch of these in the codebase --- enzyme/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/BUILD b/enzyme/BUILD index 6f4459f0ebd2..e582b435d387 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -161,6 +161,8 @@ cc_library( "-DENZYME_VERSION_MAJOR=0", "-DENZYME_VERSION_MINOR=0", "-DENZYME_VERSION_PATCH=79", + "-Wno-unused-variable", + "-Wno-return-type", ], data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], From 1beb98b51442d50652eaa3ffb9574f4720d611f1 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 20 Feb 2024 20:42:44 -0500 Subject: [PATCH 115/131] [ActivityAnalysis] Remove isConstantValue call in activity analysis (#1608) * Remove cop2 * Add integration test * Add test * Update enzyme/test/ActivityAnalysis/integration.ll * Update test * Update test * Add missing return * Format and test * Back gt * Test fix trial * Lower test llvm to 15 * Update enzyme/test/ActivityAnalysis/integration.ll * Update activity printer for opaque type * update -> * Update activity analysis * Update if/elif llvm * Update llvm versioning --- enzyme/Enzyme/ActivityAnalysis.cpp | 8 +- enzyme/Enzyme/ActivityAnalysisPrinter.cpp | 38 +++++---- enzyme/test/ActivityAnalysis/integration.ll | 93 +++++++++++++++++++++ 3 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 enzyme/test/ActivityAnalysis/integration.ll diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 274c2fcea183..0f8ee3c6161d 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -2073,14 +2073,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { << "\n"; if (auto SI = dyn_cast(I)) { bool cop = !Hypothesis->isConstantValue(TR, SI->getValueOperand()); - bool cop2 = !Hypothesis->isConstantValue(TR, SI->getPointerOperand()); + // bool cop2 = !Hypothesis->isConstantValue(TR, + // SI->getPointerOperand()); if (EnzymePrintActivity) - llvm::errs() << " -- store potential activity: " << (int)cop << "," - << (int)cop2 << "," + llvm::errs() << " -- store potential activity: " << (int)cop << " - " << *SI << " of " << " Val=" << *Val << "\n"; potentialStore = I; - if (cop && cop2) + if (cop) // && cop2) potentiallyActiveStore = SI; } else if (auto MTI = dyn_cast(I)) { bool cop = !Hypothesis->isConstantValue(TR, MTI->getArgOperand(1)); diff --git a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp b/enzyme/Enzyme/ActivityAnalysisPrinter.cpp index a69d7c9e56e8..bfca7860fc45 100644 --- a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp +++ b/enzyme/Enzyme/ActivityAnalysisPrinter.cpp @@ -89,14 +89,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) { if (a.getType()->isFPOrFPVectorTy()) { dt = ConcreteType(a.getType()->getScalarType()); } else if (a.getType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR >= 17 -#else - auto et = a.getType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR < 17 +#if LLVM_VERSION_MAJOR >= 13 + if (a.getContext().supportsTypedPointers()) { +#endif + auto et = a.getType()->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 13 } +#endif #endif } else if (a.getType()->isIntOrIntVectorTy()) { dt = ConcreteType(BaseType::Integer); @@ -113,14 +118,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) { if (F.getReturnType()->isFPOrFPVectorTy()) { dt = ConcreteType(F.getReturnType()->getScalarType()); } else if (F.getReturnType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR >= 17 -#else - auto et = F.getReturnType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR < 17 +#if LLVM_VERSION_MAJOR >= 13 + if (F.getContext().supportsTypedPointers()) { +#endif + auto et = F.getReturnType()->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 13 } +#endif #endif } else if (F.getReturnType()->isIntOrIntVectorTy()) { dt = ConcreteType(BaseType::Integer); diff --git a/enzyme/test/ActivityAnalysis/integration.ll b/enzyme/test/ActivityAnalysis/integration.ll new file mode 100644 index 000000000000..28ce009a9b5b --- /dev/null +++ b/enzyme/test/ActivityAnalysis/integration.ll @@ -0,0 +1,93 @@ +; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %OPnewLoadEnzyme -passes="print-activity-analysis" -activity-analysis-func=f.preprocess -S | FileCheck %s; fi + +declare void @free(ptr) + +declare ptr @malloc(i64) + +; This function just returns 2*input, its derivate should be 2.0. +define void @f.preprocess(ptr %param, i64 %mallocsize, ptr %res) { + + ; arithmetic block, changing anything here makes the bug go away + %buffer1 = call ptr @malloc(i64 %mallocsize) + %tmp = call ptr @malloc(i64 72) + %ptrtoint = ptrtoint ptr %tmp to i64 + %and = and i64 %ptrtoint, -64 + %inttoptr = inttoptr i64 %and to ptr + %loadarg = load double, ptr %param + %storedargmul = fmul double %loadarg, 4.000000e+00 + store double %storedargmul, ptr %inttoptr + call void @free(ptr %tmp) + store double %storedargmul, ptr %buffer1 + + ; prep arg 0 by setting the aligned pointer to the input + %arg0 = alloca { ptr, ptr, i64 } + %arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1 + store ptr %param, ptr %arg0_aligned + + ; prep arg 1 by setting the aligned pointer to buffer1 + %arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] } + %arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1 + store ptr %buffer1, ptr %arg1_aligned + + ; prep arg 2 by setting the aligned pointer to buffer2 + %arg2 = alloca { ptr, ptr, i64 } + %arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1 + %buffer2 = call ptr @malloc(i64 8) + store ptr %buffer2, ptr %arg2_aligned + + ; nested call, required for bug + call void @nested(ptr %arg0, ptr %arg1, ptr %arg2) + + ; return a result from this function, needs to be positioned after arithmetic block for bug + %x = load double, ptr %param + %y = fmul double %x, 2.0 + store double %y, ptr %res + + ret void +} + +; Identity function, 2nd argument required for bug (but not used) +define void @nested(ptr %arg0, ptr %arg1, ptr %arg2) { + + ; load aligned pointer from %arg0 & load argument value + %loadarg = load { ptr, ptr, i64 }, ptr %arg0 + %extractarg = extractvalue { ptr, ptr, i64 } %loadarg, 1 + %loadextractarg = load double, ptr %extractarg + + ; load aligned pointer from %arg2 & store result value + %loadarg2 = load { ptr, ptr, i64 }, ptr %arg2 + %extractarg2 = extractvalue { ptr, ptr, i64 } %loadarg2, 1 + store double %loadextractarg, ptr %extractarg2 + + ret void +} + +; CHECK: ptr %param: icv:0 +; CHECK-NEXT: i64 %mallocsize: icv:1 +; CHECK-NEXT: ptr %res: icv:0 + +; CHECK: %buffer1 = call ptr @malloc(i64 %mallocsize): icv:0 ici:1 +; CHECK-NEXT: %tmp = call ptr @malloc(i64 72): icv:1 ici:1 +; CHECK-NEXT: %ptrtoint = ptrtoint ptr %tmp to i64: icv:1 ici:1 +; CHECK-NEXT: %and = and i64 %ptrtoint, -64: icv:1 ici:1 +; CHECK-NEXT: %inttoptr = inttoptr i64 %and to ptr: icv:1 ici:1 +; CHECK-NEXT: %loadarg = load double, ptr %param, align 8: icv:0 ici:0 +; CHECK-NEXT: %storedargmul = fmul double %loadarg, 4.000000e+00: icv:0 ici:0 +; CHECK-NEXT: store double %storedargmul, ptr %inttoptr, align 8: icv:1 ici:1 +; CHECK-NEXT: call void @free(ptr %tmp): icv:1 ici:1 +; CHECK-NEXT: store double %storedargmul, ptr %buffer1, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg0 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: store ptr %param, ptr %arg0_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: store ptr %buffer1, ptr %arg1_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg2 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: %buffer2 = call ptr @malloc(i64 8): icv:0 ici:1 +; CHECK-NEXT: store ptr %buffer2, ptr %arg2_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: call void @nested(ptr %arg0, ptr %arg1, ptr %arg2): icv:1 ici:0 +; CHECK-NEXT: %x = load double, ptr %param, align 8: icv:0 ici:0 +; CHECK-NEXT: %y = fmul double %x, 2.000000e+00: icv:0 ici:0 +; CHECK-NEXT: store double %y, ptr %res, align 8: icv:1 ici:0 +; CHECK-NEXT: ret void: icv:1 ici:1 From d4cb8b3dd921ead3a5f3c8e29fd071077402c36b Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 21 Feb 2024 09:51:07 -0500 Subject: [PATCH 116/131] TypeAnalysis permit passing info as fn parm attrs (#1746) --- enzyme/Enzyme/AdjointGenerator.h | 53 +++++++------ enzyme/Enzyme/EnzymeLogic.cpp | 36 ++++++++- enzyme/Enzyme/FunctionUtils.cpp | 22 +++++- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 67 ++++++++++++++-- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 87 ++++++++++++++++++++- 5 files changed, 231 insertions(+), 34 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index dbcd0947d342..9f1d3cd65dfc 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4757,10 +4757,11 @@ class AdjointGenerator : public llvm::InstVisitor { // call.getParamAttr(i, Attribute::StructRet).getValueAsType())); #endif } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - structAttrs[args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); + } for (auto ty : PrimalParamAttrsToPreserve) if (call.getAttributes().hasParamAttr(i, ty)) { auto attr = call.getAttributes().getParamAttr(i, ty); @@ -4815,15 +4816,16 @@ class AdjointGenerator : public llvm::InstVisitor { structAttrs[args.size()].push_back(attr); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - if (gutils->getWidth() == 1) { - structAttrs[args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } else { - structAttrs[args.size()].push_back( - Attribute::get(call.getContext(), "enzyme_sret_v")); + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + if (gutils->getWidth() == 1) { + structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); + } else if (attr == "enzymejl_returnRoots") { + structAttrs[args.size()].push_back( + Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + } } - } if (call.paramHasAttr(i, Attribute::StructRet)) { if (gutils->getWidth() == 1) { structAttrs[args.size()].push_back( @@ -5050,10 +5052,11 @@ class AdjointGenerator : public llvm::InstVisitor { if (call.isByValArgument(i)) { preByVal[pre_args.size()] = call.getParamByValType(i); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - structAttrs[pre_args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + structAttrs[pre_args.size()].push_back(call.getParamAttr(i, attr)); + } if (call.paramHasAttr(i, Attribute::StructRet)) { structAttrs[pre_args.size()].push_back( #if LLVM_VERSION_MAJOR >= 12 @@ -5146,15 +5149,17 @@ class AdjointGenerator : public llvm::InstVisitor { structAttrs[pre_args.size()].push_back(attr); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - if (gutils->getWidth() == 1) { - structAttrs[pre_args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } else { - structAttrs[pre_args.size()].push_back( - Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + if (gutils->getWidth() == 1) { + structAttrs[pre_args.size()].push_back( + call.getParamAttr(i, attr)); + } else if (attr == "enzymejl_returnRoots") { + structAttrs[pre_args.size()].push_back( + Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + } } - } if (call.paramHasAttr(i, Attribute::StructRet)) { if (gutils->getWidth() == 1) { structAttrs[pre_args.size()].push_back( diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 0767b3e52529..76fc71921654 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -82,6 +82,7 @@ #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex +#define getAttribute getAttributeAtIndex #define removeAttribute removeAttributeAtIndex #endif @@ -2784,7 +2785,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( NewF->addParamAttr(attrIndex, Attribute::NoAlias); } for (auto name : {"enzyme_sret", "enzyme_sret_v", "enzymejl_returnRoots", - "enzymejl_returnRoots_v"}) + "enzymejl_returnRoots_v", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) if (nf->getAttributes().hasParamAttr(attrIndex, name)) { NewF->addParamAttr(attrIndex, nf->getAttributes().getParamAttr(attrIndex, name)); @@ -2796,6 +2798,38 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( ++attrIndex; } +#if LLVM_VERSION_MAJOR >= 14 + for (auto attr : {"enzyme_ta_norecur"}) + if (nf->getAttributes().hasAttributeAtIndex(AttributeList::FunctionIndex, + attr)) { + NewF->addFnAttr( + nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (nf->getAttributes().hasAttributeAtIndex(AttributeList::ReturnIndex, + attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } +#else + for (auto attr : {"enzyme_ta_norecur"}) + if (nf->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) { + NewF->addFnAttr( + nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (nf->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } +#endif + SmallVector Returns; #if LLVM_VERSION_MAJOR >= 13 CloneFunctionInto(NewF, nf, VMap, CloneFunctionChangeType::LocalChangesOnly, diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 218ca9549dfc..3e2ade41b1c9 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2180,8 +2180,24 @@ Function *PreProcessCache::CloneFunctionWithReturns( VMapO->getMDMap() = VMap.getMDMap(); } + for (auto attr : {"enzyme_ta_norecur"}) + if (F->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) { + NewF->addAttribute( + AttributeList::FunctionIndex, + F->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + F->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } + bool hasPtrInput = false; unsigned ii = 0, jj = 0; + for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) { if (F->hasParamAttribute(ii, Attribute::StructRet)) { NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret")); @@ -2204,7 +2220,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // Attribute::ElementType)); #endif } - for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + for (auto attr : + {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) if (F->getAttributes().hasParamAttr(ii, attr)) { NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr)); for (auto ty : PrimalParamAttrsToPreserve) @@ -2250,7 +2267,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( NewF->addParamAttr(jj + 1, attr); } - for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + for (auto attr : + {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) if (F->getAttributes().hasParamAttr(ii, attr)) { if (width == 1) NewF->addParamAttr(jj + 1, diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 86a5243c438a..2515169ae0e4 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -61,6 +61,12 @@ #include +#if LLVM_VERSION_MAJOR >= 14 +#define getAttribute getAttributeAtIndex +#define hasAttribute hasAttributeAtIndex +#define addAttribute addAttributeAtIndex +#endif + using namespace llvm; extern "C" { @@ -1207,18 +1213,57 @@ void TypeAnalyzer::considerTBAA() { } if (CallBase *call = dyn_cast(&I)) { +#if LLVM_VERSION_MAJOR >= 14 + size_t num_args = call->arg_size(); +#else + size_t num_args = call->getNumArgOperands(); +#endif + + if (call->getAttributes().hasAttribute(AttributeList::ReturnIndex, + "enzyme_type")) { + auto attr = call->getAttributes().getAttribute( + AttributeList::ReturnIndex, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call, TT, call); + } + for (size_t i = 0; i < num_args; i++) { + if (call->getAttributes().hasParamAttr(i, "enzyme_type")) { + auto attr = call->getAttributes().getParamAttr(i, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call->getArgOperand(i), TT, call); + } + } + Function *F = call->getCalledFunction(); + + if (F) { + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, + "enzyme_type")) { + auto attr = F->getAttributes().getAttribute( + AttributeList::ReturnIndex, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call, TT, call); + } + size_t f_num_args = F->arg_size(); + for (size_t i = 0; i < f_num_args; i++) { + if (F->getAttributes().hasParamAttr(i, "enzyme_type")) { + auto attr = F->getAttributes().getParamAttr(i, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call->getArgOperand(i), TT, call); + } + } + } + if (auto castinst = dyn_cast(call->getCalledOperand())) { if (castinst->isCast()) if (auto fn = dyn_cast(castinst->getOperand(0))) { F = fn; } } -#if LLVM_VERSION_MAJOR >= 14 - size_t num_args = call->arg_size(); -#else - size_t num_args = call->getNumArgOperands(); -#endif if (F && F->getName().contains("__enzyme_float")) { assert(num_args == 1 || num_args == 2); assert(call->getArgOperand(0)->getType()->isPointerTy()); @@ -4356,9 +4401,16 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } } + if (call.hasFnAttr("enzyme_ta_norecur")) + return; + Function *ci = getFunctionFromCall(&call); if (ci) { + if (ci->getAttributes().hasAttribute(AttributeList::FunctionIndex, + "enzyme_ta_norecur")) + return; + StringRef funcName = getFuncNameFromCall(&call); auto blasMetaData = extractBLAS(funcName); @@ -4367,6 +4419,9 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { #include "BlasTA.inc" } + // Manual TT specification is non-interprocedural and already handled once + // at the start. + // When compiling Enzyme against standard LLVM, and not Intel's // modified version of LLVM, the intrinsic `llvm.intel.subscript` is // not fully understood by LLVM. One of the results of this is that the @@ -5596,7 +5651,7 @@ bool TypeAnalyzer::mustRemainInteger(Value *val, bool *returned) { FnTypeInfo TypeAnalyzer::getCallInfo(CallBase &call, Function &fn) { FnTypeInfo typeInfo(&fn); - int argnum = 0; + size_t argnum = 0; for (auto &arg : fn.args()) { if (argnum >= call.arg_size()) { typeInfo.Arguments.insert( diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index 2cb4a7506c94..bfcd9f488ab5 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -83,6 +83,92 @@ class TypeTree : public std::enable_shared_from_this { } } + static TypeTree parse(llvm::StringRef str, llvm::LLVMContext &ctx) { + using namespace llvm; + assert(str[0] == '{'); + str = str.substr(1); + + TypeTree Result; + while (true) { + while (str[0] == ' ') + str = str.substr(1); + if (str[0] == '}') + break; + + assert(str[0] == '['); + str = str.substr(1); + + std::vector idxs; + while (true) { + while (str[0] == ' ') + str = str.substr(1); + if (str[0] == ']') { + str = str.substr(1); + break; + } + + int idx; + bool failed = str.consumeInteger(10, idx); + (void)failed; + assert(!failed); + idxs.push_back(idx); + + while (str[0] == ' ') + str = str.substr(1); + + if (str[0] == ',') { + str = str.substr(1); + } + } + + while (str[0] == ' ') + str = str.substr(1); + + assert(str[0] == ':'); + str = str.substr(1); + + while (str[0] == ' ') + str = str.substr(1); + + auto endval = str.find(','); + auto endval2 = str.find('}'); + auto endval3 = str.find(' '); + + if (endval2 != StringRef::npos && + (endval == StringRef::npos || endval2 < endval)) + endval = endval2; + if (endval3 != StringRef::npos && + (endval == StringRef::npos || endval3 < endval)) + endval = endval3; + assert(endval != StringRef::npos); + + auto tystr = str.substr(0, endval); + str = str.substr(endval); + + ConcreteType CT(tystr, ctx); + Result.mapping.emplace(idxs, CT); + if (Result.minIndices.size() < idxs.size()) { + for (size_t i = Result.minIndices.size(), end = idxs.size(); i < end; + ++i) { + Result.minIndices.push_back(idxs[i]); + } + } + for (size_t i = 0, end = idxs.size(); i < end; ++i) { + if (idxs[i] < Result.minIndices[i]) + Result.minIndices[i] = idxs[i]; + } + + while (str[0] == ' ') + str = str.substr(1); + + if (str[0] == ',') { + str = str.substr(1); + } + } + + return Result; + } + /// Utility helper to lookup the mapping const ConcreteTypeMapType &getMapping() const { return mapping; } @@ -1228,7 +1314,6 @@ class TypeTree : public std::enable_shared_from_this { if (found != RHS.mapping.end()) { RightCT = found->second; } - bool SubLegal = true; changed |= CT.binopIn(SubLegal, RightCT, Op); if (!SubLegal) { From 638ac37a037d1b63c401f06ead22771945590854 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 21 Feb 2024 10:51:26 -0500 Subject: [PATCH 117/131] TypeAnalysis speed up Canonicalize (#1747) * TypeAnalysis speed up Canonicalize * replace old impl --- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 45 ++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index bfcd9f488ab5..22a4c25d9361 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -702,22 +702,22 @@ class TypeTree : public std::enable_shared_from_this { staging[next][pair.second].insert(pair.first[0]); } - mapping.clear(); + // TypeTree mappings which did not get combined + std::map, ConcreteType> unCombinedToAdd; - for (auto &pair : staging) { + // TypeTree mappings which did get combined into an outer -1 + std::map, ConcreteType> combinedToAdd; + + for (const auto &pair : staging) { auto &pnext = pair.first; - for (auto &pair2 : pair.second) { + for (const auto &pair2 : pair.second) { auto dt = pair2.first; const auto &set = pair2.second; - // llvm::errs() << " - set: {"; - // for(auto s : set) llvm::errs() << s << ", "; - // llvm::errs() << "} len=" << len << "\n"; - - bool legalCombine = set.count(-1); + bool legalCombine = false; // See if we can canonicalize the outermost index into a -1 - if (!legalCombine) { + if (!set.count(-1)) { size_t chunk = 1; if (pnext.size() > 0) { chunk = dl.getPointerSizeInBits() / 8; @@ -745,15 +745,38 @@ class TypeTree : public std::enable_shared_from_this { next.push_back(v); if (legalCombine) { - insert(next, dt, /*intsAreLegalPointerSub*/ true); + combinedToAdd.emplace(next, dt); } else { for (auto e : set) { next[0] = e; - insert(next, dt); + unCombinedToAdd.emplace(next, dt); } } } } + + // If we combined nothing, just return since there are no + // changes. + if (combinedToAdd.size() == 0) { + return; + } + + // Non-combined ones do not conflict, since they were already in + // a TT which we can assume contained no conflicts. + mapping = std::move(unCombinedToAdd); + minIndices[0] = -1; + + // Fusing several terms into a minus one can create a conflict + // if the prior minus one was already in the map + // time, or also generated by fusion. + // E.g. {-1:Anything, [0]:Pointer} on 8 -> create a [-1]:Pointer + // which conflicts + // Alternatively [-1,-1,-1]:Pointer, and generated a [-1,0,-1] fusion + for (const auto &pair : combinedToAdd) { + insert(pair.first, pair.second); + } + + return; } /// Keep only pointers (or anything's) to a repeated value (represented by -1) From 72342461d2e37cfdc9765084f4860f1d2b6da7d4 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 26 Feb 2024 20:35:07 +0100 Subject: [PATCH 118/131] Update enzyme-ci.yml (#1754) --- .github/workflows/enzyme-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 4a1ef073160d..fcb94bb604f3 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -61,7 +61,7 @@ jobs: strategy: fail-fast: false matrix: - llvm: ["11", "12", "13", "14", "15"] + llvm: ["12", "13", "14", "15", "16"] build: ["Release", "Debug"] # "RelWithDebInfo" timeout-minutes: 30 From 9884107d3df26d1e8d826c4dada22f7f1971cb0b Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 26 Feb 2024 20:35:24 +0100 Subject: [PATCH 119/131] Create dependabot.yml (#1755) --- .github/dependabot.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..edb6037854cc --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: +# Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" From 26c43370e4d56730d0b00e4e3b2705b28f7df375 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 26 Feb 2024 20:35:59 +0100 Subject: [PATCH 120/131] Export compile_commands.json to enable clangd support (#1753) --- .devcontainer/devcontainer.json | 6 ++++-- .gitignore | 1 + enzyme/CMakeLists.txt | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d201b73d532d..a9b75aefae89 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,5 @@ // available ubuntu versions: [20, 22] -// available llvm versions: [9, 10, 11, 12, 13, 14, 15] +// available llvm versions: [11, 12, 13, 14, 15, 16, 17, 18] { "name": "Enzyme", "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-22-llvm-16:latest", @@ -14,7 +14,9 @@ "customizations": { "vscode": { "extensions": [ - "ms-vscode.cpptools-extension-pack" + "llvm-vs-code-extensions.vscode-clangd", + "BazelBuild.vscode-bazel", + "twxs.cmake" ] } } diff --git a/.gitignore b/.gitignore index aad84254cf78..33a6c3f8cdd8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ enzyme/benchmarks/ReverseMode/*/*.o enzyme/benchmarks/ReverseMode/*/*.exe enzyme/benchmarks/ReverseMode/*/results.txt enzyme/benchmarks/ReverseMode/*/results.json +.cache diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 32d09c19e302..077f4f3a9554 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -13,6 +13,7 @@ add_definitions(-DENZYME_VERSION_MAJOR=${ENZYME_MAJOR_VERSION}) add_definitions(-DENZYME_VERSION_MINOR=${ENZYME_MINOR_VERSION}) add_definitions(-DENZYME_VERSION_PATCH=${ENZYME_PATCH_VERSION}) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull") SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") From 2dec42bf7df5015da2e78a1dd3f1de6e5b2db6ce Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 00:48:11 +0100 Subject: [PATCH 121/131] Bump actions/cache from 3 to 4 (#1757) Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/enzyme-mlir.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index b4c581a3a720..f36ed1671c46 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -46,7 +46,7 @@ jobs: - name: Cache MLIR id: cache-mlir - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: llvm-project/mlir-build key: ${{ matrix.llbuild }}-${{ matrix.os }}-mlir-${{ steps.mlir-commit.outputs.sha_short }} From d1b26fd50085f47f9b0a93071db5716de69a664c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:49:15 +0000 Subject: [PATCH 122/131] Bump mattnotmitt/doxygen-action from 1.9.2 to 1.9.8 (#1758) Bumps [mattnotmitt/doxygen-action](https://github.com/mattnotmitt/doxygen-action) from 1.9.2 to 1.9.8. - [Release notes](https://github.com/mattnotmitt/doxygen-action/releases) - [Commits](https://github.com/mattnotmitt/doxygen-action/compare/v1.9.2...v1.9.8) --- updated-dependencies: - dependency-name: mattnotmitt/doxygen-action dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/doxygen.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/doxygen.yml b/.github/workflows/doxygen.yml index 659ed659589a..8b6278c99a26 100644 --- a/.github/workflows/doxygen.yml +++ b/.github/workflows/doxygen.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: mattnotmitt/doxygen-action@v1.9.2 + - uses: mattnotmitt/doxygen-action@v1.9.8 with: working-directory: 'enzyme/' doxyfile-path: 'doxygen.cfg' From 53c69356cffa225b19eb6118304bf6de131b3be0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:56:21 +0000 Subject: [PATCH 123/131] Bump actions/checkout from 3 to 4 (#1759) Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/bcload.yml | 2 +- .github/workflows/benchmark.yml | 2 +- .github/workflows/ccpp.yml | 2 +- .github/workflows/doxygen.yml | 2 +- .github/workflows/enzyme-ci.yml | 6 +++--- .github/workflows/enzyme-julia.yml | 4 ++-- .github/workflows/enzyme-mlir.yml | 4 ++-- .github/workflows/format.yml | 2 +- .github/workflows/fortran.yml | 2 +- .github/workflows/tagger.yml | 4 ++-- 10 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/bcload.yml b/.github/workflows/bcload.yml index 33e409fd8bfe..9cbbdb7094ff 100644 --- a/.github/workflows/bcload.yml +++ b/.github/workflows/bcload.yml @@ -27,7 +27,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: cd enzyme && rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 46fa9edb1bed..d859468708b5 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -37,7 +37,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 46c64d5b369e..6efe084fdc07 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -34,7 +34,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/doxygen.yml b/.github/workflows/doxygen.yml index 8b6278c99a26..12e64dc5d8ce 100644 --- a/.github/workflows/doxygen.yml +++ b/.github/workflows/doxygen.yml @@ -9,7 +9,7 @@ jobs: docs: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: mattnotmitt/doxygen-action@v1.9.8 with: diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index fcb94bb604f3..5f8e9ea6e524 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -32,7 +32,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake @@ -71,7 +71,7 @@ jobs: brew update brew install llvm@${{ matrix.llvm }} make cmake sudo python3 -m pip install --upgrade pip lit requests - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake @@ -109,7 +109,7 @@ jobs: run: | brew install llvm@${{ matrix.llvm }} make cmake gcc sudo python3 -m pip install --upgrade pip lit - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/enzyme-julia.yml b/.github/workflows/enzyme-julia.yml index fcb6ff659a07..4aabf7204f5b 100644 --- a/.github/workflows/enzyme-julia.yml +++ b/.github/workflows/enzyme-julia.yml @@ -28,8 +28,8 @@ jobs: - x64 timeout-minutes: 60 steps: - - uses: actions/checkout@v3 - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - uses: actions/checkout@v4 with: repository: 'wsmoses/Enzyme.jl' path: ./jl diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index f36ed1671c46..aaced6f4e694 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -29,11 +29,11 @@ jobs: sudo apt-get update sudo apt-get install -y binutils ninja-build cmake gcc g++ python3 python3-dev - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: path: 'Enzyme' - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb' diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 55b2c99749be..c01757d0ff86 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: DoozyX/clang-format-lint-action@v0.16.2 with: source: 'enzyme/Enzyme enzyme/tools/enzyme-tblgen' diff --git a/.github/workflows/fortran.yml b/.github/workflows/fortran.yml index a31294c90b4f..e5e960345f63 100644 --- a/.github/workflows/fortran.yml +++ b/.github/workflows/fortran.yml @@ -40,7 +40,7 @@ jobs: sudo apt-get update && sudo apt-get install -y intel-oneapi-compiler-fortran-${{ matrix.ifx }} intel-oneapi-mpi-${{ matrix.mpi }} intel-oneapi-mpi-devel-${{ matrix.mpi }} source /opt/intel/oneapi/setvars.sh printenv >> $GITHUB_ENV - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: generate build system run: | rm -rf build && mkdir build && cd build diff --git a/.github/workflows/tagger.yml b/.github/workflows/tagger.yml index 9d712fb0cd6c..d40b3a1d943b 100644 --- a/.github/workflows/tagger.yml +++ b/.github/workflows/tagger.yml @@ -17,12 +17,12 @@ jobs: private_key: ${{ secrets.APP_PRIVATE_KEY }} repository: JuliaPackaging/Yggdrasil - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: 'JuliaPackaging/Yggdrasil' path: ygg - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: path: enz - name: replace From 1588442942fd2ed5971087034725d330b65cc1dd Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 27 Feb 2024 01:32:59 +0100 Subject: [PATCH 124/131] Update tagger.yml (#1762) --- .github/workflows/tagger.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tagger.yml b/.github/workflows/tagger.yml index d40b3a1d943b..e4a015c4145a 100644 --- a/.github/workflows/tagger.yml +++ b/.github/workflows/tagger.yml @@ -10,12 +10,12 @@ jobs: name: Enzyme Tag CI runs-on: ubuntu-latest steps: - - uses: tibdex/github-app-token@v1 + - uses: actions/create-github-app-token@v1 id: generate_token with: app_id: ${{ secrets.APP_ID }} private_key: ${{ secrets.APP_PRIVATE_KEY }} - repository: JuliaPackaging/Yggdrasil + repositories: JuliaPackaging/Yggdrasil - uses: actions/checkout@v4 with: From 3e2de5de6c03b9fe138d657ebfde42fa05af3a4e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 26 Feb 2024 21:01:36 -0500 Subject: [PATCH 125/131] C++ error message for incorrect custom gradient type (#1764) --- enzyme/Enzyme/AdjointGenerator.h | 4 ++-- enzyme/Enzyme/EnzymeLogic.cpp | 38 +++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 9f1d3cd65dfc..a3abf6991e95 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4821,7 +4821,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (call.getAttributes().hasParamAttr(i, attr)) { if (gutils->getWidth() == 1) { structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); - } else if (attr == "enzymejl_returnRoots") { + } else if (attr == std::string("enzymejl_returnRoots")) { structAttrs[args.size()].push_back( Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); } @@ -5155,7 +5155,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (gutils->getWidth() == 1) { structAttrs[pre_args.size()].push_back( call.getParamAttr(i, attr)); - } else if (attr == "enzymejl_returnRoots") { + } else if (attr == std::string("enzymejl_returnRoots")) { structAttrs[pre_args.size()].push_back( Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 76fc71921654..f930d4e1375b 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3937,14 +3937,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient( hasTape = false; // res.first.push_back(StructType::get(todiff->getContext(), {})); } else { - llvm::errs() << "expected args: ["; + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Bad function type of custom reverse pass for function " + << key.todiff->getName() << " of type " + << *key.todiff->getFunctionType() << "\n"; + ss << " expected gradient function to have argument types ["; + bool seen = false; for (auto a : res.first) { - llvm::errs() << *a << " "; + if (seen) + ss << ", "; + seen = true; + ss << *a; + } + ss << "]\n"; + ss << " Instead found " << foundcalled->getName() << " of type " + << *foundcalled->getFunctionType() << "\n"; + Value *toshow = key.todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *key.todiff << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(key.todiff), + wrap(context.ip)); + } else if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + } else { + assert(0 && "bad type for custom gradient"); + llvm_unreachable("bad type for custom gradient"); } - llvm::errs() << "]\n"; - llvm::errs() << *foundcalled << "\n"; - assert(0 && "bad type for custom gradient"); - llvm_unreachable("bad type for custom gradient"); } auto st = dyn_cast(foundcalled->getReturnType()); From 744ede0c8525271cb7e977f91ff601e794f9c3ec Mon Sep 17 00:00:00 2001 From: Matin Raayai <30674652+matinraayai@users.noreply.github.com> Date: Tue, 27 Feb 2024 00:56:50 -0500 Subject: [PATCH 126/131] Fixed incorrect usage of llvm::Function::splice. (#1751) * Fixed incorrect usage of llvm::Function::splice. * Put an #ifdef to take into account the LLVM version when using llvm::Function::splice. * fix format --------- Co-authored-by: William S. Moses --- enzyme/Enzyme/Enzyme.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 055b6f394842..47e9f2bba91f 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -2110,8 +2110,12 @@ class EnzymeBase { // Move the truncated body into the original function F.deleteBody(); +#if LLVM_VERSION_MAJOR >= 16 + F.splice(F.begin(), TruncatedFunc); +#else F.getBasicBlockList().splice(F.begin(), TruncatedFunc->getBasicBlockList()); +#endif RemapFunction(F, Mapping, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); TruncatedFunc->deleteBody(); From ca06ef36c94015695a13b786cdce89a269ec338f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 27 Feb 2024 01:16:36 -0500 Subject: [PATCH 127/131] Allow custom importing of files and syntactic sugar (#1752) * Allow custom importing of files and syntactic sugar * Fix build on older llvm vers * Update sugar.cpp * Update EnzymeClang.cpp * fix * fixup * dump * more printing * print * fixup --------- Co-authored-by: Ivan Radanov Ivanov --- enzyme/BUILD | 15 + enzyme/Enzyme/CMakeLists.txt | 6 + enzyme/Enzyme/Clang/EnzymeClang.cpp | 37 ++ enzyme/Enzyme/Clang/include_utils.td | 458 ++++++++++++++++++ enzyme/Enzyme/Enzyme.cpp | 14 +- enzyme/Enzyme/Utils.cpp | 333 ++++++++++--- enzyme/Enzyme/Utils.h | 2 + enzyme/test/Integration/ReverseMode/sugar.cpp | 91 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 29 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.h | 1 + 10 files changed, 916 insertions(+), 70 deletions(-) create mode 100644 enzyme/Enzyme/Clang/include_utils.td create mode 100644 enzyme/test/Integration/ReverseMode/sugar.cpp diff --git a/enzyme/BUILD b/enzyme/BUILD index e582b435d387..dc36cb4ad0c5 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -142,6 +142,20 @@ gentbl( ], ) +gentbl( + name = "include-utils", + tbl_outs = [( + "-gen-header-strings", + "IncludeUtils.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/Clang/include_utils.td", + td_srcs = ["Enzyme/Clang/include_utils.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeStatic", srcs = glob( @@ -167,6 +181,7 @@ cc_library( data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], deps = [ + "include-utils", ":binop-derivatives", ":blas-attributor", ":blas-derivatives", diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 1cd6e84c5be1..b27e4beb08cd 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -37,6 +37,10 @@ add_public_tablegen_target(BlasDeclarationsIncGen) add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) +set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) +enzyme_tablegen(IncludeUtils.inc -gen-header-strings) +add_public_tablegen_target(IncludeUtilsIncGen) + include_directories(${CMAKE_CURRENT_BINARY_DIR}) set(LLVM_LINK_COMPONENTS Demangle) @@ -74,6 +78,7 @@ if (${Clang_FOUND}) LLVM ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp @@ -107,6 +112,7 @@ if (${Clang_FOUND}) clang ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index a34a6429dcf7..0072c958b517 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -25,16 +25,20 @@ #include "clang/AST/Attr.h" #include "clang/AST/DeclGroup.h" #include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Basic/FileManager.h" #include "clang/Basic/MacroBuilder.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/FrontendAction.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Lex/HeaderSearch.h" #include "clang/Lex/PreprocessorOptions.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" #include "../Utils.h" +#include "IncludeUtils.inc" + using namespace clang; #if LLVM_VERSION_MAJOR >= 18 @@ -134,6 +138,39 @@ class EnzymePlugin final : public clang::ASTConsumer { Builder.defineMacro("ENZYME_VERSION_PATCH", std::to_string(ENZYME_VERSION_PATCH)); CI.getPreprocessor().setPredefines(Predefines.str()); + + auto baseFS = &CI.getFileManager().getVirtualFileSystem(); + llvm::vfs::OverlayFileSystem *fuseFS( + new llvm::vfs::OverlayFileSystem(baseFS)); + IntrusiveRefCntPtr fs( + new llvm::vfs::InMemoryFileSystem()); + + struct tm y2k = {}; + + y2k.tm_hour = 0; + y2k.tm_min = 0; + y2k.tm_sec = 0; + y2k.tm_year = 100; + y2k.tm_mon = 0; + y2k.tm_mday = 1; + time_t timer = mktime(&y2k); + for (const auto &pair : include_headers) { + fs->addFile(StringRef(pair[0]), timer, + llvm::MemoryBuffer::getMemBuffer( + StringRef(pair[1]), StringRef(pair[0]), + /*RequiresNullTerminator*/ true)); + } + + fuseFS->pushOverlay(fs); + fuseFS->pushOverlay(baseFS); + CI.getFileManager().setVirtualFileSystem(fuseFS); + + auto DE = CI.getFileManager().getDirectoryRef("/enzymeroot"); + assert(DE); + auto DL = DirectoryLookup(*DE, SrcMgr::C_User, + /*isFramework=*/false); + CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL, + /*isAngled=*/true); } ~EnzymePlugin() {} void HandleTranslationUnit(ASTContext &context) override {} diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td new file mode 100644 index 000000000000..1c99d219ce69 --- /dev/null +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -0,0 +1,458 @@ +class Headers { + string filename = filename_; + string contents = contents_; +} + +def : Headers<"/enzymeroot/enzyme/utils", [{ +#pragma once + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +extern int enzyme_const_return; +extern int enzyme_active_return; +extern int enzyme_dup_return; + +extern int enzyme_primal_return; +extern int enzyme_noret; + +template +Return __enzyme_autodiff(T...); + +template +Return __enzyme_fwddiff(T...); + +#include + +namespace enzyme { + + struct nodiff{}; + + template + struct ReverseMode { + + }; + using Reverse = ReverseMode; + using ReverseWithPrimal = ReverseMode; + + template < typename T > + struct Active{ + T value; + Active(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct Duplicated{ + T value; + T shadow; + Duplicated(T &&v, T&& s) : value(v), shadow(s) {} + }; + + template < typename T > + struct Const{ + T value; + Const(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct type_info { + static constexpr bool is_active = false; + using type = nodiff; + }; + + template < typename T > + struct type_info < Active >{ + static constexpr bool is_active = true; + using type = T; + }; + + template < typename ... T > + struct concatenated; + + template < typename ... S, typename T, typename ... rest > + struct concatenated < tuple < S ... >, T, rest ... > { + using type = typename concatenated< tuple< S ..., T>, rest ... >::type; + }; + + template < typename T > + struct concatenated < T > { + using type = T; + }; + + // Yikes! + // slightly cleaner in C++20, with std::remove_cvref + template < typename ... T > + struct autodiff_return; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type>; + }; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple< + typename type_info::type, + typename concatenated< tuple< >, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type + >; + }; + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Active & arg) { + return enzyme::tuple{enzyme_out, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Const & arg) { + return enzyme::tuple{enzyme_const, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Active & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Const & arg) { + return enzyme::tuple{arg.value}; + } + + namespace detail { + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(T &&t); + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple>{get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple{get<1>(t), get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence) { + return f(enzyme::get(impl::forward(t))...); + } + + template < typename T > + struct default_ret_activity { + using type = Const; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template < typename T > + struct ret_global; + + template + struct ret_global> { + static constexpr int* value = &enzyme_const_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_active_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_dup_return; + }; + + template + struct ret_used; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_primal_return; + }; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_noret; + }; + + } // namespace detail + + + + template < typename return_type, typename function, typename ... enz_arg_types > + __attribute__((always_inline)) + auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::primal_apply_impl(f, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename function, typename ... arg_types> + auto primal_call(function && f, arg_types && ... args) { + return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); + } + + template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types > + __attribute__((always_inline)) + auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::rev_apply_impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } + + template < typename DiffMode, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); + using RetActivity = typename detail::default_ret_activity::type; + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } +} +}]>; + +def : Headers<"/enzymeroot/enzyme/type_traits", [{ +#pragma once + +#include + +namespace enzyme { + +// this is already in C++20, but we reimplement it here for older C++ versions +template < typename T > +struct remove_cvref { + using type = + typename std::remove_reference< + typename std::remove_cv< + T + >::type + >::type; +}; + +template < typename T > +using remove_cvref_t = typename remove_cvref::type; + +namespace impl { + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>& __t) noexcept + { return static_cast<_Tp&&>(__t); } + + /** + * @brief Forward an rvalue. + * @return The parameter cast to the specified type. + * + * This function is used to implement "perfect forwarding". + */ + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>&& __t) noexcept + { + static_assert(!std::is_lvalue_reference<_Tp>::value, + "enzyme::impl::forward must not be used to convert an rvalue to an lvalue"); + return static_cast<_Tp&&>(__t); + } + +} + +} +}]>; + +def : Headers<"/enzymeroot/enzyme/tuple", [{ +#pragma once + +///////////// +// tuple.h // +///////////// + +// why reinvent the wheel and implement a tuple class? +// - ensure data is laid out in the same order the types are specified +// see: https://github.com/EnzymeAD/Enzyme/issues/1191#issuecomment-1556239213 +// - CUDA compatibility: std::tuple has some compatibility issues when used +// in a __device__ context (this may get better in c++20 with the improved +// constexpr support for std::tuple). Owning the implementation lets +// us add __host__ __device__ annotations to any part of it + +#include // for std::integer_sequence + +#include + +#define _NOEXCEPT noexcept +namespace enzyme { + +template +struct Index {}; + +template +struct value_at_position { + __attribute__((always_inline)) + T & operator[](Index) { return value; } + + __attribute__((always_inline)) + constexpr const T & operator[](Index) const { return value; } + T value; +}; + +template +struct tuple_base; + +template +struct tuple_base, T...> + : public value_at_position... { + using value_at_position::operator[]...; +}; + +template +struct tuple : public tuple_base, T...> {}; + +template +__attribute__((always_inline)) +tuple(T ...) -> tuple; + +template < int i, typename Tuple > +__attribute__((always_inline)) +decltype(auto) get(Tuple && tup) { + constexpr bool is_lvalue = std::is_lvalue_reference_v; + constexpr bool is_const = std::is_const_v>; + using T = remove_cvref_t< decltype(tup[Index{ } ]) >; + if constexpr ( is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr ( is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } +} + +template < int i, typename ... T> +__attribute__((always_inline)) +decltype(auto) get(const tuple< T ... > & tup) { + return tup[Index{} ]; +} + +template +struct tuple_size; + +template +struct tuple_size> : std::integral_constant {}; + +template +static constexpr size_t tuple_size_v = tuple_size::value; + +template +__attribute__((always_inline)) +constexpr auto forward_as_tuple(T&&... args) noexcept { + return tuple{impl::forward(args)...}; +} + +namespace impl { + +template +struct make_tuple_from_fwd_tuple; + +template +struct make_tuple_from_fwd_tuple> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd) { + return tuple{get(impl::forward(fwd))...}; + } +}; + +template +struct concat_with_fwd_tuple; + +template < typename Tuple > +using iseq = std::make_index_sequence > >; + +template +struct concat_with_fwd_tuple, std::index_sequence> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) { + return forward_as_tuple(get(impl::forward(fwd))..., get(impl::forward(t))...); + } +}; + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(Tuple&& ret) { + return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(impl::forward< Tuple >(ret)); +} + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(FWD_TUPLE&& fwd, first&& t, rest&&... ts) { + return tuple_cat(concat_with_fwd_tuple< iseq, iseq >::f(impl::forward(fwd), impl::forward(t)), impl::forward(ts)...); +} + +} // namespace impl + +template +__attribute__((always_inline)) +constexpr auto tuple_cat(Tuples&&... tuples) { + return impl::tuple_cat(impl::forward(tuples)...); +} + +} // namespace enzyme +#undef _NOEXCEPT +}]>; + +def : Headers<"/enzymeroot/enzyme/enzyme", [{ +#ifdef __cplusplus +#include "enzyme/utils" +#else +#warning "Enzyme wrapper templates only available in C++" +#endif +}]>; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 47e9f2bba91f..70f173f5734a 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -435,6 +435,9 @@ std::optional getMetadataName(llvm::Value *res) Optional getMetadataName(llvm::Value *res) #endif { + if (auto S = simplifyLoad(res)) + return getMetadataName(S); + if (auto av = dyn_cast(res)) { return cast(av->getMetadata())->getString(); } else if ((isa(res) || isa(res)) && @@ -463,12 +466,11 @@ Optional getMetadataName(llvm::Value *res) return gv->getName(); } else if (auto gv = dyn_cast(res)) { return gv->getName(); - } else { - if (isa(res)) { - return recursePhiReads(cast(res)); - } - return {}; + } else if (isa(res)) { + return recursePhiReads(cast(res)); } + + return {}; } static Value *adaptReturnedVector(Value *ret, Value *diffret, @@ -3197,6 +3199,7 @@ AnalysisKey EnzymeNewPM::Key; #include "PreserveNVVM.h" #include "TypeAnalysis/TypeAnalysisPrinter.h" #include "llvm/Passes/PassBuilder.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" #if LLVM_VERSION_MAJOR >= 15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/IPO/CalledValuePropagation.h" @@ -3427,6 +3430,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { #else prePass(MPM); #endif + MPM.addPass(llvm::AlwaysInlinerPass()); FunctionPassManager OptimizerPM; FunctionPassManager OptimizerPM2; #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index ff44cbaa715d..283460673922 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2264,7 +2264,258 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm_unreachable("unknown inst2"); } -Function *GetFunctionFromValue(Value *fn) { +// Find the base pointer of ptr and the offset in bytes from the start of +// the returned base pointer to this value. +AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) { + offset = 0; + while (true) { + if (auto CI = dyn_cast(ptr)) { + ptr = CI->getOperand(0); + continue; + } + if (auto CI = dyn_cast(ptr)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + return nullptr; + } + offset += Offset.getZExtValue(); + ptr = CI->getOperand(0); + continue; + } + if (isa(ptr)) { + break; + } + if (auto LI = dyn_cast(ptr)) { + if (auto S = simplifyLoad(LI)) { + ptr = S; + continue; + } + } + return nullptr; + } + return cast(ptr); +} + +// Find all user instructions of AI, returning tuples of Unlike a simple get users, this will recurse through any +// constant gep offsets and casts +SmallVector, 1> +findAllUsersOf(Value *AI) { + SmallVector, 1> todo; + todo.emplace_back(AI, 0); + + SmallVector, 1> users; + while (todo.size()) { + auto pair = todo.pop_back_val(); + Value *ptr = pair.first; + size_t suboff = pair.second; + + for (auto U : ptr->users()) { + if (auto CI = dyn_cast(U)) { + todo.emplace_back(CI, suboff); + continue; + } + if (auto CI = dyn_cast(U)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + users.emplace_back(cast(U), ptr, suboff); + continue; + } + todo.emplace_back(CI, suboff + Offset.getZExtValue()); + continue; + } + users.emplace_back(cast(U), ptr, suboff); + continue; + } + } + return users; +} + +// Given a pointer, find all values of size `valSz` which could be loaded from +// that pointer when indexed at offset. If it is impossible to guarantee that +// the set contains all such values, set legal to false +SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, + size_t valSz, bool &legal) { + SmallVector options; + + auto todo = findAllUsersOf(ptr0); + std::set> seen; + + while (todo.size()) { + auto pair = todo.pop_back_val(); + if (seen.count(pair)) + continue; + seen.insert(pair); + Instruction *U = std::get<0>(pair); + Value *ptr = std::get<1>(pair); + size_t suboff = std::get<2>(pair); + + // Read only users do not set the memory inside of ptr + if (isa(U)) { + continue; + } + if (auto MTI = dyn_cast(U)) + if (MTI->getOperand(0) != ptr) { + continue; + } + if (auto I = dyn_cast(U)) { + if (!I->mayWriteToMemory() && I->getType()->isVoidTy()) + continue; + } + + if (auto SI = dyn_cast(U)) { + auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout(); + + // We are storing into the ptr + if (SI->getPointerOperand() == ptr) { + auto storeSz = + (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) / + 8; + // If store is before the load would start + if (storeSz + suboff <= offset) + continue; + // if store starts after load would start + if (offset + valSz <= suboff) + continue; + + if (valSz == storeSz) { + options.push_back(SI->getValueOperand()); + continue; + } + } + + // We capture our pointer of interest, if it is stored into an alloca, + // all loads of said alloca would potentially store into. + if (SI->getValueOperand() == ptr) { + if (suboff == 0) { + size_t mid_offset = 0; + if (auto AI2 = + getBaseAndOffset(SI->getPointerOperand(), mid_offset)) { + bool sublegal = true; + auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8; + auto subPtrs = + getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal); + if (!sublegal) { + legal = false; + return options; + } + for (auto subPtr : subPtrs) { + for (const auto &pair3 : findAllUsersOf(subPtr)) { + todo.emplace_back(pair3); + } + } + continue; + } + } + } + } + + if (auto II = dyn_cast(U)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + } + + // If we copy into the ptr at a location that includes the offset, consider + // all sub uses + if (auto MTI = dyn_cast(U)) { + if (auto CI = dyn_cast(MTI->getLength())) { + if (MTI->getOperand(0) == ptr && suboff == 0 && + CI->getValue().uge(offset + valSz)) { + size_t midoffset = 0; + auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset); + if (!AI2) { + legal = false; + return options; + } + if (midoffset != 0) { + legal = false; + return options; + } + for (const auto &pair3 : findAllUsersOf(AI2)) { + todo.emplace_back(pair3); + } + continue; + } + } + } + + legal = false; + return options; + } + + return options; +} + +// Perform mem2reg/sroa to identify the innermost value being represented. +Value *simplifyLoad(Value *V, size_t valSz) { + if (auto LI = dyn_cast(V)) { + if (valSz == 0) { + auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8; + } + + Value *ptr = LI->getPointerOperand(); + size_t offset = 0; + + if (auto ptr2 = simplifyLoad(ptr)) { + ptr = ptr2; + } + auto AI = getBaseAndOffset(ptr, offset); + if (!AI) { + return nullptr; + } + + bool legal = true; + auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal); + + if (!legal) { + return nullptr; + } + std::set res; + for (auto opt : opts) { + Value *v2 = simplifyLoad(opt, valSz); + if (v2) + res.insert(v2); + else + res.insert(opt); + } + if (res.size() != 1) { + return nullptr; + } + Value *retval = *res.begin(); + return retval; + } + if (auto EVI = dyn_cast(V)) { + bool allZero = true; + for (auto idx : EVI->getIndices()) { + if (idx != 0) + allZero = false; + } + if (valSz == 0) { + auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8; + } + if (allZero) + if (auto LI = dyn_cast(EVI->getAggregateOperand())) { + return simplifyLoad(LI, valSz); + } + } + return nullptr; +} + +Value *GetFunctionValFromValue(Value *fn) { while (!isa(fn)) { if (auto ci = dyn_cast(fn)) { fn = ci->getOperand(0); @@ -2294,6 +2545,7 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + val = GetFunctionValFromValue(val); if (isa(val)) { fn = val; continue; @@ -2315,6 +2567,14 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + while (isa(val)) { + auto v2 = simplifyLoad(val); + if (v2) { + val = v2; + continue; + } + break; + } if (isa(val)) { fn = val; continue; @@ -2326,73 +2586,18 @@ Function *GetFunctionFromValue(Value *fn) { } } } - if (auto LI = dyn_cast(fn)) { - auto obj = getBaseObject(LI->getPointerOperand()); - if (isa(obj)) { - std::set> done; - SmallVector, 1> todo; - Value *stored = nullptr; - bool legal = true; - for (auto U : obj->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, obj)); - else { - legal = false; - break; - } - } - while (legal && todo.size()) { - auto tup = todo.pop_back_val(); - if (done.count(tup)) - continue; - done.insert(tup); - auto cur = tup.first; - auto prev = tup.second; - if (auto SI = dyn_cast(cur)) - if (SI->getPointerOperand() == prev) { - if (stored == SI->getValueOperand()) - continue; - else if (stored == nullptr) { - stored = SI->getValueOperand(); - continue; - } else { - legal = false; - break; - } - } - - if (isPointerArithmeticInst(cur, /*includephi*/ true)) { - for (auto U : cur->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, cur)); - else { - legal = false; - break; - } - } - continue; - } - - if (isa(cur)) - continue; - - if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) - continue; - - legal = false; - break; - } - - if (legal && stored) { - fn = stored; - continue; - } - } + if (auto S = simplifyLoad(fn)) { + fn = S; + continue; } break; } - return dyn_cast(fn); + return fn; +} + +Function *GetFunctionFromValue(Value *fn) { + return dyn_cast(GetFunctionValFromValue(fn)); } #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 37edfc1a985d..5a4b3c31cef7 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1248,6 +1248,8 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Function *GetFunctionFromValue(llvm::Value *fn); +llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0); + static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) { auto F = getFunctionFromCall(CI); auto funcName = getFuncNameFromCall(CI); diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp new file mode 100644 index 000000000000..8524342e1bfc --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -0,0 +1,91 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +#include "../test_utils.h" + +#include + +double foo(double x, double y) { return x * y; } + +double square(double x) { return x * x; } + +struct pair { + double x; + double y; +}; + +int main() { + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq2 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq3 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq3_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq4 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq4_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + printf("dmul %f %f\n", y1, y2); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + auto prim = enzyme::get<1>(y); + printf("dmul2 %f %f\n", y1, y2); + printf("dmul_prim %f\n", prim); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + APPROX_EQ(prim, 2.7*3.1, 1e-10); + } + + { + auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, 3.1, enzyme_out, 2.7); + printf("dmul2 %f %f\n", z1, z2); + APPROX_EQ(z1, 2.7, 1e-10); + APPROX_EQ(z2, 3.1, 1e-10); + } + +} diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 04f63e1ad3ac..143c85ea684e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -65,7 +65,9 @@ static cl::opt cl::values(clEnumValN(MLIRDerivatives, "gen-mlir-derivatives", "Generate MLIR derivative")), cl::values(clEnumValN(CallDerivatives, "gen-call-derivatives", - "Generate call derivative"))); + "Generate call derivative")), + cl::values(clEnumValN(GenHeaderVariables, "gen-header-strings", + "Generate header strings"))); void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, StringRef FT, StringRef cconv, Init *func, @@ -1248,6 +1250,24 @@ void printDiffUse( } } +static void emitHeaderIncludes(const RecordKeeper &recordKeeper, + raw_ostream &os) { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers"); + os << "const char* include_headers[][2] = {\n"; + bool seen = false; + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename = pattern->getValueAsString("filename"); + auto contents = pattern->getValueAsString("contents"); + os << "{\"" << filename << "\"\n,"; + os << "R\"(" << contents << ")\"\n"; + os << "}"; + seen = true; + } + os << "};\n"; +} + static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); @@ -1268,6 +1288,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case BinopDerivatives: patternNames = "BinopPattern"; break; + case GenHeaderVariables: case GenBlasDerivatives: case UpdateBlasDecl: case UpdateBlasTA: @@ -1299,6 +1320,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case MLIRDerivatives: { auto opName = pattern->getValueAsString("opName"); @@ -2089,6 +2111,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDiffUse"); case CallDerivatives: patternNames = "CallPattern"; @@ -2127,6 +2150,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case CallDerivatives: { os << " if (("; @@ -2283,6 +2307,9 @@ static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { case UpdateBlasTA: emitBlasTAUpdater(records, os); return false; + case GenHeaderVariables: + emitHeaderIncludes(records, os); + return false; default: errs() << "unknown tablegen action!\n"; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h index 368644ba0b5d..742a96d023ae 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h @@ -24,6 +24,7 @@ enum ActionType { UpdateBlasDecl, UpdateBlasTA, GenBlasDiffUse, + GenHeaderVariables, }; void emitDiffUse(const llvm::RecordKeeper &recordKeeper, llvm::raw_ostream &os, From f36e5a15909ef3fcede6372c41161d184e3b6162 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:45:59 +0000 Subject: [PATCH 128/131] Bump peter-evans/create-pull-request from 3 to 6 (#1760) Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 3 to 6. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v3...v6) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Tim Gymnich --- .github/workflows/tagger.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tagger.yml b/.github/workflows/tagger.yml index e4a015c4145a..46f8874e92e3 100644 --- a/.github/workflows/tagger.yml +++ b/.github/workflows/tagger.yml @@ -37,7 +37,7 @@ jobs: git add . - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v3 + uses: peter-evans/create-pull-request@v6 with: path: ygg commit-message: "Upgrade enzyme to ${{ github.ref }}" From c38423bacd84bdc1684bad9b9b4f100294e28364 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 27 Feb 2024 15:23:06 +0100 Subject: [PATCH 129/131] Update MLIR to 2c9b6c1b36b8185299de083c3058e0c1e7760442 (#1765) Remove most uses of `ConversionPatternRewriter` that were spurious anyway. Update one use to use `IRRewriter` instead since the conversion rewriter should no longer be used directly. Clean up includes accordingly. --- .github/workflows/enzyme-mlir.yml | 2 +- .../Implementations/LinalgAutoDiffOpInterfaceImpl.cpp | 4 +--- enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp | 9 --------- enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp | 3 --- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 11 ++--------- enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp | 10 ---------- 6 files changed, 4 insertions(+), 35 deletions(-) diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index aaced6f4e694..34a7828af856 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' - ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb' + ref: '2c9b6c1b36b8185299de083c3058e0c1e7760442' path: 'llvm-project' - name: Get MLIR commit hash diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 334d7325d7ec..8ecb851af90b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -79,9 +79,7 @@ struct GenericOpInterfaceReverse cast(gutils->getNewFromOriginal(linalgOp)); // Replace the op by a linalg.generic op if necessary - // TODO : IRRewriter rewriter(builder.getContext()/*, - // builder.getListener()*/); - ConversionPatternRewriter rewriter(builder.getContext()); + IRRewriter rewriter(builder.getContext(), builder.getListener()); auto failiureOrLinalgOp = generalizeNamedOp(rewriter, newOp); if (!failed(failiureOrLinalgOp)) { linalg::GenericOp replacement = failiureOrLinalgOp.value(); diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp index 448dd67b8dd3..5baf2f982d9d 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp @@ -11,18 +11,12 @@ // procedure to the MemRef dialect. //===----------------------------------------------------------------------===// -#include "Dialect/Dialect.h" #include "Dialect/Ops.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Rewrite/PatternApplicator.h" #include "llvm/Support/raw_ostream.h" #include "Interfaces/AutoDiffTypeInterface.h" @@ -51,9 +45,6 @@ SmallVector applyAffineMap(AffineMap aMap, SmallVector indices, struct AddToOpToIndexAndLoadPass : public enzyme::AddToOpToIndexAndLoadPassBase { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { auto loc = op->getLoc(); auto enzymeAdjoint = dyn_cast(op); diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp index de2ebeba376d..01bcd6683a96 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp @@ -105,9 +105,6 @@ void processGenericDuplication(Operation *op, OpBuilder &builder, Location loc, struct AddToOpToSplitPass : public enzyme::AddToOpToSplitPassBase { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { auto enzymeAdjoint = dyn_cast(op); auto loc = op->getLoc(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 1e2c55640280..e43db26b21a2 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -11,16 +11,13 @@ //===----------------------------------------------------------------------===// #include "Dialect/Ops.h" -#include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #define DEBUG_TYPE "enzyme" @@ -221,13 +218,9 @@ std::unique_ptr createDifferentiatePass() { } // namespace enzyme } // namespace mlir -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" - void DifferentiatePass::runOnOperation() { SymbolTableCollection symbolTable; symbolTable.getSymbolTable(getOperation()); - ConversionPatternRewriter B(getOperation()->getContext()); getOperation()->walk( [&](FunctionOpInterface op) { lowerEnzymeCalls(symbolTable, op); }); } diff --git a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp b/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp index b8a42f339280..d4e374d6c68f 100644 --- a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp +++ b/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp @@ -11,19 +11,12 @@ // procedure to the MemRef dialect. //===----------------------------------------------------------------------===// -#include "Dialect/Dialect.h" #include "Dialect/Ops.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Rewrite/PatternApplicator.h" - #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -34,9 +27,6 @@ struct ShadowedGradientToCachePass : public enzyme::ShadowedGradientToCachePassBase< ShadowedGradientToCachePass> { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { if (auto initOp = dyn_cast(op)) { if (auto type = From 21ead64a8879163f4a14c67a32e9f655c8eb99b7 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 27 Feb 2024 16:16:37 +0100 Subject: [PATCH 130/131] Fix clang-tidy findings (#1766) Includes a real memory problem: ValueRange is a non-owning container, assining the returned value of type SmallVector to ValueRange will lead to ValueRange starting with a dangling pointer. --- enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp | 4 ++-- enzyme/Enzyme/MLIR/Passes/Passes.h | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp index 5baf2f982d9d..648509f58e52 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp @@ -85,7 +85,7 @@ struct AddToOpToIndexAndLoadPass // auto load = cacheBuilder.create(loc, inputs[i], map[i], // indices); auto store = cacheBuilder.create(loc, load, // inputs[i], map[i], indices); - ValueRange mapAppliedIndices = + SmallVector mapAppliedIndices = applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc); auto load = cacheBuilder.create(loc, outs[i], mapAppliedIndices); @@ -96,7 +96,7 @@ struct AddToOpToIndexAndLoadPass } for (int i = 0; i < retargs.size(); i++) { - ValueRange mapAppliedIndices = + SmallVector mapAppliedIndices = applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc); auto load = cacheBuilder.create(loc, outs[i], mapAppliedIndices); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 25362d01294a..80c88373090c 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -65,12 +65,15 @@ class MemRefDialect; namespace func { class FuncDialect; -} +} // end namespace func +namespace affine { class AffineDialect; +} // end namespace affine + namespace LLVM { class LLVMDialect; -} +} // end namespace LLVM #define GEN_PASS_REGISTRATION #include "Passes/Passes.h.inc" From 8d75756c0792ffabf07b239440bc5c1fa48fb799 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 27 Feb 2024 16:53:36 -0500 Subject: [PATCH 131/131] Correct complex inv (#1767) --- enzyme/Enzyme/InstructionDerivatives.td | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 32b30bd8f9b4..6491fa032798 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -162,12 +162,20 @@ def CFNeg : SubRoutine<(Op (Op $re, $im):$z), (FNeg $re), (FNeg $im) )>; + +def Conj : SubRoutine<(Op (Op $re, $im):$z), + (ArrayRet + $re, + (FNeg $im) + )>; + def CFExp : SubRoutine<(Op (Op $re, $im):$z), (ArrayRet (FMul (FExp $re):$exp, (FCos $im)), (FMul $exp, (FSin $im)) )>; + // Same function as the one being called def SameFunc { } @@ -826,9 +834,10 @@ def : CallPattern<(Op (Op $x, $y):$z), def : CallPattern<(Op (Op $x, $y):$z), ["cmplx_inv"], [ - (CFDiv (CFNeg (DiffeRet)), (CFMul $z, $z)), + // Reverse mode needs to return the conjugate + (Conj (CFDiv (CFNeg (Conj (DiffeRet))), (CFMul $z, $z))), ], - (ForwardFromSummedReverse), + (CFDiv (CFNeg (Shadow $z)), (CFMul $z, $z)), [ReadNone, NoUnwind] >;