From b97aa9d5f0a4264308d03901c81e733c13042106 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Wed, 28 Feb 2024 19:22:48 -0800 Subject: [PATCH] Trucation to MPFR (#1750) * WIP MPFR truncation * MPFR truncation * Fix mpfr function mangling * Mangling * MPFR Wrappers * clang-format * Make header work in C * File header * Make it compile on llvm 11 * header * fix tests * Add TODO comment * more comments * MPFR header fix * Add mpfr test * Move mpfr runtime * Add another type of include header * fix tests * clang-format * Check for MPFR * Fix older llvm vers * llvm 11 * . * WIP deps * Proper include * Dep * Switch to inline header --- .github/workflows/ccpp.yml | 1 + enzyme/CMakeLists.txt | 7 + enzyme/Enzyme/Clang/include_utils.td | 125 +++++++ enzyme/Enzyme/Enzyme.cpp | 51 ++- enzyme/Enzyme/EnzymeLogic.cpp | 308 +++++++++--------- enzyme/Enzyme/EnzymeLogic.h | 68 +++- enzyme/test/Enzyme/Truncate/cmp.ll | 15 +- enzyme/test/Enzyme/Truncate/intrinsic.ll | 129 +++++--- enzyme/test/Enzyme/Truncate/select.ll | 6 +- enzyme/test/Enzyme/Truncate/simple.ll | 29 +- .../Integration/Truncate/truncate-all.cpp | 20 +- enzyme/test/lit.site.cfg.py.in | 6 + 12 files changed, 520 insertions(+), 245 deletions(-) diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 6efe084fdc07..1b5b293a24b7 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -27,6 +27,7 @@ jobs: - name: add llvm run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-get install -y libmpfr-dev sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev lld-${{ matrix.llvm }} clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev libzstd-dev sudo python3 -m pip install --upgrade pip lit diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 077f4f3a9554..2ab67dd3d2fb 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) project(Enzyme) include(CMakePackageConfigHelpers) +include(CheckIncludeFile) +include(CheckIncludeFileCXX) set(ENZYME_MAJOR_VERSION 0) set(ENZYME_MINOR_VERSION 0) @@ -265,6 +267,11 @@ string(REPLACE "};\n}" "};\n}}" INPUT_TEXT "${INPUT_TEXT}") string(REPLACE "const SCEV* S;\n};\n" "const SCEV* S;\n};\n}\n" INPUT_TEXT "${INPUT_TEXT}") endif() +find_library(MPFR_LIB_PATH mpfr) +CHECK_INCLUDE_FILE("mpfr.h" HAS_MPFR_H) +message("MPFR lib: " ${MPFR_LIB_PATH}) +message("MPFR header: " ${HAS_MPFR_H}) + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolutionExpander.h" "${INPUT_TEXT}") include_directories("${CMAKE_CURRENT_BINARY_DIR}/include") diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index 1c99d219ce69..cb7cdd839c20 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -456,3 +456,128 @@ def : Headers<"/enzymeroot/enzyme/enzyme", [{ #warning "Enzyme wrapper templates only available in C++" #endif }]>; + +def : Headers<"/enzymeroot/enzyme/mpfr", [{ +//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// TODO s +// +// (for MPFR ver. 2.1) +// +// We need to set the range of the allowed exponent using `mpfr_set_emin` and +// `mpfr_set_emax`. (This means we can also play with whether the range is +// centered around 0 (1?) or somewhere else) +// +// (also these need to be mutex'ed as the exponent change is global in mpfr and +// not float-specific) ... (mpfr seems to have thread safe mode - check if it is +// enabled or if it is enabled by default) +// +// For that we need to do this check: +// If the user changes the exponent range, it is her/his responsibility to +// check that all current floating-point variables are in the new allowed +// range (for example using mpfr_check_range), otherwise the subsequent +// behavior will be undefined, in the sense of the ISO C standard. +// +// MPFR docs state the following: +// Note: Overflow handling is still experimental and currently implemented +// partially. If an overflow occurs internally at the wrong place, anything +// can happen (crash, wrong results, etc). +// +// Which we would like to avoid somehow. +// +// MPFR also has this limitation that we need to address for accurate +// simulation: +// [...] subnormal numbers are not implemented. +// + +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN +#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \ + double, d, double, d, ROUNDING_MODE) +#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ + MPFR_FUNC_NAME) \ + __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) + +__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d, + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +}]>; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 70f173f5734a..4ab84bd6e67e 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -58,6 +58,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -1339,21 +1340,40 @@ class EnzymeBase { Function *F = parseFunctionParameter(CI); if (!F) return false; - if (CI->arg_size() != 3) { + unsigned ArgSize = CI->arg_size(); + if (ArgSize != 4 && ArgSize != 3) { EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, "Had incorrect number of args to __enzyme_truncate_func", *CI, - " - expected 3"); + " - expected 3 or 4"); return false; } - auto Cfrom = cast(CI->getArgOperand(1)); - assert(Cfrom); - auto Cto = cast(CI->getArgOperand(2)); - assert(Cto); + FloatTruncation truncation = [&]() -> FloatTruncation { + if (ArgSize == 3) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto = cast(CI->getArgOperand(2)); + assert(Cto); + return FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue())); + } else if (ArgSize == 4) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto_exponent = cast(CI->getArgOperand(2)); + assert(Cto_exponent); + auto Cto_significand = cast(CI->getArgOperand(3)); + assert(Cto_significand); + return FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + FloatRepresentation( + (unsigned)Cto_exponent->getValue().getZExtValue(), + (unsigned)Cto_significand->getValue().getZExtValue())); + } + llvm_unreachable("??"); + }(); + RequestContext context(CI, &Builder); - llvm::Value *res = Logic.CreateTruncateFunc( - context, F, - getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), - getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), mode); + llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); @@ -2052,14 +2072,12 @@ class EnzymeBase { } bool handleFullModuleTrunc(Function &F) { - typedef std::vector> - TruncationsTy; + typedef std::vector TruncationsTy; static TruncationsTy FullModuleTruncs = []() -> TruncationsTy { StringRef ConfigStr(EnzymeTruncateAll); auto Invalid = [=]() { // TODO emit better diagnostic - llvm::errs() << "error: invalid format for truncation config\n"; - abort(); + llvm::report_fatal_error("error: invalid format for truncation config"); }; // "64" or "11-52" @@ -2102,9 +2120,8 @@ class EnzymeBase { for (auto Truncation : FullModuleTruncs) { IRBuilder<> Builder(F.getContext()); RequestContext context(&*F.getEntryBlock().begin(), &Builder); - Function *TruncatedFunc = - Logic.CreateTruncateFunc(context, &F, Truncation.first, - Truncation.second, TruncOpFullModuleMode); + Function *TruncatedFunc = Logic.CreateTruncateFunc( + context, &F, Truncation, TruncOpFullModuleMode); ValueToValueMapTy Mapping; for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args())) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index f930d4e1375b..f8fdf3b3124a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -32,7 +32,13 @@ #include "AdjointGenerator.h" #include "EnzymeLogic.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/ErrorHandling.h" +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -4956,30 +4962,28 @@ Function *EnzymeLogic::CreateForwardDiff( } static Value *floatValTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, - FloatRepresentation to) { - Type *toTy = to.getType(B.getContext()); + FloatTruncation truncation) { + Type *toTy = truncation.getToType(B.getContext()); if (auto vty = dyn_cast(v->getType())) toTy = VectorType::get(toTy, vty->getElementCount()); return B.CreateFPTrunc(v, toTy, "enzyme_trunc"); } static Value *floatValExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, FloatRepresentation to) { - Type *fromTy = from.getBuiltinType(B.getContext()); + FloatTruncation truncation) { + Type *fromTy = truncation.getFromType(B.getContext()); if (auto vty = dyn_cast(v->getType())) fromTy = VectorType::get(fromTy, vty->getElementCount()); return B.CreateFPExt(v, fromTy, "enzyme_exp"); } static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, - FloatRepresentation to) { + FloatTruncation truncation) { if (isa(v->getType())) report_fatal_error("vector operations not allowed in mem trunc mode"); - Type *fromTy = from.getBuiltinType(B.getContext()); - Type *toTy = to.getType(B.getContext()); + Type *fromTy = truncation.getFromType(B.getContext()); + Type *toTy = truncation.getToType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); B.CreateStore( @@ -4989,15 +4993,15 @@ static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, } static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, FloatRepresentation to) { + FloatTruncation truncation) { if (isa(v->getType())) report_fatal_error("vector operations not allowed in mem trunc mode"); - Type *fromTy = from.getBuiltinType(B.getContext()); + Type *fromTy = truncation.getFromType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); auto c0 = Constant::getNullValue( - llvm::Type::getIntNTy(B.getContext(), from.getTypeWidth())); + llvm::Type::getIntNTy(B.getContext(), truncation.getFromTypeWidth())); B.CreateStore( c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); B.CreateStore( @@ -5009,8 +5013,7 @@ static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, class TruncateGenerator : public llvm::InstVisitor { private: ValueToValueMapTy &originalToNewFn; - FloatRepresentation from; - FloatRepresentation to; + FloatTruncation truncation; Type *fromType; Type *toType; Function *oldFunc; @@ -5018,23 +5021,37 @@ class TruncateGenerator : public llvm::InstVisitor { AllocaInst *tmpBlock; TruncateMode mode; EnzymeLogic &Logic; + LLVMContext &ctx; public: TruncateGenerator(ValueToValueMapTy &originalToNewFn, - FloatRepresentation from, FloatRepresentation to, - Function *oldFunc, Function *newFunc, TruncateMode mode, - EnzymeLogic &Logic) - : originalToNewFn(originalToNewFn), from(from), to(to), oldFunc(oldFunc), - newFunc(newFunc), mode(mode), Logic(Logic) { + FloatTruncation truncation, Function *oldFunc, + Function *newFunc, TruncateMode mode, EnzymeLogic &Logic) + : originalToNewFn(originalToNewFn), truncation(truncation), + oldFunc(oldFunc), newFunc(newFunc), mode(mode), Logic(Logic), + ctx(newFunc->getContext()) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - fromType = from.getBuiltinType(B.getContext()); - toType = to.getType(B.getContext()); + fromType = truncation.getFromType(ctx); + toType = truncation.getToType(ctx); + if (fromType == toType) + assert(truncation.isToMPFR()); if (mode == TruncMemMode) tmpBlock = B.CreateAlloca(fromType); else tmpBlock = nullptr; + + if (truncation.isToMPFR()) { + switch (mode) { + case TruncMemMode: + llvm::report_fatal_error( + "truncation to MPFR not supported in memory mode."); + case TruncOpMode: + case TruncOpFullModuleMode: + break; + } + } } void checkHandled(llvm::Instruction &inst) { @@ -5065,25 +5082,26 @@ class TruncateGenerator : public llvm::InstVisitor { Value *truncate(IRBuilder<> &B, Value *v) { switch (mode) { case TruncMemMode: - return floatMemTruncate(B, v, tmpBlock, from, to); + assert(!truncation.isToMPFR()); + return floatMemTruncate(B, v, tmpBlock, truncation); case TruncOpMode: case TruncOpFullModuleMode: - return floatValTruncate(B, v, tmpBlock, from, to); - default: - llvm_unreachable("Unknown trunc mode"); + if (truncation.isToMPFR()) + return v; + return floatValTruncate(B, v, tmpBlock, truncation); } + llvm_unreachable("Unknown trunc mode"); } Value *expand(IRBuilder<> &B, Value *v) { switch (mode) { case TruncMemMode: - return floatMemExpand(B, v, tmpBlock, from, to); + return floatMemExpand(B, v, tmpBlock, truncation); case TruncOpMode: case TruncOpFullModuleMode: - return floatValExpand(B, v, tmpBlock, from, to); - default: - llvm_unreachable("Unknown trunc mode"); + return floatValExpand(B, v, tmpBlock, truncation); } + llvm_unreachable("Unknown trunc mode"); } void todo(llvm::Instruction &I) { @@ -5129,26 +5147,35 @@ class TruncateGenerator : public llvm::InstVisitor { void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { - Value *newCI = nullptr; - auto newI = getNewFromOriginal(&CI); - std::string oldName = CI.getName().str(); - newI->setName(""); - if (CI.getSrcTy() == getFromType()) { - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); - } - if (CI.getDestTy() == getToType()) { + switch (mode) { + case TruncMemMode: { + Value *newCI = nullptr; auto newI = getNewFromOriginal(&CI); - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); + std::string oldName = CI.getName().str(); + newI->setName(""); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (CI.getDestTy() == getToType()) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (newCI) { + newI->replaceAllUsesWith(newCI); + newI->eraseFromParent(); + } + return; } - if (newCI) { - newI->replaceAllUsesWith(newCI); - newI->eraseFromParent(); + case TruncOpMode: + case TruncOpFullModuleMode: + return; } - return; } void visitSelectInst(llvm::SelectInst &SI) { switch (mode) { @@ -5168,16 +5195,61 @@ class TruncateGenerator : public llvm::InstVisitor { case TruncOpMode: case TruncOpFullModuleMode: return; - default: - llvm_unreachable(""); } + llvm_unreachable(""); } void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { return; } void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } + CallInst *createMPFRCall(llvm::IRBuilder<> &B, llvm::Instruction &I, + llvm::Type *RetTy, + SmallVectorImpl &ArgsIn) { + std::string Name; + if (auto BO = dyn_cast(&I)) { + Name = "binop_" + std::string(BO->getOpcodeName()); + } else if (auto II = dyn_cast(&I)) { + auto FOp = II->getCalledFunction(); + assert(FOp); + Name = "intr_" + std::string(FOp->getName()); + for (auto &C : Name) + if (C == '.') + C = '_'; + } else if (auto CI = dyn_cast(&I)) { + if (auto F = CI->getCalledFunction()) + Name = "func_" + std::string(F->getName()); + else + llvm_unreachable( + "Unexpected indirect call inst for conversion to MPFR"); + } else { + llvm_unreachable("Unexpected instruction for conversion to MPFR"); + } + + std::string MangledName = + std::string("__enzyme_mpfr_") + truncation.mangleFrom() + "_" + Name; + auto F = newFunc->getParent()->getFunction(MangledName); + SmallVector Args(ArgsIn.begin(), ArgsIn.end()); + Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); + Args.push_back(B.getInt64(truncation.getTo().significandWidth)); + if (!F) { + SmallVector ArgTypes; + for (auto Arg : Args) + ArgTypes.push_back(Arg->getType()); + FunctionType *FnTy = + FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); + F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, + newFunc->getParent()); + } + return cast(B.CreateCall(F, Args)); + } void visitBinaryOperator(llvm::BinaryOperator &BO) { + auto oldLHS = BO.getOperand(0); + auto oldRHS = BO.getOperand(1); + + if (oldLHS->getType() != getFromType() && + oldRHS->getType() != getFromType()) + return; switch (BO.getOpcode()) { default: @@ -5195,60 +5267,25 @@ class TruncateGenerator : public llvm::InstVisitor { case BinaryOperator::And: case BinaryOperator::Or: case BinaryOperator::Xor: + assert(0 && "Invalid binop opcode for float arg"); return; } - if (to.getBuiltinType(BO.getContext())) { - auto newI = getNewFromOriginal(&BO); - IRBuilder<> B(newI); - auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); - auto newRHS = truncate(B, getNewFromOriginal(BO.getOperand(1))); - switch (BO.getOpcode()) { - default: - break; - case BinaryOperator::FMul: { - auto nres = cast(B.CreateFMul(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FAdd: { - auto nres = cast(B.CreateFAdd(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FSub: { - auto nres = cast(B.CreateFSub(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FDiv: { - auto nres = cast(B.CreateFDiv(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FRem: { - auto nres = cast(B.CreateFRem(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - } + auto newI = getNewFromOriginal(&BO); + IRBuilder<> B(newI); + auto newLHS = truncate(B, getNewFromOriginal(oldLHS)); + auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); + Instruction *nres = nullptr; + if (truncation.isToMPFR()) { + SmallVector Args({newLHS, newRHS}); + nres = createMPFRCall(B, BO, truncation.getToType(ctx), Args); + } else { + nres = cast(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS)); } - todo(BO); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); return; } void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } @@ -5271,13 +5308,14 @@ class TruncateGenerator : public llvm::InstVisitor { void visitFenceInst(llvm::FenceInst &FI) { return; } bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) { + auto newI = cast(getNewFromOriginal(&CI)); + IRBuilder<> B(newI); + SmallVector orig_ops(CI.arg_size()); for (unsigned i = 0; i < CI.arg_size(); ++i) orig_ops[i] = CI.getOperand(i); bool hasFromType = false; - auto newI = cast(getNewFromOriginal(&CI)); - IRBuilder<> B(newI); SmallVector new_ops(CI.arg_size()); for (unsigned i = 0; i < CI.arg_size(); ++i) { if (orig_ops[i]->getType() == getFromType()) { @@ -5296,12 +5334,16 @@ class TruncateGenerator : public llvm::InstVisitor { if (!hasFromType) return false; - // TODO check that the intrinsic is overloaded - - CallInst *intr; - Value *nres = intr = - createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); - if (CI.getType() == getFromType()) + Instruction *intr = nullptr; + Value *nres = nullptr; + if (truncation.isToMPFR()) { + nres = intr = createMPFRCall(B, CI, retTy, new_ops); + } else { + // TODO check that the intrinsic is overloaded + nres = intr = + createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); + } + if (newI->getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); newI->replaceAllUsesWith(nres); @@ -5390,7 +5432,7 @@ class TruncateGenerator : public llvm::InstVisitor { Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) - return Logic.CreateTruncateFunc(ctx, F, from, to, mode); + return Logic.CreateTruncateFunc(ctx, F, truncation, mode); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; @@ -5457,10 +5499,11 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, Value *converted = nullptr; if (isTruncate) - converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, from, to); + converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, + FloatTruncation(from, to)); else - converted = - B.CreateFPExt(floatMemTruncate(B, v, nullptr, from, to), fromTy); + converted = B.CreateFPExt( + floatMemTruncate(B, v, nullptr, FloatTruncation(from, to)), fromTy); assert(converted); context.req->replaceAllUsesWith(converted); @@ -5471,13 +5514,9 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm::Function *totrunc, - FloatRepresentation from, - FloatRepresentation to, + FloatTruncation truncation, TruncateMode mode) { - if (from == to) - return totrunc; - - TruncateCacheKey tup(totrunc, from, to, mode); + TruncateCacheKey tup(totrunc, truncation, mode); if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { return TruncateCachedFunctions.find(tup)->second; } @@ -5492,10 +5531,9 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, Type *NewTy = totrunc->getReturnType(); FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); - std::string truncName = std::string("__enzyme_done_truncate_") + - (mode == TruncMemMode ? "mem" : "op") + "_func_" + - from.to_string() + "_" + to.to_string() + "_" + - totrunc->getName().str(); + std::string truncName = + std::string("__enzyme_done_truncate_") + truncateModeStr(mode) + + "_func_" + truncation.mangleTruncation() + "_" + totrunc->getName().str(); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, totrunc->getParent()); @@ -5530,34 +5568,6 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } - // TODO This is overloaded an doesnt do what it should do here - if (from < to) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Cannot truncate into a large width\n"; - llvm::Value *toshow = totrunc; - if (context.req) { - toshow = context.req; - ss << " at context: " << *context.req; - } else { - ss << *totrunc << "\n"; - } - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(toshow), - ErrorType::NoDerivative, nullptr, wrap(totrunc), - wrap(context.ip)); - return NewF; - } - if (context.req) { - EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, - ss.str()); - return NewF; - } - llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; - llvm::errs() << *totrunc << "\n"; - llvm_unreachable("attempting to truncate function without definition"); - } - ValueToValueMapTy originalToNewFn; for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); @@ -5579,7 +5589,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - TruncateGenerator handle(originalToNewFn, from, to, totrunc, NewF, mode, + TruncateGenerator handle(originalToNewFn, truncation, totrunc, NewF, mode, *this); for (auto &BB : *totrunc) for (auto &I : BB) diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 1e9bf216b6e9..4bb61e94c8ef 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -42,6 +42,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" #include "ActivityAnalysis.h" #include "FunctionUtils.h" @@ -287,6 +288,17 @@ getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { } enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; +[[maybe_unused]] static const char *truncateModeStr(TruncateMode mode) { + switch (mode) { + case TruncMemMode: + return "mem"; + case TruncOpMode: + return "op"; + case TruncOpFullModuleMode: + return "op_full_module"; + } + llvm_unreachable("Invalid truncation mode"); +} struct FloatRepresentation { // |_|__________|_________________| @@ -336,6 +348,54 @@ struct FloatRepresentation { } }; +struct FloatTruncation { +private: + FloatRepresentation from, to; + +public: + FloatTruncation(FloatRepresentation From, FloatRepresentation To) + : from(From), to(To) { + if (!From.canBeBuiltin()) + llvm::report_fatal_error("Float truncation `from` type is not builtin."); + if (From.exponentWidth < To.exponentWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider exponent than `to`."); + if (From.significandWidth < To.significandWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider wsignificand than `to`."); + if (From == To) + llvm::report_fatal_error( + "Float truncation `from` and `to` type must not be the same."); + } + FloatRepresentation getTo() { return to; } + unsigned getFromTypeWidth() { return from.getTypeWidth(); } + unsigned getToTypeWidth() { return to.getTypeWidth(); } + llvm::Type *getFromType(llvm::LLVMContext &ctx) { + return from.getBuiltinType(ctx); + } + bool isToMPFR() { return !to.canBeBuiltin(); } + llvm::Type *getToType(llvm::LLVMContext &ctx) { + if (to.canBeBuiltin()) { + return to.getBuiltinType(ctx); + } else { + assert(isToMPFR()); + // Currently we do not support TruncMemMode for MPFR, and we provide + // runtime wrappers around MPFR for each builtin `from` type + return from.getBuiltinType(ctx); + } + } + bool operator==(const FloatTruncation &other) const { + return from == other.from && to == other.to; + } + bool operator<(const FloatTruncation &other) const { + return std::tuple(from, to) < std::tuple(other.from, other.to); + } + std::string mangleTruncation() const { + return from.to_string() + "to" + to.to_string(); + } + std::string mangleFrom() const { return from.to_string(); } +}; + class EnzymeLogic { public: PreProcessCache PPC; @@ -583,13 +643,13 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); - using TruncateCacheKey = std::tuple; + using TruncateCacheKey = + std::tuple; std::map TruncateCachedFunctions; llvm::Function *CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, - FloatRepresentation from, - FloatRepresentation to, TruncateMode mode); + FloatTruncation truncation, + TruncateMode mode); bool CreateTruncateValue(RequestContext context, llvm::Value *addr, FloatRepresentation from, FloatRepresentation to, bool isTruncate); diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index c96efa70660a..68f0ef473a9b 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -21,13 +21,14 @@ entry: %res = call i1 %ptr(double %x, double %y) ret i1 %res } +define i1 @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_op_func(i1 (double, double)* @f, i64 64, i64 3, i64 7) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} -; CHECK: define i1 @tester(double %x, double %y) { -; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) -; CHECK-NEXT: ret i1 %res - -; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -38,7 +39,7 @@ entry: ; CHECK-DAG: %res = fcmp olt float %3, %5 ; CHECK-DAG: ret i1 %res -; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float ; CHECK-DAG: %res = fcmp olt float %enzyme_trunc, %enzyme_trunc1 diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 99568539c3f3..2299c9fb1ab3 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -1,11 +1,13 @@ ; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi ; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi +declare double @pow(double %Val, double %Power) declare double @llvm.pow.f64(double %Val, double %Power) declare double @llvm.powi.f64.i16(double %Val, i16 %power) declare void @llvm.nvvm.barrier0() define double @f(double %x, double %y) { + %res0 = call double @pow(double %x, double %y) %res1 = call double @llvm.pow.f64(double %x, double %y) %res2 = call double @llvm.powi.f64.i16(double %x, i16 2) %res = fadd double %res1, %res2 @@ -22,62 +24,93 @@ entry: %res = call double %ptr(double %x, double %y) ret double %res } -define double @tester2(double %x, double %y) { +define double @tester_op(double %x, double %y) { entry: %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y) ret double %res } +define double @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 3, i64 7) + %res = call double %ptr(double %x, double %y) + ret double %res +} -; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { -; CHECK-NEXT: %1 = alloca double, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %2 = bitcast double* %1 to float* -; CHECK-NEXT: %3 = load float, float* %2, align 4 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %4 = bitcast double* %1 to float* -; CHECK-NEXT: %5 = load float, float* %4, align 4 -; CHECK-NEXT: %res11 = call float @llvm.pow.f32(float %3, float %5) -; CHECK-NEXT: %6 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %6, align 4 -; CHECK-NEXT: %7 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res11, float* %7, align 4 -; CHECK-NEXT: %8 = load double, double* %1, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %9 = bitcast double* %1 to float* -; CHECK-NEXT: %10 = load float, float* %9, align 4 -; CHECK-NEXT: %res22 = call float @llvm.powi.f32.i16(float %10, i16 2) -; CHECK-NEXT: %11 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %11, align 4 -; CHECK-NEXT: %12 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res22, float* %12, align 4 -; CHECK-NEXT: %13 = load double, double* %1, align 8 -; CHECK-NEXT: store double %8, double* %1, align 8 -; CHECK-NEXT: %14 = bitcast double* %1 to float* -; CHECK-NEXT: %15 = load float, float* %14, align 4 -; CHECK-NEXT: store double %13, double* %1, align 8 -; CHECK-NEXT: %16 = bitcast double* %1 to float* -; CHECK-NEXT: %17 = load float, float* %16, align 4 -; CHECK-NEXT: %res = fadd float %15, %17 -; CHECK-NEXT: %18 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %18, align 4 -; CHECK-NEXT: %19 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res, float* %19, align 4 -; CHECK-NEXT: %20 = load double, double* %1, align 8 -; CHECK-NEXT: call void @llvm.nvvm.barrier0() -; CHECK-NEXT: ret double %20 +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res01 = call float @llvm.pow.f32(float %3, float %5) +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %res01, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %9 = bitcast double* %1 to float* +; CHECK-DAG: %10 = load float, float* %9, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %11 = bitcast double* %1 to float* +; CHECK-DAG: %12 = load float, float* %11, align 4 +; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %10, float %12) +; CHECK-DAG: %13 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %13, align 4 +; CHECK-DAG: %14 = bitcast double* %1 to float* +; CHECK-DAG: store float %res12, float* %14, align 4 +; CHECK-DAG: %15 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %16 = bitcast double* %1 to float* +; CHECK-DAG: %17 = load float, float* %16, align 4 +; CHECK-DAG: %res23 = call float @llvm.powi.f32.i16(float %17, i16 2) +; CHECK-DAG: %18 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %18, align 4 +; CHECK-DAG: %19 = bitcast double* %1 to float* +; CHECK-DAG: store float %res23, float* %19, align 4 +; CHECK-DAG: %20 = load double, double* %1, align 8 +; CHECK-DAG: store double %15, double* %1, align 8 +; CHECK-DAG: %21 = bitcast double* %1 to float* +; CHECK-DAG: %22 = load float, float* %21, align 4 +; CHECK-DAG: store double %20, double* %1, align 8 +; CHECK-DAG: %23 = bitcast double* %1 to float* +; CHECK-DAG: %24 = load float, float* %23, align 4 +; CHECK-DAG: %res = fadd float %22, %24 +; CHECK-DAG: %25 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %25, align 4 +; CHECK-DAG: %26 = bitcast double* %1 to float* +; CHECK-DAG: store float %res, float* %26, align 4 +; CHECK-DAG: %27 = load double, double* %1, align 8 +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %27 -; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float -; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) -; CHECK-DAG: %enzyme_exp = fpext float %res12 to double +; CHECK-DAG: %res02 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) +; CHECK-DAG: %enzyme_exp = fpext float %res02 to double ; CHECK-DAG: %enzyme_trunc3 = fptrunc double %x to float -; CHECK-DAG: %res24 = call float @llvm.powi.f32.i16(float %enzyme_trunc3, i16 2) -; CHECK-DAG: %enzyme_exp5 = fpext float %res24 to double -; CHECK-DAG: %enzyme_trunc6 = fptrunc double %enzyme_exp to float -; CHECK-DAG: %enzyme_trunc7 = fptrunc double %enzyme_exp5 to float -; CHECK-DAG: %res = fadd float %enzyme_trunc6, %enzyme_trunc7 -; CHECK-DAG: %enzyme_exp8 = fpext float %res to double +; CHECK-DAG: %enzyme_trunc4 = fptrunc double %y to float +; CHECK-DAG: %res15 = call float @llvm.pow.f32(float %enzyme_trunc3, float %enzyme_trunc4) +; CHECK-DAG: %enzyme_exp6 = fpext float %res15 to double +; CHECK-DAG: %enzyme_trunc7 = fptrunc double %x to float +; CHECK-DAG: %res28 = call float @llvm.powi.f32.i16(float %enzyme_trunc7, i16 2) +; CHECK-DAG: %enzyme_exp9 = fpext float %res28 to double +; CHECK-DAG: %enzyme_trunc10 = fptrunc double %enzyme_exp6 to float +; CHECK-DAG: %enzyme_trunc11 = fptrunc double %enzyme_exp9 to float +; CHECK-DAG: %res = fadd float %enzyme_trunc10, %enzyme_trunc11 +; CHECK-DAG: %enzyme_exp12 = fpext float %res to double +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %enzyme_exp12 + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { +; CHECK-DAG: %1 = call double @__enzyme_mpfr_64_52_func_pow(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %2 = call double @__enzyme_mpfr_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %3 = call double @__enzyme_mpfr_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7) +; CHECK-DAG: %res = call double @__enzyme_mpfr_64_52_binop_fadd(double %2, double %3, i64 3, i64 7) ; CHECK-DAG: call void @llvm.nvvm.barrier0() -; CHECK-DAG: ret double %enzyme_exp8 +; CHECK-DAG: ret double %res +; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index 365d21ab5913..afc41219fed8 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -25,10 +25,10 @@ entry: ; CHECK: define double @tester(double %x, double %y, i1 %cond) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) +; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) ; CHECK-NEXT: ret double %res -; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -44,6 +44,6 @@ entry: ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: ret double %8 -; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %res = select i1 %cond, double %x, double %y ; CHECK-DAG: ret double %res diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 19d6cf1f3a23..a57f33fcdfdb 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -17,25 +17,20 @@ entry: call void %ptr(double* %data) ret void } - -define void @tester2(double* %data) { +define void @tester_op(double* %data) { entry: %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 32) call void %ptr(double* %data) ret void } +define void @tester_op_mpfr(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 3, i64 7) + call void %ptr(double* %data) + ret void +} -; CHECK: define void @tester(double* %data) -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %data) -; CHECK-NEXT: ret void - -; CHECK: define void @tester2(double* %data) { -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %data) -; CHECK-NEXT: ret void - -; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %x) +; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52to32_23_f(double* %x) ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: store double %y, double* %1, align 8 @@ -53,7 +48,7 @@ entry: ; CHECK-DAG: store double %8, double* %x, align 8 ; CHECK-DAG: ret void -; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %x) { +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: %enzyme_trunc = fptrunc double %y to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float @@ -61,3 +56,9 @@ entry: ; CHECK-DAG: %enzyme_exp = fpext float %m to double ; CHECK-DAG: store double %enzyme_exp, double* %x, align 8 ; CHECK-DAG: ret void + +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: %m = call double @__enzyme_mpfr_64_52_binop_fmul(double %y, double %y, i64 3, i64 7) +; CHECK-DAG: store double %m, double* %x, align 8 +; CHECK-DAG: ret void diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index ad5df438842f..39e5965bda0d 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -1,12 +1,26 @@ // Baseline -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli -)" == "900000000.560000" ] ; fi + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli - | FileCheck --check-prefix BASELINE %s; fi +// BASELINE: 900000000.560000 + // Truncated -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli -)" == "900000000.000000" ] ; fi -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli -)" == "900000000.000000" ] ; fi + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli - | FileCheck --check-prefix TO_32 %s; fi +// TO_32: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli - | FileCheck --check-prefix TO_28_23 %s; fi +// TO_28_23: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DENZYME_TEST_TO_MPFR -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr; %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi +// TO_3_7: 897581056.000000 #include +#ifdef ENZYME_TEST_TO_MPFR +#include +#endif + #include "../test_utils.h" #define N 10 diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 0b8a0f831d6e..0cc5e6f28f38 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -16,6 +16,10 @@ config.llvm_shlib_ext = "@LLVM_SHLIBEXT@" config.targets_to_build = "@TARGETS_TO_BUILD@" +has_mpfr_h = "@HAS_MPFR_H@" +mpfr_lib_path = "@MPFR_LIB_PATH@" +has_mpfr = "yes" if mpfr_lib_path != "MPFR_LIB_PATH-NOTFOUND" and has_mpfr_h == "1" else "no" + ## Check the current platform with regex import re EAT_ERR_ON_X86 = ' ' @@ -112,6 +116,8 @@ if len("@ENZYME_BINARY_DIR@") == 0: config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < 15 else newPM)) config.substitutions.append(('%newLoadClangEnzyme', newPM)) +config.substitutions.append(('%hasMPFR', has_mpfr)) + # Let the main config do the real work. cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" if len("@ENZYME_SOURCE_DIR@") == 0: