-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Christopher Mauney
committed
Aug 29, 2024
1 parent
36e848c
commit 551910b
Showing
3 changed files
with
187 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#ifndef _PORTSOFCALL_UTILITY_ARRAY_HPP_ | ||
#define _PORTSOFCALL_UTILITY_ARRAY_HPP_ | ||
|
||
#include "../portability.hpp" | ||
#include <array> | ||
#include <type_traits> | ||
|
||
namespace util { | ||
|
||
namespace detail { | ||
template <class T, std::size_t N, class F, std::size_t... Is> | ||
PORTABLE_FORCEINLINE_FUNCTION constexpr auto | ||
array_map_impl(std::array<T, N> const &x, F f, std::index_sequence<Is...>) { | ||
return std::array{f(x[Is])...}; | ||
} | ||
|
||
template <class T, class U, std::size_t N, class F, std::size_t... Is> | ||
PORTABLE_FORCEINLINE_FUNCTION constexpr auto | ||
array_map_impl(std::array<T, N> const &x, std::array<U, N> const &y, F f, | ||
std::index_sequence<Is...>) { | ||
return std::array{f(x[Is], y[Is])...}; | ||
} | ||
|
||
template <std::size_t f, std::size_t l, class T, std::size_t N, class Op> | ||
PORTABLE_INLINE_FUNCTION constexpr T | ||
array_reduce_impl(std::array<T, N> const &x, Op op) { | ||
if constexpr ((l - f) == 1) | ||
return x[f]; | ||
else { | ||
constexpr std::size_t n = l - f; | ||
T left_sum = array_reduce_impl<f, f + n / 2>(x, op); | ||
T right_sum = array_reduce_impl<f + n / 2, l>(x, op); | ||
return op(left_sum, right_sum); | ||
} | ||
} | ||
|
||
} // namespace detail | ||
|
||
template <class T, std::size_t N, class F> | ||
PORTABLE_FORCEINLINE_FUNCTION constexpr auto | ||
array_map(std::array<T, N> const &x, F f) { | ||
return detail::array_map_impl(x, f, std::make_index_sequence<N>{}); | ||
} | ||
|
||
template <class T, class U, std::size_t N, class F> | ||
PORTABLE_FORCEINLINE_FUNCTION constexpr auto | ||
array_map(std::array<T, N> const &x, std::array<U, N> const &y, F f) { | ||
return detail::array_map_impl(x, y, f, std::make_index_sequence<N>{}); | ||
} | ||
|
||
template <std::size_t I, class T, std::size_t N, class Op> | ||
PORTABLE_FORCEINLINE_FUNCTION constexpr T | ||
array_partial_reduce(std::array<T, N> x, T initial_value, Op op) { | ||
static_assert(I <= N); | ||
if constexpr (I == 0) | ||
return initial_value; | ||
else | ||
return detail::array_reduce_impl<0, I>(x, op); | ||
} | ||
|
||
template <class T, std::size_t N, class Op> | ||
PORTABLE_FORCEINLINE_FUNCTION constexpr T array_reduce(std::array<T, N> x, | ||
T initial_value, Op op) { | ||
return array_partial_reduce<N>(x, initial_value, op); | ||
} | ||
template <std::size_t... Is> | ||
PORTABLE_FORCEINLINE_FUNCTION constexpr auto | ||
as_array(std::index_sequence<sizeof...(Is)>) { | ||
return std::array{Is...}; | ||
} | ||
|
||
} // namespace util | ||
|
||
#endif // _PORTSOFCALL_UTILITY_ARRAY_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#ifndef _PORTSOFCALL_UTILITY_INDEX_ALGO_HPP_ | ||
#define _PORTSOFCALL_UTILITY_INDEX_ALGO_HPP_ | ||
|
||
#include "../portability.hpp" | ||
#include "array.hpp" | ||
#include <array> | ||
#include <numeric> | ||
#include <type_traits> | ||
|
||
namespace util { | ||
|
||
template <auto I, class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
get_stride(std::array<T, N> const &dim) { | ||
static_assert(I < dim.size(), "Dim index is out of bounds"); | ||
|
||
// column major | ||
return std::accumulate(std::begin(dim), std::begin(dim) + I, std::size_t{1}, | ||
std::multiplies<std::size_t>{}); | ||
} | ||
|
||
namespace detail { | ||
template <class T, std::size_t N, std::size_t... Is> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
get_strides_impl(std::array<T, N> const &dim, std::index_sequence<Is...>) { | ||
return std::array{get_stride<Is>(dim)...}; | ||
} | ||
|
||
} // namespace detail | ||
|
||
template <class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
get_strides(std::array<T, N> const &dim) { | ||
return detail::get_strides_impl(dim, std::make_index_sequence<dim.size()>{}); | ||
} | ||
|
||
template <class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
fast_findex(std::array<T, N> const &ijk, std::array<T, N> const &dim, | ||
std::array<T, N> const &stride) { | ||
// TODO: assert ijk in bounds | ||
return std::inner_product(std::begin(ijk), std::end(ijk), std::begin(stride), | ||
std::size_t{0}); | ||
} | ||
|
||
template <class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
findex(std::array<T, N> const &ijk, std::array<T, N> const &dim) { | ||
return fast_findex(ijk, dim, get_strides(dim)); | ||
} | ||
|
||
namespace handroll { | ||
template <auto I, class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
get_stride(std::array<T, N> const &dim) { | ||
|
||
// column major | ||
return array_partial_reduce<I>(dim, T{1}, std::multiplies<std::size_t>{}); | ||
} | ||
|
||
namespace detail { | ||
template <class T, std::size_t N, std::size_t... Is> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
get_strides_impl(std::array<T, N> const &dim, std::index_sequence<Is...>) { | ||
return std::array{get_stride<Is>(dim)...}; | ||
} | ||
template <class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
get_strides(std::array<T, N> const &dim) { | ||
return detail::get_strides_impl(dim, std::make_index_sequence<dim.size()>{}); | ||
} // namespace handroll::detail | ||
template <class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
fast_findex(std::array<T, N> const &ijk, std::array<T, N> const &dim, | ||
std::array<T, N> const &stride) { | ||
// TODO: assert ijk in bounds | ||
return array_reduce( | ||
array_map(ijk, stride, [](auto a, auto b) { return a * b; }), T{1}, | ||
std::plus<std::size_t>{}); | ||
} | ||
template <class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto | ||
findex(std::array<T, N> const &ijk, std::array<T, N> const &dim) { | ||
return fast_findex(ijk, dim, get_strides(dim)); | ||
} | ||
} // namespace handroll | ||
|
||
template <class T, std::size_t N> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr std::array<T, N> | ||
fast_mindices(std::size_t idx, std::array<T, N> const &dim, | ||
std::array<T, N> const &stride) { | ||
std::array<T, N> mdidx; | ||
for (std::int64_t i = dim.size() - 1; i >= 0; --i) { | ||
mdidx[i] = idx / std::size_t(stride[i]); | ||
idx -= mdidx[i] * std::size_t(stride[i]); | ||
} | ||
return mdidx; | ||
} | ||
|
||
template <class Array> | ||
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto mindices(std::size_t idx, | ||
Array dim) { | ||
return fast_mindices(idx, dim, get_strides(dim)); | ||
} | ||
|
||
} // namespace util | ||
|
||
#endif // _PORTSOFCALL_UTILITY_INDEX_ALGO_HPP_ |