Skip to content

Commit

Permalink
Switch to inline header
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanradanov committed Feb 28, 2024
1 parent 7787b85 commit f7fc98c
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 192 deletions.
5 changes: 1 addition & 4 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,7 @@ gentbl(
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/Clang/include_utils.td",
td_srcs = [
"Enzyme/Clang/include_utils.td",
"Enzyme/Runtime/MPFR.cpp",
],
td_srcs = ["Enzyme/Clang/include_utils.td"],
deps = [
":enzyme-tblgen",
],
Expand Down
7 changes: 1 addition & 6 deletions enzyme/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,8 @@ add_public_tablegen_target(BlasTAIncGen)
add_public_tablegen_target(BlasDiffUseIncGen)

set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td)
# Need to explicitly set included files as dependencies
set(ARG_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/Runtime/MPFR.cpp" CACHE INTERNAL "deps")
# Cmake tablegen adds the current cmake dir to the include path and bazel adds
# the directory that contains the .td file, that's why we need the include here
enzyme_tablegen(IncludeUtils.inc -gen-header-strings -I${CMAKE_CURRENT_SOURCE_DIR}/Clang/)
enzyme_tablegen(IncludeUtils.inc -gen-header-strings)
add_public_tablegen_target(IncludeUtilsIncGen)
unset(ARG_DEPENDS)

include_directories(${CMAKE_CURRENT_BINARY_DIR})

Expand Down
140 changes: 129 additions & 11 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
class InlineHeader<string filename_, string contents_> {
class Headers<string filename_, string contents_> {
string filename = filename_;
string contents = contents_;
}

class FileHeader<string filename_out_, string filename_in_> {
string filename_out = filename_out_;
string filename_in = filename_in_;
}

def : InlineHeader<"/enzymeroot/enzyme/utils", [{
def : Headers<"/enzymeroot/enzyme/utils", [{
#pragma once

extern int enzyme_dup;
Expand Down Expand Up @@ -268,7 +263,7 @@ namespace enzyme {
}
}]>;

def : InlineHeader<"/enzymeroot/enzyme/type_traits", [{
def : Headers<"/enzymeroot/enzyme/type_traits", [{
#pragma once

#include <type_traits>
Expand Down Expand Up @@ -317,7 +312,7 @@ namespace impl {
}
}]>;

def : InlineHeader<"/enzymeroot/enzyme/tuple", [{
def : Headers<"/enzymeroot/enzyme/tuple", [{
#pragma once

/////////////
Expand Down Expand Up @@ -454,12 +449,135 @@ constexpr auto tuple_cat(Tuples&&... tuples) {
#undef _NOEXCEPT
}]>;

def : InlineHeader<"/enzymeroot/enzyme/enzyme", [{
def : Headers<"/enzymeroot/enzyme/enzyme", [{
#ifdef __cplusplus
#include "enzyme/utils"
#else
#warning "Enzyme wrapper templates only available in C++"
#endif
}]>;

def : FileHeader<"/enzymeroot/enzyme/mpfr", "../Runtime/MPFR.cpp">;
def : Headers<"/enzymeroot/enzyme/mpfr", [{
//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains easy to use wrappers around MPFR functions.
//
//===----------------------------------------------------------------------===//
#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__
#define __ENZYME_RUNTIME_ENZYME_MPFR__

#include <mpfr.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

// TODO s
//
// (for MPFR ver. 2.1)
//
// We need to set the range of the allowed exponent using `mpfr_set_emin` and
// `mpfr_set_emax`. (This means we can also play with whether the range is
// centered around 0 (1?) or somewhere else)
//
// (also these need to be mutex'ed as the exponent change is global in mpfr and
// not float-specific) ... (mpfr seems to have thread safe mode - check if it is
// enabled or if it is enabled by default)
//
// For that we need to do this check:
// If the user changes the exponent range, it is her/his responsibility to
// check that all current floating-point variables are in the new allowed
// range (for example using mpfr_check_range), otherwise the subsequent
// behavior will be undefined, in the sense of the ISO C standard.
//
// MPFR docs state the following:
// Note: Overflow handling is still experimental and currently implemented
// partially. If an overflow occurs internally at the wrong place, anything
// can happen (crash, wrong results, etc).
//
// Which we would like to avoid somehow.
//
// MPFR also has this limitation that we need to address for accurate
// simulation:
// [...] subnormal numbers are not implemented.
//

#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \
RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \
ROUNDING_MODE) \
__attribute__((weak)) \
RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \
ARG1 a, int64_t exponent, int64_t significand) { \
mpfr_t ma, mc; \
mpfr_init2(ma, significand); \
mpfr_init2(mc, significand); \
mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \
mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \
RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \
mpfr_clear(ma); \
mpfr_clear(mc); \
return c; \
}

#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \
RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \
MPFR_SET_ARG2, ROUNDING_MODE) \
__attribute__((weak)) \
RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \
ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \
mpfr_t ma, mb, mc; \
mpfr_init2(ma, significand); \
mpfr_init2(mb, significand); \
mpfr_init2(mc, significand); \
mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \
mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \
mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \
RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \
mpfr_clear(ma); \
mpfr_clear(mb); \
mpfr_clear(mc); \
return c; \
}

#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN
#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \
ROUNDING_MODE) \
__ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \
double, d, double, d, ROUNDING_MODE)
#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \
MPFR_FUNC_NAME) \
__ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \
__ENZYME_MPFR_DEFAULT_ROUNDING_MODE)

__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul)
__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add)
__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div)

__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d,
__ENZYME_MPFR_DEFAULT_ROUNDING_MODE)

#ifdef __cplusplus
}
#endif

