Skip to content

Commit

Permalink
WIP unified accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Sep 9, 2024
1 parent 64afdbb commit e5d2c75
Showing 1 changed file with 198 additions and 6 deletions.
204 changes: 198 additions & 6 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"

#include <mpfr.h>

#include <cerrno>
#include <cmath>
#include <cstring>
Expand All @@ -39,6 +41,7 @@
#include <limits>
#include <map>
#include <numeric>
#include <random>
#include <regex>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -107,9 +110,15 @@ static cl::opt<std::string> FPOptSolverType("fpopt-solver-type", cl::init("dp"),
static cl::opt<int64_t> FPOptComputationCostBudget(
"fpopt-comp-cost-budget", cl::init(100000000000L), cl::Hidden,
cl::desc("The maximum computation cost budget for the solver"));
static cl::opt<int> FPOptMaxFPCCDepth(
static cl::opt<unsigned> FPOptMaxFPCCDepth(
"fpopt-max-fpcc-depth", cl::init(10), cl::Hidden,
cl::desc("The maximum depth of a floating-point connected component"));
static cl::opt<unsigned>
FPOptRandomSeed("fpopt-random-seed", cl::init(239778888), cl::Hidden,
cl::desc("The random seed used in the FPOpt pass"));
static cl::opt<unsigned>
FPOptNumSamples("fpopt-num-samples", cl::init(10), cl::Hidden,
cl::desc("Number of sampled points for input hypercube"));
}

class FPNode {
Expand Down Expand Up @@ -393,18 +402,17 @@ class FPNode {

return val;
}

virtual bool isExpression() const { return true; }
};

// Represents a true LLVM Value
class FPLLValue : public FPNode {
Value *value;
double lb = std::numeric_limits<double>::infinity();
double ub = -std::numeric_limits<double>::infinity();
bool input = false; // Whether `llvm::Value` is an input of an FPCC

public:
Value *value;

explicit FPLLValue(Value *value, const std::string &op,
const std::string &dtype)
: FPNode(NodeType::LLValue, op, dtype), value(value) {}
Expand Down Expand Up @@ -553,6 +561,133 @@ class FPConst : public FPNode {
}
};

// Compute the expression with MPFR at `prec` precision
// recursively. When operand is a FPConst, use its lower
// bound. When operand is a FPLLValue, get its inputs from
// `inputs`.
void goldValueHelper(FPNode *node,
const SmallMapVector<Value *, double, 4> &inputValues,
const unsigned prec, mpfr_t &res) {
mpfr_set_prec(res, prec);

if (auto *constNode = dyn_cast<FPConst>(node)) {
double constVal = constNode->getLowerBound(); // TODO: Can be improved
mpfr_set_d(res, constVal, MPFR_RNDN);
} else if (auto *valueNode = dyn_cast<FPLLValue>(node)) {
double inputValue = inputValues.lookup(valueNode->value);
mpfr_set_d(res, inputValue, MPFR_RNDN);
} else {
if (node->op == "neg") {
mpfr_t operandResult;
mpfr_init(operandResult);
goldValueHelper(node->operands[0], inputValues, prec, operandResult);
mpfr_neg(res, operandResult, MPFR_RNDN);
mpfr_clear(operandResult);
} else if (node->op == "+") {
mpfr_t operandResults[2];
for (int i = 0; i < 2; i++) {
mpfr_init(operandResults[i]);
goldValueHelper(node->operands[i], inputValues, prec,
operandResults[i]);
}
mpfr_add(res, operandResults[0], operandResults[1], MPFR_RNDN);
for (int i = 0; i < 2; i++) {
mpfr_clear(operandResults[i]);
}
} else if (node->op == "-") {
mpfr_t operandResults[2];
for (int i = 0; i < 2; i++) {
mpfr_init(operandResults[i]);
goldValueHelper(node->operands[i], inputValues, prec,
operandResults[i]);
}
mpfr_sub(res, operandResults[0], operandResults[1], MPFR_RNDN);
for (int i = 0; i < 2; i++) {
mpfr_clear(operandResults[i]);
}
} else if (node->op == "*") {
mpfr_t operandResults[2];
for (int i = 0; i < 2; i++) {
mpfr_init(operandResults[i]);
goldValueHelper(node->operands[i], inputValues, prec,
operandResults[i]);
}
mpfr_mul(res, operandResults[0], operandResults[1], MPFR_RNDN);
for (int i = 0; i < 2; i++) {
mpfr_clear(operandResults[i]);
}
} else if (node->op == "/") {
mpfr_t operandResults[2];
for (int i = 0; i < 2; i++) {
mpfr_init(operandResults[i]);
goldValueHelper(node->operands[i], inputValues, prec,
operandResults[i]);
}
mpfr_div(res, operandResults[0], operandResults[1], MPFR_RNDN);
for (int i = 0; i < 2; i++) {
mpfr_clear(operandResults[i]);
}
} else if (node->op == "sin") {
mpfr_t operandResult;
mpfr_init(operandResult);
goldValueHelper(node->operands[0], inputValues, prec, operandResult);
mpfr_sin(res, operandResult, MPFR_RNDN);
} else if (node->op == "cos") {
mpfr_t operandResult;
mpfr_init(operandResult);
goldValueHelper(node->operands[0], inputValues, prec, operandResult);
mpfr_cos(res, operandResult, MPFR_RNDN);
} else {
std::string msg = "goldValueHelper: Unexpected operator " + node->op;
llvm_unreachable(msg.c_str());
}
// TODO
}
}

