Skip to content

Commit

Permalink
custom cost model parsing & disable FP16 for now
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Sep 29, 2024
1 parent 498ce49 commit d97de6c
Showing 1 changed file with 97 additions and 28 deletions.
125 changes: 97 additions & 28 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ void changePrecision(Instruction *I, PrecisionChange &change,
Value *newI = nullptr;

if (isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) {
llvm::errs() << "PT Changing: " << *I << " to " << *newType << "\n";
// llvm::errs() << "PT Changing: " << *I << " to " << *newType << "\n";
SmallVector<Value *, 2> newOps;
for (auto &operand : I->operands()) {
Value *newOp = nullptr;
Expand Down Expand Up @@ -741,7 +741,7 @@ void changePrecision(Instruction *I, PrecisionChange &change,
}

oldToNew[I] = newI;
llvm::errs() << "PT Changing: " << *I << " to " << *newI << "\n";
// llvm::errs() << "PT Changing: " << *I << " to " << *newI << "\n";
}

struct PTCandidate {
Expand Down Expand Up @@ -1492,7 +1492,7 @@ getOperandValueProperties(const Value *V) {
}

InstructionCost getInstructionCompCost(const Instruction *I,
const TargetTransformInfo &TTI) {
const TargetTransformInfo &TTI) {
if (!FPOptCostModelPath.empty()) {
static std::map<std::pair<std::string, std::string>, InstructionCost>
CostModel;
Expand Down Expand Up @@ -1524,6 +1524,8 @@ InstructionCost getInstructionCompCost(const Instruction *I,
std::string msg = "Unexpected line in custom cost model: " + Line;
llvm_unreachable(msg.c_str());
}
// llvm::errs() << "Cost model: " << OpcodeStr << ", " << PrecisionStr
// << ", " << CostStr << "\n";

CostModel[{OpcodeStr, PrecisionStr}] = std::stoi(CostStr);
}
Expand Down Expand Up @@ -1559,9 +1561,75 @@ InstructionCost getInstructionCompCost(const Instruction *I,
break;
case Instruction::PHI:
return 0;
case Instruction::Call:
// TODO: complete
case Instruction::Call: {
auto *Call = cast<CallInst>(I);
if (auto CalledFunc = Call->getCalledFunction()) {
if (CalledFunc->isIntrinsic()) {
switch (CalledFunc->getIntrinsicID()) {
case Intrinsic::sin:
OpcodeName = "sin";
break;
case Intrinsic::cos:
OpcodeName = "cos";
break;
case Intrinsic::exp:
OpcodeName = "exp";
break;
case Intrinsic::log:
OpcodeName = "log";
break;
case Intrinsic::sqrt:
OpcodeName = "sqrt";
break;
case Intrinsic::fabs:
OpcodeName = "fabs";
break;
case Intrinsic::fmuladd:
OpcodeName = "fmuladd";
break;
default: {
std::string msg = "Custom cost model: unsupported intrinsic " +
CalledFunc->getName().str();
llvm_unreachable(msg.c_str());
}
}
} else {
std::string FuncName = CalledFunc->getName().str();
if (FuncName == "sin") {
OpcodeName = "sin";
} else if (FuncName == "cos") {
OpcodeName = "cos";
} else if (FuncName == "tan") {
OpcodeName = "tan";
} else if (FuncName == "exp") {
OpcodeName = "exp";
} else if (FuncName == "log") {
OpcodeName = "log";
} else if (FuncName == "sqrt") {
OpcodeName = "sqrt";
} else if (FuncName == "expm1") {
OpcodeName = "expm1";
} else if (FuncName == "log1p") {
OpcodeName = "log1p";
} else if (FuncName == "cbrt") {
OpcodeName = "cbrt";
} else if (FuncName == "pow") {
OpcodeName = "pow";
} else if (FuncName == "fabs") {
OpcodeName = "fabs";
} else if (FuncName == "hypot") {
OpcodeName = "hypot";
} else {
std::string msg =
"Custom cost model: unknown function call " + FuncName;
llvm_unreachable(msg.c_str());
}
}
} else {
llvm_unreachable("Custom cost model: unknown function call");
}
break;
}
default:
std::string msg = "Custom cost model: unexpected opcode " +
std::string(I->getOpcodeName());
Expand All @@ -1570,6 +1638,10 @@ InstructionCost getInstructionCompCost(const Instruction *I,

std::string PrecisionName;
Type *Ty = I->getType();
if (I->getOpcode() == Instruction::FPExt ||
I->getOpcode() == Instruction::FPTrunc) {
Ty = I->getOperand(0)->getType();
}
if (Ty->isDoubleTy()) {
PrecisionName = "double";
} else if (Ty->isFloatTy()) {
Expand All @@ -1589,6 +1661,7 @@ InstructionCost getInstructionCompCost(const Instruction *I,

std::string msg = "Custom cost model: entry not found for " + OpcodeName +
" @ " + PrecisionName;
llvm::errs() << "Unexpected Intruction: " << *I << "\n";
llvm_unreachable(msg.c_str());
}

Expand Down Expand Up @@ -1742,7 +1815,6 @@ InstructionCost getCompCost(const FPCC &component,
ValueToValueMapTy VMap;
Function *FClone = CloneFunction(F, VMap);
FClone->setName(F->getName() + "_clone");
FClone->print(llvm::errs());

pt.apply(component, &VMap);
// output values in VMap are changed to the new casted values
Expand Down Expand Up @@ -1771,7 +1843,7 @@ InstructionCost getCompCost(const FPCC &component,

if (auto *I = dyn_cast<Instruction>(cur)) {
auto instCost = getInstructionCompCost(I, TTI);
llvm::errs() << "Cost of " << *I << " is: " << instCost << "\n";
// llvm::errs() << "Cost of " << *I << " is: " << instCost << "\n";

cost += instCost;

Expand Down Expand Up @@ -2111,22 +2183,23 @@ void setUnifiedAccuracyCost(
// TODO: Consider geometric average???
assert(valueToNodeMap.count(AO.oldOutput));

llvm::errs() << "Computing real output for candidate: " << expr << "\n";
// llvm::errs() << "Computing real output for candidate: " << expr <<
// "\n";

llvm::errs() << "Current input values:\n";
for (const auto &entry : pair.value()) {
llvm::errs() << valueToNodeMap[entry.first]->symbol << ": "
<< entry.second << "\n";
}
// llvm::errs() << "Current input values:\n";
// for (const auto &entry : pair.value()) {
// llvm::errs() << valueToNodeMap[entry.first]->symbol << ": "
// << entry.second << "\n";
// }

llvm::errs() << "Gold value: " << goldVals[pair.index()] << "\n";
// llvm::errs() << "Gold value: " << goldVals[pair.index()] << "\n";

ArrayRef<FPNode *> outputs = {parsedNode.get()};
SmallVector<double, 1> results;
getMPFRValues(outputs, pair.value(), results, false);
double realVal = results[0];

llvm::errs() << "Real value: " << realVal << "\n";
// llvm::errs() << "Real value: " << realVal << "\n";
ac += std::fabs((goldVals[pair.index()] - realVal) * AO.grad);
}
candidate.accuracyCost = ac;
Expand Down Expand Up @@ -2176,16 +2249,12 @@ void setUnifiedAccuracyCost(
getMPFRValues(outputs, pair.value(), results, false, 0, &candidate);

for (const auto &[output, result] : zip(outputs, results)) {
// llvm::errs() << "DEBUG gold value: " <<
// goldVals[output][pair.index()]
// << "\n";
// llvm::errs() << "DEBUG real value: " << result << "\n";
ac +=
std::fabs((goldVals[output][pair.index()] - result) * output->grad);
}
}
candidate.accuracyCost = ac;
llvm::errs() << "Accuracy cost for PT candidate: " << ac << "\n";
// llvm::errs() << "Accuracy cost for PT candidate: " << ac << "\n";
}
}

Expand Down Expand Up @@ -3190,10 +3259,10 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) {

// Sort the operations by the gradient
llvm::sort(operations, [](const auto &a, const auto &b) {
llvm::errs() << "Gradient of " << *(a->value) << " is " << a->grad
<< "\n";
llvm::errs() << "Gradient of " << *(b->value) << " is " << b->grad
<< "\n";
// llvm::errs() << "Gradient of " << *(a->value) << " is " << a->grad
// << "\n";
// llvm::errs() << "Gradient of " << *(b->value) << " is " << b->grad
// << "\n";
assert(!std::isnan(a->grad) && "Gradient is NaN for an operation");
assert(!std::isnan(b->grad) && "Gradient is NaN for an operation");
return std::fabs(a->grad) < std::fabs(b->grad);
Expand All @@ -3206,17 +3275,17 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) {
SetVector<FPLLValue *> opsToChange(operations.begin(),
operations.begin() + numToChange);

if (!opsToChange.empty()) {
if (EnzymePrintFPOpt && !opsToChange.empty()) {
llvm::errs() << "Created PrecisionChange for " << percent
<< "% of operations (" << numToChange << ")\n";
llvm::errs() << "Subset gradient range: ["
<< std::fabs(opsToChange.front()->grad) << ", "
<< std::fabs(opsToChange.back()->grad) << "]\n";
}

SmallVector<PrecisionChangeType> precTypes{PrecisionChangeType::FP16,
PrecisionChangeType::FP32,
PrecisionChangeType::FP64};
SmallVector<PrecisionChangeType> precTypes{
/*PrecisionChangeType::FP16,*/
PrecisionChangeType::FP32, PrecisionChangeType::FP64};

for (auto prec : precTypes) {
StringRef precStr = getPrecisionChangeTypeString(prec);
Expand Down

0 comments on commit d97de6c

Please sign in to comment.