Skip to content

Commit

Permalink
Added function to create std::variant for multiple array views. (#220)
Browse files Browse the repository at this point in the history
* Added array view variant class and tests.

* Moved make_view_variant into array namespace.

* Refactored ArrayViewVariant methods.

* More refactoring.

* Updated test.

* Attempting to address gnu 7.3 compiler errors.

* Typos in comments.

* Added missing EXPECTs in test.

* Refactored detial::VariantHelper template.

* Merged in ArrayViewVariant refactor.

* Refactored introspection helpers.

* Refactor helper function signatures. Removed SFINAE test.

* Cleaned up some garbage in test.

* Removed reference qualifier on visitor template parameter.

* Moved ValuesTypes and Ranks structs into array::detail namespace.

* Tidied up naming consistency.

* Renamed ValueType and Ranks structs.

* Revert parameter names in ArrayViewVariant.h

* Revert parameter names in ArrayViewVariant.h
  • Loading branch information
odlomax authored Oct 8, 2024
1 parent 051dfc4 commit b944fba
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/atlas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,8 @@ array/Range.h
array/Vector.h
array/Vector.cc
array/SVector.h
array/ArrayViewVariant.h
array/ArrayViewVariant.cc
array/helpers/ArrayInitializer.h
array/helpers/ArrayAssigner.h
array/helpers/ArrayWriter.h
Expand Down
1 change: 1 addition & 0 deletions src/atlas/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "atlas/array/ArraySpec.h"
#include "atlas/array/ArrayStrides.h"
#include "atlas/array/ArrayView.h"
#include "atlas/array/ArrayViewVariant.h"
#include "atlas/array/DataType.h"
#include "atlas/array/LocalView.h"
#include "atlas/array/MakeView.h"
Expand Down
110 changes: 110 additions & 0 deletions src/atlas/array/ArrayViewVariant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* (C) Crown Copyright 2024 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#include "atlas/array/ArrayViewVariant.h"

#include <string>
#include <type_traits>

#include "atlas/runtime/Exception.h"

namespace atlas {
namespace array {

using namespace detail;

namespace {

template <bool IsConst>
struct VariantTypeHelper {
using type = ArrayViewVariant;
};

template <>
struct VariantTypeHelper<true> {
using type = ConstArrayViewVariant;
};

template <typename ArrayType>
using VariantType =
typename VariantTypeHelper<std::is_const_v<ArrayType>>::type;

// Match array.rank() and array.datatype() to variant types. Return result of
// makeView on a successful pattern match.
template <size_t TypeIndex = 0, typename ArrayType, typename MakeView>
VariantType<ArrayType> executeMakeView(ArrayType& array,
const MakeView& makeView) {
using View = std::variant_alternative_t<TypeIndex, VariantType<ArrayType>>;
using Value = typename View::non_const_value_type;
constexpr auto Rank = View::rank();

if (array.datatype() == DataType::kind<Value>() && array.rank() == Rank) {
return makeView(array, Value{}, std::integral_constant<int, Rank>{});
}

if constexpr (TypeIndex < std::variant_size_v<VariantType<ArrayType>> - 1) {
return executeMakeView<TypeIndex + 1>(array, makeView);
} else {
throw_Exception("ArrayView<" + array.datatype().str() + ", " +
std::to_string(array.rank()) +
"> is not an alternative in ArrayViewVariant.",
Here());
}
}

template <typename ArrayType>
VariantType<ArrayType> makeViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

template <typename ArrayType>
VariantType<ArrayType> makeHostViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_host_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

template <typename ArrayType>
VariantType<ArrayType> makeDeviceViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_device_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

} // namespace

ArrayViewVariant make_view_variant(Array& array) {
return makeViewVariantImpl(array);
}

ConstArrayViewVariant make_view_variant(const Array& array) {
return makeViewVariantImpl(array);
}

ArrayViewVariant make_host_view_variant(Array& array) {
return makeHostViewVariantImpl(array);
}

ConstArrayViewVariant make_host_view_variant(const Array& array) {
return makeHostViewVariantImpl(array);
}

ArrayViewVariant make_device_view_variant(Array& array) {
return makeDeviceViewVariantImpl(array);
}

ConstArrayViewVariant make_device_view_variant(const Array& array) {
return makeDeviceViewVariantImpl(array);
}

} // namespace array
} // namespace atlas
103 changes: 103 additions & 0 deletions src/atlas/array/ArrayViewVariant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* (C) Crown Copyright 2024 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#pragma once

#include <variant>

#include "atlas/array.h"

namespace atlas {
namespace array {

namespace detail {

using namespace array;

// Container struct for a list of types.
template <typename... Ts>
struct Types {
using add_const = Types<std::add_const_t<Ts>...>;
};

// Container struct for a list of integers.
template <int... Is>
struct Ints {};

template <typename ValueTypes, typename Ranks, typename... ArrayViews>
struct VariantHelper;

// Recursively construct ArrayView std::variant from types Ts and ranks Is.
template <typename T, typename... Ts, int... Is, typename... ArrayViews>
struct VariantHelper<Types<T, Ts...>, Ints<Is...>, ArrayViews...> {
using type = typename VariantHelper<Types<Ts...>, Ints<Is...>, ArrayViews...,
ArrayView<T, Is>...>::type;
};

// End recursion.
template <int... Is, typename... ArrayViews>
struct VariantHelper<Types<>, Ints<Is...>, ArrayViews...> {
using type = std::variant<ArrayViews...>;
};

template <typename ValueTypes, typename Ranks>
using Variant = typename VariantHelper<ValueTypes, Ranks>::type;

using VariantValueTypes =
detail::Types<float, double, int, long, unsigned long>;

using VariantRanks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>;

} // namespace detail

/// @brief Variant containing all supported non-const ArrayView alternatives.
using ArrayViewVariant =
detail::Variant<detail::VariantValueTypes, detail::VariantRanks>;

/// @brief Variant containing all supported const ArrayView alternatives.
using ConstArrayViewVariant =
detail::Variant<detail::VariantValueTypes::add_const, detail::VariantRanks>;

/// @brief Create an ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_view_variant(Array& array);

/// @brief Create a const ArrayView and assign to an ArrayViewVariant.
ConstArrayViewVariant make_view_variant(const Array& array);

/// @brief Create a host ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_host_view_variant(Array& array);

/// @brief Create a const host ArrayView and assign to an ArrayViewVariant.
ConstArrayViewVariant make_host_view_variant(const Array& array);

/// @brief Create a device ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_device_view_variant(Array& array);

/// @brief Create a const device ArrayView and assign to an ArrayViewVariant.
ConstArrayViewVariant make_device_view_variant(const Array& array);

/// @brief Return true if View::rank() is any of Ranks...
template <int... Ranks, typename View>
constexpr bool is_rank(const View&) {
return ((std::decay_t<View>::rank() == Ranks) || ...);
}
/// @brief Return true if View::value_type is any of ValuesTypes...
template <typename... ValueTypes, typename View>
constexpr bool is_value_type(const View&) {
using ValueType = typename std::decay_t<View>::value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}

/// @brief Return true if View::non_const_value_type is any of ValuesTypes...
template <typename... ValueTypes, typename View>
constexpr bool is_non_const_value_type(const View&) {
using ValueType = typename std::decay_t<View>::non_const_value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}

} // namespace array
} // namespace atlas
5 changes: 5 additions & 0 deletions src/tests/array/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ atlas_add_hic_test(
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_array_view_variant
SOURCES test_array_view_variant.cc
LIBS atlas
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)
131 changes: 131 additions & 0 deletions src/tests/array/test_array_view_variant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* (C) Crown Copyright 2024 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#include <type_traits>
#include <variant>

#include "atlas/array.h"
#include "atlas/array/ArrayViewVariant.h"
#include "eckit/utils/Overloaded.h"
#include "tests/AtlasTestEnvironment.h"

namespace atlas {
namespace test {

using namespace array;

CASE("test variant assignment") {
auto array1 = array::ArrayT<float>(2);
auto array2 = array::ArrayT<double>(2, 3);
auto array3 = array::ArrayT<int>(2, 3, 4);
const auto& arrayRef = array1;

array1.allocateDevice();
array2.allocateDevice();
array3.allocateDevice();

auto view1 = make_view_variant(array1);
auto view2 = make_view_variant(array2);
auto view3 = make_view_variant(array3);
auto view4 = make_view_variant(arrayRef);

const auto hostView1 = make_host_view_variant(array1);
const auto hostView2 = make_host_view_variant(array2);
const auto hostView3 = make_host_view_variant(array3);
const auto hostView4 = make_host_view_variant(arrayRef);

auto deviceView1 = make_device_view_variant(array1);
auto deviceView2 = make_device_view_variant(array2);
auto deviceView3 = make_device_view_variant(array3);
auto deviceView4 = make_device_view_variant(arrayRef);

const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) {
std::visit(
[](auto view) {
EXPECT((is_rank<1>(view)));
EXPECT((is_value_type<float>(view)));
EXPECT((is_non_const_value_type<float>(view)));
},
var1);

std::visit(
[](auto view) {
EXPECT((is_rank<2>(view)));
EXPECT((is_value_type<double>(view)));
EXPECT((is_non_const_value_type<double>(view)));
},
var2);

std::visit(
[](auto view) {
EXPECT((is_rank<3>(view)));
EXPECT((is_value_type<int>(view)));
EXPECT((is_non_const_value_type<int>(view)));
},
var3);

std::visit(
[](auto view) {
EXPECT((is_rank<1>(view)));
EXPECT((is_value_type<const float>(view)));
EXPECT((is_non_const_value_type<float>(view)));
},
var4);
};

visitVariants(view1, view2, view3, view4);
visitVariants(hostView1, hostView2, hostView3, hostView4);
visitVariants(deviceView1, deviceView2, deviceView3, deviceView4);
}

CASE("test std::visit") {
auto array1 = ArrayT<int>(10);
make_view<int, 1>(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

auto array2 = ArrayT<int>(5, 2);
make_view<int, 2>(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

const auto var1 = make_view_variant(array1);
const auto var2 = make_view_variant(array2);
auto rank1Tested = false;
auto rank2Tested = false;

const auto visitor = [&](auto view) {
if constexpr (is_rank<1>(view)) {
EXPECT((is_value_type<int>(view)));
auto testValue = int{0};
for (auto i = size_t{0}; i < view.size(); ++i) {
const auto value = view(i);
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++));
}
rank1Tested = true;
} else if constexpr (is_rank<2>(view)) {
EXPECT((is_value_type<int>(view)));
auto testValue = int{0};
for (auto i = idx_t{0}; i < view.shape(0); ++i) {
for (auto j = idx_t{0}; j < view.shape(1); ++j) {
const auto value = view(i, j);
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++));
}
}
rank2Tested = true;
} else {
// Test should not reach here.
EXPECT(false);
}
};

std::visit(visitor, var1);
EXPECT(rank1Tested);
std::visit(visitor, var2);
EXPECT(rank2Tested);
}

} // namespace test
} // namespace atlas

int main(int argc, char** argv) { return atlas::test::run(argc, argv); }

0 comments on commit b944fba

Please sign in to comment.