diff --git a/src/atlas/array.h b/src/atlas/array.h index c2cf7f720..a7ac48d07 100644 --- a/src/atlas/array.h +++ b/src/atlas/array.h @@ -23,6 +23,7 @@ #include "atlas/array/ArraySpec.h" #include "atlas/array/ArrayStrides.h" #include "atlas/array/ArrayView.h" +#include "atlas/array/ArrayViewVariant.h" #include "atlas/array/DataType.h" #include "atlas/array/LocalView.h" #include "atlas/array/MakeView.h" diff --git a/src/atlas/array/ArrayViewVariant.cc b/src/atlas/array/ArrayViewVariant.cc index 28d532117..f62efd2d8 100644 --- a/src/atlas/array/ArrayViewVariant.cc +++ b/src/atlas/array/ArrayViewVariant.cc @@ -19,22 +19,34 @@ using namespace detail; namespace { +template +struct VariantTypeHelper { + using type = ArrayViewVariant; +}; + +template <> +struct VariantTypeHelper { + using type = ConstArrayViewVariant; +}; + +template +using VariantType = + typename VariantTypeHelper>::type; + // Match array.rank() and array.datatype() to variant types. Return result of // makeView on a successful pattern match. template -ArrayViewVariant executeMakeView(ArrayType& array, const MakeView& makeView) { - using View = std::variant_alternative_t; - constexpr auto Const = std::is_const_v; - - if constexpr (std::is_const_v == Const) { - using Value = typename View::non_const_value_type; - constexpr auto Rank = View::rank(); - if (array.datatype() == DataType::kind() && array.rank() == Rank) { - return makeView(array, Value{}, std::integral_constant{}); - } +VariantType executeMakeView(ArrayType& array, + const MakeView& makeView) { + using View = std::variant_alternative_t>; + using Value = typename View::non_const_value_type; + constexpr auto Rank = View::rank(); + + if (array.datatype() == DataType::kind() && array.rank() == Rank) { + return makeView(array, Value{}, std::integral_constant{}); } - if constexpr (TypeIndex < std::variant_size_v - 1) { + if constexpr (TypeIndex < std::variant_size_v> - 1) { return executeMakeView(array, makeView); } else { throw_Exception("ArrayView<" + array.datatype().str() + ", " + @@ -45,7 +57,7 @@ ArrayViewVariant executeMakeView(ArrayType& array, const MakeView& makeView) { } template -ArrayViewVariant makeViewVariantImpl(ArrayType& array) { +VariantType makeViewVariantImpl(ArrayType& array) { const auto makeView = [](auto& array, auto value, auto rank) { return make_view(array); }; @@ -53,7 +65,7 @@ ArrayViewVariant makeViewVariantImpl(ArrayType& array) { } template -ArrayViewVariant makeHostViewVariantImpl(ArrayType& array) { +VariantType makeHostViewVariantImpl(ArrayType& array) { const auto makeView = [](auto& array, auto value, auto rank) { return make_host_view(array); }; @@ -61,7 +73,7 @@ ArrayViewVariant makeHostViewVariantImpl(ArrayType& array) { } template -ArrayViewVariant makeDeviceViewVariantImpl(ArrayType& array) { +VariantType makeDeviceViewVariantImpl(ArrayType& array) { const auto makeView = [](auto& array, auto value, auto rank) { return make_device_view(array); }; @@ -74,7 +86,7 @@ ArrayViewVariant make_view_variant(Array& array) { return makeViewVariantImpl(array); } -ArrayViewVariant make_view_variant(const Array& array) { +ConstArrayViewVariant make_view_variant(const Array& array) { return makeViewVariantImpl(array); } @@ -82,7 +94,7 @@ ArrayViewVariant make_host_view_variant(Array& array) { return makeHostViewVariantImpl(array); } -ArrayViewVariant make_host_view_variant(const Array& array) { +ConstArrayViewVariant make_host_view_variant(const Array& array) { return makeHostViewVariantImpl(array); } @@ -90,7 +102,7 @@ ArrayViewVariant make_device_view_variant(Array& array) { return makeDeviceViewVariantImpl(array); } -ArrayViewVariant make_device_view_variant(const Array& array) { +ConstArrayViewVariant make_device_view_variant(const Array& array) { return makeDeviceViewVariantImpl(array); } diff --git a/src/atlas/array/ArrayViewVariant.h b/src/atlas/array/ArrayViewVariant.h index d5dd196f8..0ea3f87c6 100644 --- a/src/atlas/array/ArrayViewVariant.h +++ b/src/atlas/array/ArrayViewVariant.h @@ -20,7 +20,9 @@ using namespace array; // Container struct for a list of types. template -struct Types {}; +struct Types { + using add_const = Types...>; +}; // Container struct for a list of integers. template @@ -33,7 +35,6 @@ struct VariantHelper; template struct VariantHelper, Ints, ArrayViews...> { using type = typename VariantHelper, Ints, ArrayViews..., - ArrayView..., ArrayView...>::type; }; @@ -54,26 +55,42 @@ using Values = detail::Types; /// @brief Supported ArrayView ranks. using Ranks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>; -/// @brief Variant containing all supported ArrayView alternatives. +/// @brief Variant containing all supported non-const ArrayView alternatives. using ArrayViewVariant = detail::Variant; +/// @brief Variant containing all supported const ArrayView alternatives. +using ConstArrayViewVariant = detail::Variant; + /// @brief Create an ArrayView and assign to an ArrayViewVariant. ArrayViewVariant make_view_variant(Array& array); /// @brief Create a const ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_view_variant(const Array& array); +ConstArrayViewVariant make_view_variant(const Array& array); /// @brief Create a host ArrayView and assign to an ArrayViewVariant. ArrayViewVariant make_host_view_variant(Array& array); /// @brief Create a const host ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_host_view_variant(const Array& array); +ConstArrayViewVariant make_host_view_variant(const Array& array); /// @brief Create a device ArrayView and assign to an ArrayViewVariant. ArrayViewVariant make_device_view_variant(Array& array); /// @brief Create a const device ArrayView and assign to an ArrayViewVariant. -ArrayViewVariant make_device_view_variant(const Array& array); +ConstArrayViewVariant make_device_view_variant(const Array& array); + +/// @brief Return true if ArrayView::rank() is any of Ranks... +template +constexpr bool RankIs() { + return ((std::decay_t::rank() == Ranks) || ...); +} + +/// @brief Return true if View::non_const_value_type is any of Values... +template +constexpr bool ValueIs() { + using Value = typename std::decay_t::non_const_value_type; + return ((std::is_same_v) || ...); +} } // namespace array } // namespace atlas diff --git a/src/tests/array/test_array_view_variant.cc b/src/tests/array/test_array_view_variant.cc index 9f60a10be..29a3cb044 100644 --- a/src/tests/array/test_array_view_variant.cc +++ b/src/tests/array/test_array_view_variant.cc @@ -82,11 +82,6 @@ CASE("test variant assignment") { visitVariants(deviceView1, deviceView2, deviceView3, deviceView4); } -template -constexpr auto Rank() { - return std::remove_reference_t::rank(); -} - CASE("test std::visit") { auto array1 = ArrayT(10); make_view(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -127,11 +122,11 @@ CASE("test std::visit") { auto rank2Tested = false; const auto visitor = [&](auto&& view) { - if constexpr (Rank() == 1) { + if constexpr (RankIs()) { rank1Tested = true; return testRank1(view); } - if constexpr (Rank() == 2) { + if constexpr (RankIs()) { rank2Tested = true; return testRank2(view); } @@ -151,15 +146,15 @@ CASE("test std::visit") { auto rank1Tested = false; auto rank2Tested = false; const auto visitor = eckit::Overloaded{ - [&](auto&& view) -> std::enable_if_t() == 1> { + [&](auto&& view) -> std::enable_if_t()> { testRank1(view); rank1Tested = true; }, - [&](auto&& view) -> std::enable_if_t() == 2> { + [&](auto&& view) -> std::enable_if_t()> { testRank2(view); rank2Tested = true; }, - [](auto&& view) -> std::enable_if_t<(Rank() > 2)> { + [](auto&& view) -> std::enable_if_t()> { // Test should not reach here. EXPECT(false); }};