Skip to content

Commit

Permalink
generalized dp solver
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Sep 30, 2024
1 parent 0621807 commit 834d7c3
Showing 1 changed file with 109 additions and 11 deletions.
120 changes: 109 additions & 11 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <sstream>
#include <string>
#include <utility>
#include <variant>

#include "Herbie.h"
#include "Utils.h"
Expand Down Expand Up @@ -2131,6 +2132,8 @@ class ApplicableFPCC {
// topological order with respect to operand dependencies. Insert FP casts
// between llvm::Value inputs and first level of instructions to be changed.
// Restore precisions of the last level of instructions to be changed.
llvm::errs() << "Applying PT candidate #" << candidateIndex << ": "
<< candidates[candidateIndex].desc << "\n";
candidates[candidateIndex].apply(component);
}

Expand Down Expand Up @@ -2650,6 +2653,17 @@ bool accuracyGreedySolver(
return changed;
}

struct SolutionStep {
std::variant<ApplicableOutput *, ApplicableFPCC *> item;
size_t candidateIndex;

SolutionStep(ApplicableOutput *ao_, size_t idx)
: item(ao_), candidateIndex(idx) {}

SolutionStep(ApplicableFPCC *acc_, size_t idx)
: item(acc_), candidateIndex(idx) {}
};

bool accuracyDPSolver(
SmallVector<ApplicableOutput, 4> &AOs, SmallVector<ApplicableFPCC, 4> &ACCs,
std::unordered_map<Value *, std::shared_ptr<FPNode>> &valueToNodeMap,
Expand All @@ -2659,9 +2673,7 @@ bool accuracyDPSolver(
<< FPOptComputationCostBudget << "\n";

using CostMap = std::map<InstructionCost, double>;
using SolutionMap =
std::map<InstructionCost,
SmallVector<std::pair<ApplicableOutput *, size_t>>>;
using SolutionMap = std::map<InstructionCost, SmallVector<SolutionStep>>;

CostMap costToAccuracyMap;
costToAccuracyMap[0] = 0;
Expand All @@ -2672,8 +2684,6 @@ bool accuracyDPSolver(
CostMap newCostToAccuracyMap;
SolutionMap newCostToSolutionMap;

llvm::errs() << "Processing AO: " << AO.expr << "\n";

for (const auto &pair : costToAccuracyMap) {
InstructionCost currCompCost = pair.first;
double currAccCost = pair.second;
Expand All @@ -2694,7 +2704,7 @@ bool accuracyDPSolver(
InstructionCost newCompCost = currCompCost + candCompCost;
double newAccCost = currAccCost + candAccCost;

llvm::errs() << "Candidate " << i
llvm::errs() << "AO candidate " << i
<< " has accuracy cost: " << candAccCost
<< " and computation cost: " << candCompCost << "\n";

Expand All @@ -2704,7 +2714,7 @@ bool accuracyDPSolver(
newCostToAccuracyMap[newCompCost] = newAccCost;
newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost];
newCostToSolutionMap[newCompCost].emplace_back(&AO, i);
llvm::errs() << "Updating accuracy map (candidate " << i
llvm::errs() << "Updating accuracy map (AO candidate " << i
<< "): computation cost " << newCompCost
<< " -> accuracy cost " << newAccCost << "\n";
}
Expand All @@ -2724,7 +2734,7 @@ bool accuracyDPSolver(
double otherAccCost = r.second;

if (currCompCost > otherCompCost && currAccCost >= otherAccCost) {
llvm::errs() << "Candidate with computation cost: " << currCompCost
llvm::errs() << "AO candidate with computation cost: " << currCompCost
<< " and accuracy cost: " << currAccCost
<< " is dominated by candidate with computation cost: "
<< otherCompCost
Expand All @@ -2745,6 +2755,81 @@ bool accuracyDPSolver(
costToSolutionMap.swap(prunedCostToSolutionMap);
}

for (auto &ACC : ACCs) {
CostMap newCostToAccuracyMap;
SolutionMap newCostToSolutionMap;

for (const auto &pair : costToAccuracyMap) {
InstructionCost currCompCost = pair.first;
double currAccCost = pair.second;

// It is possible to apply zero candidate for an ACC
if (newCostToAccuracyMap.find(currCompCost) ==
newCostToAccuracyMap.end() ||
newCostToAccuracyMap[currCompCost] > currAccCost) {
newCostToAccuracyMap[currCompCost] = currAccCost;
newCostToSolutionMap[currCompCost] = costToSolutionMap[currCompCost];
}

for (auto &candidate : enumerate(ACC.candidates)) {
size_t i = candidate.index();
auto candCompCost = ACC.getCompCostDelta(i);
auto candAccCost = ACC.getAccCostDelta(i);

InstructionCost newCompCost = currCompCost + candCompCost;
double newAccCost = currAccCost + candAccCost;

llvm::errs() << "ACC candidate " << i << " (" << candidate.value().desc
<< ") has accuracy cost: " << candAccCost
<< " and computation cost: " << candCompCost << "\n";

if (newCostToAccuracyMap.find(newCompCost) ==
newCostToAccuracyMap.end() ||
newCostToAccuracyMap[newCompCost] > newAccCost) {
newCostToAccuracyMap[newCompCost] = newAccCost;
newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost];
newCostToSolutionMap[newCompCost].emplace_back(&ACC, i);
llvm::errs() << "Updating accuracy map (ACC candidate " << i
<< "): computation cost " << newCompCost
<< " -> accuracy cost " << newAccCost << "\n";
}
}
}

CostMap prunedCostToAccuracyMap;
SolutionMap prunedCostToSolutionMap;

for (const auto &l : newCostToAccuracyMap) {
InstructionCost currCompCost = l.first;
double currAccCost = l.second;

bool dominated = false;
for (const auto &r : newCostToAccuracyMap) {
InstructionCost otherCompCost = r.first;
double otherAccCost = r.second;

if (currCompCost > otherCompCost && currAccCost >= otherAccCost) {
llvm::errs() << "ACC candidate with computation cost: "
<< currCompCost << " and accuracy cost: " << currAccCost
<< " is dominated by candidate with computation cost: "
<< otherCompCost
<< " and accuracy cost: " << otherAccCost << "\n";
dominated = true;
break;
}
}

if (!dominated) {
prunedCostToAccuracyMap[currCompCost] = currAccCost;
prunedCostToSolutionMap[currCompCost] =
newCostToSolutionMap[currCompCost];
}
}

costToAccuracyMap.swap(prunedCostToAccuracyMap);
costToSolutionMap.swap(prunedCostToSolutionMap);
}

llvm::errs() << "DP Table: \n";
for (const auto &pair : costToAccuracyMap) {
llvm::errs() << "Computation cost: " << pair.first
Expand Down Expand Up @@ -2780,12 +2865,25 @@ bool accuracyDPSolver(
assert(costToSolutionMap.find(bestCompCost) != costToSolutionMap.end() &&
"FPOpt DP solver: expected a solution!");

llvm::errs() << "\n!!! DP solver: Applying solution ... !!!\n";
for (const auto &solution : costToSolutionMap[bestCompCost]) {
auto *AO = solution.first;
size_t i = solution.second;
AO->apply(i, valueToNodeMap, symbolToValueMap);
std::visit(
[&](auto *item) {
using T = std::decay_t<decltype(*item)>;
if constexpr (std::is_same_v<T, ApplicableOutput>) {
item->apply(solution.candidateIndex, valueToNodeMap,
symbolToValueMap);
} else if constexpr (std::is_same_v<T, ApplicableFPCC>) {
item->apply(solution.candidateIndex);
} else {
llvm_unreachable(
"accuracyDPSolver: Unexpected type of solution step");
}
},
solution.item);
changed = true;
}
llvm::errs() << "!!! DP Solver: Solution applied !!!\n\n";

return changed;
}
Expand Down

0 comments on commit 834d7c3

Please sign in to comment.