Skip to content

Commit

Permalink
Refactored spherical vector interpolation method.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Oct 14, 2024
1 parent a7897b9 commit a065858
Showing 1 changed file with 34 additions and 50 deletions.
84 changes: 34 additions & 50 deletions src/atlas/interpolation/method/sphericalvector/SphericalVector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

#include "atlas/interpolation/method/sphericalvector/SphericalVector.h"

#include <cmath>
#include <variant>

#include "atlas/array/ArrayView.h"
#include "atlas/field/Field.h"
#include "atlas/field/FieldSet.h"
Expand Down Expand Up @@ -203,59 +206,40 @@ template <typename MatMul>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField,
const MatMul& matMul) {
ATLAS_ASSERT_MSG(sourceField.variables() == 2 || sourceField.variables() == 3,
"Vector field can only have 2 or 3 components.");

if (sourceField.datatype().kind() == array::DataType::KIND_REAL64) {
interpolate_vector_field<double>(sourceField, targetField, matMul);
return;
}

if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) {
interpolate_vector_field<float>(sourceField, targetField, matMul);
return;
}
const auto sourceViewVariant = array::make_view_variant(sourceField);

const auto sourceViewVisitor = [&](auto sourceView) {
if constexpr (array::is_rank<2, 3>(sourceView) &&
array::is_non_const_value_type<float, double>(sourceView)) {
using SourceView = std::decay_t<decltype(sourceView)>;
using Value = typename SourceView::non_const_value_type;
constexpr auto Rank = SourceView::rank();
auto targetView = array::make_view<Value, Rank>(targetField);

switch (sourceField.variables()) {
case 2:
return matMul.apply(sourceView, targetView, twoVector);
case 3:
return matMul.apply(sourceView, targetView, threeVector);
default:
ATLAS_THROW_EXCEPTION("Error: no support for " +
std::to_string(sourceField.variables()) +
" variable vector fields.\n" +
" Number of variables must be 2 or 3.");
}

} else {
ATLAS_THROW_EXCEPTION(
"Error: no support for rank = " + std::to_string(sourceField.rank()) +
" and value type = " + sourceField.datatype().str() + ".\n" +
"Vector field must have rank 2 or 3 with value type "
"float or double");
}
};

ATLAS_NOTIMPLEMENTED;
std::visit(sourceViewVisitor, sourceViewVariant);
};

template <typename Value, typename MatMul>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField,
const MatMul& matMul) {
if (sourceField.rank() == 2) {
interpolate_vector_field<Value, 2>(sourceField, targetField, matMul);
return;
}

if (sourceField.rank() == 3) {
interpolate_vector_field<Value, 3>(sourceField, targetField, matMul);
return;
}

ATLAS_NOTIMPLEMENTED;
}

template <typename Value, int Rank, typename MatMul>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField,
const MatMul& matMul) {
const auto sourceView = array::make_view<Value, Rank>(sourceField);
auto targetView = array::make_view<Value, Rank>(targetField);

if (sourceField.variables() == 2) {
matMul.apply(sourceView, targetView, twoVector);
return;
}

if (sourceField.variables() == 3) {
matMul.apply(sourceView, targetView, threeVector);
return;
}

ATLAS_NOTIMPLEMENTED;
}

} // namespace method
} // namespace interpolation
} // namespace atlas

0 comments on commit a065858

Please sign in to comment.