Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 27, 2024
1 parent 11682bb commit dbd38e7
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 71 deletions.
228 changes: 171 additions & 57 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ extern int enzyme_dupnoneed;
extern int enzyme_out;
extern int enzyme_const;

extern int enzyme_const_return;
extern int enzyme_active_return;
extern int enzyme_dup_return;

extern int enzyme_primal_return;
extern int enzyme_noret;

template<typename Return, typename... T>
Return __enzyme_autodiff(T...);

Expand All @@ -21,118 +28,237 @@ Return __enzyme_fwddiff(T...);

namespace enzyme {

enum ReturnActivity{
INACTIVE,
ACTIVE,
DUPLICATED
};

struct nodiff{};

template<bool ReturnPrimal = false>
struct ReverseMode {

};
using Reverse = ReverseMode<false>;
using ReverseWithPrimal = ReverseMode<true>;

template < typename T >
struct active{
struct Active{
T value;
Active(T &&v) : value(v) {}
operator T&() { return value; }
};

template < typename T >
active(T) -> active<T>;

template < typename T >
struct duplicated{
struct Duplicated{
T value;
T shadow;
Duplicated(T &&v, T&& s) : value(v), shadow(s) {}
};

template < typename T >
struct inactive{
struct Const{
T value;
Const(T &&v) : value(v) {}
operator T&() { return value; }
};

template < typename T >
struct type_info {
static constexpr bool is_active = false;

#ifdef ENZYME_OMIT_INACTIVE
using type = tuple<>;
#else
using type = tuple<nodiff>;
#endif
using type = nodiff;
};

template < typename T >
struct type_info < active<T> >{
struct type_info < Active<T> >{
static constexpr bool is_active = true;
using type = tuple<T>;
using type = T;
};

template < typename ... T >
struct concatenated;

template < typename ... S, typename ... T, typename ... rest >
struct concatenated < tuple < S ... >, tuple < T ... >, rest ... > {
using type = typename concatenated< tuple< S ..., T ... >, rest ... >::type;
};

template < typename ... T >
struct concatenated < tuple < T ... > > {
using type = tuple< T ... >;
template < typename ... S, typename T, typename ... rest >
struct concatenated < tuple < S ... >, T, rest ... > {
using type = typename concatenated< tuple< S ..., T>, rest ... >::type;
};

template < typename T >
struct concatenated < tuple<T> > {
struct concatenated < T > {
using type = T;
};

// Yikes!
// slightly cleaner in C++20, with std::remove_cvref
template < typename ... T >
struct autodiff_return {
using type = typename concatenated<
struct autodiff_return;

template < typename RetType, typename ... T >
struct autodiff_return<ReverseMode<false>, RetType, T...>
{
using type = tuple<typename concatenated< tuple< >,
typename type_info<
typename remove_cvref< T >::type
>::type ...
>::type;
>::type>;
};

template < typename RetType, typename ... T >
struct autodiff_return<ReverseMode<true>, RetType, T...>
{
using type = tuple<
typename type_info<RetType>::type,
typename concatenated< tuple< >,
typename type_info<
typename remove_cvref< T >::type
>::type ...
>::type
>;
};

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::duplicated<T> & arg) {
auto expand_args(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple<int, T, T>{enzyme_dup, arg.value, arg.shadow};
}

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::active<T> & arg) {
auto expand_args(const enzyme::Active<T> & arg) {
return enzyme::tuple<int, T>{enzyme_out, arg.value};
}

template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::inactive<T> & arg) {
auto expand_args(const enzyme::Const<T> & arg) {
return enzyme::tuple<int, T>{enzyme_const, arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Active<T> & arg) {
return enzyme::tuple<T>{arg.value};
}

template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Const<T> & arg) {
return enzyme::tuple<T>{arg.value};
}

namespace detail {
template<typename T>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(T &&t);

template<typename ...T>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(tuple<tuple<T...>> &&t) {
return tuple<tuple<T...>>{get<0>(t)};
}

template<typename ...T, typename R>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(tuple<R, tuple<T...>> &&t) {
return tuple{get<1>(t), get<0>(t)};
}

template <class return_type, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) apply_impl(void* f, Tuple&& t, std::index_sequence<I...>) {
return __enzyme_autodiff<return_type>(f, enzyme::get<I>(impl::forward<Tuple>(t))...);
constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, enzyme::get<I>(impl::forward<Tuple>(t))...));
}

template <typename function, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence<I...>) {
return f(enzyme::get<I>(impl::forward<Tuple>(t))...);
}

template < typename T >
struct default_ret_activity {
using type = Const<T>;
};

template <>
struct default_ret_activity<float> {
using type = Active<float>;
};

template <>
struct default_ret_activity<double> {
using type = Active<double>;
};

template < typename T >
struct ret_global;

template<typename T>
struct ret_global<Const<T>> {
static constexpr int* value = &enzyme_const_return;
};

template<typename T>
struct ret_global<Active<T>> {
static constexpr int* value = &enzyme_active_return;
};

template<typename T>
struct ret_global<Duplicated<T>> {
static constexpr int* value = &enzyme_dup_return;
};

template<typename Mode, typename RetAct>
struct ret_used;

template<typename RetAct>
struct ret_used<ReverseMode<true>, RetAct> {
static constexpr int* value = &enzyme_primal_return;
};

template<typename RetAct>
struct ret_used<ReverseMode<false>, RetAct> {
static constexpr int* value = &enzyme_noret;
};

} // namespace detail



template < typename return_type, typename function, typename ... enz_arg_types >
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::apply_impl<return_type>((void*)f, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::primal_apply_impl<return_type>(f, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename function, typename ... arg_types>
auto primal_call(function && f, arg_types && ... args) {
return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...));
}

template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types >
__attribute__((always_inline))
auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::rev_apply_impl<return_type>((void*)f, detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return rev_autodiff_impl<return_type, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}

template < typename DiffMode, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using return_type = typename autodiff_return<arg_types...>::type;
return autodiff_impl<return_type, function>(impl::forward<function>(f), enzyme::tuple_cat(expand_args(args)...));
using primal_return_type = decltype(primal_call<function, arg_types...>(impl::forward<function>(f), impl::forward<arg_types>(args)...));
using RetActivity = typename detail::default_ret_activity<primal_return_type>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return rev_autodiff_impl<return_type, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}
}
}]>;
Expand Down Expand Up @@ -162,7 +288,7 @@ namespace impl {
template<typename _Tp>
__attribute__((always_inline))
constexpr _Tp&&
forward(typename std::remove_reference<_Tp>::type& __t) noexcept
forward(std::remove_reference_t<_Tp>& __t) noexcept
{ return static_cast<_Tp&&>(__t); }

/**
Expand All @@ -174,12 +300,13 @@ namespace impl {
template<typename _Tp>
__attribute__((always_inline))
constexpr _Tp&&
forward(typename std::remove_reference<_Tp>::type&& __t) noexcept
forward(std::remove_reference_t<_Tp>&& __t) noexcept
{
static_assert(!std::is_lvalue_reference<_Tp>::value,
"enzyme::impl::forward must not be used to convert an rvalue to an lvalue");
return static_cast<_Tp&&>(__t);
}

}

}
Expand Down Expand Up @@ -263,23 +390,10 @@ struct tuple_size<tuple<T...>> : std::integral_constant<size_t, sizeof...(T)> {}
template <typename Tuple>
static constexpr size_t tuple_size_v = tuple_size<Tuple>::value;

template <typename T>
__attribute__((always_inline))
auto forward(std::remove_reference_t<T>& arg) _NOEXCEPT {
return static_cast<T&&>(arg);
}

template <typename T>
__attribute__((always_inline))
auto forward(std::remove_reference_t<T>&& arg) _NOEXCEPT {
static_assert(!std::is_lvalue_reference<T>::value, "cannot forward an rvalue as an lvalue");
return static_cast<T&&>(arg);
}

template <typename... T>
__attribute__((always_inline))
constexpr auto forward_as_tuple(T&&... args) noexcept {
return tuple<T&&...>{forward<T>(args)...};
return tuple<T&&...>{impl::forward<T>(args)...};
}

namespace impl {
Expand Down
3 changes: 0 additions & 3 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,6 @@ Optional<StringRef> getMetadataName(llvm::Value *res)
return recursePhiReads(cast<PHINode>(res));
}

llvm::errs() << " failed to simplify: " << *res << "\n";

return {};
}

Expand Down Expand Up @@ -2848,7 +2846,6 @@ class EnzymeBase {
}

bool run(Module &M) {
llvm::errs() << M << "\n";
Logic.clear();

bool changed = false;
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/PreserveNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,6 @@ extern "C" void AddPreserveNVVMPass(LLVMPassManagerRef PM, uint8_t Begin) {

PreserveNVVMNewPM::Result
PreserveNVVMNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) {
llvm::errs() << " PNPM: " << M << "\n";
bool changed = false;
for (auto &F : M)
changed |= preserveNVVM(Begin, F);
Expand Down
Loading

0 comments on commit dbd38e7

Please sign in to comment.