Skip to content

Commit

Permalink
Merged in ArrayViewVariant refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Sep 25, 2024
1 parent 529d362 commit aeb352f
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 33 deletions.
1 change: 1 addition & 0 deletions src/atlas/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "atlas/array/ArraySpec.h"
#include "atlas/array/ArrayStrides.h"
#include "atlas/array/ArrayView.h"
#include "atlas/array/ArrayViewVariant.h"
#include "atlas/array/DataType.h"
#include "atlas/array/LocalView.h"
#include "atlas/array/MakeView.h"
Expand Down
46 changes: 29 additions & 17 deletions src/atlas/array/ArrayViewVariant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,34 @@ using namespace detail;

namespace {

template <bool IsConst>
struct VariantTypeHelper {
using type = ArrayViewVariant;
};

template <>
struct VariantTypeHelper<true> {
using type = ConstArrayViewVariant;
};

template <typename ArrayType>
using VariantType =
typename VariantTypeHelper<std::is_const_v<ArrayType>>::type;

// Match array.rank() and array.datatype() to variant types. Return result of
// makeView on a successful pattern match.
template <size_t TypeIndex = 0, typename ArrayType, typename MakeView>
ArrayViewVariant executeMakeView(ArrayType& array, const MakeView& makeView) {
using View = std::variant_alternative_t<TypeIndex, ArrayViewVariant>;
constexpr auto Const = std::is_const_v<typename View::value_type>;

if constexpr (std::is_const_v<ArrayType> == Const) {
using Value = typename View::non_const_value_type;
constexpr auto Rank = View::rank();
if (array.datatype() == DataType::kind<Value>() && array.rank() == Rank) {
return makeView(array, Value{}, std::integral_constant<int, Rank>{});
}
VariantType<ArrayType> executeMakeView(ArrayType& array,
const MakeView& makeView) {
using View = std::variant_alternative_t<TypeIndex, VariantType<ArrayType>>;
using Value = typename View::non_const_value_type;
constexpr auto Rank = View::rank();

if (array.datatype() == DataType::kind<Value>() && array.rank() == Rank) {
return makeView(array, Value{}, std::integral_constant<int, Rank>{});
}

if constexpr (TypeIndex < std::variant_size_v<ArrayViewVariant> - 1) {
if constexpr (TypeIndex < std::variant_size_v<VariantType<ArrayType>> - 1) {
return executeMakeView<TypeIndex + 1>(array, makeView);
} else {
throw_Exception("ArrayView<" + array.datatype().str() + ", " +
Expand All @@ -45,23 +57,23 @@ ArrayViewVariant executeMakeView(ArrayType& array, const MakeView& makeView) {
}

template <typename ArrayType>
ArrayViewVariant makeViewVariantImpl(ArrayType& array) {
VariantType<ArrayType> makeViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

template <typename ArrayType>
ArrayViewVariant makeHostViewVariantImpl(ArrayType& array) {
VariantType<ArrayType> makeHostViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_host_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

template <typename ArrayType>
ArrayViewVariant makeDeviceViewVariantImpl(ArrayType& array) {
VariantType<ArrayType> makeDeviceViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_device_view<decltype(value), decltype(rank)::value>(array);
};
Expand All @@ -74,23 +86,23 @@ ArrayViewVariant make_view_variant(Array& array) {
return makeViewVariantImpl(array);
}

ArrayViewVariant make_view_variant(const Array& array) {
ConstArrayViewVariant make_view_variant(const Array& array) {
return makeViewVariantImpl(array);
}

ArrayViewVariant make_host_view_variant(Array& array) {
return makeHostViewVariantImpl(array);
}

ArrayViewVariant make_host_view_variant(const Array& array) {
ConstArrayViewVariant make_host_view_variant(const Array& array) {
return makeHostViewVariantImpl(array);
}

ArrayViewVariant make_device_view_variant(Array& array) {
return makeDeviceViewVariantImpl(array);
}

ArrayViewVariant make_device_view_variant(const Array& array) {
ConstArrayViewVariant make_device_view_variant(const Array& array) {
return makeDeviceViewVariantImpl(array);
}

Expand Down
29 changes: 23 additions & 6 deletions src/atlas/array/ArrayViewVariant.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ using namespace array;

// Container struct for a list of types.
template <typename... Ts>
struct Types {};
struct Types {
using add_const = Types<std::add_const_t<Ts>...>;
};

// Container struct for a list of integers.
template <int... Is>
Expand All @@ -33,7 +35,6 @@ struct VariantHelper;
template <typename T, typename... Ts, int... Is, typename... ArrayViews>
struct VariantHelper<Types<T, Ts...>, Ints<Is...>, ArrayViews...> {
using type = typename VariantHelper<Types<Ts...>, Ints<Is...>, ArrayViews...,
ArrayView<const T, Is>...,
ArrayView<T, Is>...>::type;
};

Expand All @@ -54,26 +55,42 @@ using Values = detail::Types<float, double, int, long, unsigned long>;
/// @brief Supported ArrayView ranks.
using Ranks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>;

/// @brief Variant containing all supported ArrayView alternatives.
/// @brief Variant containing all supported non-const ArrayView alternatives.
using ArrayViewVariant = detail::Variant<Values, Ranks>;

/// @brief Variant containing all supported const ArrayView alternatives.
using ConstArrayViewVariant = detail::Variant<Values::add_const, Ranks>;

/// @brief Create an ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_view_variant(Array& array);

/// @brief Create a const ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_view_variant(const Array& array);
ConstArrayViewVariant make_view_variant(const Array& array);

/// @brief Create a host ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_host_view_variant(Array& array);

/// @brief Create a const host ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_host_view_variant(const Array& array);
ConstArrayViewVariant make_host_view_variant(const Array& array);

/// @brief Create a device ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_device_view_variant(Array& array);

/// @brief Create a const device ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_device_view_variant(const Array& array);
ConstArrayViewVariant make_device_view_variant(const Array& array);

/// @brief Return true if ArrayView<typename, int>::rank() is any of Ranks...
template <typename View, int... Ranks>
constexpr bool RankIs() {
return ((std::decay_t<View>::rank() == Ranks) || ...);
}

/// @brief Return true if View::non_const_value_type is any of Values...
template <typename View, typename... Values>
constexpr bool ValueIs() {
using Value = typename std::decay_t<View>::non_const_value_type;
return ((std::is_same_v<Value, Values>) || ...);
}

} // namespace array
} // namespace atlas
15 changes: 5 additions & 10 deletions src/tests/array/test_array_view_variant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ CASE("test variant assignment") {
visitVariants(deviceView1, deviceView2, deviceView3, deviceView4);
}

template <typename View>
constexpr auto Rank() {
return std::remove_reference_t<View>::rank();
}

CASE("test std::visit") {
auto array1 = ArrayT<int>(10);
make_view<int, 1>(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
Expand Down Expand Up @@ -127,11 +122,11 @@ CASE("test std::visit") {
auto rank2Tested = false;

const auto visitor = [&](auto&& view) {
if constexpr (Rank<decltype(view)>() == 1) {
if constexpr (RankIs<decltype(view), 1>()) {
rank1Tested = true;
return testRank1(view);
}
if constexpr (Rank<decltype(view)>() == 2) {
if constexpr (RankIs<decltype(view), 2>()) {
rank2Tested = true;
return testRank2(view);
}
Expand All @@ -151,15 +146,15 @@ CASE("test std::visit") {
auto rank1Tested = false;
auto rank2Tested = false;
const auto visitor = eckit::Overloaded{
[&](auto&& view) -> std::enable_if_t<Rank<decltype(view)>() == 1> {
[&](auto&& view) -> std::enable_if_t<RankIs<decltype(view), 1>()> {
testRank1(view);
rank1Tested = true;
},
[&](auto&& view) -> std::enable_if_t<Rank<decltype(view)>() == 2> {
[&](auto&& view) -> std::enable_if_t<RankIs<decltype(view), 2>()> {
testRank2(view);
rank2Tested = true;
},
[](auto&& view) -> std::enable_if_t<(Rank<decltype(view)>() > 2)> {
[](auto&& view) -> std::enable_if_t<!RankIs<decltype(view), 1, 2>()> {
// Test should not reach here.
EXPECT(false);
}};
Expand Down

0 comments on commit aeb352f

Please sign in to comment.