#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__
}]>;
122 changes: 0 additions & 122 deletions enzyme/Enzyme/Runtime/MPFR.cpp

This file was deleted.

59 changes: 10 additions & 49 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
Expand Down Expand Up @@ -1253,56 +1252,18 @@ void printDiffUse(

static void emitHeaderIncludes(const RecordKeeper &recordKeeper,
raw_ostream &os) {
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers");
os << "const char* include_headers[][2] = {\n";
bool seen = false;
{
const auto &patterns =
recordKeeper.getAllDerivedDefinitions("InlineHeader");
for (Record *pattern : patterns) {
if (seen)
os << ",\n";
auto filename = pattern->getValueAsString("filename");
auto contents = pattern->getValueAsString("contents");
os << "{\"" << filename << "\"\n,";
os << "R\"(" << contents << ")\"\n";
os << "}";
seen = true;
}
}
{
const auto &patterns = recordKeeper.getAllDerivedDefinitions("FileHeader");
for (Record *pattern : patterns) {
if (seen)
os << ",\n";
auto filename_out = pattern->getValueAsString("filename_out");
std::string filename_in = pattern->getValueAsString("filename_in").str();
std::string included_file;
#if LLVM_VERSION_MAJOR >= 15
auto contents_or_err =
llvm::SrcMgr.OpenIncludeFile(filename_in, included_file);
if (!contents_or_err)
PrintFatalError(pattern->getLoc(),
Twine("Could not read file ") + filename_in);
auto &contents = contents_or_err.get();
#else
auto buf =
llvm::SrcMgr.AddIncludeFile(filename_in,
#if LLVM_VERSION_MAJOR >= 12
pattern->getFieldLoc("filename_in"),
#else
SMLoc::getFromPointer(nullptr),
#endif
included_file);
if (!buf)
PrintFatalError(pattern->getLoc(),
Twine("Could not read file ") + filename_in);
auto contents = llvm::SrcMgr.getMemoryBuffer(buf);
#endif
os << "{\"" << filename_out << "\"\n,";
os << "R\"(" << contents->getBuffer() << ")\"\n";
os << "}";
seen = true;
}
for (Record *pattern : patterns) {
if (seen)
os << ",\n";
auto filename = pattern->getValueAsString("filename");
auto contents = pattern->getValueAsString("contents");
os << "{\"" << filename << "\"\n,";
os << "R\"(" << contents << ")\"\n";
os << "}";
seen = true;
}
os << "};\n";
}
Expand Down

0 comments on commit f7fc98c

Please sign in to comment.