From f62c3427f80f05a135b218f083b5204a094564b2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 21:01:35 -0400 Subject: [PATCH] Enable module in undef for type (#2015) --- enzyme/Enzyme/AdjointGenerator.h | 9 +++++---- enzyme/Enzyme/CApi.cpp | 5 +++-- enzyme/Enzyme/CacheUtility.cpp | 6 ++++-- enzyme/Enzyme/CallDerivatives.cpp | 4 ++-- enzyme/Enzyme/GradientUtils.cpp | 3 ++- enzyme/Enzyme/Utils.cpp | 8 +++++--- enzyme/Enzyme/Utils.h | 3 ++- 7 files changed, 23 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index f66639edce70..19a832dd8e8e 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4719,6 +4719,7 @@ class AdjointGenerator : public llvm::InstVisitor { BuilderZ.setFastMathFlags(getFast()); CallInst *newCall = cast(gutils->getNewFromOriginal(&call)); + Module &M = *call.getParent()->getParent()->getParent(); bool foreignFunction = called == nullptr; @@ -4816,7 +4817,7 @@ class AdjointGenerator : public llvm::InstVisitor { (writeOnlyNoCapture && readOnly); if (replace) { - argi = getUndefinedValueForType(argi->getType()); + argi = getUndefinedValueForType(M, argi->getType()); } argsInverted.push_back(argTy); args.push_back(argi); @@ -5127,7 +5128,7 @@ class AdjointGenerator : public llvm::InstVisitor { (argTy == DIFFE_TYPE::DUP_NONEED && (writeOnlyNoCapture || !isa(getBaseObject(call.getArgOperand(i)))))) { - prearg = getUndefinedValueForType(argi->getType()); + prearg = getUndefinedValueForType(M, argi->getType()); preType = ValueType::None; } pre_args.push_back(prearg); @@ -5145,7 +5146,7 @@ class AdjointGenerator : public llvm::InstVisitor { (argTy == DIFFE_TYPE::DUP_NONEED && (writeOnlyNoCapture || !isa(getBaseObject(call.getOperand(i))))))) { - argi = getUndefinedValueForType(argi->getType()); + argi = getUndefinedValueForType(M, argi->getType()); revType = ValueType::None; } args.push_back(lookup(argi, Builder2)); @@ -5214,7 +5215,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (writeOnlyNoCapture && !replaceFunction && TR.query(call.getArgOperand(i))[{-1, -1}] == BaseType::Pointer) { - darg = getUndefinedValueForType(argi->getType()); + darg = getUndefinedValueForType(M, argi->getType()); } else { darg = gutils->invertPointerM(call.getArgOperand(i), Builder2); revType = (revType == ValueType::None) ? ValueType::Shadow diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index fa7001bb70eb..87dec8ca6954 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -1648,8 +1648,9 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { FT->isVarArg()); // Create the new function + auto &M = *F->getParent(); Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), - F->getName(), F->getParent()); + F->getName(), &M); ValueToValueMapTy VMap; // Loop over the arguments, copying the names of the mapped arguments over... @@ -1948,7 +1949,7 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { } if (outinds.size() > 1) out = B.CreateInBoundsGEP(sretTy, out, outinds); - B.CreateStore(getUndefinedValueForType(PT), out); + B.CreateStore(getUndefinedValueForType(M, PT), out); } return; } diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index 1630d8c03938..84181c1be8a2 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -840,7 +840,7 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T, alloc->setAlignment(Align(align)); } if (sublimits.size() == 0) { - auto val = getUndefinedValueForType(types.back()); + auto val = getUndefinedValueForType(*newFunc->getParent(), types.back()); if (!isa(val)) scopeInstructions[alloc].push_back(entryBuilder.CreateStore(val, alloc)); } @@ -949,7 +949,9 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T, // TODO change this to a power-of-two allocation strategy auto zerostore = allocationBuilder.CreateStore( - getUndefinedValueForType(allocType, /*forceZero*/ true), storeInto); + getUndefinedValueForType(*newFunc->getParent(), allocType, + /*forceZero*/ true), + storeInto); scopeInstructions[alloc].push_back(zerostore); IRBuilder<> build(containedloops.back().first.incvar->getNextNode()); diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index fc74e10a7a18..b8154840988b 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3009,8 +3009,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (!forwardsShadow) { if (Mode == DerivativeMode::ReverseModePrimal) { // Needs a stronger replacement check/assertion. - Value *replacement = - getUndefinedValueForType(placeholder->getType()); + Value *replacement = getUndefinedValueForType( + *gutils->oldFunc->getParent(), placeholder->getType()); gutils->replaceAWithB(placeholder, replacement); gutils->invertedPointers.erase(found); gutils->invertedPointers.insert(std::make_pair( diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 2eafa51c79b3..bc7523760272 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -8236,7 +8236,8 @@ bool GradientUtils::isOriginalBlock(const BasicBlock &BB) const { void GradientUtils::eraseFictiousPHIs() { { for (auto P : rematerializedPrimalOrShadowAllocations) { - Value *replacement = getUndefinedValueForType(P->getType()); + Value *replacement = + getUndefinedValueForType(*oldFunc->getParent(), P->getType()); P->replaceAllUsesWith(replacement); erase(P); } diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 46b533a51d5b..f151ae6fc4ba 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -67,7 +67,8 @@ void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef, LLVMValueRef *(*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef, uint64_t *size) = nullptr; LLVMTypeRef (*EnzymeDefaultTapeType)(LLVMContextRef) = nullptr; -LLVMValueRef (*EnzymeUndefinedValueForType)(LLVMTypeRef, uint8_t) = nullptr; +LLVMValueRef (*EnzymeUndefinedValueForType)(LLVMModuleRef, LLVMTypeRef, + uint8_t) = nullptr; LLVMValueRef (*EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset, LLVMBuilderRef, @@ -2928,10 +2929,11 @@ llvm::Optional extractBLAS(llvm::StringRef in) return {}; } -llvm::Constant *getUndefinedValueForType(llvm::Type *T, bool forceZero) { +llvm::Constant *getUndefinedValueForType(llvm::Module &M, llvm::Type *T, + bool forceZero) { if (EnzymeUndefinedValueForType) return cast( - unwrap(EnzymeUndefinedValueForType(wrap(T), forceZero))); + unwrap(EnzymeUndefinedValueForType(wrap(&M), wrap(T), forceZero))); else if (EnzymeZeroCache || forceZero) return Constant::getNullValue(T); else diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 1cfe729e0329..2e48a4c7b0f1 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1695,7 +1695,8 @@ static inline bool isNoEscapingAllocation(const llvm::CallBase *call) { bool attributeKnownFunctions(llvm::Function &F); -llvm::Constant *getUndefinedValueForType(llvm::Type *T, bool forceZero = false); +llvm::Constant *getUndefinedValueForType(llvm::Module &M, llvm::Type *T, + bool forceZero = false); llvm::Value *SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM,