Skip to content

Commit

Permalink
store
Browse files Browse the repository at this point in the history
  • Loading branch information
Christopher Mauney committed Aug 29, 2024
1 parent 36e848c commit 551910b
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 21 deletions.
26 changes: 5 additions & 21 deletions ports-of-call/portable_arrays.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
// NOTE THE TRAILING INDEX INSIDE THE PARENTHESES IS INDEXED FASTEST

#include "portability.hpp"
#include "utility/array.hpp"
#include <algorithm>
#include <array>
#include <assert.h>
#include <cstddef> // size_t
#include <cstring> // memset()
#include <functional>
#include <integer_sequence>
#include <numeric>
#include <type_traits>
#include <utility> // swap()
Expand All @@ -43,25 +45,6 @@ namespace detail {
template <std::size_t I, std::size_t V>
constexpr std::size_t to_const = V;

// array type of dimensions/strides
// multiply reduce array
// NOTE: we can do product of variadic parameters `vars...` as
// `arr_mul({vars...})`
template <typename T, std::size_t N>
PORTABLE_INLINE_FUNCTION auto arr_mul(const std::array<T, N> &a) {
auto r = T{1};
for (auto v : a)
r *= v;
return r;
}
PORTABLE_FORCEINLINE_FUNCTION
decltype(auto) vp_prod() {
return [](auto &&v) {
return std::accumulate(v.begin(), v.end(), 1,
std::multiplies<std::size_t>());
};
}

} // namespace detail

template <typename T>
Expand Down Expand Up @@ -134,7 +117,7 @@ class PortableMDArray {
}