// If looking for ground truth, compute a "correct" answer with MPFR.
// For each sampled input configuration:
// 0. Ignore `FPNode.dtype`.
// 1. Compute the expression with MPFR at `prec` precision
// by calling `goldValueHelper`. When operand is a FPConst, use its
// lower bound. When operand is a FPLLValue, get its inputs from
// `inputs`.
// 2. Dynamically extend precisions
// until the first `groundTruthPrec` bits of significand don't change.
double getGoldValue(FPNode *output,
const SmallMapVector<Value *, double, 4> &inputValues,
const unsigned groundTruthPrec = 53) {
assert(output);

unsigned currentPrec = 64;

mpfr_t res, prevRes;
mpfr_init2(res, currentPrec);
mpfr_init2(prevRes, currentPrec);
mpfr_set_zero(prevRes, 1);

bool timeout = false;

while (!timeout) {
goldValueHelper(output, inputValues, currentPrec, res);

// TODO: Not really eq. Need to check the first `groundTruthPrec` bits
if (mpfr_eq(res, prevRes, groundTruthPrec)) {
break;
} else {
mpfr_set(prevRes, res, MPFR_RNDN);
currentPrec += 16;
mpfr_set_prec(res, currentPrec);
mpfr_set_prec(prevRes, currentPrec);
}

// TODO: Add a timeout mechanism
}

mpfr_clears(res, prevRes, (mpfr_ptr)0);
return mpfr_get_d(res, MPFR_RNDN);
}

FPNode *
parseHerbieExpr(const std::string &expr,
std::unordered_map<Value *, FPNode *> &valueToNodeMap,
Expand Down Expand Up @@ -1210,8 +1345,9 @@ bool improveViaHerbie(
input.close();

std::string Program = HERBIE_BINARY;
SmallVector<llvm::StringRef> Args = {Program, "report", "--seed",
"239778888", "--timeout", "60"};
SmallVector<llvm::StringRef> Args = {
Program, "report", "--seed", std::to_string(FPOptRandomSeed),
"--timeout", "60"};

Args.push_back("--disable");
Args.push_back("generate:proofs"); // We can't show HTML reports
Expand Down Expand Up @@ -1534,6 +1670,59 @@ std::string getPrecondition(
return preconditions.empty() ? "TRUE" : "(and" + preconditions + ")";
}

bool getSampledPoints(
const SmallSet<std::string, 8> &args,
const std::unordered_map<Value *, FPNode *> &valueToNodeMap,
const std::unordered_map<std::string, Value *> &symbolToValueMap,
SmallVector<SmallMapVector<Value *, double, 4>, 4> &sampledPoints) {
std::mt19937 gen(FPOptRandomSeed);
std::uniform_real_distribution<> dis;

// Create a hypercube of input operands
SmallMapVector<Value *, SmallVector<double, 2>, 4> hypercube;
for (const auto &arg : args) {
const auto *node = valueToNodeMap.at(symbolToValueMap.at(arg));
Value *val = symbolToValueMap.at(arg);

double lower = node->getLowerBound();
double upper = node->getUpperBound();

hypercube.insert({val, {lower, upper}});
}

llvm::errs() << "Hypercube:\n";
for (const auto &entry : hypercube) {
Value *val = entry.first;
double lower = entry.second[0];
double upper = entry.second[1];
llvm::errs() << valueToNodeMap.at(val)->symbol << ": [" << lower << ", "
<< upper << "]\n";
}

// Sample `FPOptNumSamples` points from the hypercube. Store it in
// `sampledPoints`.
sampledPoints.clear();
sampledPoints.resize(FPOptNumSamples);
for (int i = 0; i < FPOptNumSamples; ++i) {
SmallMapVector<Value *, double, 4> point;
for (const auto &entry : hypercube) {
Value *val = entry.first;
double lower = entry.second[0];
double upper = entry.second[1];
double sample = dis(gen, decltype(dis)::param_type{lower, upper});
point.insert({val, sample});
}
sampledPoints[i] = point;
llvm::errs() << "Sample " << i << ":\n";
for (const auto &entry : point) {
llvm::errs() << valueToNodeMap.at(entry.first)->symbol << ": "
<< entry.second << "\n";
}
}

return true;
}

// Given the cost budget `FPOptComputationCostBudget`, we want to minimize the
// accuracy cost of the rewritten expressions.
bool accuracyGreedySolver(
Expand Down Expand Up @@ -2118,6 +2307,9 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) {
if (!FPOptLogPath.empty()) {
std::string precondition =
getPrecondition(args, valueToNodeMap, symbolToValueMap);
SmallVector<SmallMapVector<Value *, double, 4>, 4> sampledPoints;
getSampledPoints(args, valueToNodeMap, symbolToValueMap,
sampledPoints);
properties += " :pre " + precondition;
}

Expand Down

0 comments on commit e5d2c75

Please sign in to comment.