From 3cf4007fea1aa61f0e7c42f1f2234b92cb9f43cd Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Sun, 18 Feb 2024 04:17:10 +0900 Subject: [PATCH 01/27] WIP MPFR truncation --- enzyme/Enzyme/Enzyme.cpp | 4 +- enzyme/Enzyme/EnzymeLogic.cpp | 195 +++++++++++++--------------------- enzyme/Enzyme/EnzymeLogic.h | 63 ++++++++++- 3 files changed, 137 insertions(+), 125 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 70f173f5734a..12fd64d91f1d 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -23,6 +23,7 @@ // the function passed as the first argument. // //===----------------------------------------------------------------------===// +#include "llvm/Support/ErrorHandling.h" #include #if LLVM_VERSION_MAJOR >= 16 @@ -2058,8 +2059,7 @@ class EnzymeBase { StringRef ConfigStr(EnzymeTruncateAll); auto Invalid = [=]() { // TODO emit better diagnostic - llvm::errs() << "error: invalid format for truncation config\n"; - abort(); + llvm::report_fatal_error("error: invalid format for truncation config") }; // "64" or "11-52" diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index f930d4e1375b..0c6c8f1e3903 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -33,6 +33,8 @@ #include "EnzymeLogic.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/Support/ErrorHandling.h" +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -5009,10 +5011,8 @@ static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, class TruncateGenerator : public llvm::InstVisitor { private: ValueToValueMapTy &originalToNewFn; - FloatRepresentation from; - FloatRepresentation to; + FloatTruncation truncation; Type *fromType; - Type *toType; Function *oldFunc; Function *newFunc; AllocaInst *tmpBlock; @@ -5021,20 +5021,29 @@ class TruncateGenerator : public llvm::InstVisitor { public: TruncateGenerator(ValueToValueMapTy &originalToNewFn, - FloatRepresentation from, FloatRepresentation to, - Function *oldFunc, Function *newFunc, TruncateMode mode, - EnzymeLogic &Logic) - : originalToNewFn(originalToNewFn), from(from), to(to), oldFunc(oldFunc), - newFunc(newFunc), mode(mode), Logic(Logic) { + FloatTruncation truncation, Function *oldFunc, + Function *newFunc, TruncateMode mode, EnzymeLogic &Logic) + : originalToNewFn(originalToNewFn), truncation(truncation), + oldFunc(oldFunc), newFunc(newFunc), mode(mode), Logic(Logic) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - fromType = from.getBuiltinType(B.getContext()); - toType = to.getType(B.getContext()); + fromType = truncation.getFromType(B.getContext()); if (mode == TruncMemMode) tmpBlock = B.CreateAlloca(fromType); else tmpBlock = nullptr; + + if (truncation.isToMPFR()) { + switch (mode) { + case TruncMemMode: + llvm::report_fatal_error( + "truncation to MPFR not supported in memory mode."); + case TruncOpMode: + case TruncOpFullModuleMode: + break; + } + } } void checkHandled(llvm::Instruction &inst) { @@ -5069,9 +5078,8 @@ class TruncateGenerator : public llvm::InstVisitor { case TruncOpMode: case TruncOpFullModuleMode: return floatValTruncate(B, v, tmpBlock, from, to); - default: - llvm_unreachable("Unknown trunc mode"); } + llvm_unreachable("Unknown trunc mode"); } Value *expand(IRBuilder<> &B, Value *v) { @@ -5081,9 +5089,8 @@ class TruncateGenerator : public llvm::InstVisitor { case TruncOpMode: case TruncOpFullModuleMode: return floatValExpand(B, v, tmpBlock, from, to); - default: - llvm_unreachable("Unknown trunc mode"); } + llvm_unreachable("Unknown trunc mode"); } void todo(llvm::Instruction &I) { @@ -5129,26 +5136,35 @@ class TruncateGenerator : public llvm::InstVisitor { void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { - Value *newCI = nullptr; - auto newI = getNewFromOriginal(&CI); - std::string oldName = CI.getName().str(); - newI->setName(""); - if (CI.getSrcTy() == getFromType()) { - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); - } - if (CI.getDestTy() == getToType()) { + switch (mode) { + case TruncMemMode: { + Value *newCI = nullptr; auto newI = getNewFromOriginal(&CI); - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); + std::string oldName = CI.getName().str(); + newI->setName(""); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (CI.getDestTy() == getToType()) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (newCI) { + newI->replaceAllUsesWith(newCI); + newI->eraseFromParent(); + } + return; } - if (newCI) { - newI->replaceAllUsesWith(newCI); - newI->eraseFromParent(); + case TruncOpMode: + case TruncOpFullModuleMode: + return; } - return; } void visitSelectInst(llvm::SelectInst &SI) { switch (mode) { @@ -5168,9 +5184,8 @@ class TruncateGenerator : public llvm::InstVisitor { case TruncOpMode: case TruncOpFullModuleMode: return; - default: - llvm_unreachable(""); } + llvm_unreachable(""); } void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } @@ -5178,6 +5193,11 @@ class TruncateGenerator : public llvm::InstVisitor { void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } void visitBinaryOperator(llvm::BinaryOperator &BO) { + auto oldLHS = BO.getOperand(0); + auto oldRHS = BO.getOperand(1); + + if (oldLHS != getFromType() && oldRHS != getFromType()) + return; switch (BO.getOpcode()) { default: @@ -5198,57 +5218,25 @@ class TruncateGenerator : public llvm::InstVisitor { return; } - if (to.getBuiltinType(BO.getContext())) { + auto newI = getNewFromOriginal(&BO); + Instruction *nres = nullptr; + if (truncation.isToMPFR()) { auto newI = getNewFromOriginal(&BO); IRBuilder<> B(newI); - auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); - auto newRHS = truncate(B, getNewFromOriginal(BO.getOperand(1))); - switch (BO.getOpcode()) { - default: - break; - case BinaryOperator::FMul: { - auto nres = cast(B.CreateFMul(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FAdd: { - auto nres = cast(B.CreateFAdd(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FSub: { - auto nres = cast(B.CreateFSub(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FDiv: { - auto nres = cast(B.CreateFDiv(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FRem: { - auto nres = cast(B.CreateFRem(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - } + auto newLHS = getNewFromOriginal(oldLHS); + auto newRHS = getNewFromOriginal(oldRHS); + nres = cast( + B.CreateCall(createMPFRCall(BO.getOpcode), {newLHS, newRHS})); + } else { + IRBuilder<> B(newI); + auto newLHS = truncate(B, getNewFromOriginal(oldLHS)); + auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); + nres = cast(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS)); } - todo(BO); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); return; } void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } @@ -5471,13 +5459,9 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm::Function *totrunc, - FloatRepresentation from, - FloatRepresentation to, + FloatTruncation truncation, TruncateMode mode) { - if (from == to) - return totrunc; - - TruncateCacheKey tup(totrunc, from, to, mode); + TruncateCacheKey tup(totrunc, truncation, mode); if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { return TruncateCachedFunctions.find(tup)->second; } @@ -5492,10 +5476,9 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, Type *NewTy = totrunc->getReturnType(); FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); - std::string truncName = std::string("__enzyme_done_truncate_") + - (mode == TruncMemMode ? "mem" : "op") + "_func_" + - from.to_string() + "_" + to.to_string() + "_" + - totrunc->getName().str(); + std::string truncName = + std::string("__enzyme_done_truncate_") + truncateModeStr(mode) + + "_func_" + truncation.mangleString() + "_" + totrunc->getName().str(); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, totrunc->getParent()); @@ -5530,34 +5513,6 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } - // TODO This is overloaded an doesnt do what it should do here - if (from < to) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Cannot truncate into a large width\n"; - llvm::Value *toshow = totrunc; - if (context.req) { - toshow = context.req; - ss << " at context: " << *context.req; - } else { - ss << *totrunc << "\n"; - } - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(toshow), - ErrorType::NoDerivative, nullptr, wrap(totrunc), - wrap(context.ip)); - return NewF; - } - if (context.req) { - EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, - ss.str()); - return NewF; - } - llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; - llvm::errs() << *totrunc << "\n"; - llvm_unreachable("attempting to truncate function without definition"); - } - ValueToValueMapTy originalToNewFn; for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); @@ -5579,7 +5534,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - TruncateGenerator handle(originalToNewFn, from, to, totrunc, NewF, mode, + TruncateGenerator handle(originalToNewFn, truncation, totrunc, NewF, mode, *this); for (auto &BB : *totrunc) for (auto &I : BB) diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 1e9bf216b6e9..a25a0c778f59 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -42,6 +42,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" #include "ActivityAnalysis.h" #include "FunctionUtils.h" @@ -50,19 +51,21 @@ #include "Utils.h" extern "C" { -extern llvm::cl::opt EnzymePrint; +ext #include "llvm/Support/ErrorHandling.h" ern llvm::cl::opt EnzymePrint; extern llvm::cl::opt EnzymeJuliaAddrLoad; } enum class AugmentedStruct { Tape, Return, DifferentialReturn }; -static inline std::string str(AugmentedStruct c) { +static inline std::string str(A #include + "llvm/Support/ErrorHandling.h" ugmentedStruct c) { switch (c) { case AugmentedStruct::Tape: return "tape"; case AugmentedStruct::Return: return "return"; - case AugmentedStruct::DifferentialReturn: + case AugmentedStr #include + "llvm/Support/ErrorHandling.h" uct::DifferentialReturn: return "DifferentialReturn"; default: llvm_unreachable("unknown cache type"); @@ -287,6 +290,16 @@ getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { } enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; +static const char *truncateModeStr(TruncateMode mode) { + switch (mode) { + case TruncMemMode: + return "trunc_mem"; + case TruncOpMode: + return "trunc_op"; + case TruncOpFullModuleMode: + return "trunc_op_full_module"; + } +} struct FloatRepresentation { // |_|__________|_________________| @@ -336,6 +349,50 @@ struct FloatRepresentation { } }; +struct FloatTruncation { +private: + FloatRepresentation from, to; + +public: + FloatTruncation(FloatRepresentation From, FloatRepresentation To) + : from(From), to(To) { + if (!From.canBeBuiltin()) + llvm::report_fatal_error("Float truncation `from` type is not builtin."); + if (From.exponentWidth < To.exponentWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider exponent than `to`."); + if (From.significandWidth < To.significandWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider wsignificand than `to`."); + if (From == To) + llvm::report_fatal_error( + "Float truncation `from` and `to` type must not be the same."); + } + llvm::Type *getFromType(llvm::LLVMContext &ctx) { + return from.getBuiltinType(ctx); + } + bool isToMPFR() { return !to.canBeBuiltin(); } + llvm::Type *getToType(llvm::LLVMContext &ctx) { + if (to.canBeBuiltin()) { + return to.getBuiltinType(ctx); + } else { + assert(isToMPFR()); + // Currently we do not support TruncMemMode for MPFR, and we provide + // runtime wrappers around MPFR for each builtin `from` type + return from.getBuiltinType(ctx); + } + } + bool operator==(const FloatTruncation &other) const { + return from == other.from && to == other.to; + } + bool operator<(const FloatTruncation &other) const { + return std::tuple(from, to) < std::tuple(other.from, other.to); + } + std::string mangleString() const { + return from.to_string() + "to" + to.to_string(); + } +}; + class EnzymeLogic { public: PreProcessCache PPC; From 4160604556f5056799ace32d92c75cdbc4113139 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Wed, 21 Feb 2024 21:28:30 +0900 Subject: [PATCH 02/27] MPFR truncation --- enzyme/Enzyme/Enzyme.cpp | 16 ++--- enzyme/Enzyme/EnzymeLogic.cpp | 123 +++++++++++++++++++++++----------- enzyme/Enzyme/EnzymeLogic.h | 18 ++--- 3 files changed, 100 insertions(+), 57 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 12fd64d91f1d..6f5a5317f723 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1353,8 +1353,10 @@ class EnzymeBase { RequestContext context(CI, &Builder); llvm::Value *res = Logic.CreateTruncateFunc( context, F, - getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), - getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), mode); + FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue())), + mode); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); @@ -2053,13 +2055,12 @@ class EnzymeBase { } bool handleFullModuleTrunc(Function &F) { - typedef std::vector> - TruncationsTy; + typedef std::vector TruncationsTy; static TruncationsTy FullModuleTruncs = []() -> TruncationsTy { StringRef ConfigStr(EnzymeTruncateAll); auto Invalid = [=]() { // TODO emit better diagnostic - llvm::report_fatal_error("error: invalid format for truncation config") + llvm::report_fatal_error("error: invalid format for truncation config"); }; // "64" or "11-52" @@ -2102,9 +2103,8 @@ class EnzymeBase { for (auto Truncation : FullModuleTruncs) { IRBuilder<> Builder(F.getContext()); RequestContext context(&*F.getEntryBlock().begin(), &Builder); - Function *TruncatedFunc = - Logic.CreateTruncateFunc(context, &F, Truncation.first, - Truncation.second, TruncOpFullModuleMode); + Function *TruncatedFunc = Logic.CreateTruncateFunc( + context, &F, Truncation, TruncOpFullModuleMode); ValueToValueMapTy Mapping; for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args())) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 0c6c8f1e3903..f3b400b32f87 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -32,7 +32,10 @@ #include "AdjointGenerator.h" #include "EnzymeLogic.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/ErrorHandling.h" #include @@ -4958,30 +4961,28 @@ Function *EnzymeLogic::CreateForwardDiff( } static Value *floatValTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, - FloatRepresentation to) { - Type *toTy = to.getType(B.getContext()); + FloatTruncation truncation) { + Type *toTy = truncation.getToType(B.getContext()); if (auto vty = dyn_cast(v->getType())) toTy = VectorType::get(toTy, vty->getElementCount()); return B.CreateFPTrunc(v, toTy, "enzyme_trunc"); } static Value *floatValExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, FloatRepresentation to) { - Type *fromTy = from.getBuiltinType(B.getContext()); + FloatTruncation truncation) { + Type *fromTy = truncation.getFromType(B.getContext()); if (auto vty = dyn_cast(v->getType())) fromTy = VectorType::get(fromTy, vty->getElementCount()); return B.CreateFPExt(v, fromTy, "enzyme_exp"); } static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, - FloatRepresentation to) { + FloatTruncation truncation) { if (isa(v->getType())) report_fatal_error("vector operations not allowed in mem trunc mode"); - Type *fromTy = from.getBuiltinType(B.getContext()); - Type *toTy = to.getType(B.getContext()); + Type *fromTy = truncation.getFromType(B.getContext()); + Type *toTy = truncation.getToType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); B.CreateStore( @@ -4991,15 +4992,15 @@ static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, } static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, FloatRepresentation to) { + FloatTruncation truncation) { if (isa(v->getType())) report_fatal_error("vector operations not allowed in mem trunc mode"); - Type *fromTy = from.getBuiltinType(B.getContext()); + Type *fromTy = truncation.getFromType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); auto c0 = Constant::getNullValue( - llvm::Type::getIntNTy(B.getContext(), from.getTypeWidth())); + llvm::Type::getIntNTy(B.getContext(), truncation.getFromTypeWidth())); B.CreateStore( c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); B.CreateStore( @@ -5013,21 +5014,27 @@ class TruncateGenerator : public llvm::InstVisitor { ValueToValueMapTy &originalToNewFn; FloatTruncation truncation; Type *fromType; + Type *toType; Function *oldFunc; Function *newFunc; AllocaInst *tmpBlock; TruncateMode mode; EnzymeLogic &Logic; + LLVMContext &ctx; public: TruncateGenerator(ValueToValueMapTy &originalToNewFn, FloatTruncation truncation, Function *oldFunc, Function *newFunc, TruncateMode mode, EnzymeLogic &Logic) : originalToNewFn(originalToNewFn), truncation(truncation), - oldFunc(oldFunc), newFunc(newFunc), mode(mode), Logic(Logic) { + oldFunc(oldFunc), newFunc(newFunc), mode(mode), Logic(Logic), + ctx(newFunc->getContext()) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - fromType = truncation.getFromType(B.getContext()); + fromType = truncation.getFromType(ctx); + toType = truncation.getToType(ctx); + if (fromType == toType) + assert(truncation.isToMPFR()); if (mode == TruncMemMode) tmpBlock = B.CreateAlloca(fromType); @@ -5074,10 +5081,13 @@ class TruncateGenerator : public llvm::InstVisitor { Value *truncate(IRBuilder<> &B, Value *v) { switch (mode) { case TruncMemMode: - return floatMemTruncate(B, v, tmpBlock, from, to); + assert(!truncation.isToMPFR()); + return floatMemTruncate(B, v, tmpBlock, truncation); case TruncOpMode: case TruncOpFullModuleMode: - return floatValTruncate(B, v, tmpBlock, from, to); + if (truncation.isToMPFR()) + return v; + return floatValTruncate(B, v, tmpBlock, truncation); } llvm_unreachable("Unknown trunc mode"); } @@ -5085,10 +5095,10 @@ class TruncateGenerator : public llvm::InstVisitor { Value *expand(IRBuilder<> &B, Value *v) { switch (mode) { case TruncMemMode: - return floatMemExpand(B, v, tmpBlock, from, to); + return floatMemExpand(B, v, tmpBlock, truncation); case TruncOpMode: case TruncOpFullModuleMode: - return floatValExpand(B, v, tmpBlock, from, to); + return floatValExpand(B, v, tmpBlock, truncation); } llvm_unreachable("Unknown trunc mode"); } @@ -5192,11 +5202,41 @@ class TruncateGenerator : public llvm::InstVisitor { void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { return; } void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } + CallInst *createMPFRCall(llvm::IRBuilder<> &B, llvm::Instruction &I, + llvm::Type *RetTy, SmallVectorImpl &Args) { + std::string Name; + if (auto BO = dyn_cast(&I)) { + Name = BO->getOpcodeName(); + } else if (auto CI = dyn_cast(&I)) { + if (auto F = CI->getFunction()) + Name = F->getName(); + else + llvm_unreachable( + "Unexpected indirect call inst for conversion to MPFR"); + } else { + llvm_unreachable("Unexpected instruction for conversion to MPFR"); + } + + std::string MangledName = + std::string("__enzyme_mpfr_") + truncation.mangleString() + "_" + Name; + auto F = newFunc->getParent()->getFunction(MangledName); + if (!F) { + SmallVector ArgTypes; + for (auto Arg : Args) + ArgTypes.push_back(Arg->getType()); + FunctionType *FnTy = + FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); + F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, + newFunc->getParent()); + } + return cast(B.CreateCall(F, Args)); + } void visitBinaryOperator(llvm::BinaryOperator &BO) { auto oldLHS = BO.getOperand(0); auto oldRHS = BO.getOperand(1); - if (oldLHS != getFromType() && oldRHS != getFromType()) + if (oldLHS->getType() != getFromType() && + oldRHS->getType() != getFromType()) return; switch (BO.getOpcode()) { @@ -5215,22 +5255,19 @@ class TruncateGenerator : public llvm::InstVisitor { case BinaryOperator::And: case BinaryOperator::Or: case BinaryOperator::Xor: + assert(0 && "Invalid binop opcode for float arg"); return; } auto newI = getNewFromOriginal(&BO); + IRBuilder<> B(newI); + auto newLHS = truncate(B, getNewFromOriginal(oldLHS)); + auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); Instruction *nres = nullptr; if (truncation.isToMPFR()) { - auto newI = getNewFromOriginal(&BO); - IRBuilder<> B(newI); - auto newLHS = getNewFromOriginal(oldLHS); - auto newRHS = getNewFromOriginal(oldRHS); - nres = cast( - B.CreateCall(createMPFRCall(BO.getOpcode), {newLHS, newRHS})); + SmallVector Args({newLHS, newRHS}); + nres = createMPFRCall(B, BO, truncation.getToType(ctx), Args); } else { - IRBuilder<> B(newI); - auto newLHS = truncate(B, getNewFromOriginal(oldLHS)); - auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); nres = cast(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS)); } nres->takeName(newI); @@ -5259,13 +5296,14 @@ class TruncateGenerator : public llvm::InstVisitor { void visitFenceInst(llvm::FenceInst &FI) { return; } bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) { + auto newI = cast(getNewFromOriginal(&CI)); + IRBuilder<> B(newI); + SmallVector orig_ops(CI.arg_size()); for (unsigned i = 0; i < CI.arg_size(); ++i) orig_ops[i] = CI.getOperand(i); bool hasFromType = false; - auto newI = cast(getNewFromOriginal(&CI)); - IRBuilder<> B(newI); SmallVector new_ops(CI.arg_size()); for (unsigned i = 0; i < CI.arg_size(); ++i) { if (orig_ops[i]->getType() == getFromType()) { @@ -5284,12 +5322,16 @@ class TruncateGenerator : public llvm::InstVisitor { if (!hasFromType) return false; - // TODO check that the intrinsic is overloaded - - CallInst *intr; - Value *nres = intr = - createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); - if (CI.getType() == getFromType()) + Instruction *intr = nullptr; + Value *nres = nullptr; + if (truncation.isToMPFR()) { + nres = intr = createMPFRCall(B, CI, retTy, new_ops); + } else { + // TODO check that the intrinsic is overloaded + nres = intr = + createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); + } + if (newI->getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); newI->replaceAllUsesWith(nres); @@ -5378,7 +5420,7 @@ class TruncateGenerator : public llvm::InstVisitor { Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) - return Logic.CreateTruncateFunc(ctx, F, from, to, mode); + return Logic.CreateTruncateFunc(ctx, F, truncation, mode); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; @@ -5445,10 +5487,11 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, Value *converted = nullptr; if (isTruncate) - converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, from, to); + converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, + FloatTruncation(from, to)); else - converted = - B.CreateFPExt(floatMemTruncate(B, v, nullptr, from, to), fromTy); + converted = B.CreateFPExt( + floatMemTruncate(B, v, nullptr, FloatTruncation(from, to)), fromTy); assert(converted); context.req->replaceAllUsesWith(converted); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index a25a0c778f59..50972a70ce47 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -51,21 +51,19 @@ #include "Utils.h" extern "C" { -ext #include "llvm/Support/ErrorHandling.h" ern llvm::cl::opt EnzymePrint; +extern llvm::cl::opt EnzymePrint; extern llvm::cl::opt EnzymeJuliaAddrLoad; } enum class AugmentedStruct { Tape, Return, DifferentialReturn }; -static inline std::string str(A #include - "llvm/Support/ErrorHandling.h" ugmentedStruct c) { +static inline std::string str(AugmentedStruct c) { switch (c) { case AugmentedStruct::Tape: return "tape"; case AugmentedStruct::Return: return "return"; - case AugmentedStr #include - "llvm/Support/ErrorHandling.h" uct::DifferentialReturn: + case AugmentedStruct::DifferentialReturn: return "DifferentialReturn"; default: llvm_unreachable("unknown cache type"); @@ -368,6 +366,8 @@ struct FloatTruncation { llvm::report_fatal_error( "Float truncation `from` and `to` type must not be the same."); } + unsigned getFromTypeWidth() { return from.getTypeWidth(); } + unsigned getToTypeWidth() { return to.getTypeWidth(); } llvm::Type *getFromType(llvm::LLVMContext &ctx) { return from.getBuiltinType(ctx); } @@ -640,13 +640,13 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); - using TruncateCacheKey = std::tuple; + using TruncateCacheKey = + std::tuple; std::map TruncateCachedFunctions; llvm::Function *CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, - FloatRepresentation from, - FloatRepresentation to, TruncateMode mode); + FloatTruncation truncation, + TruncateMode mode); bool CreateTruncateValue(RequestContext context, llvm::Value *addr, FloatRepresentation from, FloatRepresentation to, bool isTruncate); From 5045fffac72591cf51ad55b1c976a032f78b88ed Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 01:30:39 +0900 Subject: [PATCH 03/27] Fix mpfr function mangling --- enzyme/Enzyme/Enzyme.cpp | 41 +++++++++++++++++------- enzyme/Enzyme/EnzymeLogic.cpp | 14 ++++++-- enzyme/Enzyme/EnzymeLogic.h | 8 ++--- enzyme/test/Enzyme/Truncate/cmp.ll | 11 ++++--- enzyme/test/Enzyme/Truncate/intrinsic.ll | 19 ++++++++++- enzyme/test/Enzyme/Truncate/simple.ll | 25 ++++++++------- 6 files changed, 81 insertions(+), 37 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 6f5a5317f723..2e91430e9cfd 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1340,23 +1340,40 @@ class EnzymeBase { Function *F = parseFunctionParameter(CI); if (!F) return false; - if (CI->arg_size() != 3) { + unsigned ArgSize = CI->arg_size(); + if (ArgSize != 4 && ArgSize != 3) { EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, "Had incorrect number of args to __enzyme_truncate_func", *CI, - " - expected 3"); + " - expected 3 or 4"); return false; } - auto Cfrom = cast(CI->getArgOperand(1)); - assert(Cfrom); - auto Cto = cast(CI->getArgOperand(2)); - assert(Cto); - RequestContext context(CI, &Builder); - llvm::Value *res = Logic.CreateTruncateFunc( - context, F, - FloatTruncation( + FloatTruncation truncation = [&]() -> FloatTruncation { + if (ArgSize == 3) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto = cast(CI->getArgOperand(2)); + assert(Cto); + return FloatTruncation( getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), - getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue())), - mode); + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue())); + } else if (ArgSize == 4) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto_exponent = cast(CI->getArgOperand(2)); + assert(Cto_exponent); + auto Cto_significand = cast(CI->getArgOperand(3)); + assert(Cto_significand); + return FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + FloatRepresentation( + (unsigned)Cto_exponent->getValue().getZExtValue(), + (unsigned)Cto_significand->getValue().getZExtValue())); + } + llvm_unreachable("??"); + }(); + + RequestContext context(CI, &Builder); + llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index f3b400b32f87..1e0b5584d48a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -34,6 +34,7 @@ #include "llvm/IR/GlobalValue.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/ErrorHandling.h" @@ -5206,10 +5207,17 @@ class TruncateGenerator : public llvm::InstVisitor { llvm::Type *RetTy, SmallVectorImpl &Args) { std::string Name; if (auto BO = dyn_cast(&I)) { - Name = BO->getOpcodeName(); + Name = "binop_" + std::string(BO->getOpcodeName()); + } else if (auto II = dyn_cast(&I)) { + auto FOp = II->getCalledFunction(); + assert(FOp); + Name = "intr_" + std::string(FOp->getName()); + for (auto &C : Name) + if (C == '.') + C = '_'; } else if (auto CI = dyn_cast(&I)) { - if (auto F = CI->getFunction()) - Name = F->getName(); + if (auto F = CI->getCalledFunction()) + Name = "func_" + std::string(F->getName()); else llvm_unreachable( "Unexpected indirect call inst for conversion to MPFR"); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 50972a70ce47..190b4dabc60a 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -288,14 +288,14 @@ getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { } enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; -static const char *truncateModeStr(TruncateMode mode) { +[[maybe_unused]] static const char *truncateModeStr(TruncateMode mode) { switch (mode) { case TruncMemMode: - return "trunc_mem"; + return "mem"; case TruncOpMode: - return "trunc_op"; + return "op"; case TruncOpFullModuleMode: - return "trunc_op_full_module"; + return "op_full_module"; } } diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index c96efa70660a..8d501cd24f7d 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -21,11 +21,12 @@ entry: %res = call i1 %ptr(double %x, double %y) ret i1 %res } - -; CHECK: define i1 @tester(double %x, double %y) { -; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) -; CHECK-NEXT: ret i1 %res +define i1 @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_op_func(i1 (double, double)* @f, i64 64, i64 3, i64 7) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} ; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { ; CHECK-DAG: %1 = alloca double, align 8 diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 99568539c3f3..0a51f088c78d 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -1,11 +1,13 @@ ; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi ; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi +declare double @pow(double %Val, double %Power) declare double @llvm.pow.f64(double %Val, double %Power) declare double @llvm.powi.f64.i16(double %Val, i16 %power) declare void @llvm.nvvm.barrier0() define double @f(double %x, double %y) { + %res0 = call double @pow(double %x, double %y) %res1 = call double @llvm.pow.f64(double %x, double %y) %res2 = call double @llvm.powi.f64.i16(double %x, i16 2) %res = fadd double %res1, %res2 @@ -22,12 +24,18 @@ entry: %res = call double %ptr(double %x, double %y) ret double %res } -define double @tester2(double %x, double %y) { +define double @tester_op(double %x, double %y) { entry: %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y) ret double %res } +define double @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 3, i64 7) + %res = call double %ptr(double %x, double %y) + ret double %res +} ; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { ; CHECK-NEXT: %1 = alloca double, align 8 @@ -81,3 +89,12 @@ entry: ; CHECK-DAG: %enzyme_exp8 = fpext float %res to double ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %enzyme_exp8 + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { +; CHECK-DAG: %1 = call double @__enzyme_mpfr_64_52to11_7_func_pow(double %x, double %y) +; CHECK-DAG: %2 = call double @__enzyme_mpfr_64_52to11_7_intr_llvm_pow_f64(double %x, double %y) +; CHECK-DAG: %3 = call double @__enzyme_mpfr_64_52to11_7_intr_llvm_powi_f64_i16(double %x, i16 2) +; CHECK-DAG: %res = call double @__enzyme_mpfr_64_52to11_7_binop_fadd(double %2, double %3) +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %res +; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 19d6cf1f3a23..4c5a85ce2ef1 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -17,23 +17,18 @@ entry: call void %ptr(double* %data) ret void } - -define void @tester2(double* %data) { +define void @tester_op(double* %data) { entry: %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 32) call void %ptr(double* %data) ret void } - -; CHECK: define void @tester(double* %data) -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %data) -; CHECK-NEXT: ret void - -; CHECK: define void @tester2(double* %data) { -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %data) -; CHECK-NEXT: ret void +define void @tester_op_mpfr(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 3, i64 7) + call void %ptr(double* %data) + ret void +} ; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %x) ; CHECK-DAG: %1 = alloca double, align 8 @@ -61,3 +56,9 @@ entry: ; CHECK-DAG: %enzyme_exp = fpext float %m to double ; CHECK-DAG: store double %enzyme_exp, double* %x, align 8 ; CHECK-DAG: ret void +_ +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: %m = call double @__enzyme_mpfr_64_52to11_7_fmul(double %y, double %y) +; CHECK-DAG: store double %m, double* %x, align 8 +; CHECK-DAG: ret void From 1572431341304e64d92011b378fb3dcb99a56d87 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 08:35:35 +0900 Subject: [PATCH 04/27] Mangling --- enzyme/Enzyme/EnzymeLogic.cpp | 9 ++++++--- enzyme/Enzyme/EnzymeLogic.h | 7 ++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 1e0b5584d48a..2a08f4a698bf 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5204,7 +5204,7 @@ class TruncateGenerator : public llvm::InstVisitor { void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } CallInst *createMPFRCall(llvm::IRBuilder<> &B, llvm::Instruction &I, - llvm::Type *RetTy, SmallVectorImpl &Args) { + llvm::Type *RetTy, SmallVectorImpl &ArgsIn) { std::string Name; if (auto BO = dyn_cast(&I)) { Name = "binop_" + std::string(BO->getOpcodeName()); @@ -5226,8 +5226,11 @@ class TruncateGenerator : public llvm::InstVisitor { } std::string MangledName = - std::string("__enzyme_mpfr_") + truncation.mangleString() + "_" + Name; + std::string("__enzyme_mpfr_") + truncation.mangleFrom() + "_" + Name; auto F = newFunc->getParent()->getFunction(MangledName); + SmallVector Args(ArgsIn.begin(), ArgsIn.end()); + Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); + Args.push_back(B.getInt64(truncation.getTo().significandWidth)); if (!F) { SmallVector ArgTypes; for (auto Arg : Args) @@ -5529,7 +5532,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); std::string truncName = std::string("__enzyme_done_truncate_") + truncateModeStr(mode) + - "_func_" + truncation.mangleString() + "_" + totrunc->getName().str(); + "_func_" + truncation.mangleTruncation() + "_" + totrunc->getName().str(); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, totrunc->getParent()); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 190b4dabc60a..ce329c429c0b 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -297,6 +297,7 @@ enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; case TruncOpFullModuleMode: return "op_full_module"; } + llvm_unreachable("Invalid truncation mode"); } struct FloatRepresentation { @@ -366,6 +367,7 @@ struct FloatTruncation { llvm::report_fatal_error( "Float truncation `from` and `to` type must not be the same."); } + FloatRepresentation getTo() { return to;} unsigned getFromTypeWidth() { return from.getTypeWidth(); } unsigned getToTypeWidth() { return to.getTypeWidth(); } llvm::Type *getFromType(llvm::LLVMContext &ctx) { @@ -388,9 +390,12 @@ struct FloatTruncation { bool operator<(const FloatTruncation &other) const { return std::tuple(from, to) < std::tuple(other.from, other.to); } - std::string mangleString() const { + std::string mangleTruncation() const { return from.to_string() + "to" + to.to_string(); } + std::string mangleFrom() const { + return from.to_string(); + } }; class EnzymeLogic { From 772c7f6748856f96df9a0a7cc7c928a3999217af Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 08:36:19 +0900 Subject: [PATCH 05/27] MPFR Wrappers --- enzyme/Enzyme/Runtime/EnzymeMPFR.cpp | 37 ++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 enzyme/Enzyme/Runtime/EnzymeMPFR.cpp diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp b/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp new file mode 100644 index 000000000000..cea570339802 --- /dev/null +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp @@ -0,0 +1,37 @@ +#include +#include + +extern "C" { + +#define BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, RET, MPFR_GET, ARG1, \ + MPFR_SET_ARG1, ARG2, MPFR_SET_ARG2, ROUNDING_MODE) \ + __attribute__((weak)) RET __enzyme_mpfr_##FROM_TYPE_binop_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } + +#define DEFAULT_ROUNDING_MODE GMP_RNDN +#define DBL_MANGLE 64_52 +#define DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, ROUNDING_MODE) \ + BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, DBL_MANGLE, double, d, double, d, \ + double, d, ROUNDING_MODE) +#define DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, MPFR_FUNC_NAME) \ + DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, DEFAULT_ROUNDING_MODE) + + // BINOP(fmul, mul, 64_52, double, d, double, d, double, d, GMP_RNDN) + DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) + DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) + DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) + +} From 3378e0e1aba8f0e2d5d582c1cd31c6708fa53041 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 08:40:28 +0900 Subject: [PATCH 06/27] clang-format --- enzyme/Enzyme/EnzymeLogic.cpp | 3 ++- enzyme/Enzyme/EnzymeLogic.h | 6 ++---- enzyme/Enzyme/Runtime/EnzymeMPFR.cpp | 13 ++++++------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 2a08f4a698bf..d372b46d8813 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5204,7 +5204,8 @@ class TruncateGenerator : public llvm::InstVisitor { void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } CallInst *createMPFRCall(llvm::IRBuilder<> &B, llvm::Instruction &I, - llvm::Type *RetTy, SmallVectorImpl &ArgsIn) { + llvm::Type *RetTy, + SmallVectorImpl &ArgsIn) { std::string Name; if (auto BO = dyn_cast(&I)) { Name = "binop_" + std::string(BO->getOpcodeName()); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index ce329c429c0b..4bb61e94c8ef 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -367,7 +367,7 @@ struct FloatTruncation { llvm::report_fatal_error( "Float truncation `from` and `to` type must not be the same."); } - FloatRepresentation getTo() { return to;} + FloatRepresentation getTo() { return to; } unsigned getFromTypeWidth() { return from.getTypeWidth(); } unsigned getToTypeWidth() { return to.getTypeWidth(); } llvm::Type *getFromType(llvm::LLVMContext &ctx) { @@ -393,9 +393,7 @@ struct FloatTruncation { std::string mangleTruncation() const { return from.to_string() + "to" + to.to_string(); } - std::string mangleFrom() const { - return from.to_string(); - } + std::string mangleFrom() const { return from.to_string(); } }; class EnzymeLogic { diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp b/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp index cea570339802..170167269d95 100644 --- a/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp @@ -26,12 +26,11 @@ extern "C" { #define DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, ROUNDING_MODE) \ BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, DBL_MANGLE, double, d, double, d, \ double, d, ROUNDING_MODE) -#define DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, MPFR_FUNC_NAME) \ - DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, DEFAULT_ROUNDING_MODE) - - // BINOP(fmul, mul, 64_52, double, d, double, d, double, d, GMP_RNDN) - DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) - DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) - DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) +#define DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, MPFR_FUNC_NAME) \ + DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, DEFAULT_ROUNDING_MODE) +// BINOP(fmul, mul, 64_52, double, d, double, d, double, d, GMP_RNDN) +DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) +DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) +DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) } From 23ea6459edc83604fc709d8c76f65fdadb4320cb Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 08:46:50 +0900 Subject: [PATCH 07/27] Make header work in C --- enzyme/Enzyme/Runtime/EnzymeMPFR.cpp | 36 -------------- enzyme/Enzyme/Runtime/EnzymeMPFR.h | 72 ++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 36 deletions(-) delete mode 100644 enzyme/Enzyme/Runtime/EnzymeMPFR.cpp create mode 100644 enzyme/Enzyme/Runtime/EnzymeMPFR.h diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp b/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp deleted file mode 100644 index 170167269d95..000000000000 --- a/enzyme/Enzyme/Runtime/EnzymeMPFR.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include -#include - -extern "C" { - -#define BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, RET, MPFR_GET, ARG1, \ - MPFR_SET_ARG1, ARG2, MPFR_SET_ARG2, ROUNDING_MODE) \ - __attribute__((weak)) RET __enzyme_mpfr_##FROM_TYPE_binop_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ - mpfr_t ma, mb, mc; \ - mpfr_init2(ma, significand); \ - mpfr_init2(mb, significand); \ - mpfr_init2(mc, significand); \ - mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ - mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ - mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ - RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ - mpfr_clear(ma); \ - mpfr_clear(mb); \ - mpfr_clear(mc); \ - return c; \ - } - -#define DEFAULT_ROUNDING_MODE GMP_RNDN -#define DBL_MANGLE 64_52 -#define DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, ROUNDING_MODE) \ - BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, DBL_MANGLE, double, d, double, d, \ - double, d, ROUNDING_MODE) -#define DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, MPFR_FUNC_NAME) \ - DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, DEFAULT_ROUNDING_MODE) - -// BINOP(fmul, mul, 64_52, double, d, double, d, double, d, GMP_RNDN) -DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) -DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) -DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) -} diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.h b/enzyme/Enzyme/Runtime/EnzymeMPFR.h new file mode 100644 index 000000000000..175ea317bef3 --- /dev/null +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.h @@ -0,0 +1,72 @@ +//===- EnzymeMPFR.h - Implementation of forward and reverse pass generation==// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define __ENZYME_MPFR_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, RET, \ + MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __attribute__((weak)) RET __enzyme_mpfr_##FROM_TYPE_binop_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN +#define __ENZYME_MPFR_DBL_MANGLE 64_52 +#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, __ENZYME_MPFR_DBL_MANGLE, \ + double, d, double, d, double, d, ROUNDING_MODE) +#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ + MPFR_FUNC_NAME) \ + __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ From f51efe476d8b14701b67cfa7cd203bf425f30bcc Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 08:49:52 +0900 Subject: [PATCH 08/27] File header --- enzyme/Enzyme/Runtime/EnzymeMPFR.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.h b/enzyme/Enzyme/Runtime/EnzymeMPFR.h index 175ea317bef3..30127d39633f 100644 --- a/enzyme/Enzyme/Runtime/EnzymeMPFR.h +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.h @@ -1,4 +1,4 @@ -//===- EnzymeMPFR.h - Implementation of forward and reverse pass generation==// +//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===// // // Enzyme Project // From 5d1a0f21f70357da7d694e637e5cc5eb9a89a098 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 08:53:48 +0900 Subject: [PATCH 09/27] Make it compile on llvm 11 --- enzyme/Enzyme/Enzyme.cpp | 2 +- enzyme/Enzyme/EnzymeLogic.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 2e91430e9cfd..4ab84bd6e67e 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -23,7 +23,6 @@ // the function passed as the first argument. // //===----------------------------------------------------------------------===// -#include "llvm/Support/ErrorHandling.h" #include #if LLVM_VERSION_MAJOR >= 16 @@ -59,6 +58,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Analysis/BasicAliasAnalysis.h" diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index d372b46d8813..f8fdf3b3124a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5229,11 +5229,11 @@ class TruncateGenerator : public llvm::InstVisitor { std::string MangledName = std::string("__enzyme_mpfr_") + truncation.mangleFrom() + "_" + Name; auto F = newFunc->getParent()->getFunction(MangledName); - SmallVector Args(ArgsIn.begin(), ArgsIn.end()); + SmallVector Args(ArgsIn.begin(), ArgsIn.end()); Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); Args.push_back(B.getInt64(truncation.getTo().significandWidth)); if (!F) { - SmallVector ArgTypes; + SmallVector ArgTypes; for (auto Arg : Args) ArgTypes.push_back(Arg->getType()); FunctionType *FnTy = @@ -5277,7 +5277,7 @@ class TruncateGenerator : public llvm::InstVisitor { auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); Instruction *nres = nullptr; if (truncation.isToMPFR()) { - SmallVector Args({newLHS, newRHS}); + SmallVector Args({newLHS, newRHS}); nres = createMPFRCall(B, BO, truncation.getToType(ctx), Args); } else { nres = cast(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS)); From e5e6303d44c5f64b905abf534283d4ccde96ea74 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 10:01:32 +0900 Subject: [PATCH 10/27] header --- enzyme/Enzyme/Runtime/EnzymeMPFR.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.h b/enzyme/Enzyme/Runtime/EnzymeMPFR.h index 30127d39633f..c4de03e169a7 100644 --- a/enzyme/Enzyme/Runtime/EnzymeMPFR.h +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.h @@ -31,10 +31,11 @@ extern "C" { #endif -#define __ENZYME_MPFR_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, RET, \ - MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ +#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ MPFR_SET_ARG2, ROUNDING_MODE) \ - __attribute__((weak)) RET __enzyme_mpfr_##FROM_TYPE_binop_##LLVM_OP_NAME( \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE_##OP_TYPE_##LLVM_OP_NAME( \ ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ mpfr_t ma, mb, mc; \ mpfr_init2(ma, significand); \ @@ -54,8 +55,9 @@ extern "C" { #define __ENZYME_MPFR_DBL_MANGLE 64_52 #define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ ROUNDING_MODE) \ - __ENZYME_MPFR_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, __ENZYME_MPFR_DBL_MANGLE, \ - double, d, double, d, double, d, ROUNDING_MODE) + __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, \ + __ENZYME_MPFR_DBL_MANGLE, double, d, double, d, double, \ + d, ROUNDING_MODE) #define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ MPFR_FUNC_NAME) \ __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ From b7d652320e8946dc5ff9f67b4df91b83473d11fc Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 10:01:49 +0900 Subject: [PATCH 11/27] fix tests --- enzyme/test/Enzyme/Truncate/cmp.ll | 4 +- enzyme/test/Enzyme/Truncate/intrinsic.ll | 118 +++++++++++++---------- enzyme/test/Enzyme/Truncate/select.ll | 6 +- enzyme/test/Enzyme/Truncate/simple.ll | 8 +- 4 files changed, 76 insertions(+), 60 deletions(-) diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index 8d501cd24f7d..68f0ef473a9b 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -28,7 +28,7 @@ entry: ret i1 %res } -; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -39,7 +39,7 @@ entry: ; CHECK-DAG: %res = fcmp olt float %3, %5 ; CHECK-DAG: ret i1 %res -; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float ; CHECK-DAG: %res = fcmp olt float %enzyme_trunc, %enzyme_trunc1 diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 0a51f088c78d..2299c9fb1ab3 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -37,64 +37,80 @@ entry: ret double %res } -; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { -; CHECK-NEXT: %1 = alloca double, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %2 = bitcast double* %1 to float* -; CHECK-NEXT: %3 = load float, float* %2, align 4 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %4 = bitcast double* %1 to float* -; CHECK-NEXT: %5 = load float, float* %4, align 4 -; CHECK-NEXT: %res11 = call float @llvm.pow.f32(float %3, float %5) -; CHECK-NEXT: %6 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %6, align 4 -; CHECK-NEXT: %7 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res11, float* %7, align 4 -; CHECK-NEXT: %8 = load double, double* %1, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %9 = bitcast double* %1 to float* -; CHECK-NEXT: %10 = load float, float* %9, align 4 -; CHECK-NEXT: %res22 = call float @llvm.powi.f32.i16(float %10, i16 2) -; CHECK-NEXT: %11 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %11, align 4 -; CHECK-NEXT: %12 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res22, float* %12, align 4 -; CHECK-NEXT: %13 = load double, double* %1, align 8 -; CHECK-NEXT: store double %8, double* %1, align 8 -; CHECK-NEXT: %14 = bitcast double* %1 to float* -; CHECK-NEXT: %15 = load float, float* %14, align 4 -; CHECK-NEXT: store double %13, double* %1, align 8 -; CHECK-NEXT: %16 = bitcast double* %1 to float* -; CHECK-NEXT: %17 = load float, float* %16, align 4 -; CHECK-NEXT: %res = fadd float %15, %17 -; CHECK-NEXT: %18 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %18, align 4 -; CHECK-NEXT: %19 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res, float* %19, align 4 -; CHECK-NEXT: %20 = load double, double* %1, align 8 -; CHECK-NEXT: call void @llvm.nvvm.barrier0() -; CHECK-NEXT: ret double %20 +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res01 = call float @llvm.pow.f32(float %3, float %5) +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %res01, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %9 = bitcast double* %1 to float* +; CHECK-DAG: %10 = load float, float* %9, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %11 = bitcast double* %1 to float* +; CHECK-DAG: %12 = load float, float* %11, align 4 +; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %10, float %12) +; CHECK-DAG: %13 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %13, align 4 +; CHECK-DAG: %14 = bitcast double* %1 to float* +; CHECK-DAG: store float %res12, float* %14, align 4 +; CHECK-DAG: %15 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %16 = bitcast double* %1 to float* +; CHECK-DAG: %17 = load float, float* %16, align 4 +; CHECK-DAG: %res23 = call float @llvm.powi.f32.i16(float %17, i16 2) +; CHECK-DAG: %18 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %18, align 4 +; CHECK-DAG: %19 = bitcast double* %1 to float* +; CHECK-DAG: store float %res23, float* %19, align 4 +; CHECK-DAG: %20 = load double, double* %1, align 8 +; CHECK-DAG: store double %15, double* %1, align 8 +; CHECK-DAG: %21 = bitcast double* %1 to float* +; CHECK-DAG: %22 = load float, float* %21, align 4 +; CHECK-DAG: store double %20, double* %1, align 8 +; CHECK-DAG: %23 = bitcast double* %1 to float* +; CHECK-DAG: %24 = load float, float* %23, align 4 +; CHECK-DAG: %res = fadd float %22, %24 +; CHECK-DAG: %25 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %25, align 4 +; CHECK-DAG: %26 = bitcast double* %1 to float* +; CHECK-DAG: store float %res, float* %26, align 4 +; CHECK-DAG: %27 = load double, double* %1, align 8 +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %27 -; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float -; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) -; CHECK-DAG: %enzyme_exp = fpext float %res12 to double +; CHECK-DAG: %res02 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) +; CHECK-DAG: %enzyme_exp = fpext float %res02 to double ; CHECK-DAG: %enzyme_trunc3 = fptrunc double %x to float -; CHECK-DAG: %res24 = call float @llvm.powi.f32.i16(float %enzyme_trunc3, i16 2) -; CHECK-DAG: %enzyme_exp5 = fpext float %res24 to double -; CHECK-DAG: %enzyme_trunc6 = fptrunc double %enzyme_exp to float -; CHECK-DAG: %enzyme_trunc7 = fptrunc double %enzyme_exp5 to float -; CHECK-DAG: %res = fadd float %enzyme_trunc6, %enzyme_trunc7 -; CHECK-DAG: %enzyme_exp8 = fpext float %res to double +; CHECK-DAG: %enzyme_trunc4 = fptrunc double %y to float +; CHECK-DAG: %res15 = call float @llvm.pow.f32(float %enzyme_trunc3, float %enzyme_trunc4) +; CHECK-DAG: %enzyme_exp6 = fpext float %res15 to double +; CHECK-DAG: %enzyme_trunc7 = fptrunc double %x to float +; CHECK-DAG: %res28 = call float @llvm.powi.f32.i16(float %enzyme_trunc7, i16 2) +; CHECK-DAG: %enzyme_exp9 = fpext float %res28 to double +; CHECK-DAG: %enzyme_trunc10 = fptrunc double %enzyme_exp6 to float +; CHECK-DAG: %enzyme_trunc11 = fptrunc double %enzyme_exp9 to float +; CHECK-DAG: %res = fadd float %enzyme_trunc10, %enzyme_trunc11 +; CHECK-DAG: %enzyme_exp12 = fpext float %res to double ; CHECK-DAG: call void @llvm.nvvm.barrier0() -; CHECK-DAG: ret double %enzyme_exp8 +; CHECK-DAG: ret double %enzyme_exp12 ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_mpfr_64_52to11_7_func_pow(double %x, double %y) -; CHECK-DAG: %2 = call double @__enzyme_mpfr_64_52to11_7_intr_llvm_pow_f64(double %x, double %y) -; CHECK-DAG: %3 = call double @__enzyme_mpfr_64_52to11_7_intr_llvm_powi_f64_i16(double %x, i16 2) -; CHECK-DAG: %res = call double @__enzyme_mpfr_64_52to11_7_binop_fadd(double %2, double %3) +; CHECK-DAG: %1 = call double @__enzyme_mpfr_64_52_func_pow(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %2 = call double @__enzyme_mpfr_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %3 = call double @__enzyme_mpfr_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7) +; CHECK-DAG: %res = call double @__enzyme_mpfr_64_52_binop_fadd(double %2, double %3, i64 3, i64 7) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index 365d21ab5913..afc41219fed8 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -25,10 +25,10 @@ entry: ; CHECK: define double @tester(double %x, double %y, i1 %cond) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) +; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) ; CHECK-NEXT: ret double %res -; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -44,6 +44,6 @@ entry: ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: ret double %8 -; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %res = select i1 %cond, double %x, double %y ; CHECK-DAG: ret double %res diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 4c5a85ce2ef1..a57f33fcdfdb 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -30,7 +30,7 @@ entry: ret void } -; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %x) +; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52to32_23_f(double* %x) ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: store double %y, double* %1, align 8 @@ -48,7 +48,7 @@ entry: ; CHECK-DAG: store double %8, double* %x, align 8 ; CHECK-DAG: ret void -; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %x) { +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: %enzyme_trunc = fptrunc double %y to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float @@ -56,9 +56,9 @@ entry: ; CHECK-DAG: %enzyme_exp = fpext float %m to double ; CHECK-DAG: store double %enzyme_exp, double* %x, align 8 ; CHECK-DAG: ret void -_ + ; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_mpfr_64_52to11_7_fmul(double %y, double %y) +; CHECK-DAG: %m = call double @__enzyme_mpfr_64_52_binop_fmul(double %y, double %y, i64 3, i64 7) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void From 6b486ccae2982ec80adbf745eece5c82fbcbc3ea Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 10:14:40 +0900 Subject: [PATCH 12/27] Add TODO comment --- enzyme/Enzyme/Runtime/EnzymeMPFR.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.h b/enzyme/Enzyme/Runtime/EnzymeMPFR.h index c4de03e169a7..4081fef15a0a 100644 --- a/enzyme/Enzyme/Runtime/EnzymeMPFR.h +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.h @@ -31,6 +31,32 @@ extern "C" { #endif +// TODO s +// +// (for MPFR ver. 2.1) +// +// We need to set the range of the allowed exponent using `mpfr_set_emin` and +// `mpfr_set_emax`. (This means we can also play with whether the range is +// centered around 0 (1?) or somewhere else) +// +// For that we need to do this check: +// If the user changes the exponent range, it is her/his responsibility to +// check that all current floating-point variables are in the new allowed +// range (for example using mpfr_check_range), otherwise the subsequent +// behavior will be undefined, in the sense of the ISO C standard. +// +// MPFR docs state the following: +// Note: Overflow handling is still experimental and currently implemented +// partially. If an overflow occurs internally at the wrong place, anything +// can happen (crash, wrong results, etc). +// +// Which we would like to avoid somehow. +// +// MPFR also has this limitation that we need to address for accurate +// simulation: +// [...] subnormal numbers are not implemented. +// + #define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ MPFR_SET_ARG2, ROUNDING_MODE) \ From d966d940656f1e522e0177a0f8c8ef8e1df9b7c0 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 22 Feb 2024 13:25:43 +0900 Subject: [PATCH 13/27] more comments --- enzyme/Enzyme/Runtime/EnzymeMPFR.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.h b/enzyme/Enzyme/Runtime/EnzymeMPFR.h index 4081fef15a0a..1df417215a20 100644 --- a/enzyme/Enzyme/Runtime/EnzymeMPFR.h +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.h @@ -39,6 +39,10 @@ extern "C" { // `mpfr_set_emax`. (This means we can also play with whether the range is // centered around 0 (1?) or somewhere else) // +// (also these need to be mutex'ed as the exponent change is global in mpfr and +// not float-specific) ... (mpfr seems to have thread safe mode - check if it is +// enabled or if it is enabled by default) +// // For that we need to do this check: // If the user changes the exponent range, it is her/his responsibility to // check that all current floating-point variables are in the new allowed From b1ad60bead15140fd462e8ee1f435fa60dce43b1 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 26 Feb 2024 16:35:03 +0900 Subject: [PATCH 14/27] MPFR header fix --- enzyme/Enzyme/Runtime/EnzymeMPFR.h | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.h b/enzyme/Enzyme/Runtime/EnzymeMPFR.h index 1df417215a20..dd148d2ebfe1 100644 --- a/enzyme/Enzyme/Runtime/EnzymeMPFR.h +++ b/enzyme/Enzyme/Runtime/EnzymeMPFR.h @@ -61,11 +61,28 @@ extern "C" { // [...] subnormal numbers are not implemented. // +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } + #define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ MPFR_SET_ARG2, ROUNDING_MODE) \ __attribute__((weak)) \ - RET __enzyme_mpfr_##FROM_TYPE_##OP_TYPE_##LLVM_OP_NAME( \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ mpfr_t ma, mb, mc; \ mpfr_init2(ma, significand); \ @@ -82,12 +99,10 @@ extern "C" { } #define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN -#define __ENZYME_MPFR_DBL_MANGLE 64_52 #define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ ROUNDING_MODE) \ - __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, \ - __ENZYME_MPFR_DBL_MANGLE, double, d, double, d, double, \ - d, ROUNDING_MODE) + __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \ + double, d, double, d, ROUNDING_MODE) #define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ MPFR_FUNC_NAME) \ __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ @@ -97,6 +112,9 @@ __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) +__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d, + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + #ifdef __cplusplus } #endif From 4ba6b4b3288dde9fbe8aea947cbc98d1a072573a Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 26 Feb 2024 16:35:52 +0900 Subject: [PATCH 15/27] Add mpfr test --- enzyme/test/Integration/Truncate/truncate-all.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index ad5df438842f..863212852750 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -4,6 +4,7 @@ // Truncated // RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli -)" == "900000000.000000" ] ; fi // RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli -)" == "900000000.000000" ] ; fi +// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to3-7" | %lli -)" == "897581056.000000" ] ; fi #include From 3b6ec5ead9de7fee4533c613478359c480083a73 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 16:19:52 +0900 Subject: [PATCH 16/27] Move mpfr runtime --- enzyme/Enzyme/Runtime/{EnzymeMPFR.h => MPFR.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename enzyme/Enzyme/Runtime/{EnzymeMPFR.h => MPFR.cpp} (100%) diff --git a/enzyme/Enzyme/Runtime/EnzymeMPFR.h b/enzyme/Enzyme/Runtime/MPFR.cpp similarity index 100% rename from enzyme/Enzyme/Runtime/EnzymeMPFR.h rename to enzyme/Enzyme/Runtime/MPFR.cpp From 563d579347f7f5da6b8cbd29a90d1bf1c100eb93 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 16:23:17 +0900 Subject: [PATCH 17/27] Add another type of include header --- enzyme/Enzyme/Clang/include_utils.td | 17 +++++--- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 41 +++++++++++++++----- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index 1c99d219ce69..6f57a243f765 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -1,9 +1,14 @@ -class Headers { +class InlineHeader { string filename = filename_; string contents = contents_; } -def : Headers<"/enzymeroot/enzyme/utils", [{ +class FileHeader { + string filename_out = filename_out_; + string filename_in = filename_in_; +} + +def : InlineHeader<"/enzymeroot/enzyme/utils", [{ #pragma once extern int enzyme_dup; @@ -263,7 +268,7 @@ namespace enzyme { } }]>; -def : Headers<"/enzymeroot/enzyme/type_traits", [{ +def : InlineHeader<"/enzymeroot/enzyme/type_traits", [{ #pragma once #include @@ -312,7 +317,7 @@ namespace impl { } }]>; -def : Headers<"/enzymeroot/enzyme/tuple", [{ +def : InlineHeader<"/enzymeroot/enzyme/tuple", [{ #pragma once ///////////// @@ -449,10 +454,12 @@ constexpr auto tuple_cat(Tuples&&... tuples) { #undef _NOEXCEPT }]>; -def : Headers<"/enzymeroot/enzyme/enzyme", [{ +def : InlineHeader<"/enzymeroot/enzyme/enzyme", [{ #ifdef __cplusplus #include "enzyme/utils" #else #warning "Enzyme wrapper templates only available in C++" #endif }]>; + +def : FileHeader<"/enzymeroot/enzyme/mpfr", "Runtime/MPFR.cpp">; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 143c85ea684e..2a334e0fd95a 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -1252,18 +1253,38 @@ void printDiffUse( static void emitHeaderIncludes(const RecordKeeper &recordKeeper, raw_ostream &os) { - const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers"); os << "const char* include_headers[][2] = {\n"; bool seen = false; - for (Record *pattern : patterns) { - if (seen) - os << ",\n"; - auto filename = pattern->getValueAsString("filename"); - auto contents = pattern->getValueAsString("contents"); - os << "{\"" << filename << "\"\n,"; - os << "R\"(" << contents << ")\"\n"; - os << "}"; - seen = true; + { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("InlineHeader"); + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename = pattern->getValueAsString("filename"); + auto contents = pattern->getValueAsString("contents"); + os << "{\"" << filename << "\"\n,"; + os << "R\"(" << contents << ")\"\n"; + os << "}"; + seen = true; + } + } + { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("FileHeader"); + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename_out = pattern->getValueAsString("filename_out"); + std::string filename_in = pattern->getValueAsString("filename_in").str(); + std::string included_file; + auto contents = llvm::SrcMgr.OpenIncludeFile(filename_in, included_file); + //llvm::MemoryBuffer::getFile(filename_in, /*IsText=*/true); + if (!contents) + PrintFatalError(pattern->getLoc(), Twine("Could not read file ") + filename_in); + os << "{\"" << filename_out << "\"\n,"; + os << "R\"(" << contents.get()->getBuffer() << ")\"\n"; + os << "}"; + seen = true; + } } os << "};\n"; } From c9400654813eab572c4fa75528342bdf8a73ff86 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 16:41:56 +0900 Subject: [PATCH 18/27] fix tests --- .../Integration/Truncate/truncate-all.cpp | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index 863212852750..0e1fc11b6750 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -1,13 +1,26 @@ // Baseline -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli -)" == "900000000.560000" ] ; fi + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli - | FileCheck --check-prefix BASELINE %s; fi +// BASELINE: 900000000.560000 + // Truncated -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli -)" == "900000000.000000" ] ; fi -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli -)" == "900000000.000000" ] ; fi -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to3-7" | %lli -)" == "897581056.000000" ] ; fi + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli - | FileCheck --check-prefix TO_32 %s; fi +// TO_32: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli - | FileCheck --check-prefix TO_28_23 %s; fi +// TO_28_23: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ]; then %clang -DENZYME_TEST_TO_MPFR -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr; %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi +// TO_3_7: 897581056.000000 #include +#ifdef ENZYME_TEST_TO_MPFR +#include +#endif + #include "../test_utils.h" #define N 10 From d22b6ef21a224f69aa96db9a448442e4598a2a5f Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 16:49:27 +0900 Subject: [PATCH 19/27] clang-format --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 2a334e0fd95a..299221aa9baf 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1256,7 +1256,8 @@ static void emitHeaderIncludes(const RecordKeeper &recordKeeper, os << "const char* include_headers[][2] = {\n"; bool seen = false; { - const auto &patterns = recordKeeper.getAllDerivedDefinitions("InlineHeader"); + const auto &patterns = + recordKeeper.getAllDerivedDefinitions("InlineHeader"); for (Record *pattern : patterns) { if (seen) os << ",\n"; @@ -1277,9 +1278,10 @@ static void emitHeaderIncludes(const RecordKeeper &recordKeeper, std::string filename_in = pattern->getValueAsString("filename_in").str(); std::string included_file; auto contents = llvm::SrcMgr.OpenIncludeFile(filename_in, included_file); - //llvm::MemoryBuffer::getFile(filename_in, /*IsText=*/true); + // llvm::MemoryBuffer::getFile(filename_in, /*IsText=*/true); if (!contents) - PrintFatalError(pattern->getLoc(), Twine("Could not read file ") + filename_in); + PrintFatalError(pattern->getLoc(), + Twine("Could not read file ") + filename_in); os << "{\"" << filename_out << "\"\n,"; os << "R\"(" << contents.get()->getBuffer() << ")\"\n"; os << "}"; From 4e745df6fc5eb36cd351b75dc51c5f68f548bd93 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 17:52:49 +0900 Subject: [PATCH 20/27] Check for MPFR --- .github/workflows/ccpp.yml | 1 + enzyme/CMakeLists.txt | 7 +++++++ enzyme/test/Integration/Truncate/truncate-all.cpp | 2 +- enzyme/test/lit.site.cfg.py.in | 6 ++++++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 6efe084fdc07..1b5b293a24b7 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -27,6 +27,7 @@ jobs: - name: add llvm run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-get install -y libmpfr-dev sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev lld-${{ matrix.llvm }} clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev libzstd-dev sudo python3 -m pip install --upgrade pip lit diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 077f4f3a9554..2ab67dd3d2fb 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) project(Enzyme) include(CMakePackageConfigHelpers) +include(CheckIncludeFile) +include(CheckIncludeFileCXX) set(ENZYME_MAJOR_VERSION 0) set(ENZYME_MINOR_VERSION 0) @@ -265,6 +267,11 @@ string(REPLACE "};\n}" "};\n}}" INPUT_TEXT "${INPUT_TEXT}") string(REPLACE "const SCEV* S;\n};\n" "const SCEV* S;\n};\n}\n" INPUT_TEXT "${INPUT_TEXT}") endif() +find_library(MPFR_LIB_PATH mpfr) +CHECK_INCLUDE_FILE("mpfr.h" HAS_MPFR_H) +message("MPFR lib: " ${MPFR_LIB_PATH}) +message("MPFR header: " ${HAS_MPFR_H}) + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolutionExpander.h" "${INPUT_TEXT}") include_directories("${CMAKE_CURRENT_BINARY_DIR}/include") diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index 0e1fc11b6750..39e5965bda0d 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -12,7 +12,7 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli - | FileCheck --check-prefix TO_28_23 %s; fi // TO_28_23: 900000000.000000 -// RUN: if [ %llvmver -ge 12 ]; then %clang -DENZYME_TEST_TO_MPFR -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr; %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DENZYME_TEST_TO_MPFR -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr; %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi // TO_3_7: 897581056.000000 #include diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 0b8a0f831d6e..0cc5e6f28f38 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -16,6 +16,10 @@ config.llvm_shlib_ext = "@LLVM_SHLIBEXT@" config.targets_to_build = "@TARGETS_TO_BUILD@" +has_mpfr_h = "@HAS_MPFR_H@" +mpfr_lib_path = "@MPFR_LIB_PATH@" +has_mpfr = "yes" if mpfr_lib_path != "MPFR_LIB_PATH-NOTFOUND" and has_mpfr_h == "1" else "no" + ## Check the current platform with regex import re EAT_ERR_ON_X86 = ' ' @@ -112,6 +116,8 @@ if len("@ENZYME_BINARY_DIR@") == 0: config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < 15 else newPM)) config.substitutions.append(('%newLoadClangEnzyme', newPM)) +config.substitutions.append(('%hasMPFR', has_mpfr)) + # Let the main config do the real work. cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" if len("@ENZYME_SOURCE_DIR@") == 0: From 61c3e32ac42bec371a1f99877048b291cfb6e45f Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 18:28:31 +0900 Subject: [PATCH 21/27] Fix older llvm vers --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 299221aa9baf..67b2e48ab605 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1277,13 +1277,23 @@ static void emitHeaderIncludes(const RecordKeeper &recordKeeper, auto filename_out = pattern->getValueAsString("filename_out"); std::string filename_in = pattern->getValueAsString("filename_in").str(); std::string included_file; - auto contents = llvm::SrcMgr.OpenIncludeFile(filename_in, included_file); - // llvm::MemoryBuffer::getFile(filename_in, /*IsText=*/true); - if (!contents) +#if LLVM_VERSION_MAJOR >= 15 + auto contents_or_err = + llvm::SrcMgr.OpenIncludeFile(filename_in, included_file); + if (!contents_or_err) PrintFatalError(pattern->getLoc(), Twine("Could not read file ") + filename_in); + auto &contents = contents_or_err.get(); +#else + auto buf = llvm::SrcMgr.AddIncludeFile( + filename_in, pattern->getFieldLoc("filename_in"), included_file); + if (!buf) + PrintFatalError(pattern->getLoc(), + Twine("Could not read file ") + filename_in); + auto contents = llvm::SrcMgr.getMemoryBuffer(buf); +#endif os << "{\"" << filename_out << "\"\n,"; - os << "R\"(" << contents.get()->getBuffer() << ")\"\n"; + os << "R\"(" << contents->getBuffer() << ")\"\n"; os << "}"; seen = true; } From 259224c1b96c4f2564f21073ae6c2993ca9c3ebb Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 21:24:46 +0900 Subject: [PATCH 22/27] llvm 11 --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 67b2e48ab605..e66d01dd9ab0 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1285,8 +1285,14 @@ static void emitHeaderIncludes(const RecordKeeper &recordKeeper, Twine("Could not read file ") + filename_in); auto &contents = contents_or_err.get(); #else - auto buf = llvm::SrcMgr.AddIncludeFile( - filename_in, pattern->getFieldLoc("filename_in"), included_file); + auto buf = + llvm::SrcMgr.AddIncludeFile(filename_in, +#if LLVM_VERSION_MAJOR >= 12 + pattern->getFieldLoc("filename_in"), +#else + pattern->getLoc()[1], +#endif + included_file); if (!buf) PrintFatalError(pattern->getLoc(), Twine("Could not read file ") + filename_in); From 6a06b797625ae19d3c673c83792d01b8e300c31a Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Tue, 27 Feb 2024 21:27:29 +0900 Subject: [PATCH 23/27] . --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index e66d01dd9ab0..328da7aa14b6 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1290,7 +1290,7 @@ static void emitHeaderIncludes(const RecordKeeper &recordKeeper, #if LLVM_VERSION_MAJOR >= 12 pattern->getFieldLoc("filename_in"), #else - pattern->getLoc()[1], + SMLoc::getFromPointer(nullptr), #endif included_file); if (!buf) From 5ba264e6810571aabb898d3c5ebb2de79233069b Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Wed, 28 Feb 2024 13:41:18 +0900 Subject: [PATCH 24/27] WIP deps --- enzyme/BUILD | 5 ++++- enzyme/Enzyme/CMakeLists.txt | 10 +++++++++- enzyme/Enzyme/Clang/include_utils.td | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index dc36cb4ad0c5..74612696a469 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -150,7 +150,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/Clang/include_utils.td", - td_srcs = ["Enzyme/Clang/include_utils.td"], + td_srcs = [ + "Enzyme/Clang/include_utils.td", + "Enzyme/Runtime/MPFR.cpp", + ], deps = [ ":enzyme-tblgen", ], diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index b27e4beb08cd..4ced425cf44f 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -38,8 +38,16 @@ add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) -enzyme_tablegen(IncludeUtils.inc -gen-header-strings) +enzyme_tablegen(IncludeUtils.inc -gen-header-strings -I${CMAKE_CURRENT_SOURCE_DIR}/Clang/) add_public_tablegen_target(IncludeUtilsIncGen) +# add_custom_target(RuntimeMPFRcpp +# OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/RuntimeMPFR.cpp" +# COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/Runtime/MPFR.cpp") + +# add_custom_target(RuntimeMPFRcpp DEPENDS Runtime/MPFR.cpp) +# add_dependencies(IncludeUtilsIncGen RuntimeMPFRcpp) +#target_sources(IncludeUtilsIncGen INTERFACE Runtime/MPFR.cpp) +set_property(TARGET IncludeUtilsIncGen APPEND PROPERTY OBJECT_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/Runtime/MPFR.cpp") include_directories(${CMAKE_CURRENT_BINARY_DIR}) diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index 6f57a243f765..a2d97508f2b8 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -462,4 +462,4 @@ def : InlineHeader<"/enzymeroot/enzyme/enzyme", [{ #endif }]>; -def : FileHeader<"/enzymeroot/enzyme/mpfr", "Runtime/MPFR.cpp">; +def : FileHeader<"/enzymeroot/enzyme/mpfr", "../Runtime/MPFR.cpp">; From 4eb5d42983d0f201c9398af4952a201d895d6edb Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Wed, 28 Feb 2024 13:46:59 +0900 Subject: [PATCH 25/27] Proper include --- enzyme/Enzyme/CMakeLists.txt | 8 -------- 1 file changed, 8 deletions(-) diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 4ced425cf44f..a55be9525bec 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -40,14 +40,6 @@ add_public_tablegen_target(BlasDiffUseIncGen) set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) enzyme_tablegen(IncludeUtils.inc -gen-header-strings -I${CMAKE_CURRENT_SOURCE_DIR}/Clang/) add_public_tablegen_target(IncludeUtilsIncGen) -# add_custom_target(RuntimeMPFRcpp -# OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/RuntimeMPFR.cpp" -# COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/Runtime/MPFR.cpp") - -# add_custom_target(RuntimeMPFRcpp DEPENDS Runtime/MPFR.cpp) -# add_dependencies(IncludeUtilsIncGen RuntimeMPFRcpp) -#target_sources(IncludeUtilsIncGen INTERFACE Runtime/MPFR.cpp) -set_property(TARGET IncludeUtilsIncGen APPEND PROPERTY OBJECT_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/Runtime/MPFR.cpp") include_directories(${CMAKE_CURRENT_BINARY_DIR}) From 7787b85b21e97559be1399446b427539266e3523 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Wed, 28 Feb 2024 14:46:15 +0900 Subject: [PATCH 26/27] Dep --- enzyme/Enzyme/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index a55be9525bec..7d42aa7d8742 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -38,8 +38,13 @@ add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) +# Need to explicitly set included files as dependencies +set(ARG_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/Runtime/MPFR.cpp" CACHE INTERNAL "deps") +# Cmake tablegen adds the current cmake dir to the include path and bazel adds +# the directory that contains the .td file, that's why we need the include here enzyme_tablegen(IncludeUtils.inc -gen-header-strings -I${CMAKE_CURRENT_SOURCE_DIR}/Clang/) add_public_tablegen_target(IncludeUtilsIncGen) +unset(ARG_DEPENDS) include_directories(${CMAKE_CURRENT_BINARY_DIR}) From f7fc98cb48880170044ca1898da10a661f832b9c Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Wed, 28 Feb 2024 15:17:24 +0900 Subject: [PATCH 27/27] Switch to inline header --- enzyme/BUILD | 5 +- enzyme/Enzyme/CMakeLists.txt | 7 +- enzyme/Enzyme/Clang/include_utils.td | 140 +++++++++++++++++-- enzyme/Enzyme/Runtime/MPFR.cpp | 122 ---------------- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 59 ++------ 5 files changed, 141 insertions(+), 192 deletions(-) delete mode 100644 enzyme/Enzyme/Runtime/MPFR.cpp diff --git a/enzyme/BUILD b/enzyme/BUILD index 74612696a469..dc36cb4ad0c5 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -150,10 +150,7 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/Clang/include_utils.td", - td_srcs = [ - "Enzyme/Clang/include_utils.td", - "Enzyme/Runtime/MPFR.cpp", - ], + td_srcs = ["Enzyme/Clang/include_utils.td"], deps = [ ":enzyme-tblgen", ], diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 7d42aa7d8742..b27e4beb08cd 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -38,13 +38,8 @@ add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) -# Need to explicitly set included files as dependencies -set(ARG_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/Runtime/MPFR.cpp" CACHE INTERNAL "deps") -# Cmake tablegen adds the current cmake dir to the include path and bazel adds -# the directory that contains the .td file, that's why we need the include here -enzyme_tablegen(IncludeUtils.inc -gen-header-strings -I${CMAKE_CURRENT_SOURCE_DIR}/Clang/) +enzyme_tablegen(IncludeUtils.inc -gen-header-strings) add_public_tablegen_target(IncludeUtilsIncGen) -unset(ARG_DEPENDS) include_directories(${CMAKE_CURRENT_BINARY_DIR}) diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index a2d97508f2b8..cb7cdd839c20 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -1,14 +1,9 @@ -class InlineHeader { +class Headers { string filename = filename_; string contents = contents_; } -class FileHeader { - string filename_out = filename_out_; - string filename_in = filename_in_; -} - -def : InlineHeader<"/enzymeroot/enzyme/utils", [{ +def : Headers<"/enzymeroot/enzyme/utils", [{ #pragma once extern int enzyme_dup; @@ -268,7 +263,7 @@ namespace enzyme { } }]>; -def : InlineHeader<"/enzymeroot/enzyme/type_traits", [{ +def : Headers<"/enzymeroot/enzyme/type_traits", [{ #pragma once #include @@ -317,7 +312,7 @@ namespace impl { } }]>; -def : InlineHeader<"/enzymeroot/enzyme/tuple", [{ +def : Headers<"/enzymeroot/enzyme/tuple", [{ #pragma once ///////////// @@ -454,7 +449,7 @@ constexpr auto tuple_cat(Tuples&&... tuples) { #undef _NOEXCEPT }]>; -def : InlineHeader<"/enzymeroot/enzyme/enzyme", [{ +def : Headers<"/enzymeroot/enzyme/enzyme", [{ #ifdef __cplusplus #include "enzyme/utils" #else @@ -462,4 +457,127 @@ def : InlineHeader<"/enzymeroot/enzyme/enzyme", [{ #endif }]>; -def : FileHeader<"/enzymeroot/enzyme/mpfr", "../Runtime/MPFR.cpp">; +def : Headers<"/enzymeroot/enzyme/mpfr", [{ +//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// TODO s +// +// (for MPFR ver. 2.1) +// +// We need to set the range of the allowed exponent using `mpfr_set_emin` and +// `mpfr_set_emax`. (This means we can also play with whether the range is +// centered around 0 (1?) or somewhere else) +// +// (also these need to be mutex'ed as the exponent change is global in mpfr and +// not float-specific) ... (mpfr seems to have thread safe mode - check if it is +// enabled or if it is enabled by default) +// +// For that we need to do this check: +// If the user changes the exponent range, it is her/his responsibility to +// check that all current floating-point variables are in the new allowed +// range (for example using mpfr_check_range), otherwise the subsequent +// behavior will be undefined, in the sense of the ISO C standard. +// +// MPFR docs state the following: +// Note: Overflow handling is still experimental and currently implemented +// partially. If an overflow occurs internally at the wrong place, anything +// can happen (crash, wrong results, etc). +// +// Which we would like to avoid somehow. +// +// MPFR also has this limitation that we need to address for accurate +// simulation: +// [...] subnormal numbers are not implemented. +// + +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN +#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \ + double, d, double, d, ROUNDING_MODE) +#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ + MPFR_FUNC_NAME) \ + __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) + +__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d, + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +}]>; diff --git a/enzyme/Enzyme/Runtime/MPFR.cpp b/enzyme/Enzyme/Runtime/MPFR.cpp deleted file mode 100644 index dd148d2ebfe1..000000000000 --- a/enzyme/Enzyme/Runtime/MPFR.cpp +++ /dev/null @@ -1,122 +0,0 @@ -//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===// -// -// Enzyme Project -// -// Part of the Enzyme Project, under the Apache License v2.0 with LLVM -// Exceptions. See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, -// } -// -//===----------------------------------------------------------------------===// -// -// This file contains easy to use wrappers around MPFR functions. -// -//===----------------------------------------------------------------------===// -#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ -#define __ENZYME_RUNTIME_ENZYME_MPFR__ - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// TODO s -// -// (for MPFR ver. 2.1) -// -// We need to set the range of the allowed exponent using `mpfr_set_emin` and -// `mpfr_set_emax`. (This means we can also play with whether the range is -// centered around 0 (1?) or somewhere else) -// -// (also these need to be mutex'ed as the exponent change is global in mpfr and -// not float-specific) ... (mpfr seems to have thread safe mode - check if it is -// enabled or if it is enabled by default) -// -// For that we need to do this check: -// If the user changes the exponent range, it is her/his responsibility to -// check that all current floating-point variables are in the new allowed -// range (for example using mpfr_check_range), otherwise the subsequent -// behavior will be undefined, in the sense of the ISO C standard. -// -// MPFR docs state the following: -// Note: Overflow handling is still experimental and currently implemented -// partially. If an overflow occurs internally at the wrong place, anything -// can happen (crash, wrong results, etc). -// -// Which we would like to avoid somehow. -// -// MPFR also has this limitation that we need to address for accurate -// simulation: -// [...] subnormal numbers are not implemented. -// - -#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ - RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ - ROUNDING_MODE) \ - __attribute__((weak)) \ - RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand) { \ - mpfr_t ma, mc; \ - mpfr_init2(ma, significand); \ - mpfr_init2(mc, significand); \ - mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ - mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ - RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ - mpfr_clear(ma); \ - mpfr_clear(mc); \ - return c; \ - } - -#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ - RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ - MPFR_SET_ARG2, ROUNDING_MODE) \ - __attribute__((weak)) \ - RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ - mpfr_t ma, mb, mc; \ - mpfr_init2(ma, significand); \ - mpfr_init2(mb, significand); \ - mpfr_init2(mc, significand); \ - mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ - mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ - mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ - RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ - mpfr_clear(ma); \ - mpfr_clear(mb); \ - mpfr_clear(mc); \ - return c; \ - } - -#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN -#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ - ROUNDING_MODE) \ - __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \ - double, d, double, d, ROUNDING_MODE) -#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ - MPFR_FUNC_NAME) \ - __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ - __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) - -__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) -__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) -__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) - -__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d, - __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) - -#ifdef __cplusplus -} -#endif - -#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 328da7aa14b6..143c85ea684e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -16,7 +16,6 @@ #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/Path.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -1253,56 +1252,18 @@ void printDiffUse( static void emitHeaderIncludes(const RecordKeeper &recordKeeper, raw_ostream &os) { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers"); os << "const char* include_headers[][2] = {\n"; bool seen = false; - { - const auto &patterns = - recordKeeper.getAllDerivedDefinitions("InlineHeader"); - for (Record *pattern : patterns) { - if (seen) - os << ",\n"; - auto filename = pattern->getValueAsString("filename"); - auto contents = pattern->getValueAsString("contents"); - os << "{\"" << filename << "\"\n,"; - os << "R\"(" << contents << ")\"\n"; - os << "}"; - seen = true; - } - } - { - const auto &patterns = recordKeeper.getAllDerivedDefinitions("FileHeader"); - for (Record *pattern : patterns) { - if (seen) - os << ",\n"; - auto filename_out = pattern->getValueAsString("filename_out"); - std::string filename_in = pattern->getValueAsString("filename_in").str(); - std::string included_file; -#if LLVM_VERSION_MAJOR >= 15 - auto contents_or_err = - llvm::SrcMgr.OpenIncludeFile(filename_in, included_file); - if (!contents_or_err) - PrintFatalError(pattern->getLoc(), - Twine("Could not read file ") + filename_in); - auto &contents = contents_or_err.get(); -#else - auto buf = - llvm::SrcMgr.AddIncludeFile(filename_in, -#if LLVM_VERSION_MAJOR >= 12 - pattern->getFieldLoc("filename_in"), -#else - SMLoc::getFromPointer(nullptr), -#endif - included_file); - if (!buf) - PrintFatalError(pattern->getLoc(), - Twine("Could not read file ") + filename_in); - auto contents = llvm::SrcMgr.getMemoryBuffer(buf); -#endif - os << "{\"" << filename_out << "\"\n,"; - os << "R\"(" << contents->getBuffer() << ")\"\n"; - os << "}"; - seen = true; - } + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename = pattern->getValueAsString("filename"); + auto contents = pattern->getValueAsString("contents"); + os << "{\"" << filename << "\"\n,"; + os << "R\"(" << contents << ")\"\n"; + os << "}"; + seen = true; } os << "};\n"; }