Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 26, 2024
1 parent 1520b63 commit 62756e3
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 316 deletions.
1 change: 1 addition & 0 deletions enzyme/Enzyme/Clang/EnzymeClang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendAction.h"
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "clang/Lex/HeaderSearch.h"
#include "clang/Lex/PreprocessorOptions.h"
#include "clang/Sema/Sema.h"
#include "clang/Sema/SemaDiagnostic.h"
Expand Down
43 changes: 34 additions & 9 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,22 @@ namespace enzyme {
template <class return_type, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) apply_impl(void* f, Tuple&& t, std::index_sequence<I...>) {
return __enzyme_autodiff<return_type>(f, enzyme::get<I>(std::forward<Tuple>(t))...);
return __enzyme_autodiff<return_type>(f, enzyme::get<I>(impl::forward<Tuple>(t))...);
}
} // namespace detail

template < typename return_type, typename function, typename ... enz_arg_types >
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::apply_impl<return_type>((void*)f, std::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
return detail::apply_impl<return_type>((void*)f, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}

template < typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
using return_type = typename autodiff_return<arg_types...>::type;
return autodiff_impl<return_type, function>(std::forward<function>(f), enzyme::tuple_cat(expand_args(args)...));
return autodiff_impl<return_type, function>(impl::forward<function>(f), enzyme::tuple_cat(expand_args(args)...));
}
}
}]>;
Expand All @@ -158,6 +158,30 @@ struct remove_cvref {
template < typename T >
using remove_cvref_t = typename remove_cvref<T>::type;

namespace impl {
template<typename _Tp>
__attribute__((always_inline))
constexpr _Tp&&
forward(typename std::remove_reference<_Tp>::type& __t) noexcept
{ return static_cast<_Tp&&>(__t); }

/**
* @brief Forward an rvalue.
* @return The parameter cast to the specified type.
*
* This function is used to implement "perfect forwarding".
*/
template<typename _Tp>
__attribute__((always_inline))
constexpr _Tp&&
forward(typename std::remove_reference<_Tp>::type&& __t) noexcept
{
static_assert(!std::is_lvalue_reference<_Tp>::value,
"enzyme::impl::forward must not be used to convert an rvalue to an lvalue");
return static_cast<_Tp&&>(__t);
}
}

}
}]>;

Expand Down Expand Up @@ -209,6 +233,7 @@ template <typename... T>
struct tuple : public tuple_base<std::make_integer_sequence<int, sizeof...(T)>, T...> {};

template <typename... T>
__attribute__((always_inline))
tuple(T ...) -> tuple<T...>;

template < int i, typename Tuple >
Expand Down Expand Up @@ -267,7 +292,7 @@ struct make_tuple_from_fwd_tuple<std::index_sequence<indices...>> {
template <typename FWD_TUPLE>
__attribute__((always_inline))
static constexpr auto f(FWD_TUPLE&& fwd) {
return tuple{get<indices>(forward<FWD_TUPLE>(fwd))...};
return tuple{get<indices>(impl::forward<FWD_TUPLE>(fwd))...};
}
};

Expand All @@ -282,28 +307,28 @@ struct concat_with_fwd_tuple<std::index_sequence<fwd_indices...>, std::index_seq
template <typename FWD_TUPLE, typename TUPLE>
__attribute__((always_inline))
static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) {
return forward_as_tuple(get<fwd_indices>(forward<FWD_TUPLE>(fwd))..., get<indices>(std::forward<TUPLE>(t))...);
return forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
}
};

template <typename Tuple>
__attribute__((always_inline))
static constexpr auto tuple_cat(Tuple&& ret) {
return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(forward< Tuple >(ret));
return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(impl::forward< Tuple >(ret));
}

