Skip to content

Commit

Permalink
Enable module in undef for type (#2015)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jul 27, 2024
1 parent 4c80601 commit f62c342
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 15 deletions.
9 changes: 5 additions & 4 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4719,6 +4719,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
BuilderZ.setFastMathFlags(getFast());

CallInst *newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
Module &M = *call.getParent()->getParent()->getParent();

bool foreignFunction = called == nullptr;

Expand Down Expand Up @@ -4816,7 +4817,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
(writeOnlyNoCapture && readOnly);

if (replace) {
argi = getUndefinedValueForType(argi->getType());
argi = getUndefinedValueForType(M, argi->getType());
}
argsInverted.push_back(argTy);
args.push_back(argi);
Expand Down Expand Up @@ -5127,7 +5128,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
(argTy == DIFFE_TYPE::DUP_NONEED &&
(writeOnlyNoCapture ||
!isa<Argument>(getBaseObject(call.getArgOperand(i)))))) {
prearg = getUndefinedValueForType(argi->getType());
prearg = getUndefinedValueForType(M, argi->getType());
preType = ValueType::None;
}
pre_args.push_back(prearg);
Expand All @@ -5145,7 +5146,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
(argTy == DIFFE_TYPE::DUP_NONEED &&
(writeOnlyNoCapture ||
!isa<Argument>(getBaseObject(call.getOperand(i))))))) {
argi = getUndefinedValueForType(argi->getType());
argi = getUndefinedValueForType(M, argi->getType());
revType = ValueType::None;
}
args.push_back(lookup(argi, Builder2));
Expand Down Expand Up @@ -5214,7 +5215,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {

if (writeOnlyNoCapture && !replaceFunction &&
TR.query(call.getArgOperand(i))[{-1, -1}] == BaseType::Pointer) {
darg = getUndefinedValueForType(argi->getType());
darg = getUndefinedValueForType(M, argi->getType());
} else {
darg = gutils->invertPointerM(call.getArgOperand(i), Builder2);
revType = (revType == ValueType::None) ? ValueType::Shadow
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1648,8 +1648,9 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
FT->isVarArg());

// Create the new function
auto &M = *F->getParent();
Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
F->getName(), F->getParent());
F->getName(), &M);

ValueToValueMapTy VMap;
// Loop over the arguments, copying the names of the mapped arguments over...
Expand Down Expand Up @@ -1948,7 +1949,7 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
}
if (outinds.size() > 1)
out = B.CreateInBoundsGEP(sretTy, out, outinds);
B.CreateStore(getUndefinedValueForType(PT), out);
B.CreateStore(getUndefinedValueForType(M, PT), out);
}
return;
}
Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
alloc->setAlignment(Align(align));
}
if (sublimits.size() == 0) {
auto val = getUndefinedValueForType(types.back());
auto val = getUndefinedValueForType(*newFunc->getParent(), types.back());
if (!isa<UndefValue>(val))
scopeInstructions[alloc].push_back(entryBuilder.CreateStore(val, alloc));
}
Expand Down Expand Up @@ -949,7 +949,9 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
// TODO change this to a power-of-two allocation strategy

auto zerostore = allocationBuilder.CreateStore(
getUndefinedValueForType(allocType, /*forceZero*/ true), storeInto);
getUndefinedValueForType(*newFunc->getParent(), allocType,
/*forceZero*/ true),
storeInto);
scopeInstructions[alloc].push_back(zerostore);

IRBuilder<> build(containedloops.back().first.incvar->getNextNode());
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3009,8 +3009,8 @@ bool AdjointGenerator::handleKnownCallDerivatives(
if (!forwardsShadow) {
if (Mode == DerivativeMode::ReverseModePrimal) {
// Needs a stronger replacement check/assertion.
Value *replacement =
getUndefinedValueForType(placeholder->getType());
Value *replacement = getUndefinedValueForType(
*gutils->oldFunc->getParent(), placeholder->getType());
gutils->replaceAWithB(placeholder, replacement);
gutils->invertedPointers.erase(found);
gutils->invertedPointers.insert(std::make_pair(
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8236,7 +8236,8 @@ bool GradientUtils::isOriginalBlock(const BasicBlock &BB) const {
void GradientUtils::eraseFictiousPHIs() {
{
for (auto P : rematerializedPrimalOrShadowAllocations) {
Value *replacement = getUndefinedValueForType(P->getType());
Value *replacement =
getUndefinedValueForType(*oldFunc->getParent(), P->getType());
P->replaceAllUsesWith(replacement);
erase(P);
}
Expand Down
8 changes: 5 additions & 3 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef,
LLVMValueRef *(*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef,
uint64_t *size) = nullptr;
LLVMTypeRef (*EnzymeDefaultTapeType)(LLVMContextRef) = nullptr;
LLVMValueRef (*EnzymeUndefinedValueForType)(LLVMTypeRef, uint8_t) = nullptr;
LLVMValueRef (*EnzymeUndefinedValueForType)(LLVMModuleRef, LLVMTypeRef,
uint8_t) = nullptr;

LLVMValueRef (*EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset,
LLVMBuilderRef,
Expand Down Expand Up @@ -2928,10 +2929,11 @@ llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
return {};
}

llvm::Constant *getUndefinedValueForType(llvm::Type *T, bool forceZero) {
llvm::Constant *getUndefinedValueForType(llvm::Module &M, llvm::Type *T,
bool forceZero) {
if (EnzymeUndefinedValueForType)
return cast<Constant>(
unwrap(EnzymeUndefinedValueForType(wrap(T), forceZero)));
unwrap(EnzymeUndefinedValueForType(wrap(&M), wrap(T), forceZero)));
else if (EnzymeZeroCache || forceZero)
return Constant::getNullValue(T);
else
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,8 @@ static inline bool isNoEscapingAllocation(const llvm::CallBase *call) {

bool attributeKnownFunctions(llvm::Function &F);

llvm::Constant *getUndefinedValueForType(llvm::Type *T, bool forceZero = false);
llvm::Constant *getUndefinedValueForType(llvm::Module &M, llvm::Type *T,
bool forceZero = false);

llvm::Value *SanitizeDerivatives(llvm::Value *val, llvm::Value *toset,
llvm::IRBuilder<> &BuilderM,
Expand Down

0 comments on commit f62c342

Please sign in to comment.