Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup gutils to avoid null reference #1739

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1542,7 +1542,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
lc) &&
gutils->getNewFromOriginal(P0->getParent()) == lc.header) {
SmallVector<BasicBlock *, 1> Latches;
gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches);
gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches);
bool allIncoming = true;
for (auto Latch : Latches) {
if (&SI != P0->getIncomingValueForBlock(Latch)) {
Expand Down Expand Up @@ -2206,7 +2206,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
lc) &&
gutils->getNewFromOriginal(P0->getParent()) == lc.header) {
SmallVector<BasicBlock *, 1> Latches;
gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches);
gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches);
bool allIncoming = true;
for (auto Latch : Latches) {
if (&BO != P0->getIncomingValueForBlock(Latch)) {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3323,7 +3323,7 @@ bool AdjointGenerator::handleKnownCallDerivatives(
// rematerialization is loop level. This is because one can have a
// loop level cache, but a function level allocation (e.g. for stack
// allocas). If we deleted it here, we would have no allocation!
auto AllocationLoop = gutils->OrigLI.getLoopFor(call.getParent());
auto AllocationLoop = gutils->OrigLI->getLoopFor(call.getParent());
// An allocation within a loop, must definitionally be a loop level
// allocation (but not always the other way around.
if (AllocationLoop)
Expand Down
26 changes: 13 additions & 13 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ void calculateUnusedValuesInFunction(
if (newMemory) {
bool foundStore = false;
allInstructionsBetween(
gutils->OrigLI, cast<Instruction>(at),
*gutils->OrigLI, cast<Instruction>(at),
const_cast<MemTransferInst *>(mti),
[&](Instruction *I) -> bool {
if (!I->mayWriteToMemory())
Expand All @@ -994,7 +994,7 @@ void calculateUnusedValuesInFunction(
}

if (writesToMemoryReadBy(
gutils->OrigAA, TLI,
*gutils->OrigAA, TLI,
/*maybeReader*/ const_cast<MemTransferInst *>(mti),
/*maybeWriter*/ I)) {
foundStore = true;
Expand Down Expand Up @@ -1143,7 +1143,7 @@ void calculateUnusedStoresInFunction(
if (newMemory) {
bool foundStore = false;
allInstructionsBetween(
gutils->OrigLI, cast<Instruction>(at),
*gutils->OrigLI, cast<Instruction>(at),
const_cast<MemTransferInst *>(mti), [&](Instruction *I) -> bool {
if (!I->mayWriteToMemory())
return /*earlyBreak*/ false;
Expand All @@ -1152,7 +1152,7 @@ void calculateUnusedStoresInFunction(

// if (I == &MTI) return;
if (writesToMemoryReadBy(
gutils->OrigAA, TLI,
*gutils->OrigAA, TLI,
/*maybeReader*/ const_cast<MemTransferInst *>(mti),
/*maybeWriter*/ I)) {
foundStore = true;
Expand Down Expand Up @@ -1552,7 +1552,7 @@ bool legalCombinedForwardReverse(
auto consider = [&](Instruction *user) {
if (!user->mayReadFromMemory())
return false;
if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI,
if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI,
/*maybeReader*/ user,
/*maybeWriter*/ inst)) {

Expand Down Expand Up @@ -1585,7 +1585,7 @@ bool legalCombinedForwardReverse(
if (!post->mayWriteToMemory())
return false;

if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI,
if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI,
/*maybeReader*/ inst,
/*maybeWriter*/ post)) {
if (EnzymePrintPerf) {
Expand Down Expand Up @@ -2398,9 +2398,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(

CacheAnalysis CA(gutils->allocationsWithGuaranteedFree,
gutils->rematerializableAllocations, gutils->TR,
gutils->OrigAA, gutils->oldFunc,
*gutils->OrigAA, gutils->oldFunc,
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable,
*gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable,
_overwritten_argsPP, DerivativeMode::ReverseModePrimal, omp);
const std::map<CallInst *, const std::vector<bool>> overwritten_args_map =
CA.compute_overwritten_args_for_callsites();
Expand Down Expand Up @@ -3346,7 +3346,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils,
gutils->getNewFromOriginal(orig->getParent()) == loopContext.header &&
loopContext.exitBlocks.size() == 1) {
SmallVector<BasicBlock *, 1> Latches;
gutils->OrigLI.getLoopFor(orig->getParent())->getLoopLatches(Latches);
gutils->OrigLI->getLoopFor(orig->getParent())->getLoopLatches(Latches);
bool allIncoming = true;
for (auto Latch : Latches) {
if (activeUses[0] != orig->getIncomingValueForBlock(Latch)) {
Expand Down Expand Up @@ -4080,9 +4080,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
gutils->computeGuaranteedFrees();
CacheAnalysis CA(gutils->allocationsWithGuaranteedFree,
gutils->rematerializableAllocations, gutils->TR,
gutils->OrigAA, gutils->oldFunc,
*gutils->OrigAA, gutils->oldFunc,
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable,
*gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable,
_overwritten_argsPP, key.mode, omp);
const std::map<CallInst *, const std::vector<bool>> overwritten_args_map =
(augmenteddata) ? augmenteddata->overwritten_args_map
Expand Down Expand Up @@ -4734,10 +4734,10 @@ Function *EnzymeLogic::CreateForwardDiff(
gutils->computeGuaranteedFrees();
CacheAnalysis CA(
gutils->allocationsWithGuaranteedFree,
gutils->rematerializableAllocations, gutils->TR, gutils->OrigAA,
gutils->rematerializableAllocations, gutils->TR, *gutils->OrigAA,
gutils->oldFunc,
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable,
*gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable,
_overwritten_argsPP, mode, omp);
const std::map<CallInst *, const std::vector<bool>> overwritten_args_map =
CA.compute_overwritten_args_for_callsites();
Expand Down
Loading
Loading