From ca06ef36c94015695a13b786cdce89a269ec338f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 27 Feb 2024 01:16:36 -0500 Subject: [PATCH] Allow custom importing of files and syntactic sugar (#1752) * Allow custom importing of files and syntactic sugar * Fix build on older llvm vers * Update sugar.cpp * Update EnzymeClang.cpp * fix * fixup * dump * more printing * print * fixup --------- Co-authored-by: Ivan Radanov Ivanov --- enzyme/BUILD | 15 + enzyme/Enzyme/CMakeLists.txt | 6 + enzyme/Enzyme/Clang/EnzymeClang.cpp | 37 ++ enzyme/Enzyme/Clang/include_utils.td | 458 ++++++++++++++++++ enzyme/Enzyme/Enzyme.cpp | 14 +- enzyme/Enzyme/Utils.cpp | 333 ++++++++++--- enzyme/Enzyme/Utils.h | 2 + enzyme/test/Integration/ReverseMode/sugar.cpp | 91 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 29 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.h | 1 + 10 files changed, 916 insertions(+), 70 deletions(-) create mode 100644 enzyme/Enzyme/Clang/include_utils.td create mode 100644 enzyme/test/Integration/ReverseMode/sugar.cpp diff --git a/enzyme/BUILD b/enzyme/BUILD index e582b435d387..dc36cb4ad0c5 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -142,6 +142,20 @@ gentbl( ], ) +gentbl( + name = "include-utils", + tbl_outs = [( + "-gen-header-strings", + "IncludeUtils.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/Clang/include_utils.td", + td_srcs = ["Enzyme/Clang/include_utils.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeStatic", srcs = glob( @@ -167,6 +181,7 @@ cc_library( data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], deps = [ + "include-utils", ":binop-derivatives", ":blas-attributor", ":blas-derivatives", diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 1cd6e84c5be1..b27e4beb08cd 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -37,6 +37,10 @@ add_public_tablegen_target(BlasDeclarationsIncGen) add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) +set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) +enzyme_tablegen(IncludeUtils.inc -gen-header-strings) +add_public_tablegen_target(IncludeUtilsIncGen) + include_directories(${CMAKE_CURRENT_BINARY_DIR}) set(LLVM_LINK_COMPONENTS Demangle) @@ -74,6 +78,7 @@ if (${Clang_FOUND}) LLVM ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp @@ -107,6 +112,7 @@ if (${Clang_FOUND}) clang ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index a34a6429dcf7..0072c958b517 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -25,16 +25,20 @@ #include "clang/AST/Attr.h" #include "clang/AST/DeclGroup.h" #include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Basic/FileManager.h" #include "clang/Basic/MacroBuilder.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/FrontendAction.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Lex/HeaderSearch.h" #include "clang/Lex/PreprocessorOptions.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" #include "../Utils.h" +#include "IncludeUtils.inc" + using namespace clang; #if LLVM_VERSION_MAJOR >= 18 @@ -134,6 +138,39 @@ class EnzymePlugin final : public clang::ASTConsumer { Builder.defineMacro("ENZYME_VERSION_PATCH", std::to_string(ENZYME_VERSION_PATCH)); CI.getPreprocessor().setPredefines(Predefines.str()); + + auto baseFS = &CI.getFileManager().getVirtualFileSystem(); + llvm::vfs::OverlayFileSystem *fuseFS( + new llvm::vfs::OverlayFileSystem(baseFS)); + IntrusiveRefCntPtr fs( + new llvm::vfs::InMemoryFileSystem()); + + struct tm y2k = {}; + + y2k.tm_hour = 0; + y2k.tm_min = 0; + y2k.tm_sec = 0; + y2k.tm_year = 100; + y2k.tm_mon = 0; + y2k.tm_mday = 1; + time_t timer = mktime(&y2k); + for (const auto &pair : include_headers) { + fs->addFile(StringRef(pair[0]), timer, + llvm::MemoryBuffer::getMemBuffer( + StringRef(pair[1]), StringRef(pair[0]), + /*RequiresNullTerminator*/ true)); + } + + fuseFS->pushOverlay(fs); + fuseFS->pushOverlay(baseFS); + CI.getFileManager().setVirtualFileSystem(fuseFS); + + auto DE = CI.getFileManager().getDirectoryRef("/enzymeroot"); + assert(DE); + auto DL = DirectoryLookup(*DE, SrcMgr::C_User, + /*isFramework=*/false); + CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL, + /*isAngled=*/true); } ~EnzymePlugin() {} void HandleTranslationUnit(ASTContext &context) override {} diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td new file mode 100644 index 000000000000..1c99d219ce69 --- /dev/null +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -0,0 +1,458 @@ +class Headers { + string filename = filename_; + string contents = contents_; +} + +def : Headers<"/enzymeroot/enzyme/utils", [{ +#pragma once + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +extern int enzyme_const_return; +extern int enzyme_active_return; +extern int enzyme_dup_return; + +extern int enzyme_primal_return; +extern int enzyme_noret; + +template +Return __enzyme_autodiff(T...); + +template +Return __enzyme_fwddiff(T...); + +#include + +namespace enzyme { + + struct nodiff{}; + + template + struct ReverseMode { + + }; + using Reverse = ReverseMode; + using ReverseWithPrimal = ReverseMode; + + template < typename T > + struct Active{ + T value; + Active(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct Duplicated{ + T value; + T shadow; + Duplicated(T &&v, T&& s) : value(v), shadow(s) {} + }; + + template < typename T > + struct Const{ + T value; + Const(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct type_info { + static constexpr bool is_active = false; + using type = nodiff; + }; + + template < typename T > + struct type_info < Active >{ + static constexpr bool is_active = true; + using type = T; + }; + + template < typename ... T > + struct concatenated; + + template < typename ... S, typename T, typename ... rest > + struct concatenated < tuple < S ... >, T, rest ... > { + using type = typename concatenated< tuple< S ..., T>, rest ... >::type; + }; + + template < typename T > + struct concatenated < T > { + using type = T; + }; + + // Yikes! + // slightly cleaner in C++20, with std::remove_cvref + template < typename ... T > + struct autodiff_return; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type>; + }; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple< + typename type_info::type, + typename concatenated< tuple< >, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type + >; + }; + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Active & arg) { + return enzyme::tuple{enzyme_out, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Const & arg) { + return enzyme::tuple{enzyme_const, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Active & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Const & arg) { + return enzyme::tuple{arg.value}; + } + + namespace detail { + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(T &&t); + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple>{get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple{get<1>(t), get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence) { + return f(enzyme::get(impl::forward(t))...); + } + + template < typename T > + struct default_ret_activity { + using type = Const; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template < typename T > + struct ret_global; + + template + struct ret_global> { + static constexpr int* value = &enzyme_const_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_active_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_dup_return; + }; + + template + struct ret_used; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_primal_return; + }; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_noret; + }; + + } // namespace detail + + + + template < typename return_type, typename function, typename ... enz_arg_types > + __attribute__((always_inline)) + auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::primal_apply_impl(f, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename function, typename ... arg_types> + auto primal_call(function && f, arg_types && ... args) { + return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); + } + + template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types > + __attribute__((always_inline)) + auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::rev_apply_impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } + + template < typename DiffMode, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); + using RetActivity = typename detail::default_ret_activity::type; + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } +} +}]>; + +def : Headers<"/enzymeroot/enzyme/type_traits", [{ +#pragma once + +#include + +namespace enzyme { + +// this is already in C++20, but we reimplement it here for older C++ versions +template < typename T > +struct remove_cvref { + using type = + typename std::remove_reference< + typename std::remove_cv< + T + >::type + >::type; +}; + +template < typename T > +using remove_cvref_t = typename remove_cvref::type; + +namespace impl { + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>& __t) noexcept + { return static_cast<_Tp&&>(__t); } + + /** + * @brief Forward an rvalue. + * @return The parameter cast to the specified type. + * + * This function is used to implement "perfect forwarding". + */ + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>&& __t) noexcept + { + static_assert(!std::is_lvalue_reference<_Tp>::value, + "enzyme::impl::forward must not be used to convert an rvalue to an lvalue"); + return static_cast<_Tp&&>(__t); + } + +} + +} +}]>; + +def : Headers<"/enzymeroot/enzyme/tuple", [{ +#pragma once + +///////////// +// tuple.h // +///////////// + +// why reinvent the wheel and implement a tuple class? +// - ensure data is laid out in the same order the types are specified +// see: https://github.com/EnzymeAD/Enzyme/issues/1191#issuecomment-1556239213 +// - CUDA compatibility: std::tuple has some compatibility issues when used +// in a __device__ context (this may get better in c++20 with the improved +// constexpr support for std::tuple). Owning the implementation lets +// us add __host__ __device__ annotations to any part of it + +#include // for std::integer_sequence + +#include + +#define _NOEXCEPT noexcept +namespace enzyme { + +template +struct Index {}; + +template +struct value_at_position { + __attribute__((always_inline)) + T & operator[](Index) { return value; } + + __attribute__((always_inline)) + constexpr const T & operator[](Index) const { return value; } + T value; +}; + +template +struct tuple_base; + +template +struct tuple_base, T...> + : public value_at_position... { + using value_at_position::operator[]...; +}; + +template +struct tuple : public tuple_base, T...> {}; + +template +__attribute__((always_inline)) +tuple(T ...) -> tuple; + +template < int i, typename Tuple > +__attribute__((always_inline)) +decltype(auto) get(Tuple && tup) { + constexpr bool is_lvalue = std::is_lvalue_reference_v; + constexpr bool is_const = std::is_const_v>; + using T = remove_cvref_t< decltype(tup[Index{ } ]) >; + if constexpr ( is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr ( is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } +} + +template < int i, typename ... T> +__attribute__((always_inline)) +decltype(auto) get(const tuple< T ... > & tup) { + return tup[Index{} ]; +} + +template +struct tuple_size; + +template +struct tuple_size> : std::integral_constant {}; + +template +static constexpr size_t tuple_size_v = tuple_size::value; + +template +__attribute__((always_inline)) +constexpr auto forward_as_tuple(T&&... args) noexcept { + return tuple{impl::forward(args)...}; +} + +namespace impl { + +template +struct make_tuple_from_fwd_tuple; + +template +struct make_tuple_from_fwd_tuple> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd) { + return tuple{get(impl::forward(fwd))...}; + } +}; + +template +struct concat_with_fwd_tuple; + +template < typename Tuple > +using iseq = std::make_index_sequence > >; + +template +struct concat_with_fwd_tuple, std::index_sequence> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) { + return forward_as_tuple(get(impl::forward(fwd))..., get(impl::forward(t))...); + } +}; + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(Tuple&& ret) { + return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(impl::forward< Tuple >(ret)); +} + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(FWD_TUPLE&& fwd, first&& t, rest&&... ts) { + return tuple_cat(concat_with_fwd_tuple< iseq, iseq >::f(impl::forward(fwd), impl::forward(t)), impl::forward(ts)...); +} + +} // namespace impl + +template +__attribute__((always_inline)) +constexpr auto tuple_cat(Tuples&&... tuples) { + return impl::tuple_cat(impl::forward(tuples)...); +} + +} // namespace enzyme +#undef _NOEXCEPT +}]>; + +def : Headers<"/enzymeroot/enzyme/enzyme", [{ +#ifdef __cplusplus +#include "enzyme/utils" +#else +#warning "Enzyme wrapper templates only available in C++" +#endif +}]>; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 47e9f2bba91f..70f173f5734a 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -435,6 +435,9 @@ std::optional getMetadataName(llvm::Value *res) Optional getMetadataName(llvm::Value *res) #endif { + if (auto S = simplifyLoad(res)) + return getMetadataName(S); + if (auto av = dyn_cast(res)) { return cast(av->getMetadata())->getString(); } else if ((isa(res) || isa(res)) && @@ -463,12 +466,11 @@ Optional getMetadataName(llvm::Value *res) return gv->getName(); } else if (auto gv = dyn_cast(res)) { return gv->getName(); - } else { - if (isa(res)) { - return recursePhiReads(cast(res)); - } - return {}; + } else if (isa(res)) { + return recursePhiReads(cast(res)); } + + return {}; } static Value *adaptReturnedVector(Value *ret, Value *diffret, @@ -3197,6 +3199,7 @@ AnalysisKey EnzymeNewPM::Key; #include "PreserveNVVM.h" #include "TypeAnalysis/TypeAnalysisPrinter.h" #include "llvm/Passes/PassBuilder.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" #if LLVM_VERSION_MAJOR >= 15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/IPO/CalledValuePropagation.h" @@ -3427,6 +3430,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { #else prePass(MPM); #endif + MPM.addPass(llvm::AlwaysInlinerPass()); FunctionPassManager OptimizerPM; FunctionPassManager OptimizerPM2; #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index ff44cbaa715d..283460673922 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2264,7 +2264,258 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm_unreachable("unknown inst2"); } -Function *GetFunctionFromValue(Value *fn) { +// Find the base pointer of ptr and the offset in bytes from the start of +// the returned base pointer to this value. +AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) { + offset = 0; + while (true) { + if (auto CI = dyn_cast(ptr)) { + ptr = CI->getOperand(0); + continue; + } + if (auto CI = dyn_cast(ptr)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + return nullptr; + } + offset += Offset.getZExtValue(); + ptr = CI->getOperand(0); + continue; + } + if (isa(ptr)) { + break; + } + if (auto LI = dyn_cast(ptr)) { + if (auto S = simplifyLoad(LI)) { + ptr = S; + continue; + } + } + return nullptr; + } + return cast(ptr); +} + +// Find all user instructions of AI, returning tuples of Unlike a simple get users, this will recurse through any +// constant gep offsets and casts +SmallVector, 1> +findAllUsersOf(Value *AI) { + SmallVector, 1> todo; + todo.emplace_back(AI, 0); + + SmallVector, 1> users; + while (todo.size()) { + auto pair = todo.pop_back_val(); + Value *ptr = pair.first; + size_t suboff = pair.second; + + for (auto U : ptr->users()) { + if (auto CI = dyn_cast(U)) { + todo.emplace_back(CI, suboff); + continue; + } + if (auto CI = dyn_cast(U)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + users.emplace_back(cast(U), ptr, suboff); + continue; + } + todo.emplace_back(CI, suboff + Offset.getZExtValue()); + continue; + } + users.emplace_back(cast(U), ptr, suboff); + continue; + } + } + return users; +} + +// Given a pointer, find all values of size `valSz` which could be loaded from +// that pointer when indexed at offset. If it is impossible to guarantee that +// the set contains all such values, set legal to false +SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, + size_t valSz, bool &legal) { + SmallVector options; + + auto todo = findAllUsersOf(ptr0); + std::set> seen; + + while (todo.size()) { + auto pair = todo.pop_back_val(); + if (seen.count(pair)) + continue; + seen.insert(pair); + Instruction *U = std::get<0>(pair); + Value *ptr = std::get<1>(pair); + size_t suboff = std::get<2>(pair); + + // Read only users do not set the memory inside of ptr + if (isa(U)) { + continue; + } + if (auto MTI = dyn_cast(U)) + if (MTI->getOperand(0) != ptr) { + continue; + } + if (auto I = dyn_cast(U)) { + if (!I->mayWriteToMemory() && I->getType()->isVoidTy()) + continue; + } + + if (auto SI = dyn_cast(U)) { + auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout(); + + // We are storing into the ptr + if (SI->getPointerOperand() == ptr) { + auto storeSz = + (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) / + 8; + // If store is before the load would start + if (storeSz + suboff <= offset) + continue; + // if store starts after load would start + if (offset + valSz <= suboff) + continue; + + if (valSz == storeSz) { + options.push_back(SI->getValueOperand()); + continue; + } + } + + // We capture our pointer of interest, if it is stored into an alloca, + // all loads of said alloca would potentially store into. + if (SI->getValueOperand() == ptr) { + if (suboff == 0) { + size_t mid_offset = 0; + if (auto AI2 = + getBaseAndOffset(SI->getPointerOperand(), mid_offset)) { + bool sublegal = true; + auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8; + auto subPtrs = + getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal); + if (!sublegal) { + legal = false; + return options; + } + for (auto subPtr : subPtrs) { + for (const auto &pair3 : findAllUsersOf(subPtr)) { + todo.emplace_back(pair3); + } + } + continue; + } + } + } + } + + if (auto II = dyn_cast(U)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + } + + // If we copy into the ptr at a location that includes the offset, consider + // all sub uses + if (auto MTI = dyn_cast(U)) { + if (auto CI = dyn_cast(MTI->getLength())) { + if (MTI->getOperand(0) == ptr && suboff == 0 && + CI->getValue().uge(offset + valSz)) { + size_t midoffset = 0; + auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset); + if (!AI2) { + legal = false; + return options; + } + if (midoffset != 0) { + legal = false; + return options; + } + for (const auto &pair3 : findAllUsersOf(AI2)) { + todo.emplace_back(pair3); + } + continue; + } + } + } + + legal = false; + return options; + } + + return options; +} + +// Perform mem2reg/sroa to identify the innermost value being represented. +Value *simplifyLoad(Value *V, size_t valSz) { + if (auto LI = dyn_cast(V)) { + if (valSz == 0) { + auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8; + } + + Value *ptr = LI->getPointerOperand(); + size_t offset = 0; + + if (auto ptr2 = simplifyLoad(ptr)) { + ptr = ptr2; + } + auto AI = getBaseAndOffset(ptr, offset); + if (!AI) { + return nullptr; + } + + bool legal = true; + auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal); + + if (!legal) { + return nullptr; + } + std::set res; + for (auto opt : opts) { + Value *v2 = simplifyLoad(opt, valSz); + if (v2) + res.insert(v2); + else + res.insert(opt); + } + if (res.size() != 1) { + return nullptr; + } + Value *retval = *res.begin(); + return retval; + } + if (auto EVI = dyn_cast(V)) { + bool allZero = true; + for (auto idx : EVI->getIndices()) { + if (idx != 0) + allZero = false; + } + if (valSz == 0) { + auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8; + } + if (allZero) + if (auto LI = dyn_cast(EVI->getAggregateOperand())) { + return simplifyLoad(LI, valSz); + } + } + return nullptr; +} + +Value *GetFunctionValFromValue(Value *fn) { while (!isa(fn)) { if (auto ci = dyn_cast(fn)) { fn = ci->getOperand(0); @@ -2294,6 +2545,7 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + val = GetFunctionValFromValue(val); if (isa(val)) { fn = val; continue; @@ -2315,6 +2567,14 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + while (isa(val)) { + auto v2 = simplifyLoad(val); + if (v2) { + val = v2; + continue; + } + break; + } if (isa(val)) { fn = val; continue; @@ -2326,73 +2586,18 @@ Function *GetFunctionFromValue(Value *fn) { } } } - if (auto LI = dyn_cast(fn)) { - auto obj = getBaseObject(LI->getPointerOperand()); - if (isa(obj)) { - std::set> done; - SmallVector, 1> todo; - Value *stored = nullptr; - bool legal = true; - for (auto U : obj->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, obj)); - else { - legal = false; - break; - } - } - while (legal && todo.size()) { - auto tup = todo.pop_back_val(); - if (done.count(tup)) - continue; - done.insert(tup); - auto cur = tup.first; - auto prev = tup.second; - if (auto SI = dyn_cast(cur)) - if (SI->getPointerOperand() == prev) { - if (stored == SI->getValueOperand()) - continue; - else if (stored == nullptr) { - stored = SI->getValueOperand(); - continue; - } else { - legal = false; - break; - } - } - - if (isPointerArithmeticInst(cur, /*includephi*/ true)) { - for (auto U : cur->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, cur)); - else { - legal = false; - break; - } - } - continue; - } - - if (isa(cur)) - continue; - - if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) - continue; - - legal = false; - break; - } - - if (legal && stored) { - fn = stored; - continue; - } - } + if (auto S = simplifyLoad(fn)) { + fn = S; + continue; } break; } - return dyn_cast(fn); + return fn; +} + +Function *GetFunctionFromValue(Value *fn) { + return dyn_cast(GetFunctionValFromValue(fn)); } #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 37edfc1a985d..5a4b3c31cef7 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1248,6 +1248,8 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Function *GetFunctionFromValue(llvm::Value *fn); +llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0); + static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) { auto F = getFunctionFromCall(CI); auto funcName = getFuncNameFromCall(CI); diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp new file mode 100644 index 000000000000..8524342e1bfc --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -0,0 +1,91 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +#include "../test_utils.h" + +#include + +double foo(double x, double y) { return x * y; } + +double square(double x) { return x * x; } + +struct pair { + double x; + double y; +}; + +int main() { + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq2 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq3 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq3_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq4 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq4_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + printf("dmul %f %f\n", y1, y2); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + auto prim = enzyme::get<1>(y); + printf("dmul2 %f %f\n", y1, y2); + printf("dmul_prim %f\n", prim); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + APPROX_EQ(prim, 2.7*3.1, 1e-10); + } + + { + auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, 3.1, enzyme_out, 2.7); + printf("dmul2 %f %f\n", z1, z2); + APPROX_EQ(z1, 2.7, 1e-10); + APPROX_EQ(z2, 3.1, 1e-10); + } + +} diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 04f63e1ad3ac..143c85ea684e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -65,7 +65,9 @@ static cl::opt cl::values(clEnumValN(MLIRDerivatives, "gen-mlir-derivatives", "Generate MLIR derivative")), cl::values(clEnumValN(CallDerivatives, "gen-call-derivatives", - "Generate call derivative"))); + "Generate call derivative")), + cl::values(clEnumValN(GenHeaderVariables, "gen-header-strings", + "Generate header strings"))); void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, StringRef FT, StringRef cconv, Init *func, @@ -1248,6 +1250,24 @@ void printDiffUse( } } +static void emitHeaderIncludes(const RecordKeeper &recordKeeper, + raw_ostream &os) { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers"); + os << "const char* include_headers[][2] = {\n"; + bool seen = false; + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename = pattern->getValueAsString("filename"); + auto contents = pattern->getValueAsString("contents"); + os << "{\"" << filename << "\"\n,"; + os << "R\"(" << contents << ")\"\n"; + os << "}"; + seen = true; + } + os << "};\n"; +} + static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); @@ -1268,6 +1288,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case BinopDerivatives: patternNames = "BinopPattern"; break; + case GenHeaderVariables: case GenBlasDerivatives: case UpdateBlasDecl: case UpdateBlasTA: @@ -1299,6 +1320,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case MLIRDerivatives: { auto opName = pattern->getValueAsString("opName"); @@ -2089,6 +2111,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDiffUse"); case CallDerivatives: patternNames = "CallPattern"; @@ -2127,6 +2150,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case CallDerivatives: { os << " if (("; @@ -2283,6 +2307,9 @@ static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { case UpdateBlasTA: emitBlasTAUpdater(records, os); return false; + case GenHeaderVariables: + emitHeaderIncludes(records, os); + return false; default: errs() << "unknown tablegen action!\n"; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h index 368644ba0b5d..742a96d023ae 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h @@ -24,6 +24,7 @@ enum ActionType { UpdateBlasDecl, UpdateBlasTA, GenBlasDiffUse, + GenHeaderVariables, }; void emitDiffUse(const llvm::RecordKeeper &recordKeeper, llvm::raw_ostream &os,