PORTABLE_FORCEINLINE_FUNCTION int GetSize() const {
return detail::vp_prod()(nxs_);
return util::array_reduce(nxs_, std::multiplies<std::size_t>{});
}
PORTABLE_FORCEINLINE_FUNCTION std::size_t GetSizeInBytes() const {
return GetSize() * sizeof(T);
Expand All @@ -143,7 +126,8 @@ class PortableMDArray {
PORTABLE_INLINE_FUNCTION size_t GetRank() const { return rank_; }
template <typename... NXs>
PORTABLE_INLINE_FUNCTION void Reshape(NXs... nxs) {
assert(detail::vp_prod()(std::array{nxs...}) == GetSize());
assert(util::array_reduce(std::array{nxs...},
std::multiplies<std::size_t>{}) == GetSize());
update_layout(nxs...);
}
PORTABLE_FORCEINLINE_FUNCTION bool IsShallowSlice() { return true; }
Expand Down
74 changes: 74 additions & 0 deletions ports-of-call/utility/array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef _PORTSOFCALL_UTILITY_ARRAY_HPP_
#define _PORTSOFCALL_UTILITY_ARRAY_HPP_

#include "../portability.hpp"
#include <array>
#include <type_traits>

namespace util {

namespace detail {
template <class T, std::size_t N, class F, std::size_t... Is>
PORTABLE_FORCEINLINE_FUNCTION constexpr auto
array_map_impl(std::array<T, N> const &x, F f, std::index_sequence<Is...>) {
return std::array{f(x[Is])...};
}

template <class T, class U, std::size_t N, class F, std::size_t... Is>
PORTABLE_FORCEINLINE_FUNCTION constexpr auto
array_map_impl(std::array<T, N> const &x, std::array<U, N> const &y, F f,
std::index_sequence<Is...>) {
return std::array{f(x[Is], y[Is])...};
}

template <std::size_t f, std::size_t l, class T, std::size_t N, class Op>
PORTABLE_INLINE_FUNCTION constexpr T
array_reduce_impl(std::array<T, N> const &x, Op op) {
if constexpr ((l - f) == 1)
return x[f];
else {
constexpr std::size_t n = l - f;
T left_sum = array_reduce_impl<f, f + n / 2>(x, op);
T right_sum = array_reduce_impl<f + n / 2, l>(x, op);
return op(left_sum, right_sum);
}
}

} // namespace detail

template <class T, std::size_t N, class F>
PORTABLE_FORCEINLINE_FUNCTION constexpr auto
array_map(std::array<T, N> const &x, F f) {
return detail::array_map_impl(x, f, std::make_index_sequence<N>{});
}

template <class T, class U, std::size_t N, class F>
PORTABLE_FORCEINLINE_FUNCTION constexpr auto
array_map(std::array<T, N> const &x, std::array<U, N> const &y, F f) {
return detail::array_map_impl(x, y, f, std::make_index_sequence<N>{});
}

template <std::size_t I, class T, std::size_t N, class Op>
PORTABLE_FORCEINLINE_FUNCTION constexpr T
array_partial_reduce(std::array<T, N> x, T initial_value, Op op) {
static_assert(I <= N);
if constexpr (I == 0)
return initial_value;
else
return detail::array_reduce_impl<0, I>(x, op);
}

template <class T, std::size_t N, class Op>
PORTABLE_FORCEINLINE_FUNCTION constexpr T array_reduce(std::array<T, N> x,
T initial_value, Op op) {
return array_partial_reduce<N>(x, initial_value, op);
}
template <std::size_t... Is>
PORTABLE_FORCEINLINE_FUNCTION constexpr auto
as_array(std::index_sequence<sizeof...(Is)>) {
return std::array{Is...};
}

} // namespace util

#endif // _PORTSOFCALL_UTILITY_ARRAY_HPP_
108 changes: 108 additions & 0 deletions ports-of-call/utility/index_algo.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#ifndef _PORTSOFCALL_UTILITY_INDEX_ALGO_HPP_
#define _PORTSOFCALL_UTILITY_INDEX_ALGO_HPP_

#include "../portability.hpp"
#include "array.hpp"
#include <array>
#include <numeric>
#include <type_traits>

namespace util {

template <auto I, class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
get_stride(std::array<T, N> const &dim) {
static_assert(I < dim.size(), "Dim index is out of bounds");

// column major
return std::accumulate(std::begin(dim), std::begin(dim) + I, std::size_t{1},
std::multiplies<std::size_t>{});
}

namespace detail {
template <class T, std::size_t N, std::size_t... Is>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
get_strides_impl(std::array<T, N> const &dim, std::index_sequence<Is...>) {
return std::array{get_stride<Is>(dim)...};
}

} // namespace detail

template <class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
get_strides(std::array<T, N> const &dim) {
return detail::get_strides_impl(dim, std::make_index_sequence<dim.size()>{});
}

template <class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
fast_findex(std::array<T, N> const &ijk, std::array<T, N> const &dim,
std::array<T, N> const &stride) {
// TODO: assert ijk in bounds
return std::inner_product(std::begin(ijk), std::end(ijk), std::begin(stride),
std::size_t{0});
}

template <class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
findex(std::array<T, N> const &ijk, std::array<T, N> const &dim) {
return fast_findex(ijk, dim, get_strides(dim));
}

namespace handroll {
template <auto I, class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
get_stride(std::array<T, N> const &dim) {

// column major
return array_partial_reduce<I>(dim, T{1}, std::multiplies<std::size_t>{});
}

namespace detail {
template <class T, std::size_t N, std::size_t... Is>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
get_strides_impl(std::array<T, N> const &dim, std::index_sequence<Is...>) {
return std::array{get_stride<Is>(dim)...};
}
template <class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
get_strides(std::array<T, N> const &dim) {
return detail::get_strides_impl(dim, std::make_index_sequence<dim.size()>{});
} // namespace handroll::detail
template <class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
fast_findex(std::array<T, N> const &ijk, std::array<T, N> const &dim,
std::array<T, N> const &stride) {
// TODO: assert ijk in bounds
return array_reduce(
array_map(ijk, stride, [](auto a, auto b) { return a * b; }), T{1},
std::plus<std::size_t>{});
}
template <class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto
findex(std::array<T, N> const &ijk, std::array<T, N> const &dim) {
return fast_findex(ijk, dim, get_strides(dim));
}
} // namespace handroll

template <class T, std::size_t N>
PORTABLE_FORCEINLINE_FUNCTION static constexpr std::array<T, N>
fast_mindices(std::size_t idx, std::array<T, N> const &dim,
std::array<T, N> const &stride) {
std::array<T, N> mdidx;
for (std::int64_t i = dim.size() - 1; i >= 0; --i) {
mdidx[i] = idx / std::size_t(stride[i]);
idx -= mdidx[i] * std::size_t(stride[i]);
}
return mdidx;
}

template <class Array>
PORTABLE_FORCEINLINE_FUNCTION static constexpr auto mindices(std::size_t idx,
Array dim) {
return fast_mindices(idx, dim, get_strides(dim));
}

} // namespace util

#endif // _PORTSOFCALL_UTILITY_INDEX_ALGO_HPP_

0 comments on commit 551910b

Please sign in to comment.