From 63ac32353411a26e5558209db15ec361e7c83718 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 21 Feb 2024 22:38:29 -0500 Subject: [PATCH] Allow custom importing of files and syntactic sugar --- enzyme/BUILD | 15 + enzyme/Enzyme/CMakeLists.txt | 6 + enzyme/Enzyme/Clang/EnzymeClang.cpp | 35 ++ enzyme/Enzyme/Clang/include_utils.td | 319 ++++++++++++++++++ enzyme/Enzyme/Enzyme.cpp | 255 +++++++++++++- enzyme/test/Integration/ReverseMode/sugar.cpp | 40 +++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 29 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.h | 1 + 8 files changed, 694 insertions(+), 6 deletions(-) create mode 100644 enzyme/Enzyme/Clang/include_utils.td create mode 100644 enzyme/test/Integration/ReverseMode/sugar.cpp diff --git a/enzyme/BUILD b/enzyme/BUILD index e582b435d387..dc36cb4ad0c5 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -142,6 +142,20 @@ gentbl( ], ) +gentbl( + name = "include-utils", + tbl_outs = [( + "-gen-header-strings", + "IncludeUtils.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/Clang/include_utils.td", + td_srcs = ["Enzyme/Clang/include_utils.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeStatic", srcs = glob( @@ -167,6 +181,7 @@ cc_library( data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], deps = [ + "include-utils", ":binop-derivatives", ":blas-attributor", ":blas-derivatives", diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 1cd6e84c5be1..b27e4beb08cd 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -37,6 +37,10 @@ add_public_tablegen_target(BlasDeclarationsIncGen) add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) +set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) +enzyme_tablegen(IncludeUtils.inc -gen-header-strings) +add_public_tablegen_target(IncludeUtilsIncGen) + include_directories(${CMAKE_CURRENT_BINARY_DIR}) set(LLVM_LINK_COMPONENTS Demangle) @@ -74,6 +78,7 @@ if (${Clang_FOUND}) LLVM ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp @@ -107,6 +112,7 @@ if (${Clang_FOUND}) clang ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index a34a6429dcf7..dfb41c45a290 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -35,6 +35,8 @@ #include "../Utils.h" +#include "IncludeUtils.inc" + using namespace clang; #if LLVM_VERSION_MAJOR >= 18 @@ -134,6 +136,39 @@ class EnzymePlugin final : public clang::ASTConsumer { Builder.defineMacro("ENZYME_VERSION_PATCH", std::to_string(ENZYME_VERSION_PATCH)); CI.getPreprocessor().setPredefines(Predefines.str()); + + auto baseFS = CI.getFileManager().getVirtualFileSystemPtr(); + llvm::vfs::OverlayFileSystem *fuseFS( + new llvm::vfs::OverlayFileSystem(baseFS)); + IntrusiveRefCntPtr fs( + new llvm::vfs::InMemoryFileSystem()); + + struct tm y2k = {}; + + y2k.tm_hour = 0; + y2k.tm_min = 0; + y2k.tm_sec = 0; + y2k.tm_year = 100; + y2k.tm_mon = 0; + y2k.tm_mday = 1; + time_t timer = mktime(&y2k); + for (const auto &pair : include_headers) { + fs->addFile(StringRef(pair[0]), timer, + llvm::MemoryBuffer::getMemBuffer( + StringRef(pair[1]), StringRef(pair[0]), + /*RequiresNullTerminator*/ true)); + } + + fuseFS->pushOverlay(fs); + fuseFS->pushOverlay(baseFS); + CI.getFileManager().setVirtualFileSystem(fuseFS); + + auto DE = CI.getFileManager().getDirectoryRef("/enzymeroot"); + assert(DE); + auto DL = DirectoryLookup(*DE, SrcMgr::C_User, + /*isFramework=*/false); + CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL, + /*isAngled=*/true); } ~EnzymePlugin() {} void HandleTranslationUnit(ASTContext &context) override {} diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td new file mode 100644 index 000000000000..71d461d09d55 --- /dev/null +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -0,0 +1,319 @@ +class Headers { + string filename = filename_; + string contents = contents_; +} + +def : Headers<"/enzymeroot/enzyme/utils", [{ +#pragma once + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +template +Return __enzyme_autodiff(T...); + +template +Return __enzyme_fwddiff(T...); + +#include + +namespace enzyme { + + enum ReturnActivity{ + INACTIVE, + ACTIVE, + DUPLICATED + }; + + struct nodiff{}; + + template < typename T > + struct active{ + T value; + operator T&() { return value; } + }; + + template < typename T > + active(T) -> active; + + template < typename T > + struct duplicated{ + T value; + T shadow; + }; + + template < typename T > + struct inactive{ + T value; + }; + + template < typename T > + struct type_info { + static constexpr bool is_active = false; + + #ifdef ENZYME_OMIT_INACTIVE + using type = tuple<>; + #else + using type = tuple; + #endif + }; + + template < typename T > + struct type_info < active >{ + static constexpr bool is_active = true; + using type = tuple; + }; + + template < typename ... T > + struct concatenated; + + template < typename ... S, typename ... T, typename ... rest > + struct concatenated < tuple < S ... >, tuple < T ... >, rest ... > { + using type = typename concatenated< tuple< S ..., T ... >, rest ... >::type; + }; + + template < typename ... T > + struct concatenated < tuple < T ... > > { + using type = tuple< T ... >; + }; + + template < typename T > + struct concatenated < tuple > { + using type = T; + }; + + // Yikes! + // slightly cleaner in C++20, with std::remove_cvref + template < typename ... T > + struct autodiff_return { + using type = typename concatenated< + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type; + }; + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::duplicated & arg) { + return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::active & arg) { + return enzyme::tuple{enzyme_out, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::inactive & arg) { + return enzyme::tuple{enzyme_const, arg.value}; + } + + namespace detail { + template + __attribute__((always_inline)) + constexpr decltype(auto) apply_impl(void* f, Tuple&& t, std::index_sequence) { + return __enzyme_autodiff(f, enzyme::get(std::forward(t))...); + } + } // namespace detail + + template < typename return_type, typename function, typename ... enz_arg_types > + __attribute__((always_inline)) + auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::apply_impl((void*)f, std::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using return_type = typename autodiff_return::type; + return autodiff_impl(std::forward(f), enzyme::tuple_cat(expand_args(args)...)); + } +} +}]>; + +def : Headers<"/enzymeroot/enzyme/type_traits", [{ +#pragma once + +#include + +namespace enzyme { + +// this is already in C++20, but we reimplement it here for older C++ versions +template < typename T > +struct remove_cvref { + using type = + typename std::remove_reference< + typename std::remove_cv< + T + >::type + >::type; +}; + +template < typename T > +using remove_cvref_t = typename remove_cvref::type; + +} +}]>; + +def : Headers<"/enzymeroot/enzyme/tuple", [{ +#pragma once + +///////////// +// tuple.h // +///////////// + +// why reinvent the wheel and implement a tuple class? +// - ensure data is laid out in the same order the types are specified +// see: https://github.com/EnzymeAD/Enzyme/issues/1191#issuecomment-1556239213 +// - CUDA compatibility: std::tuple has some compatibility issues when used +// in a __device__ context (this may get better in c++20 with the improved +// constexpr support for std::tuple). Owning the implementation lets +// us add __host__ __device__ annotations to any part of it + +#include // for std::integer_sequence + +#include + +#define _NOEXCEPT noexcept +namespace enzyme { + +template +struct Index {}; + +template +struct value_at_position { + __attribute__((always_inline)) + T & operator[](Index) { return value; } + + __attribute__((always_inline)) + constexpr const T & operator[](Index) const { return value; } + T value; +}; + +template +struct tuple_base; + +template +struct tuple_base, T...> + : public value_at_position... { + using value_at_position::operator[]...; +}; + +template +struct tuple : public tuple_base, T...> {}; + +template +tuple(T ...) -> tuple; + +template < int i, typename Tuple > +__attribute__((always_inline)) +decltype(auto) get(Tuple && tup) { + constexpr bool is_lvalue = std::is_lvalue_reference_v; + constexpr bool is_const = std::is_const_v>; + using T = remove_cvref_t< decltype(tup[Index{ } ]) >; + if constexpr ( is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr ( is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } +} + +template < int i, typename ... T> +__attribute__((always_inline)) +decltype(auto) get(const tuple< T ... > & tup) { + return tup[Index{} ]; +} + +template +struct tuple_size; + +template +struct tuple_size> : std::integral_constant {}; + +template +static constexpr size_t tuple_size_v = tuple_size::value; + +template +__attribute__((always_inline)) +auto forward(std::remove_reference_t& arg) _NOEXCEPT { + return static_cast(arg); +} + +template +__attribute__((always_inline)) +auto forward(std::remove_reference_t&& arg) _NOEXCEPT { + static_assert(!std::is_lvalue_reference::value, "cannot forward an rvalue as an lvalue"); + return static_cast(arg); +} + +template +__attribute__((always_inline)) +constexpr auto forward_as_tuple(T&&... args) noexcept { + return tuple{forward(args)...}; +} + +namespace impl { + +template +struct make_tuple_from_fwd_tuple; + +template +struct make_tuple_from_fwd_tuple> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd) { + return tuple{get(forward(fwd))...}; + } +}; + +template +struct concat_with_fwd_tuple; + +template < typename Tuple > +using iseq = std::make_index_sequence > >; + +template +struct concat_with_fwd_tuple, std::index_sequence> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) { + return forward_as_tuple(get(forward(fwd))..., get(std::forward(t))...); + } +}; + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(Tuple&& ret) { + return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(forward< Tuple >(ret)); +} + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(FWD_TUPLE&& fwd, first&& t, rest&&... ts) { + return tuple_cat(concat_with_fwd_tuple< iseq, iseq >::f(forward(fwd), std::forward(t)), std::forward(ts)...); +} + +} // namespace impl + +template +__attribute__((always_inline)) +constexpr auto tuple_cat(Tuples&&... tuples) { + return impl::tuple_cat(std::forward(tuples)...); +} + +} // namespace enzyme +#undef _NOEXCEPT +}]>; + +def : Headers<"/enzymeroot/enzyme/enzyme", [{ +#ifdef __cplusplus +#include "enzyme/utils" +#else +#warning "Enzyme wrapper templates only available in C++" +#endif +}]>; \ No newline at end of file diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 055b6f394842..4141b539fe76 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -429,12 +429,258 @@ static Optional recursePhiReads(PHINode *val) return finalMetadata; } +Value *simplifyLoad(Value *LI, size_t valSz = 0); + +// Find the base pointer of ptr and the offset in bytes from the start of +// the returned base pointer to this value. +AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) { + offset = 0; + while (true) { + if (auto CI = dyn_cast(ptr)) { + ptr = CI->getOperand(0); + continue; + } + if (auto CI = dyn_cast(ptr)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + return nullptr; + } + offset += Offset.getZExtValue(); + ptr = CI->getOperand(0); + continue; + } + if (isa(ptr)) { + break; + } + if (auto LI = dyn_cast(ptr)) { + if (auto S = simplifyLoad(LI)) { + ptr = S; + continue; + } + } + return nullptr; + } + return cast(ptr); +} + +// Find all user instructions of AI, returning tuples of Unlike a simple get users, this will recurse through any +// constant gep offsets and casts +SmallVector, 1> +findAllUsersOf(Value *AI) { + SmallVector, 1> todo; + todo.emplace_back(AI, 0); + + SmallVector, 1> users; + while (todo.size()) { + auto pair = todo.pop_back_val(); + Value *ptr = pair.first; + size_t suboff = pair.second; + + for (auto U : ptr->users()) { + if (auto CI = dyn_cast(U)) { + todo.emplace_back(CI, suboff); + continue; + } + if (auto CI = dyn_cast(U)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + users.emplace_back(cast(U), ptr, suboff); + continue; + } + todo.emplace_back(CI, suboff + Offset.getZExtValue()); + continue; + } + users.emplace_back(cast(U), ptr, suboff); + continue; + } + } + return users; +} + +// Given a pointer, find all values of size `valSz` which could be loaded from +// that pointer when indexed at offset. If it is impossible to guarantee that +// the set contains all such values, set legal to false +SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, + size_t valSz, bool &legal) { + SmallVector options; + + auto todo = findAllUsersOf(ptr0); + std::set> seen; + + while (todo.size()) { + auto pair = todo.pop_back_val(); + if (seen.count(pair)) + continue; + seen.insert(pair); + Instruction *U = std::get<0>(pair); + Value *ptr = std::get<1>(pair); + size_t suboff = std::get<2>(pair); + + // Read only users do not set the memory inside of ptr + if (isa(U)) { + continue; + } + if (auto MTI = dyn_cast(U)) + if (MTI->getOperand(0) != ptr) { + continue; + } + if (auto I = dyn_cast(U)) { + if (!I->mayWriteToMemory() && I->getType()->isVoidTy()) + continue; + } + + if (auto SI = dyn_cast(U)) { + auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout(); + + // We are storing into the ptr + if (SI->getPointerOperand() == ptr) { + auto storeSz = + (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) / + 8; + // If store is before the load would start + if (storeSz + suboff <= offset) + continue; + // if store starts after load would start + if (offset + valSz <= suboff) + continue; + + if (valSz == storeSz) { + options.push_back(SI->getValueOperand()); + continue; + } + } + + // We capture our pointer of interest, if it is stored into an alloca, + // all loads of said alloca would potentially store into. + if (SI->getValueOperand() == ptr) { + if (suboff == 0) { + size_t mid_offset = 0; + if (auto AI2 = + getBaseAndOffset(SI->getPointerOperand(), mid_offset)) { + bool sublegal = true; + auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8; + auto subPtrs = + getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal); + if (!sublegal) { + legal = false; + return options; + } + for (auto subPtr : subPtrs) { + for (const auto &pair3 : findAllUsersOf(subPtr)) { + todo.emplace_back(pair3); + } + } + continue; + } + } + } + } + + // If we copy into the ptr at a location that includes the offset, consider + // all sub uses + if (auto MTI = dyn_cast(U)) { + if (auto CI = dyn_cast(MTI->getLength())) { + if (MTI->getOperand(0) == ptr && suboff == 0 && + CI->getValue().uge(offset + valSz)) { + size_t midoffset = 0; + auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset); + if (!AI2) { + legal = false; + return options; + } + if (midoffset != 0) { + legal = false; + return options; + } + for (const auto &pair3 : findAllUsersOf(AI2)) { + todo.emplace_back(pair3); + } + continue; + } + } + } + + legal = false; + return options; + } + + return options; +} + +// Perform mem2reg/sroa to identify the innermost value being represented. +Value *simplifyLoad(Value *V, size_t valSz) { + if (auto LI = dyn_cast(V)) { + if (valSz == 0) { + auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8; + } + + Value *ptr = LI->getPointerOperand(); + size_t offset = 0; + + auto AI = getBaseAndOffset(ptr, offset); + if (!AI) + return nullptr; + + bool legal = true; + auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal); + + if (!legal) { + return nullptr; + } + std::set res; + for (auto opt : opts) { + Value *v2 = simplifyLoad(opt, valSz); + if (v2) + res.insert(v2); + else + res.insert(opt); + } + if (res.size() != 1) { + return nullptr; + } + Value *retval = *res.begin(); + return retval; + } + if (auto EVI = dyn_cast(V)) { + bool allZero = true; + for (auto idx : EVI->getIndices()) { + if (idx != 0) + allZero = false; + } + if (valSz == 0) { + auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8; + } + if (allZero) + if (auto LI = dyn_cast(EVI->getAggregateOperand())) { + return simplifyLoad(LI, valSz); + } + } + return nullptr; +} + #if LLVM_VERSION_MAJOR > 16 std::optional getMetadataName(llvm::Value *res) #else Optional getMetadataName(llvm::Value *res) #endif { + if (auto S = simplifyLoad(res)) + return getMetadataName(S); + if (auto av = dyn_cast(res)) { return cast(av->getMetadata())->getString(); } else if ((isa(res) || isa(res)) && @@ -463,12 +709,11 @@ Optional getMetadataName(llvm::Value *res) return gv->getName(); } else if (auto gv = dyn_cast(res)) { return gv->getName(); - } else { - if (isa(res)) { - return recursePhiReads(cast(res)); - } - return {}; + } else if (isa(res)) { + return recursePhiReads(cast(res)); } + + return {}; } static Value *adaptReturnedVector(Value *ret, Value *diffret, diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp new file mode 100644 index 000000000000..3ad7db8d170f --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -0,0 +1,40 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +#include "../test_utils.h" + +#include + +double foo(double x, double y) { return x * y; } + +struct pair { + double x; + double y; +}; + +int main() { + + enzyme::autodiff_return< enzyme::active&& >::type q1; + double mo = q1; + + enzyme::active x1{3.1}; + enzyme::active x2{2.7}; + auto y = enzyme::autodiff(foo, x1, x2); + auto y1 = enzyme::get<0>(y); + auto y2 = enzyme::get<1>(y); + printf("%f %f\n", y1, y2); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + + auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, x1.value, enzyme_out, x2.value); + printf("%f %f\n", z1, z2); + APPROX_EQ(z1, 2.7, 1e-10); + APPROX_EQ(z2, 3.1, 1e-10); + +} \ No newline at end of file diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 04f63e1ad3ac..143c85ea684e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -65,7 +65,9 @@ static cl::opt cl::values(clEnumValN(MLIRDerivatives, "gen-mlir-derivatives", "Generate MLIR derivative")), cl::values(clEnumValN(CallDerivatives, "gen-call-derivatives", - "Generate call derivative"))); + "Generate call derivative")), + cl::values(clEnumValN(GenHeaderVariables, "gen-header-strings", + "Generate header strings"))); void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, StringRef FT, StringRef cconv, Init *func, @@ -1248,6 +1250,24 @@ void printDiffUse( } } +static void emitHeaderIncludes(const RecordKeeper &recordKeeper, + raw_ostream &os) { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers"); + os << "const char* include_headers[][2] = {\n"; + bool seen = false; + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename = pattern->getValueAsString("filename"); + auto contents = pattern->getValueAsString("contents"); + os << "{\"" << filename << "\"\n,"; + os << "R\"(" << contents << ")\"\n"; + os << "}"; + seen = true; + } + os << "};\n"; +} + static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); @@ -1268,6 +1288,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case BinopDerivatives: patternNames = "BinopPattern"; break; + case GenHeaderVariables: case GenBlasDerivatives: case UpdateBlasDecl: case UpdateBlasTA: @@ -1299,6 +1320,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case MLIRDerivatives: { auto opName = pattern->getValueAsString("opName"); @@ -2089,6 +2111,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDiffUse"); case CallDerivatives: patternNames = "CallPattern"; @@ -2127,6 +2150,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case CallDerivatives: { os << " if (("; @@ -2283,6 +2307,9 @@ static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { case UpdateBlasTA: emitBlasTAUpdater(records, os); return false; + case GenHeaderVariables: + emitHeaderIncludes(records, os); + return false; default: errs() << "unknown tablegen action!\n"; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h index 368644ba0b5d..742a96d023ae 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h @@ -24,6 +24,7 @@ enum ActionType { UpdateBlasDecl, UpdateBlasTA, GenBlasDiffUse, + GenHeaderVariables, }; void emitDiffUse(const llvm::RecordKeeper &recordKeeper, llvm::raw_ostream &os,