Skip to content

Commit

Permalink
Trucation to MPFR (#1750)
Browse files Browse the repository at this point in the history
* WIP MPFR truncation

* MPFR truncation

* Fix mpfr function mangling

* Mangling

* MPFR Wrappers

* clang-format

* Make header work in C

* File header

* Make it compile on llvm 11

* header

* fix tests

* Add TODO comment

* more comments

* MPFR header fix

* Add mpfr test

* Move mpfr runtime

* Add another type of include header

* fix tests

* clang-format

* Check for MPFR

* Fix older llvm vers

* llvm 11

* .

* WIP deps

* Proper include

* Dep

* Switch to inline header
  • Loading branch information
ivanradanov authored Feb 29, 2024
1 parent 070601e commit b97aa9d
Show file tree
Hide file tree
Showing 12 changed files with 520 additions and 245 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ccpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
- name: add llvm
run: |
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
sudo apt-get install -y libmpfr-dev
sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true
sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev lld-${{ matrix.llvm }} clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev libzstd-dev
sudo python3 -m pip install --upgrade pip lit
Expand Down
7 changes: 7 additions & 0 deletions enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13)
project(Enzyme)

include(CMakePackageConfigHelpers)
include(CheckIncludeFile)
include(CheckIncludeFileCXX)

set(ENZYME_MAJOR_VERSION 0)
set(ENZYME_MINOR_VERSION 0)
Expand Down Expand Up @@ -265,6 +267,11 @@ string(REPLACE "};\n}" "};\n}}" INPUT_TEXT "${INPUT_TEXT}")
string(REPLACE "const SCEV* S;\n};\n" "const SCEV* S;\n};\n}\n" INPUT_TEXT "${INPUT_TEXT}")
endif()

find_library(MPFR_LIB_PATH mpfr)
CHECK_INCLUDE_FILE("mpfr.h" HAS_MPFR_H)
message("MPFR lib: " ${MPFR_LIB_PATH})
message("MPFR header: " ${HAS_MPFR_H})

file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolutionExpander.h" "${INPUT_TEXT}")

include_directories("${CMAKE_CURRENT_BINARY_DIR}/include")
Expand Down
125 changes: 125 additions & 0 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,128 @@ def : Headers<"/enzymeroot/enzyme/enzyme", [{
#warning "Enzyme wrapper templates only available in C++"
#endif
}]>;

def : Headers<"/enzymeroot/enzyme/mpfr", [{
//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains easy to use wrappers around MPFR functions.
//
//===----------------------------------------------------------------------===//
#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__
#define __ENZYME_RUNTIME_ENZYME_MPFR__

#include <mpfr.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

// TODO s
//
// (for MPFR ver. 2.1)
//
// We need to set the range of the allowed exponent using `mpfr_set_emin` and
// `mpfr_set_emax`. (This means we can also play with whether the range is
// centered around 0 (1?) or somewhere else)
//
// (also these need to be mutex'ed as the exponent change is global in mpfr and
// not float-specific) ... (mpfr seems to have thread safe mode - check if it is
// enabled or if it is enabled by default)
//
// For that we need to do this check:
// If the user changes the exponent range, it is her/his responsibility to
// check that all current floating-point variables are in the new allowed
// range (for example using mpfr_check_range), otherwise the subsequent
// behavior will be undefined, in the sense of the ISO C standard.
//
// MPFR docs state the following:
// Note: Overflow handling is still experimental and currently implemented
// partially. If an overflow occurs internally at the wrong place, anything
// can happen (crash, wrong results, etc).
//
// Which we would like to avoid somehow.
//
// MPFR also has this limitation that we need to address for accurate
// simulation:
// [...] subnormal numbers are not implemented.
//

#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \
RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \
ROUNDING_MODE) \
__attribute__((weak)) \
RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \
ARG1 a, int64_t exponent, int64_t significand) { \
mpfr_t ma, mc; \
mpfr_init2(ma, significand); \
mpfr_init2(mc, significand); \
mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \
mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \
RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \
mpfr_clear(ma); \
mpfr_clear(mc); \
return c; \
}

#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \
RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \
MPFR_SET_ARG2, ROUNDING_MODE) \
__attribute__((weak)) \
RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \
ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \
mpfr_t ma, mb, mc; \
mpfr_init2(ma, significand); \
mpfr_init2(mb, significand); \
mpfr_init2(mc, significand); \
mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \
mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \
mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \
RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \
mpfr_clear(ma); \
mpfr_clear(mb); \
mpfr_clear(mc); \
return c; \
}

#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN
#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \
ROUNDING_MODE) \
__ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \
double, d, double, d, ROUNDING_MODE)
#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \
MPFR_FUNC_NAME) \
__ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \
__ENZYME_MPFR_DEFAULT_ROUNDING_MODE)

__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul)
__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add)
__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div)

__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d,
__ENZYME_MPFR_DEFAULT_ROUNDING_MODE)

#ifdef __cplusplus
}
#endif

#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__
}]>;
51 changes: 34 additions & 17 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Transforms/Scalar.h"

#include "llvm/Analysis/BasicAliasAnalysis.h"
Expand Down Expand Up @@ -1339,21 +1340,40 @@ class EnzymeBase {
Function *F = parseFunctionParameter(CI);
if (!F)
return false;
if (CI->arg_size() != 3) {
unsigned ArgSize = CI->arg_size();
if (ArgSize != 4 && ArgSize != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate_func", *CI,
" - expected 3");
" - expected 3 or 4");
return false;
}
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
FloatTruncation truncation = [&]() -> FloatTruncation {
if (ArgSize == 3) {
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
return FloatTruncation(
getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()));
} else if (ArgSize == 4) {
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto_exponent = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto_exponent);
auto Cto_significand = cast<ConstantInt>(CI->getArgOperand(3));
assert(Cto_significand);
return FloatTruncation(
getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
FloatRepresentation(
(unsigned)Cto_exponent->getValue().getZExtValue(),
(unsigned)Cto_significand->getValue().getZExtValue()));
}
llvm_unreachable("??");
}();

RequestContext context(CI, &Builder);
llvm::Value *res = Logic.CreateTruncateFunc(
context, F,
getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()),
getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), mode);
llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode);
if (!res)
return false;
res = Builder.CreatePointerCast(res, CI->getType());
Expand Down Expand Up @@ -2052,14 +2072,12 @@ class EnzymeBase {
}

bool handleFullModuleTrunc(Function &F) {
typedef std::vector<std::pair<FloatRepresentation, FloatRepresentation>>
TruncationsTy;
typedef std::vector<FloatTruncation> TruncationsTy;
static TruncationsTy FullModuleTruncs = []() -> TruncationsTy {
StringRef ConfigStr(EnzymeTruncateAll);
auto Invalid = [=]() {
// TODO emit better diagnostic
llvm::errs() << "error: invalid format for truncation config\n";
abort();
llvm::report_fatal_error("error: invalid format for truncation config");
};

// "64" or "11-52"
Expand Down Expand Up @@ -2102,9 +2120,8 @@ class EnzymeBase {
for (auto Truncation : FullModuleTruncs) {
IRBuilder<> Builder(F.getContext());
RequestContext context(&*F.getEntryBlock().begin(), &Builder);
Function *TruncatedFunc =
Logic.CreateTruncateFunc(context, &F, Truncation.first,
Truncation.second, TruncOpFullModuleMode);
Function *TruncatedFunc = Logic.CreateTruncateFunc(
context, &F, Truncation, TruncOpFullModuleMode);

ValueToValueMapTy Mapping;
for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args()))
Expand Down
Loading

0 comments on commit b97aa9d

Please sign in to comment.