template <typename FWD_TUPLE, typename first, typename... rest>
__attribute__((always_inline))
static constexpr auto tuple_cat(FWD_TUPLE&& fwd, first&& t, rest&&... ts) {
return tuple_cat(concat_with_fwd_tuple< iseq<FWD_TUPLE>, iseq<first> >::f(forward<FWD_TUPLE>(fwd), std::forward<first>(t)), std::forward<rest>(ts)...);
return tuple_cat(concat_with_fwd_tuple< iseq<FWD_TUPLE>, iseq<first> >::f(impl::forward<FWD_TUPLE>(fwd), impl::forward<first>(t)), impl::forward<rest>(ts)...);
}

} // namespace impl

template <typename... Tuples>
__attribute__((always_inline))
constexpr auto tuple_cat(Tuples&&... tuples) {
return impl::tuple_cat(std::forward<Tuples>(tuples)...);
return impl::tuple_cat(impl::forward<Tuples>(tuples)...);
}

} // namespace enzyme
Expand All @@ -316,4 +341,4 @@ def : Headers<"/enzymeroot/enzyme/enzyme", [{
#else
#warning "Enzyme wrapper templates only available in C++"
#endif
}]>;
}]>;
245 changes: 2 additions & 243 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,249 +429,6 @@ static Optional<StringRef> recursePhiReads(PHINode *val)
return finalMetadata;
}

Value *simplifyLoad(Value *LI, size_t valSz = 0);

// Find the base pointer of ptr and the offset in bytes from the start of
// the returned base pointer to this value.
AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) {
offset = 0;
while (true) {
if (auto CI = dyn_cast<CastInst>(ptr)) {
ptr = CI->getOperand(0);
continue;
}
if (auto CI = dyn_cast<GetElementPtrInst>(ptr)) {
auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
MapVector<Value *, APInt> VariableOffsets;
auto width = sizeof(size_t) * 8;
APInt Offset(width, 0);
bool success = collectOffset(cast<GEPOperator>(CI), DL, width,
VariableOffsets, Offset);
if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
return nullptr;
}
offset += Offset.getZExtValue();
ptr = CI->getOperand(0);
continue;
}
if (isa<AllocaInst>(ptr)) {
break;
}
if (auto LI = dyn_cast<LoadInst>(ptr)) {
if (auto S = simplifyLoad(LI)) {
ptr = S;
continue;
}
}
return nullptr;
}
return cast<AllocaInst>(ptr);
}

// Find all user instructions of AI, returning tuples of <instruction, value,
// byte offet from AI> Unlike a simple get users, this will recurse through any
// constant gep offsets and casts
SmallVector<std::tuple<Instruction *, Value *, size_t>, 1>
findAllUsersOf(Value *AI) {
SmallVector<std::pair<Value *, size_t>, 1> todo;
todo.emplace_back(AI, 0);

SmallVector<std::tuple<Instruction *, Value *, size_t>, 1> users;
while (todo.size()) {
auto pair = todo.pop_back_val();
Value *ptr = pair.first;
size_t suboff = pair.second;

for (auto U : ptr->users()) {
if (auto CI = dyn_cast<CastInst>(U)) {
todo.emplace_back(CI, suboff);
continue;
}
if (auto CI = dyn_cast<GetElementPtrInst>(U)) {
auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
MapVector<Value *, APInt> VariableOffsets;
auto width = sizeof(size_t) * 8;
APInt Offset(width, 0);
bool success = collectOffset(cast<GEPOperator>(CI), DL, width,
VariableOffsets, Offset);

if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
users.emplace_back(cast<Instruction>(U), ptr, suboff);
continue;
}
todo.emplace_back(CI, suboff + Offset.getZExtValue());
continue;
}
users.emplace_back(cast<Instruction>(U), ptr, suboff);
continue;
}
}
return users;
}

