diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index a41adb0515e..e38bc208f08 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -30,6 +30,8 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include + #include #include #include @@ -39,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -107,9 +110,15 @@ static cl::opt FPOptSolverType("fpopt-solver-type", cl::init("dp"), static cl::opt FPOptComputationCostBudget( "fpopt-comp-cost-budget", cl::init(100000000000L), cl::Hidden, cl::desc("The maximum computation cost budget for the solver")); -static cl::opt FPOptMaxFPCCDepth( +static cl::opt FPOptMaxFPCCDepth( "fpopt-max-fpcc-depth", cl::init(10), cl::Hidden, cl::desc("The maximum depth of a floating-point connected component")); +static cl::opt + FPOptRandomSeed("fpopt-random-seed", cl::init(239778888), cl::Hidden, + cl::desc("The random seed used in the FPOpt pass")); +static cl::opt + FPOptNumSamples("fpopt-num-samples", cl::init(10), cl::Hidden, + cl::desc("Number of sampled points for input hypercube")); } class FPNode { @@ -393,18 +402,17 @@ class FPNode { return val; } - - virtual bool isExpression() const { return true; } }; // Represents a true LLVM Value class FPLLValue : public FPNode { - Value *value; double lb = std::numeric_limits::infinity(); double ub = -std::numeric_limits::infinity(); bool input = false; // Whether `llvm::Value` is an input of an FPCC public: + Value *value; + explicit FPLLValue(Value *value, const std::string &op, const std::string &dtype) : FPNode(NodeType::LLValue, op, dtype), value(value) {} @@ -553,6 +561,133 @@ class FPConst : public FPNode { } }; +// Compute the expression with MPFR at `prec` precision +// recursively. When operand is a FPConst, use its lower +// bound. When operand is a FPLLValue, get its inputs from +// `inputs`. +void goldValueHelper(FPNode *node, + const SmallMapVector &inputValues, + const unsigned prec, mpfr_t &res) { + mpfr_set_prec(res, prec); + + if (auto *constNode = dyn_cast(node)) { + double constVal = constNode->getLowerBound(); // TODO: Can be improved + mpfr_set_d(res, constVal, MPFR_RNDN); + } else if (auto *valueNode = dyn_cast(node)) { + double inputValue = inputValues.lookup(valueNode->value); + mpfr_set_d(res, inputValue, MPFR_RNDN); + } else { + if (node->op == "neg") { + mpfr_t operandResult; + mpfr_init(operandResult); + goldValueHelper(node->operands[0], inputValues, prec, operandResult); + mpfr_neg(res, operandResult, MPFR_RNDN); + mpfr_clear(operandResult); + } else if (node->op == "+") { + mpfr_t operandResults[2]; + for (int i = 0; i < 2; i++) { + mpfr_init(operandResults[i]); + goldValueHelper(node->operands[i], inputValues, prec, + operandResults[i]); + } + mpfr_add(res, operandResults[0], operandResults[1], MPFR_RNDN); + for (int i = 0; i < 2; i++) { + mpfr_clear(operandResults[i]); + } + } else if (node->op == "-") { + mpfr_t operandResults[2]; + for (int i = 0; i < 2; i++) { + mpfr_init(operandResults[i]); + goldValueHelper(node->operands[i], inputValues, prec, + operandResults[i]); + } + mpfr_sub(res, operandResults[0], operandResults[1], MPFR_RNDN); + for (int i = 0; i < 2; i++) { + mpfr_clear(operandResults[i]); + } + } else if (node->op == "*") { + mpfr_t operandResults[2]; + for (int i = 0; i < 2; i++) { + mpfr_init(operandResults[i]); + goldValueHelper(node->operands[i], inputValues, prec, + operandResults[i]); + } + mpfr_mul(res, operandResults[0], operandResults[1], MPFR_RNDN); + for (int i = 0; i < 2; i++) { + mpfr_clear(operandResults[i]); + } + } else if (node->op == "/") { + mpfr_t operandResults[2]; + for (int i = 0; i < 2; i++) { + mpfr_init(operandResults[i]); + goldValueHelper(node->operands[i], inputValues, prec, + operandResults[i]); + } + mpfr_div(res, operandResults[0], operandResults[1], MPFR_RNDN); + for (int i = 0; i < 2; i++) { + mpfr_clear(operandResults[i]); + } + } else if (node->op == "sin") { + mpfr_t operandResult; + mpfr_init(operandResult); + goldValueHelper(node->operands[0], inputValues, prec, operandResult); + mpfr_sin(res, operandResult, MPFR_RNDN); + } else if (node->op == "cos") { + mpfr_t operandResult; + mpfr_init(operandResult); + goldValueHelper(node->operands[0], inputValues, prec, operandResult); + mpfr_cos(res, operandResult, MPFR_RNDN); + } else { + std::string msg = "goldValueHelper: Unexpected operator " + node->op; + llvm_unreachable(msg.c_str()); + } + // TODO + } +} + +// If looking for ground truth, compute a "correct" answer with MPFR. +// For each sampled input configuration: +// 0. Ignore `FPNode.dtype`. +// 1. Compute the expression with MPFR at `prec` precision +// by calling `goldValueHelper`. When operand is a FPConst, use its +// lower bound. When operand is a FPLLValue, get its inputs from +// `inputs`. +// 2. Dynamically extend precisions +// until the first `groundTruthPrec` bits of significand don't change. +double getGoldValue(FPNode *output, + const SmallMapVector &inputValues, + const unsigned groundTruthPrec = 53) { + assert(output); + + unsigned currentPrec = 64; + + mpfr_t res, prevRes; + mpfr_init2(res, currentPrec); + mpfr_init2(prevRes, currentPrec); + mpfr_set_zero(prevRes, 1); + + bool timeout = false; + + while (!timeout) { + goldValueHelper(output, inputValues, currentPrec, res); + + // TODO: Not really eq. Need to check the first `groundTruthPrec` bits + if (mpfr_eq(res, prevRes, groundTruthPrec)) { + break; + } else { + mpfr_set(prevRes, res, MPFR_RNDN); + currentPrec += 16; + mpfr_set_prec(res, currentPrec); + mpfr_set_prec(prevRes, currentPrec); + } + + // TODO: Add a timeout mechanism + } + + mpfr_clears(res, prevRes, (mpfr_ptr)0); + return mpfr_get_d(res, MPFR_RNDN); +} + FPNode * parseHerbieExpr(const std::string &expr, std::unordered_map &valueToNodeMap, @@ -1210,8 +1345,9 @@ bool improveViaHerbie( input.close(); std::string Program = HERBIE_BINARY; - SmallVector Args = {Program, "report", "--seed", - "239778888", "--timeout", "60"}; + SmallVector Args = { + Program, "report", "--seed", std::to_string(FPOptRandomSeed), + "--timeout", "60"}; Args.push_back("--disable"); Args.push_back("generate:proofs"); // We can't show HTML reports @@ -1534,6 +1670,59 @@ std::string getPrecondition( return preconditions.empty() ? "TRUE" : "(and" + preconditions + ")"; } +bool getSampledPoints( + const SmallSet &args, + const std::unordered_map &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints) { + std::mt19937 gen(FPOptRandomSeed); + std::uniform_real_distribution<> dis; + + // Create a hypercube of input operands + SmallMapVector, 4> hypercube; + for (const auto &arg : args) { + const auto *node = valueToNodeMap.at(symbolToValueMap.at(arg)); + Value *val = symbolToValueMap.at(arg); + + double lower = node->getLowerBound(); + double upper = node->getUpperBound(); + + hypercube.insert({val, {lower, upper}}); + } + + llvm::errs() << "Hypercube:\n"; + for (const auto &entry : hypercube) { + Value *val = entry.first; + double lower = entry.second[0]; + double upper = entry.second[1]; + llvm::errs() << valueToNodeMap.at(val)->symbol << ": [" << lower << ", " + << upper << "]\n"; + } + + // Sample `FPOptNumSamples` points from the hypercube. Store it in + // `sampledPoints`. + sampledPoints.clear(); + sampledPoints.resize(FPOptNumSamples); + for (int i = 0; i < FPOptNumSamples; ++i) { + SmallMapVector point; + for (const auto &entry : hypercube) { + Value *val = entry.first; + double lower = entry.second[0]; + double upper = entry.second[1]; + double sample = dis(gen, decltype(dis)::param_type{lower, upper}); + point.insert({val, sample}); + } + sampledPoints[i] = point; + llvm::errs() << "Sample " << i << ":\n"; + for (const auto &entry : point) { + llvm::errs() << valueToNodeMap.at(entry.first)->symbol << ": " + << entry.second << "\n"; + } + } + + return true; +} + // Given the cost budget `FPOptComputationCostBudget`, we want to minimize the // accuracy cost of the rewritten expressions. bool accuracyGreedySolver( @@ -2118,6 +2307,9 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) { if (!FPOptLogPath.empty()) { std::string precondition = getPrecondition(args, valueToNodeMap, symbolToValueMap); + SmallVector, 4> sampledPoints; + getSampledPoints(args, valueToNodeMap, symbolToValueMap, + sampledPoints); properties += " :pre " + precondition; }