From 50989686fb36b715d44ccfc11ccf7fff6c53a111 Mon Sep 17 00:00:00 2001 From: odlomax Date: Mon, 14 Oct 2024 11:04:34 +0100 Subject: [PATCH] Refactored spherical vector interpolation method. --- .../method/sphericalvector/SphericalVector.cc | 84 ++++++++----------- 1 file changed, 34 insertions(+), 50 deletions(-) diff --git a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc index 65e4e41d0..8bcd0bef1 100644 --- a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc +++ b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc @@ -5,6 +5,9 @@ * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. */ +#include +#include + #include "atlas/interpolation/method/sphericalvector/SphericalVector.h" #include "atlas/array/ArrayView.h" @@ -203,59 +206,40 @@ template void SphericalVector::interpolate_vector_field(const Field& sourceField, Field& targetField, const MatMul& matMul) { - ATLAS_ASSERT_MSG(sourceField.variables() == 2 || sourceField.variables() == 3, - "Vector field can only have 2 or 3 components."); - - if (sourceField.datatype().kind() == array::DataType::KIND_REAL64) { - interpolate_vector_field(sourceField, targetField, matMul); - return; - } - - if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) { - interpolate_vector_field(sourceField, targetField, matMul); - return; - } + const auto sourceViewVariant = array::make_view_variant(sourceField); + + const auto sourceViewVisitor = [&](auto sourceView) { + if constexpr (array::is_rank<2, 3>(sourceView) && + array::is_non_const_value_type(sourceView)) { + using SourceView = std::decay_t; + using Value = typename SourceView::non_const_value_type; + constexpr auto Rank = SourceView::rank(); + auto targetView = array::make_view(targetField); + + switch (sourceField.variables()) { + case 2: + return matMul.apply(sourceView, targetView, twoVector); + case 3: + return matMul.apply(sourceView, targetView, threeVector); + default: + ATLAS_THROW_EXCEPTION("Error: no support for " + + std::to_string(sourceField.variables()) + + " variable vector fields.\n" + + " Number of variables must be 2 or 3."); + } + + } else { + ATLAS_THROW_EXCEPTION( + "Error: no support for rank = " + std::to_string(sourceField.rank()) + + " and value type = " + sourceField.datatype().str() + ".\n" + + "Vector field must have rank 2 or 3 with value type " + "float or double"); + } + }; - ATLAS_NOTIMPLEMENTED; + std::visit(sourceViewVisitor, sourceViewVariant); }; -template -void SphericalVector::interpolate_vector_field(const Field& sourceField, - Field& targetField, - const MatMul& matMul) { - if (sourceField.rank() == 2) { - interpolate_vector_field(sourceField, targetField, matMul); - return; - } - - if (sourceField.rank() == 3) { - interpolate_vector_field(sourceField, targetField, matMul); - return; - } - - ATLAS_NOTIMPLEMENTED; -} - -template -void SphericalVector::interpolate_vector_field(const Field& sourceField, - Field& targetField, - const MatMul& matMul) { - const auto sourceView = array::make_view(sourceField); - auto targetView = array::make_view(targetField); - - if (sourceField.variables() == 2) { - matMul.apply(sourceView, targetView, twoVector); - return; - } - - if (sourceField.variables() == 3) { - matMul.apply(sourceView, targetView, threeVector); - return; - } - - ATLAS_NOTIMPLEMENTED; -} - } // namespace method } // namespace interpolation } // namespace atlas