From 4573ad3452833af4dcc49fb3751657aaaa9ce33a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 4 Mar 2024 20:06:36 -0500 Subject: [PATCH] Add forward mode c++ syntax --- enzyme/Enzyme/Clang/include_utils.td | 111 +++++++++++++++--- .../Integration/ForwardMode/CMakeLists.txt | 2 +- enzyme/test/Integration/ForwardMode/sugar.cpp | 56 +++++++++ enzyme/test/Integration/ReverseMode/sugar.cpp | 2 +- 4 files changed, 155 insertions(+), 16 deletions(-) create mode 100644 enzyme/test/Integration/ForwardMode/sugar.cpp diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index cb7cdd839c20..cce7c3488390 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -36,6 +36,11 @@ namespace enzyme { }; using Reverse = ReverseMode; using ReverseWithPrimal = ReverseMode; + + struct ForwardMode { + + }; + using Forward = ForwardMode; template < typename T > struct Active{ @@ -51,6 +56,13 @@ namespace enzyme { Duplicated(T &&v, T&& s) : value(v), shadow(s) {} }; + template < typename T > + struct DuplicatedNoNeed{ + T value; + T shadow; + DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {} + }; + template < typename T > struct Const{ T value; @@ -110,6 +122,24 @@ namespace enzyme { >::type >; }; + + template < typename T0, typename ... T > + struct autodiff_return, T...> + { + using type = tuple; + }; + + template < typename T0, typename ... T > + struct autodiff_return, T...> + { + using type = tuple; + }; + + template < typename T0, typename ... T > + struct autodiff_return, T...> + { + using type = tuple; + }; template < typename T > __attribute__((always_inline)) @@ -117,6 +147,12 @@ namespace enzyme { return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; } + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::DuplicatedNoNeed & arg) { + return enzyme::tuple{enzyme_dupnoneed, arg.value, arg.shadow}; + } + template < typename T > __attribute__((always_inline)) auto expand_args(const enzyme::Active & arg) { @@ -135,6 +171,12 @@ namespace enzyme { return enzyme::tuple{arg.value}; } + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::DuplicatedNoNeed & arg) { + return enzyme::tuple{arg.value}; + } + template < typename T > __attribute__((always_inline)) auto primal_args(const enzyme::Active & arg) { @@ -164,11 +206,26 @@ namespace enzyme { return tuple{get<1>(t), get<0>(t)}; } + template + struct autodiff_apply {}; + + template + struct autodiff_apply> { template __attribute__((always_inline)) - constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); } + }; + + template <> + struct autodiff_apply { + template + __attribute__((always_inline)) + static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + return __enzyme_fwddiff(f, ret_attr, enzyme::get(impl::forward(t))...); + } + }; template __attribute__((always_inline)) @@ -176,20 +233,30 @@ namespace enzyme { return f(enzyme::get(impl::forward(t))...); } - template < typename T > + template < typename Mode, typename T > struct default_ret_activity { using type = Const; }; - template <> - struct default_ret_activity { + template + struct default_ret_activity, float> { using type = Active; }; - template <> - struct default_ret_activity { + template + struct default_ret_activity, double> { using type = Active; }; + + template<> + struct default_ret_activity { + using type = DuplicatedNoNeed; + }; + + template<> + struct default_ret_activity { + using type = DuplicatedNoNeed; + }; template < typename T > struct ret_global; @@ -209,6 +276,11 @@ namespace enzyme { static constexpr int* value = &enzyme_dup_return; }; + template + struct ret_global> { + static constexpr int* value = &enzyme_dup_return; + }; + template struct ret_used; @@ -222,9 +294,20 @@ namespace enzyme { static constexpr int* value = &enzyme_noret; }; - } // namespace detail - + template + struct ret_used> { + static constexpr int* value = &enzyme_noret; + }; + template + struct ret_used> { + static constexpr int* value = &enzyme_primal_return; + }; + template + struct ret_used> { + static constexpr int* value = &enzyme_primal_return; + }; + } // namespace detail template < typename return_type, typename function, typename ... enz_arg_types > __attribute__((always_inline)) @@ -238,27 +321,27 @@ namespace enzyme { return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); } - template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types > + template < typename return_type, typename DiffMode, typename function, typename RetActivity, typename ... enz_arg_types > __attribute__((always_inline)) - auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { using Tuple = enzyme::tuple< enz_arg_types ... >; - return detail::rev_apply_impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + return detail::autodiff_apply::template impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); } template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> __attribute__((always_inline)) auto autodiff(function && f, arg_types && ... args) { using return_type = typename autodiff_return::type; - return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); } template < typename DiffMode, typename function, typename ... arg_types> __attribute__((always_inline)) auto autodiff(function && f, arg_types && ... args) { using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); - using RetActivity = typename detail::default_ret_activity::type; + using RetActivity = typename detail::default_ret_activity::type; using return_type = typename autodiff_return::type; - return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + return autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); } } }]>; diff --git a/enzyme/test/Integration/ForwardMode/CMakeLists.txt b/enzyme/test/Integration/ForwardMode/CMakeLists.txt index 11d5d7a53b73..457f96782e40 100644 --- a/enzyme/test/Integration/ForwardMode/CMakeLists.txt +++ b/enzyme/test/Integration/ForwardMode/CMakeLists.txt @@ -1,7 +1,7 @@ # Run regression and unit tests add_lit_testsuite(check-enzyme-integration-forward "Running enzyme forward mode integration tests" ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${ENZYME_TEST_DEPS} + DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR} ARGS -v ) diff --git a/enzyme/test/Integration/ForwardMode/sugar.cpp b/enzyme/test/Integration/ForwardMode/sugar.cpp new file mode 100644 index 000000000000..d37187b5922e --- /dev/null +++ b/enzyme/test/Integration/ForwardMode/sugar.cpp @@ -0,0 +1,56 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +#include "../test_utils.h" + +#include + +double foo(double x, double y) { return x * y; } + +double square(double x) { return x * x; } + +struct pair { + double x; + double y; +}; + +int main() { + + { + enzyme::tuple< double, double > dsq = enzyme::autodiff>(square, enzyme::Duplicated(3.1, 1.0)); + double dd = enzyme::get<1>(dsq); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + + double pp = enzyme::get<0>(dsq); + printf("sq = %f\n", pp); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::tuple< double > dsq = enzyme::autodiff>(square, enzyme::Duplicated(3.1, 1.0)); + double dd = enzyme::get<0>(dsq); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::tuple< double > dsq = enzyme::autodiff(square, enzyme::Duplicated(3.1, 1.0)); + double dd = enzyme::get<0>(dsq); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::tuple< double > dsq = enzyme::autodiff>(square, enzyme::Duplicated(3.1, 1.0)); + double pp = enzyme::get<0>(dsq); + printf("sq = %f\n", pp); + APPROX_EQ(pp, 3.1*3.1, 1e-10); + } +} diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp index 8524342e1bfc..a57a994e1ca7 100644 --- a/enzyme/test/Integration/ReverseMode/sugar.cpp +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -2,7 +2,7 @@ // RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi