diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index 3a41c7499768..1c99d219ce69 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -11,6 +11,13 @@ extern int enzyme_dupnoneed; extern int enzyme_out; extern int enzyme_const; +extern int enzyme_const_return; +extern int enzyme_active_return; +extern int enzyme_dup_return; + +extern int enzyme_primal_return; +extern int enzyme_noret; + template Return __enzyme_autodiff(T...); @@ -21,118 +28,237 @@ Return __enzyme_fwddiff(T...); namespace enzyme { - enum ReturnActivity{ - INACTIVE, - ACTIVE, - DUPLICATED - }; - struct nodiff{}; + template + struct ReverseMode { + + }; + using Reverse = ReverseMode; + using ReverseWithPrimal = ReverseMode; + template < typename T > - struct active{ + struct Active{ T value; + Active(T &&v) : value(v) {} operator T&() { return value; } }; template < typename T > - active(T) -> active; - - template < typename T > - struct duplicated{ + struct Duplicated{ T value; T shadow; + Duplicated(T &&v, T&& s) : value(v), shadow(s) {} }; template < typename T > - struct inactive{ + struct Const{ T value; + Const(T &&v) : value(v) {} + operator T&() { return value; } }; template < typename T > struct type_info { static constexpr bool is_active = false; - - #ifdef ENZYME_OMIT_INACTIVE - using type = tuple<>; - #else - using type = tuple; - #endif + using type = nodiff; }; template < typename T > - struct type_info < active >{ + struct type_info < Active >{ static constexpr bool is_active = true; - using type = tuple; + using type = T; }; template < typename ... T > struct concatenated; - template < typename ... S, typename ... T, typename ... rest > - struct concatenated < tuple < S ... >, tuple < T ... >, rest ... > { - using type = typename concatenated< tuple< S ..., T ... >, rest ... >::type; - }; - - template < typename ... T > - struct concatenated < tuple < T ... > > { - using type = tuple< T ... >; + template < typename ... S, typename T, typename ... rest > + struct concatenated < tuple < S ... >, T, rest ... > { + using type = typename concatenated< tuple< S ..., T>, rest ... >::type; }; template < typename T > - struct concatenated < tuple > { + struct concatenated < T > { using type = T; }; // Yikes! // slightly cleaner in C++20, with std::remove_cvref template < typename ... T > - struct autodiff_return { - using type = typename concatenated< + struct autodiff_return; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple, typename type_info< typename remove_cvref< T >::type >::type ... - >::type; + >::type>; + }; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple< + typename type_info::type, + typename concatenated< tuple< >, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type + >; }; template < typename T > __attribute__((always_inline)) - auto expand_args(const enzyme::duplicated & arg) { + auto expand_args(const enzyme::Duplicated & arg) { return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; } template < typename T > __attribute__((always_inline)) - auto expand_args(const enzyme::active & arg) { + auto expand_args(const enzyme::Active & arg) { return enzyme::tuple{enzyme_out, arg.value}; } template < typename T > __attribute__((always_inline)) - auto expand_args(const enzyme::inactive & arg) { + auto expand_args(const enzyme::Const & arg) { return enzyme::tuple{enzyme_const, arg.value}; } + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Active & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Const & arg) { + return enzyme::tuple{arg.value}; + } + namespace detail { + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(T &&t); + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple>{get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple{get<1>(t), get<0>(t)}; + } + template __attribute__((always_inline)) - constexpr decltype(auto) apply_impl(void* f, Tuple&& t, std::index_sequence) { - return __enzyme_autodiff(f, enzyme::get(impl::forward(t))...); + constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); } + + template + __attribute__((always_inline)) + constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence) { + return f(enzyme::get(impl::forward(t))...); + } + + template < typename T > + struct default_ret_activity { + using type = Const; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template < typename T > + struct ret_global; + + template + struct ret_global> { + static constexpr int* value = &enzyme_const_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_active_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_dup_return; + }; + + template + struct ret_used; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_primal_return; + }; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_noret; + }; + } // namespace detail + + template < typename return_type, typename function, typename ... enz_arg_types > __attribute__((always_inline)) - auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { using Tuple = enzyme::tuple< enz_arg_types ... >; - return detail::apply_impl((void*)f, impl::forward(arg_tup), std::make_index_sequence>{}); + return detail::primal_apply_impl(f, impl::forward(arg_tup), std::make_index_sequence>{}); } template < typename function, typename ... arg_types> + auto primal_call(function && f, arg_types && ... args) { + return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); + } + + template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types > + __attribute__((always_inline)) + auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::rev_apply_impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } + + template < typename DiffMode, typename function, typename ... arg_types> __attribute__((always_inline)) auto autodiff(function && f, arg_types && ... args) { - using return_type = typename autodiff_return::type; - return autodiff_impl(impl::forward(f), enzyme::tuple_cat(expand_args(args)...)); + using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); + using RetActivity = typename detail::default_ret_activity::type; + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); } } }]>; @@ -162,7 +288,7 @@ namespace impl { template __attribute__((always_inline)) constexpr _Tp&& - forward(typename std::remove_reference<_Tp>::type& __t) noexcept + forward(std::remove_reference_t<_Tp>& __t) noexcept { return static_cast<_Tp&&>(__t); } /** @@ -174,12 +300,13 @@ namespace impl { template __attribute__((always_inline)) constexpr _Tp&& - forward(typename std::remove_reference<_Tp>::type&& __t) noexcept + forward(std::remove_reference_t<_Tp>&& __t) noexcept { static_assert(!std::is_lvalue_reference<_Tp>::value, "enzyme::impl::forward must not be used to convert an rvalue to an lvalue"); return static_cast<_Tp&&>(__t); } + } } @@ -263,23 +390,10 @@ struct tuple_size> : std::integral_constant {} template static constexpr size_t tuple_size_v = tuple_size::value; -template -__attribute__((always_inline)) -auto forward(std::remove_reference_t& arg) _NOEXCEPT { - return static_cast(arg); -} - -template -__attribute__((always_inline)) -auto forward(std::remove_reference_t&& arg) _NOEXCEPT { - static_assert(!std::is_lvalue_reference::value, "cannot forward an rvalue as an lvalue"); - return static_cast(arg); -} - template __attribute__((always_inline)) constexpr auto forward_as_tuple(T&&... args) noexcept { - return tuple{forward(args)...}; + return tuple{impl::forward(args)...}; } namespace impl { diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index e0856d6fc80a..bfe99d8bb14f 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -470,8 +470,6 @@ Optional getMetadataName(llvm::Value *res) return recursePhiReads(cast(res)); } - llvm::errs() << " failed to simplify: " << *res << "\n"; - return {}; } @@ -2848,7 +2846,6 @@ class EnzymeBase { } bool run(Module &M) { - llvm::errs() << M << "\n"; Logic.clear(); bool changed = false; diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index d0971656d80b..84b8b5b9540a 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -780,7 +780,6 @@ extern "C" void AddPreserveNVVMPass(LLVMPassManagerRef PM, uint8_t Begin) { PreserveNVVMNewPM::Result PreserveNVVMNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { - llvm::errs() << " PNPM: " << M << "\n"; bool changed = false; for (auto &F : M) changed |= preserveNVVM(Begin, F); diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp index 63a41f658dfb..8524342e1bfc 100644 --- a/enzyme/test/Integration/ReverseMode/sugar.cpp +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -13,28 +13,79 @@ double foo(double x, double y) { return x * y; } +double square(double x) { return x * x; } + struct pair { double x; double y; }; int main() { + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq2 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq3 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq3_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } - enzyme::autodiff_return< enzyme::active&& >::type q1; - double mo = q1; + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq4 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq4_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + printf("dmul %f %f\n", y1, y2); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + } - enzyme::active x1{3.1}; - enzyme::active x2{2.7}; - auto y = enzyme::autodiff(foo, x1, x2); - auto y1 = enzyme::get<0>(y); - auto y2 = enzyme::get<1>(y); - printf("%f %f\n", y1, y2); + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + auto prim = enzyme::get<1>(y); + printf("dmul2 %f %f\n", y1, y2); + printf("dmul_prim %f\n", prim); APPROX_EQ(y1, 2.7, 1e-10); APPROX_EQ(y2, 3.1, 1e-10); + APPROX_EQ(prim, 2.7*3.1, 1e-10); + } - auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, x1.value, enzyme_out, x2.value); - printf("%f %f\n", z1, z2); + { + auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, 3.1, enzyme_out, 2.7); + printf("dmul2 %f %f\n", z1, z2); APPROX_EQ(z1, 2.7, 1e-10); APPROX_EQ(z2, 3.1, 1e-10); + } }