Skip to content

Commit

Permalink
Add forward mode c++ syntax (#1776)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Mar 6, 2024
1 parent 0b62188 commit 7e41c58
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 16 deletions.
111 changes: 97 additions & 14 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ namespace enzyme {
};
using Reverse = ReverseMode<false>;
using ReverseWithPrimal = ReverseMode<true>;

struct ForwardMode {

};
using Forward = ForwardMode;

template < typename T >
struct Active{
Expand All @@ -51,6 +56,13 @@ namespace enzyme {
Duplicated(T &&v, T&& s) : value(v), shadow(s) {}
};

template < typename T >
struct DuplicatedNoNeed{
T value;
T shadow;
DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {}
};

template < typename T >
struct Const{
T value;
Expand Down Expand Up @@ -110,13 +122,37 @@ namespace enzyme {
>::type
>;
};

template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, Const<T0>, T...>
{
using type = tuple<T0>;
};

template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, Duplicated<T0>, T...>
{
using type = tuple<T0, T0>;
};

template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, DuplicatedNoNeed<T0>, T...>
{
using type = tuple<T0>;
};

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple<int, T, T>{enzyme_dup, arg.value, arg.shadow};
}

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple<int, T, T>{enzyme_dupnoneed, arg.value, arg.shadow};
}

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::Active<T> & arg) {
Expand All @@ -135,6 +171,12 @@ namespace enzyme {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Active<T> & arg) {
Expand Down Expand Up @@ -164,32 +206,57 @@ namespace enzyme {
return tuple{get<1>(t), get<0>(t)};
}

template <typename Mode>
struct autodiff_apply {};

template <bool Mode>
struct autodiff_apply<ReverseMode<Mode>> {
template <class return_type, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...));
}
};

template <>
struct autodiff_apply<ForwardMode> {
template <class return_type, class Tuple, std::size_t... I>
__attribute__((always_inline))
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...);
}
};

template <typename function, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence<I...>) {
return f(enzyme::get<I>(impl::forward<Tuple>(t))...);
}

template < typename T >
template < typename Mode, typename T >
struct default_ret_activity {
using type = Const<T>;
};

template <>
struct default_ret_activity<float> {
template <bool prim>
struct default_ret_activity<ReverseMode<prim>, float> {
using type = Active<float>;
};

template <>
struct default_ret_activity<double> {
template <bool prim>
struct default_ret_activity<ReverseMode<prim>, double> {
using type = Active<double>;
};

template<>
struct default_ret_activity<ForwardMode, float> {
using type = DuplicatedNoNeed<float>;
};

template<>
struct default_ret_activity<ForwardMode, double> {
using type = DuplicatedNoNeed<double>;
};

template < typename T >
struct ret_global;
Expand All @@ -209,6 +276,11 @@ namespace enzyme {
static constexpr int* value = &enzyme_dup_return;
};

template<typename T>
struct ret_global<DuplicatedNoNeed<T>> {
static constexpr int* value = &enzyme_dup_return;
};

template<typename Mode, typename RetAct>
struct ret_used;

Expand All @@ -222,9 +294,20 @@ namespace enzyme {
static constexpr int* value = &enzyme_noret;
};

} // namespace detail

template<typename T>
struct ret_used<ForwardMode, DuplicatedNoNeed<T>> {
static constexpr int* value = &enzyme_noret;
};
template<typename T>
struct ret_used<ForwardMode, Const<T>> {
static constexpr int* value = &enzyme_primal_return;
};
template<typename T>
struct ret_used<ForwardMode, Duplicated<T>> {
static constexpr int* value = &enzyme_primal_return;
};

} // namespace detail

template < typename return_type, typename function, typename ... enz_arg_types >
__attribute__((always_inline))
Expand All @@ -238,27 +321,27 @@ namespace enzyme {
return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...));
}

template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types >
template < typename return_type, typename DiffMode, typename function, typename RetActivity, typename ... enz_arg_types >
__attribute__((always_inline))
auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::rev_apply_impl<return_type>((void*)f, detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)f, detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return rev_autodiff_impl<return_type, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}

template < typename DiffMode, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using primal_return_type = decltype(primal_call<function, arg_types...>(impl::forward<function>(f), impl::forward<arg_types>(args)...));
using RetActivity = typename detail::default_ret_activity<primal_return_type>::type;
using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return rev_autodiff_impl<return_type, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}
}
}]>;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Integration/ForwardMode/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Run regression and unit tests
add_lit_testsuite(check-enzyme-integration-forward "Running enzyme forward mode integration tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${ENZYME_TEST_DEPS}
DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR}
ARGS -v
)

Expand Down
56 changes: 56 additions & 0 deletions enzyme/test/Integration/ForwardMode/sugar.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi

#include "../test_utils.h"

#include <enzyme/enzyme>

double foo(double x, double y) { return x * y; }

double square(double x) { return x * x; }

struct pair {
double x;
double y;
};

int main() {

{
enzyme::tuple< double, double > dsq = enzyme::autodiff<enzyme::Forward, enzyme::Duplicated<double>>(square, enzyme::Duplicated<double>(3.1, 1.0));
double dd = enzyme::get<1>(dsq);
printf("dsq = %f\n", dd);
APPROX_EQ(dd, 3.1*2, 1e-10);

double pp = enzyme::get<0>(dsq);
printf("sq = %f\n", pp);
APPROX_EQ(dd, 3.1*2, 1e-10);
}

{
enzyme::tuple< double > dsq = enzyme::autodiff<enzyme::Forward, enzyme::DuplicatedNoNeed<double>>(square, enzyme::Duplicated<double>(3.1, 1.0));
double dd = enzyme::get<0>(dsq);
printf("dsq = %f\n", dd);
APPROX_EQ(dd, 3.1*2, 1e-10);
}

{
enzyme::tuple< double > dsq = enzyme::autodiff<enzyme::Forward>(square, enzyme::Duplicated<double>(3.1, 1.0));
double dd = enzyme::get<0>(dsq);
printf("dsq = %f\n", dd);
APPROX_EQ(dd, 3.1*2, 1e-10);
}

{
enzyme::tuple< double > dsq = enzyme::autodiff<enzyme::Forward, enzyme::Const<double>>(square, enzyme::Duplicated<double>(3.1, 1.0));
double pp = enzyme::get<0>(dsq);
printf("sq = %f\n", pp);
APPROX_EQ(pp, 3.1*3.1, 1e-10);
}
}
2 changes: 1 addition & 1 deletion enzyme/test/Integration/ReverseMode/sugar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
Expand Down

0 comments on commit 7e41c58

Please sign in to comment.