-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added function to create
std::variant
for multiple array views. (#220)
* Added array view variant class and tests. * Moved make_view_variant into array namespace. * Refactored ArrayViewVariant methods. * More refactoring. * Updated test. * Attempting to address gnu 7.3 compiler errors. * Typos in comments. * Added missing EXPECTs in test. * Refactored detial::VariantHelper template. * Merged in ArrayViewVariant refactor. * Refactored introspection helpers. * Refactor helper function signatures. Removed SFINAE test. * Cleaned up some garbage in test. * Removed reference qualifier on visitor template parameter. * Moved ValuesTypes and Ranks structs into array::detail namespace. * Tidied up naming consistency. * Renamed ValueType and Ranks structs. * Revert parameter names in ArrayViewVariant.h * Revert parameter names in ArrayViewVariant.h
- Loading branch information
Showing
6 changed files
with
352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
* (C) Crown Copyright 2024 Met Office | ||
* | ||
* This software is licensed under the terms of the Apache Licence Version 2.0 | ||
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
*/ | ||
|
||
#include "atlas/array/ArrayViewVariant.h" | ||
|
||
#include <string> | ||
#include <type_traits> | ||
|
||
#include "atlas/runtime/Exception.h" | ||
|
||
namespace atlas { | ||
namespace array { | ||
|
||
using namespace detail; | ||
|
||
namespace { | ||
|
||
template <bool IsConst> | ||
struct VariantTypeHelper { | ||
using type = ArrayViewVariant; | ||
}; | ||
|
||
template <> | ||
struct VariantTypeHelper<true> { | ||
using type = ConstArrayViewVariant; | ||
}; | ||
|
||
template <typename ArrayType> | ||
using VariantType = | ||
typename VariantTypeHelper<std::is_const_v<ArrayType>>::type; | ||
|
||
// Match array.rank() and array.datatype() to variant types. Return result of | ||
// makeView on a successful pattern match. | ||
template <size_t TypeIndex = 0, typename ArrayType, typename MakeView> | ||
VariantType<ArrayType> executeMakeView(ArrayType& array, | ||
const MakeView& makeView) { | ||
using View = std::variant_alternative_t<TypeIndex, VariantType<ArrayType>>; | ||
using Value = typename View::non_const_value_type; | ||
constexpr auto Rank = View::rank(); | ||
|
||
if (array.datatype() == DataType::kind<Value>() && array.rank() == Rank) { | ||
return makeView(array, Value{}, std::integral_constant<int, Rank>{}); | ||
} | ||
|
||
if constexpr (TypeIndex < std::variant_size_v<VariantType<ArrayType>> - 1) { | ||
return executeMakeView<TypeIndex + 1>(array, makeView); | ||
} else { | ||
throw_Exception("ArrayView<" + array.datatype().str() + ", " + | ||
std::to_string(array.rank()) + | ||
"> is not an alternative in ArrayViewVariant.", | ||
Here()); | ||
} | ||
} | ||
|
||
template <typename ArrayType> | ||
VariantType<ArrayType> makeViewVariantImpl(ArrayType& array) { | ||
const auto makeView = [](auto& array, auto value, auto rank) { | ||
return make_view<decltype(value), decltype(rank)::value>(array); | ||
}; | ||
return executeMakeView<>(array, makeView); | ||
} | ||
|
||
template <typename ArrayType> | ||
VariantType<ArrayType> makeHostViewVariantImpl(ArrayType& array) { | ||
const auto makeView = [](auto& array, auto value, auto rank) { | ||
return make_host_view<decltype(value), decltype(rank)::value>(array); | ||
}; | ||
return executeMakeView<>(array, makeView); | ||
} | ||
|
||
template <typename ArrayType> | ||
VariantType<ArrayType> makeDeviceViewVariantImpl(ArrayType& array) { | ||
const auto makeView = [](auto& array, auto value, auto rank) { | ||
return make_device_view<decltype(value), decltype(rank)::value>(array); | ||
}; | ||
return executeMakeView<>(array, makeView); | ||
} | ||
|
||
} // namespace | ||
|
||
ArrayViewVariant make_view_variant(Array& array) { | ||
return makeViewVariantImpl(array); | ||
} | ||
|
||
ConstArrayViewVariant make_view_variant(const Array& array) { | ||
return makeViewVariantImpl(array); | ||
} | ||
|
||
ArrayViewVariant make_host_view_variant(Array& array) { | ||
return makeHostViewVariantImpl(array); | ||
} | ||
|
||
ConstArrayViewVariant make_host_view_variant(const Array& array) { | ||
return makeHostViewVariantImpl(array); | ||
} | ||
|
||
ArrayViewVariant make_device_view_variant(Array& array) { | ||
return makeDeviceViewVariantImpl(array); | ||
} | ||
|
||
ConstArrayViewVariant make_device_view_variant(const Array& array) { | ||
return makeDeviceViewVariantImpl(array); | ||
} | ||
|
||
} // namespace array | ||
} // namespace atlas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/* | ||
* (C) Crown Copyright 2024 Met Office | ||
* | ||
* This software is licensed under the terms of the Apache Licence Version 2.0 | ||
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <variant> | ||
|
||
#include "atlas/array.h" | ||
|
||
namespace atlas { | ||
namespace array { | ||
|
||
namespace detail { | ||
|
||
using namespace array; | ||
|
||
// Container struct for a list of types. | ||
template <typename... Ts> | ||
struct Types { | ||
using add_const = Types<std::add_const_t<Ts>...>; | ||
}; | ||
|
||
// Container struct for a list of integers. | ||
template <int... Is> | ||
struct Ints {}; | ||
|
||
template <typename ValueTypes, typename Ranks, typename... ArrayViews> | ||
struct VariantHelper; | ||
|
||
// Recursively construct ArrayView std::variant from types Ts and ranks Is. | ||
template <typename T, typename... Ts, int... Is, typename... ArrayViews> | ||
struct VariantHelper<Types<T, Ts...>, Ints<Is...>, ArrayViews...> { | ||
using type = typename VariantHelper<Types<Ts...>, Ints<Is...>, ArrayViews..., | ||
ArrayView<T, Is>...>::type; | ||
}; | ||
|
||
// End recursion. | ||
template <int... Is, typename... ArrayViews> | ||
struct VariantHelper<Types<>, Ints<Is...>, ArrayViews...> { | ||
using type = std::variant<ArrayViews...>; | ||
}; | ||
|
||
template <typename ValueTypes, typename Ranks> | ||
using Variant = typename VariantHelper<ValueTypes, Ranks>::type; | ||
|
||
using VariantValueTypes = | ||
detail::Types<float, double, int, long, unsigned long>; | ||
|
||
using VariantRanks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>; | ||
|
||
} // namespace detail | ||
|
||
/// @brief Variant containing all supported non-const ArrayView alternatives. | ||
using ArrayViewVariant = | ||
detail::Variant<detail::VariantValueTypes, detail::VariantRanks>; | ||
|
||
/// @brief Variant containing all supported const ArrayView alternatives. | ||
using ConstArrayViewVariant = | ||
detail::Variant<detail::VariantValueTypes::add_const, detail::VariantRanks>; | ||
|
||
/// @brief Create an ArrayView and assign to an ArrayViewVariant. | ||
ArrayViewVariant make_view_variant(Array& array); | ||
|
||
/// @brief Create a const ArrayView and assign to an ArrayViewVariant. | ||
ConstArrayViewVariant make_view_variant(const Array& array); | ||
|
||
/// @brief Create a host ArrayView and assign to an ArrayViewVariant. | ||
ArrayViewVariant make_host_view_variant(Array& array); | ||
|
||
/// @brief Create a const host ArrayView and assign to an ArrayViewVariant. | ||
ConstArrayViewVariant make_host_view_variant(const Array& array); | ||
|
||
/// @brief Create a device ArrayView and assign to an ArrayViewVariant. | ||
ArrayViewVariant make_device_view_variant(Array& array); | ||
|
||
/// @brief Create a const device ArrayView and assign to an ArrayViewVariant. | ||
ConstArrayViewVariant make_device_view_variant(const Array& array); | ||
|
||
/// @brief Return true if View::rank() is any of Ranks... | ||
template <int... Ranks, typename View> | ||
constexpr bool is_rank(const View&) { | ||
return ((std::decay_t<View>::rank() == Ranks) || ...); | ||
} | ||
/// @brief Return true if View::value_type is any of ValuesTypes... | ||
template <typename... ValueTypes, typename View> | ||
constexpr bool is_value_type(const View&) { | ||
using ValueType = typename std::decay_t<View>::value_type; | ||
return ((std::is_same_v<ValueType, ValueTypes>) || ...); | ||
} | ||
|
||
/// @brief Return true if View::non_const_value_type is any of ValuesTypes... | ||
template <typename... ValueTypes, typename View> | ||
constexpr bool is_non_const_value_type(const View&) { | ||
using ValueType = typename std::decay_t<View>::non_const_value_type; | ||
return ((std::is_same_v<ValueType, ValueTypes>) || ...); | ||
} | ||
|
||
} // namespace array | ||
} // namespace atlas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
/* | ||
* (C) Crown Copyright 2024 Met Office | ||
* | ||
* This software is licensed under the terms of the Apache Licence Version 2.0 | ||
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
*/ | ||
|
||
#include <type_traits> | ||
#include <variant> | ||
|
||
#include "atlas/array.h" | ||
#include "atlas/array/ArrayViewVariant.h" | ||
#include "eckit/utils/Overloaded.h" | ||
#include "tests/AtlasTestEnvironment.h" | ||
|
||
namespace atlas { | ||
namespace test { | ||
|
||
using namespace array; | ||
|
||
CASE("test variant assignment") { | ||
auto array1 = array::ArrayT<float>(2); | ||
auto array2 = array::ArrayT<double>(2, 3); | ||
auto array3 = array::ArrayT<int>(2, 3, 4); | ||
const auto& arrayRef = array1; | ||
|
||
array1.allocateDevice(); | ||
array2.allocateDevice(); | ||
array3.allocateDevice(); | ||
|
||
auto view1 = make_view_variant(array1); | ||
auto view2 = make_view_variant(array2); | ||
auto view3 = make_view_variant(array3); | ||
auto view4 = make_view_variant(arrayRef); | ||
|
||
const auto hostView1 = make_host_view_variant(array1); | ||
const auto hostView2 = make_host_view_variant(array2); | ||
const auto hostView3 = make_host_view_variant(array3); | ||
const auto hostView4 = make_host_view_variant(arrayRef); | ||
|
||
auto deviceView1 = make_device_view_variant(array1); | ||
auto deviceView2 = make_device_view_variant(array2); | ||
auto deviceView3 = make_device_view_variant(array3); | ||
auto deviceView4 = make_device_view_variant(arrayRef); | ||
|
||
const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) { | ||
std::visit( | ||
[](auto view) { | ||
EXPECT((is_rank<1>(view))); | ||
EXPECT((is_value_type<float>(view))); | ||
EXPECT((is_non_const_value_type<float>(view))); | ||
}, | ||
var1); | ||
|
||
std::visit( | ||
[](auto view) { | ||
EXPECT((is_rank<2>(view))); | ||
EXPECT((is_value_type<double>(view))); | ||
EXPECT((is_non_const_value_type<double>(view))); | ||
}, | ||
var2); | ||
|
||
std::visit( | ||
[](auto view) { | ||
EXPECT((is_rank<3>(view))); | ||
EXPECT((is_value_type<int>(view))); | ||
EXPECT((is_non_const_value_type<int>(view))); | ||
}, | ||
var3); | ||
|
||
std::visit( | ||
[](auto view) { | ||
EXPECT((is_rank<1>(view))); | ||
EXPECT((is_value_type<const float>(view))); | ||
EXPECT((is_non_const_value_type<float>(view))); | ||
}, | ||
var4); | ||
}; | ||
|
||
visitVariants(view1, view2, view3, view4); | ||
visitVariants(hostView1, hostView2, hostView3, hostView4); | ||
visitVariants(deviceView1, deviceView2, deviceView3, deviceView4); | ||
} | ||
|
||
CASE("test std::visit") { | ||
auto array1 = ArrayT<int>(10); | ||
make_view<int, 1>(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); | ||
|
||
auto array2 = ArrayT<int>(5, 2); | ||
make_view<int, 2>(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); | ||
|
||
const auto var1 = make_view_variant(array1); | ||
const auto var2 = make_view_variant(array2); | ||
auto rank1Tested = false; | ||
auto rank2Tested = false; | ||
|
||
const auto visitor = [&](auto view) { | ||
if constexpr (is_rank<1>(view)) { | ||
EXPECT((is_value_type<int>(view))); | ||
auto testValue = int{0}; | ||
for (auto i = size_t{0}; i < view.size(); ++i) { | ||
const auto value = view(i); | ||
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++)); | ||
} | ||
rank1Tested = true; | ||
} else if constexpr (is_rank<2>(view)) { | ||
EXPECT((is_value_type<int>(view))); | ||
auto testValue = int{0}; | ||
for (auto i = idx_t{0}; i < view.shape(0); ++i) { | ||
for (auto j = idx_t{0}; j < view.shape(1); ++j) { | ||
const auto value = view(i, j); | ||
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++)); | ||
} | ||
} | ||
rank2Tested = true; | ||
} else { | ||
// Test should not reach here. | ||
EXPECT(false); | ||
} | ||
}; | ||
|
||
std::visit(visitor, var1); | ||
EXPECT(rank1Tested); | ||
std::visit(visitor, var2); | ||
EXPECT(rank2Tested); | ||
} | ||
|
||
} // namespace test | ||
} // namespace atlas | ||
|
||
int main(int argc, char** argv) { return atlas::test::run(argc, argv); } |