Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trucation to MPFR #1750

Merged
merged 27 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ccpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
- name: add llvm
run: |
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
sudo apt-get install -y libmpfr-dev
sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true
sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev lld-${{ matrix.llvm }} clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev libzstd-dev
sudo python3 -m pip install --upgrade pip lit
Expand Down
7 changes: 7 additions & 0 deletions enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13)
project(Enzyme)

include(CMakePackageConfigHelpers)
include(CheckIncludeFile)
include(CheckIncludeFileCXX)

set(ENZYME_MAJOR_VERSION 0)
set(ENZYME_MINOR_VERSION 0)
Expand Down Expand Up @@ -265,6 +267,11 @@ string(REPLACE "};\n}" "};\n}}" INPUT_TEXT "${INPUT_TEXT}")
string(REPLACE "const SCEV* S;\n};\n" "const SCEV* S;\n};\n}\n" INPUT_TEXT "${INPUT_TEXT}")
endif()

find_library(MPFR_LIB_PATH mpfr)
ivanradanov marked this conversation as resolved.
Show resolved Hide resolved
CHECK_INCLUDE_FILE("mpfr.h" HAS_MPFR_H)
message("MPFR lib: " ${MPFR_LIB_PATH})
message("MPFR header: " ${HAS_MPFR_H})

file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolutionExpander.h" "${INPUT_TEXT}")

include_directories("${CMAKE_CURRENT_BINARY_DIR}/include")
Expand Down
125 changes: 125 additions & 0 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,128 @@ def : Headers<"/enzymeroot/enzyme/enzyme", [{
#warning "Enzyme wrapper templates only available in C++"
#endif
}]>;

def : Headers<"/enzymeroot/enzyme/mpfr", [{
//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains easy to use wrappers around MPFR functions.
//
//===----------------------------------------------------------------------===//
#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__
#define __ENZYME_RUNTIME_ENZYME_MPFR__

#include <mpfr.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

// TODO s
//
// (for MPFR ver. 2.1)
//
// We need to set the range of the allowed exponent using `mpfr_set_emin` and
// `mpfr_set_emax`. (This means we can also play with whether the range is
// centered around 0 (1?) or somewhere else)
//
// (also these need to be mutex'ed as the exponent change is global in mpfr and
// not float-specific) ... (mpfr seems to have thread safe mode - check if it is
// enabled or if it is enabled by default)
//
// For that we need to do this check:
// If the user changes the exponent range, it is her/his responsibility to
// check that all current floating-point variables are in the new allowed
// range (for example using mpfr_check_range), otherwise the subsequent
// behavior will be undefined, in the sense of the ISO C standard.
//
// MPFR docs state the following:
// Note: Overflow handling is still experimental and currently implemented
// partially. If an overflow occurs internally at the wrong place, anything
// can happen (crash, wrong results, etc).
//
// Which we would like to avoid somehow.
//
// MPFR also has this limitation that we need to address for accurate
// simulation:
// [...] subnormal numbers are not implemented.
//

#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \
RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \
ROUNDING_MODE) \
__attribute__((weak)) \
RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \
ARG1 a, int64_t exponent, int64_t significand) { \
mpfr_t ma, mc; \
mpfr_init2(ma, significand); \
mpfr_init2(mc, significand); \
mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \
mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \
RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \
mpfr_clear(ma); \
mpfr_clear(mc); \
return c; \
}

#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \
RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \
MPFR_SET_ARG2, ROUNDING_MODE) \
__attribute__((weak)) \
RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \
ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \
mpfr_t ma, mb, mc; \
mpfr_init2(ma, significand); \
mpfr_init2(mb, significand); \
mpfr_init2(mc, significand); \
mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \
mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \
mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \
RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \
mpfr_clear(ma); \
mpfr_clear(mb); \
mpfr_clear(mc); \
return c; \
}

#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN
#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \
ROUNDING_MODE) \
__ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \
double, d, double, d, ROUNDING_MODE)
#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \
MPFR_FUNC_NAME) \
__ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \
__ENZYME_MPFR_DEFAULT_ROUNDING_MODE)

__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul)
__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add)
__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div)

__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d,
__ENZYME_MPFR_DEFAULT_ROUNDING_MODE)

#ifdef __cplusplus
}
#endif

#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__
}]>;
51 changes: 34 additions & 17 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Transforms/Scalar.h"

#include "llvm/Analysis/BasicAliasAnalysis.h"
Expand Down Expand Up @@ -1339,21 +1340,40 @@ class EnzymeBase {
Function *F = parseFunctionParameter(CI);
if (!F)
return false;
if (CI->arg_size() != 3) {
unsigned ArgSize = CI->arg_size();
if (ArgSize != 4 && ArgSize != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate_func", *CI,
" - expected 3");
" - expected 3 or 4");
return false;
}
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
FloatTruncation truncation = [&]() -> FloatTruncation {
if (ArgSize == 3) {
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
return FloatTruncation(
getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()));
} else if (ArgSize == 4) {
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto_exponent = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto_exponent);
auto Cto_significand = cast<ConstantInt>(CI->getArgOperand(3));
assert(Cto_significand);
return FloatTruncation(
getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
FloatRepresentation(
(unsigned)Cto_exponent->getValue().getZExtValue(),
(unsigned)Cto_significand->getValue().getZExtValue()));
}
llvm_unreachable("??");
}();

RequestContext context(CI, &Builder);
llvm::Value *res = Logic.CreateTruncateFunc(
context, F,
getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), mode);
llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode);
if (!res)
return false;
res = Builder.CreatePointerCast(res, CI->getType());
Expand Down Expand Up @@ -2052,14 +2072,12 @@ class EnzymeBase {
}

bool handleFullModuleTrunc(Function &F) {
typedef std::vector<std::pair<FloatRepresentation, FloatRepresentation>>
TruncationsTy;
typedef std::vector<FloatTruncation> TruncationsTy;
static TruncationsTy FullModuleTruncs = []() -> TruncationsTy {
StringRef ConfigStr(EnzymeTruncateAll);
auto Invalid = [=]() {
// TODO emit better diagnostic
llvm::errs() << "error: invalid format for truncation config\n";
abort();
llvm::report_fatal_error("error: invalid format for truncation config");
};

// "64" or "11-52"
Expand Down Expand Up @@ -2102,9 +2120,8 @@ class EnzymeBase {
for (auto Truncation : FullModuleTruncs) {
IRBuilder<> Builder(F.getContext());
RequestContext context(&*F.getEntryBlock().begin(), &Builder);
Function *TruncatedFunc =
Logic.CreateTruncateFunc(context, &F, Truncation.first,
Truncation.second, TruncOpFullModuleMode);
Function *TruncatedFunc = Logic.CreateTruncateFunc(
context, &F, Truncation, TruncOpFullModuleMode);

ValueToValueMapTy Mapping;
for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args()))
Expand Down
Loading
Loading