diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index cb7cdd839c20..cce7c3488390 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -36,6 +36,11 @@ namespace enzyme { }; using Reverse = ReverseMode; using ReverseWithPrimal = ReverseMode; + + struct ForwardMode { + + }; + using Forward = ForwardMode; template < typename T > struct Active{ @@ -51,6 +56,13 @@ namespace enzyme { Duplicated(T &&v, T&& s) : value(v), shadow(s) {} }; + template < typename T > + struct DuplicatedNoNeed{ + T value; + T shadow; + DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {} + }; + template < typename T > struct Const{ T value; @@ -110,6 +122,24 @@ namespace enzyme { >::type >; }; + + template < typename T0, typename ... T > + struct autodiff_return, T...> + { + using type = tuple; + }; + + template < typename T0, typename ... T > + struct autodiff_return, T...> + { + using type = tuple; + }; + + template < typename T0, typename ... T > + struct autodiff_return, T...> + { + using type = tuple; + }; template < typename T > __attribute__((always_inline)) @@ -117,6 +147,12 @@ namespace enzyme { return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; } + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::DuplicatedNoNeed & arg) { + return enzyme::tuple{enzyme_dupnoneed, arg.value, arg.shadow}; + } + template < typename T > __attribute__((always_inline)) auto expand_args(const enzyme::Active & arg) { @@ -135,6 +171,12 @@ namespace enzyme { return enzyme::tuple{arg.value}; } + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::DuplicatedNoNeed & arg) { + return enzyme::tuple{arg.value}; + } + template < typename T > __attribute__((always_inline)) auto primal_args(const enzyme::Active & arg) { @@ -164,11 +206,26 @@ namespace enzyme { return tuple{get<1>(t), get<0>(t)}; } + template + struct autodiff_apply {}; + + template + struct autodiff_apply> { template __attribute__((always_inline)) - constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); } + }; + + template <> + struct autodiff_apply { + template + __attribute__((always_inline)) + static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + return __enzyme_fwddiff(f, ret_attr, enzyme::get(impl::forward(t))...); + } + }; template __attribute__((always_inline)) @@ -176,20 +233,30 @@ namespace enzyme { return f(enzyme::get(impl::forward(t))...); } - template < typename T > + template < typename Mode, typename T > struct default_ret_activity { using type = Const; }; - template <> - struct default_ret_activity { + template + struct default_ret_activity, float> { using type = Active; }; - template <> - struct default_ret_activity { + template + struct default_ret_activity, double> { using type = Active; }; + + template<> + struct default_ret_activity { + using type = DuplicatedNoNeed; + }; + + template<> + struct default_ret_activity { + using type = DuplicatedNoNeed; + }; template < typename T > struct ret_global; @@ -209,6 +276,11 @@ namespace enzyme { static constexpr int* value = &enzyme_dup_return; }; + template + struct ret_global> { + static constexpr int* value = &enzyme_dup_return; + }; + template struct ret_used; @@ -222,9 +294,20 @@ namespace enzyme { static constexpr int* value = &enzyme_noret; }; - } // namespace detail - + template + struct ret_used> { + static constexpr int* value = &enzyme_noret; + }; + template + struct ret_used> { + static constexpr int* value = &enzyme_primal_return; + }; + template + struct ret_used> { + static constexpr int* value = &enzyme_primal_return; + }; + } // namespace detail template < typename return_type, typename function, typename ... enz_arg_types > __attribute__((always_inline)) @@ -238,27 +321,27 @@ namespace enzyme { return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); } - template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types > + template < typename return_type, typename DiffMode, typename function, typename RetActivity, typename ... enz_arg_types > __attribute__((always_inline)) - auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { using Tuple = enzyme::tuple< enz_arg_types ... >; - return detail::rev_apply_impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + return detail::autodiff_apply::template impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); } template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> __attribute__((always_inline)) auto autodiff(function && f, arg_types && ... args) { using return_type = typename autodiff_return::type; - return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); } template < typename DiffMode, typename function, typename ... arg_types> __attribute__((always_inline)) auto autodiff(function && f, arg_types && ... args) { using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); - using RetActivity = typename detail::default_ret_activity::type; + using RetActivity = typename detail::default_ret_activity::type; using return_type = typename autodiff_return::type; - return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); } } }]>; diff --git a/enzyme/test/Integration/ForwardMode/CMakeLists.txt b/enzyme/test/Integration/ForwardMode/CMakeLists.txt index 11d5d7a53b73..457f96782e40 100644 --- a/enzyme/test/Integration/ForwardMode/CMakeLists.txt +++ b/enzyme/test/Integration/ForwardMode/CMakeLists.txt @@ -1,7 +1,7 @@ # Run regression and unit tests add_lit_testsuite(check-enzyme-integration-forward "Running enzyme forward mode integration tests" ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${ENZYME_TEST_DEPS} + DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR} ARGS -v ) diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp index 8524342e1bfc..a57a994e1ca7 100644 --- a/enzyme/test/Integration/ReverseMode/sugar.cpp +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -2,7 +2,7 @@ // RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi