Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward mode c++ syntax #1776

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 97 additions & 14 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ namespace enzyme {
};
using Reverse = ReverseMode<false>;
using ReverseWithPrimal = ReverseMode<true>;

struct ForwardMode {

};
using Forward = ForwardMode;

template < typename T >
struct Active{
Expand All @@ -51,6 +56,13 @@ namespace enzyme {
Duplicated(T &&v, T&& s) : value(v), shadow(s) {}
};

template < typename T >
struct DuplicatedNoNeed{
T value;
T shadow;
DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {}
};

template < typename T >
struct Const{
T value;
Expand Down Expand Up @@ -110,13 +122,37 @@ namespace enzyme {
>::type
>;
};

template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, Const<T0>, T...>
{
using type = tuple<T0>;
};

template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, Duplicated<T0>, T...>
{
using type = tuple<T0, T0>;
};

template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, DuplicatedNoNeed<T0>, T...>
{
using type = tuple<T0>;
};

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple<int, T, T>{enzyme_dup, arg.value, arg.shadow};
}

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple<int, T, T>{enzyme_dupnoneed, arg.value, arg.shadow};
}

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::Active<T> & arg) {
Expand All @@ -135,6 +171,12 @@ namespace enzyme {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Active<T> & arg) {
Expand Down Expand Up @@ -164,32 +206,57 @@ namespace enzyme {
return tuple{get<1>(t), get<0>(t)};
}

template <typename Mode>
struct autodiff_apply {};

template <bool Mode>
struct autodiff_apply<ReverseMode<Mode>> {
template <class return_type, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...));
}
};

template <>
struct autodiff_apply<ForwardMode> {
template <class return_type, class Tuple, std::size_t... I>
__attribute__((always_inline))
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return __enzyme_fwddiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...);
}
};

template <typename function, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence<I...>) {
return f(enzyme::get<I>(impl::forward<Tuple>(t))...);
}

template < typename T >
template < typename Mode, typename T >
struct default_ret_activity {
using type = Const<T>;
};

template <>
struct default_ret_activity<float> {
template <bool prim>
struct default_ret_activity<ReverseMode<prim>, float> {
using type = Active<float>;
};

template <>
struct default_ret_activity<double> {
template <bool prim>
struct default_ret_activity<ReverseMode<prim>, double> {
using type = Active<double>;
};

template<>
struct default_ret_activity<ForwardMode, float> {
using type = DuplicatedNoNeed<float>;
};

template<>
struct default_ret_activity<ForwardMode, double> {
using type = DuplicatedNoNeed<double>;
};

template < typename T >
struct ret_global;
Expand All @@ -209,6 +276,11 @@ namespace enzyme {
static constexpr int* value = &enzyme_dup_return;
};

template<typename T>
struct ret_global<DuplicatedNoNeed<T>> {
static constexpr int* value = &enzyme_dup_return;
};

template<typename Mode, typename RetAct>
struct ret_used;

Expand All @@ -222,9 +294,20 @@ namespace enzyme {
static constexpr int* value = &enzyme_noret;
};

} // namespace detail

template<typename T>
struct ret_used<ForwardMode, DuplicatedNoNeed<T>> {
static constexpr int* value = &enzyme_noret;
};
template<typename T>
struct ret_used<ForwardMode, Const<T>> {
static constexpr int* value = &enzyme_primal_return;
};
template<typename T>
struct ret_used<ForwardMode, Duplicated<T>> {
static constexpr int* value = &enzyme_primal_return;
};

} // namespace detail

template < typename return_type, typename function, typename ... enz_arg_types >
__attribute__((always_inline))
Expand All @@ -238,27 +321,27 @@ namespace enzyme {
return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...));
}

template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types >
template < typename return_type, typename DiffMode, typename function, typename RetActivity, typename ... enz_arg_types >
__attribute__((always_inline))
auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::rev_apply_impl<return_type>((void*)f, detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)f, detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return rev_autodiff_impl<return_type, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}

template < typename DiffMode, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using primal_return_type = decltype(primal_call<function, arg_types...>(impl::forward<function>(f), impl::forward<arg_types>(args)...));
using RetActivity = typename detail::default_ret_activity<primal_return_type>::type;
using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return rev_autodiff_impl<return_type, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}
}
}]>;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Integration/ForwardMode/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Run regression and unit tests
add_lit_testsuite(check-enzyme-integration-forward "Running enzyme forward mode integration tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${ENZYME_TEST_DEPS}
DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR}
ARGS -v
)

Expand Down
56 changes: 56 additions & 0 deletions enzyme/test/Integration/ForwardMode/sugar.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi

#include "../test_utils.h"

#include <enzyme/enzyme>

double foo(double x, double y) { return x * y; }

double square(double x) { return x * x; }

struct pair {
double x;
double y;
};

int main() {

{
enzyme::tuple< double, double > dsq = enzyme::autodiff<enzyme::Forward, enzyme::Duplicated<double>>(square, enzyme::Duplicated<double>(3.1, 1.0));
double dd = enzyme::get<1>(dsq);
printf("dsq = %f\n", dd);
APPROX_EQ(dd, 3.1*2, 1e-10);

double pp = enzyme::get<0>(dsq);
printf("sq = %f\n", pp);
APPROX_EQ(dd, 3.1*2, 1e-10);
}

{
enzyme::tuple< double > dsq = enzyme::autodiff<enzyme::Forward, enzyme::DuplicatedNoNeed<double>>(square, enzyme::Duplicated<double>(3.1, 1.0));
double dd = enzyme::get<0>(dsq);
printf("dsq = %f\n", dd);
APPROX_EQ(dd, 3.1*2, 1e-10);
}

{
enzyme::tuple< double > dsq = enzyme::autodiff<enzyme::Forward>(square, enzyme::Duplicated<double>(3.1, 1.0));
double dd = enzyme::get<0>(dsq);
printf("dsq = %f\n", dd);
APPROX_EQ(dd, 3.1*2, 1e-10);
}

{
enzyme::tuple< double > dsq = enzyme::autodiff<enzyme::Forward, enzyme::Const<double>>(square, enzyme::Duplicated<double>(3.1, 1.0));
double pp = enzyme::get<0>(dsq);
printf("sq = %f\n", pp);
APPROX_EQ(pp, 3.1*3.1, 1e-10);
}
}
2 changes: 1 addition & 1 deletion enzyme/test/Integration/ReverseMode/sugar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
Expand Down
Loading