Skip to content

Commit

Permalink
Improve errors on llvm main (#1969)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jul 7, 2024
1 parent f80c238 commit cd39401
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 7 deletions.
44 changes: 42 additions & 2 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4399,8 +4399,14 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
if (subdata && subdata->returns.find(AugmentedStruct::Tape) !=
subdata->returns.end()) {
if (Mode == DerivativeMode::ReverseModeGradient) {
if (tape == nullptr)
if (tape == nullptr) {
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
tape = BuilderZ.CreatePHI(subdata->tapeType, 0, "tapeArg");
}
tape = gutils->cacheForReverse(
BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ));
}
Expand Down Expand Up @@ -4949,7 +4955,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {

auto idx = *tapeIdx;
FunctionType *FT = subdata->fn->getFunctionType();

#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
tape = BuilderZ.CreatePHI(
(tapeIdx == -1)
? FT->getReturnType()
Expand Down Expand Up @@ -5599,6 +5609,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
if (!tape) {
assert(tapeIdx);
auto tval = *tapeIdx;
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
tape = BuilderZ.CreatePHI(
(tapeIdx == -1) ? FT->getReturnType()
: cast<StructType>(FT->getReturnType())
Expand All @@ -5614,12 +5629,22 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
if (DifferentialUseAnalysis::is_value_needed_in_reverse<
QueryType::Primal>(gutils, &call, Mode, oldUnreachable) &&
!gutils->unnecessaryIntermediates.count(&call)) {
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
cachereplace = BuilderZ.CreatePHI(call.getType(), 1,
call.getName() + "_tmpcacheB");
cachereplace = gutils->cacheForReverse(
BuilderZ, cachereplace,
getIndex(&call, CacheType::Self, BuilderZ));
} else {
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
auto pn = BuilderZ.CreatePHI(
call.getType(), 1, (call.getName() + "_replacementE").str());
gutils->fictiousPHIs[pn] = &call;
Expand Down Expand Up @@ -5719,12 +5744,22 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
QueryType::Primal>(gutils, &call, Mode, oldUnreachable) &&
!gutils->unnecessaryIntermediates.count(&call)) {
assert(!replaceFunction);
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
cachereplace = BuilderZ.CreatePHI(call.getType(), 1,
call.getName() + "_cachereplace2");
cachereplace = gutils->cacheForReverse(
BuilderZ, cachereplace,
getIndex(&call, CacheType::Self, BuilderZ));
} else {
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
auto pn = BuilderZ.CreatePHI(call.getType(), 1,
call.getName() + "_replacementC");
gutils->fictiousPHIs[pn] = &call;
Expand Down Expand Up @@ -6205,6 +6240,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
// EnzymeLogic.
tapeType = (llvm::Type *)fd->second;

#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
tape = BuilderZ.CreatePHI(tapeType, 0);
tape = gutils->cacheForReverse(
BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ),
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,6 @@ Value *CacheUtility::getCachePointer(llvm::Type *T, bool inForwardPass,
Value *extraSize) {
assert(ctx.Block);
assert(cache);

auto sublimits = getSubLimits(inForwardPass, &BuilderM, ctx, extraSize);

Value *next = cache;
Expand Down
10 changes: 8 additions & 2 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,14 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
old = BuilderM.CreateLoad(getShadowType(val->getType()), ptr);
}
if (dif->getType() != old->getType()) {
llvm::errs() << " val: " << *val << " dif: " << *dif << " old: " << *old
<< "\n";
if (auto inst = dyn_cast<Instruction>(val)) {
EmitFailure("IllegalAddingType", inst->getDebugLoc(), inst, "val ", *val,
" dif ", *dif, " old ", *old);
return addedSelects;
}
llvm::errs() << " IllegalAddingType val: " << *val << " dif: " << *dif
<< " old: " << *old << "\n";
llvm_unreachable("IllegalAddingType");
}

assert(dif->getType() == old->getType());
Expand Down
25 changes: 23 additions & 2 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8349,6 +8349,11 @@ void GradientUtils::forceAugmentedReturns() {
if (!isConstantValue(inst)) {
IRBuilder<> BuilderZ(inst);
getForwardBuilder(BuilderZ);
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
Type *antiTy = getShadowType(inst->getType());
PHINode *anti =
BuilderZ.CreatePHI(antiTy, 1, inst->getName() + "'dual_phi");
Expand All @@ -8368,6 +8373,11 @@ void GradientUtils::forceAugmentedReturns() {
if (isa<LoadInst>(inst)) {
IRBuilder<> BuilderZ(inst);
getForwardBuilder(BuilderZ);
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
Type *antiTy = getShadowType(inst->getType());
PHINode *anti =
BuilderZ.CreatePHI(antiTy, 1, inst->getName() + "'il_phi");
Expand Down Expand Up @@ -8406,6 +8416,11 @@ void GradientUtils::forceAugmentedReturns() {

IRBuilder<> BuilderZ(inst);
getForwardBuilder(BuilderZ);
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif

// Shadow allocations must strictly preceede the primal, lest Julia have
// GC issues. Consider the following: %r = gc_alloc() init %r
Expand All @@ -8420,8 +8435,14 @@ void GradientUtils::forceAugmentedReturns() {
// inside %dr would hit garbage and segfault. However, by having the %dr
// first, then it will be zero'd before the %r allocation, preventing the
// issue.
if (isAllocationCall(inst, TLI))
if (isAllocationCall(inst, TLI)) {
BuilderZ.SetInsertPoint(getNewFromOriginal(inst));
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
}
Type *antiTy = getShadowType(inst->getType());

PHINode *anti = BuilderZ.CreatePHI(antiTy, 1, op->getName() + "'ip_phi");
Expand Down Expand Up @@ -9051,7 +9072,7 @@ void GradientUtils::eraseWithPlaceholder(Instruction *I, Instruction *orig,
if (!inspos.getHeadBit()) {
auto srcmarker = I->getParent()->getMarker(inspos);
if (srcmarker && !srcmarker->empty()) {
inspos--;
inspos.setHeadBit(true);
}
}
}
Expand Down

0 comments on commit cd39401

Please sign in to comment.