// Given a pointer, find all values of size `valSz` which could be loaded from
// that pointer when indexed at offset. If it is impossible to guarantee that
// the set contains all such values, set legal to false
SmallVector<Value *, 1> getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset,
size_t valSz, bool &legal) {
SmallVector<Value *, 1> options;

auto todo = findAllUsersOf(ptr0);
std::set<std::tuple<Instruction *, Value *, size_t>> seen;

while (todo.size()) {
auto pair = todo.pop_back_val();
if (seen.count(pair))
continue;
seen.insert(pair);
Instruction *U = std::get<0>(pair);
Value *ptr = std::get<1>(pair);
size_t suboff = std::get<2>(pair);

// Read only users do not set the memory inside of ptr
if (isa<LoadInst>(U)) {
continue;
}
if (auto MTI = dyn_cast<MemTransferInst>(U))
if (MTI->getOperand(0) != ptr) {
continue;
}
if (auto I = dyn_cast<Instruction>(U)) {
if (!I->mayWriteToMemory() && I->getType()->isVoidTy())
continue;
}

if (auto SI = dyn_cast<StoreInst>(U)) {
auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout();

// We are storing into the ptr
if (SI->getPointerOperand() == ptr) {
auto storeSz =
(DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) /
8;
// If store is before the load would start
if (storeSz + suboff <= offset)
continue;
// if store starts after load would start
if (offset + valSz <= suboff)
continue;

if (valSz == storeSz) {
options.push_back(SI->getValueOperand());
continue;
}
}

// We capture our pointer of interest, if it is stored into an alloca,
// all loads of said alloca would potentially store into.
if (SI->getValueOperand() == ptr) {
if (suboff == 0) {
size_t mid_offset = 0;
if (auto AI2 =
getBaseAndOffset(SI->getPointerOperand(), mid_offset)) {
bool sublegal = true;
auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8;
auto subPtrs =
getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal);
if (!sublegal) {
legal = false;
return options;
}
for (auto subPtr : subPtrs) {
for (const auto &pair3 : findAllUsersOf(subPtr)) {
todo.emplace_back(pair3);
}
}
continue;
}
}
}
}

// If we copy into the ptr at a location that includes the offset, consider
// all sub uses
if (auto MTI = dyn_cast<MemTransferInst>(U)) {
if (auto CI = dyn_cast<ConstantInt>(MTI->getLength())) {
if (MTI->getOperand(0) == ptr && suboff == 0 &&
CI->getValue().uge(offset + valSz)) {
size_t midoffset = 0;
auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset);
if (!AI2) {
legal = false;
return options;
}
if (midoffset != 0) {
legal = false;
return options;
}
for (const auto &pair3 : findAllUsersOf(AI2)) {
todo.emplace_back(pair3);
}
continue;
}
}
}

legal = false;
return options;
}

return options;
}

// Perform mem2reg/sroa to identify the innermost value being represented.
Value *simplifyLoad(Value *V, size_t valSz) {
if (auto LI = dyn_cast<LoadInst>(V)) {
if (valSz == 0) {
auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout();
valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8;
}

Value *ptr = LI->getPointerOperand();
size_t offset = 0;

auto AI = getBaseAndOffset(ptr, offset);
if (!AI)
return nullptr;

bool legal = true;
auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal);

if (!legal) {
return nullptr;
}
std::set<Value *> res;
for (auto opt : opts) {
Value *v2 = simplifyLoad(opt, valSz);
if (v2)
res.insert(v2);
else
res.insert(opt);
}
if (res.size() != 1) {
return nullptr;
}
Value *retval = *res.begin();
return retval;
}
if (auto EVI = dyn_cast<ExtractValueInst>(V)) {
bool allZero = true;
for (auto idx : EVI->getIndices()) {
if (idx != 0)
allZero = false;
}
if (valSz == 0) {
auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout();
valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8;
}
if (allZero)
if (auto LI = dyn_cast<LoadInst>(EVI->getAggregateOperand())) {
return simplifyLoad(LI, valSz);
}
}
return nullptr;
}

#if LLVM_VERSION_MAJOR > 16
std::optional<StringRef> getMetadataName(llvm::Value *res)
#else
Expand Down Expand Up @@ -713,6 +470,8 @@ Optional<StringRef> getMetadataName(llvm::Value *res)
return recursePhiReads(cast<PHINode>(res));
}

llvm::errs() << " failed to simplify: " << *res << "\n";

return {};
}

Expand Down
Loading

0 comments on commit 62756e3

Please sign in to comment.