diff --git a/src/atlas/CMakeLists.txt b/src/atlas/CMakeLists.txt index 366fdc1e5..4a18dec2a 100644 --- a/src/atlas/CMakeLists.txt +++ b/src/atlas/CMakeLists.txt @@ -525,8 +525,6 @@ functionspace/detail/CubedSphereStructure.cc # for cubedsphere matching mesh partitioner interpolation/method/cubedsphere/CellFinder.cc interpolation/method/cubedsphere/CellFinder.h -interpolation/method/paralleltransport/ParallelTransport.h -interpolation/method/paralleltransport/ParallelTransport.cc interpolation/Vector2D.cc interpolation/Vector2D.h interpolation/Vector3D.cc @@ -671,6 +669,15 @@ interpolation/nonlinear/NonLinear.cc interpolation/nonlinear/NonLinear.h ) +if (eckit_EIGEN_FOUND) + +list (APPEND atlas_interpolation_srcs +interpolation/method/sphericalvector/SphericalVector.h +interpolation/method/sphericalvector/SphericalVector.cc +) + +endif() + list( APPEND atlas_linalg_srcs linalg/Indexing.h linalg/Introspection.h diff --git a/src/atlas/interpolation/method/MethodFactory.cc b/src/atlas/interpolation/method/MethodFactory.cc index 7a22e72c4..9f2bff494 100644 --- a/src/atlas/interpolation/method/MethodFactory.cc +++ b/src/atlas/interpolation/method/MethodFactory.cc @@ -12,7 +12,7 @@ // for static linking #include "cubedsphere/CubedSphereBilinear.h" -#include "paralleltransport/ParallelTransport.h" +#include "sphericalvector/SphericalVector.h" #include "knn/GridBoxAverage.h" #include "knn/GridBoxMaximum.h" #include "knn/KNearestNeighbours.h" @@ -48,7 +48,7 @@ void force_link() { MethodBuilder(); MethodBuilder(); MethodBuilder(); - MethodBuilder(); + MethodBuilder(); } } link; } diff --git a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc similarity index 56% rename from src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc rename to src/atlas/interpolation/method/sphericalvector/SphericalVector.cc index 59e23c18e..4376a98ae 100644 --- a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc +++ b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc @@ -16,7 +16,7 @@ #include "atlas/interpolation/Cache.h" #include "atlas/interpolation/Interpolation.h" #include "atlas/interpolation/method/MethodFactory.h" -#include "atlas/interpolation/method/paralleltransport/ParallelTransport.h" +#include "atlas/interpolation/method/sphericalvector/SphericalVector.h" #include "atlas/linalg/sparse.h" #include "atlas/option/Options.h" #include "atlas/parallel/omp/omp.h" @@ -32,33 +32,14 @@ namespace interpolation { namespace method { namespace { -MethodBuilder __builder("parallel-transport"); +MethodBuilder __builder("spherical-vector"); template -void sparseMatrixForEach(MatrixT&& matrix, const Functor& functor) { - - const auto nRows = matrix.rows(); - const auto nCols = matrix.cols(); - const auto rowIndices = matrix.outer(); - const auto colIndices = matrix.inner(); - auto valData = matrix.data(); - - atlas_omp_parallel_for(auto i = size_t{}; i < nRows; ++i) { - for (auto dataIdx = rowIndices[i]; dataIdx < rowIndices[i + 1]; ++dataIdx) { - const auto j = size_t(colIndices[dataIdx]); - auto&& value = valData[dataIdx]; - - if constexpr( - std::is_invocable_v) { - functor(value, i, j); - } - else if constexpr(std::is_invocable_v) { - functor(value, i, j, dataIdx); - } - else { - ATLAS_NOTIMPLEMENTED; - } +void sparseMatrixForEach(const MatrixT& matrix, const Functor& functor) { + + atlas_omp_parallel_for (auto k = 0; k < matrix.outerSize(); ++k) { + for (auto it = typename MatrixT::InnerIterator(matrix, k); it; ++it) { + functor(it.value(), it.row(), it.col()); } } } @@ -70,12 +51,12 @@ void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView, sparseMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) { - constexpr auto rank = std::decay_t::rank(); + constexpr auto rank = std::decay_t::rank(); if constexpr(rank == 2) { const auto sourceSlice = sourceView.slice(j, array::Range::all()); auto targetSlice = targetView.slice(i, array::Range::all()); mappingFunctor(weight, sourceSlice, targetSlice); - } + } else if constexpr(rank == 3) { const auto iterationFuctor = [&](auto&& sourceVars, auto&& targetVars) { mappingFunctor(weight, sourceVars, targetVars); @@ -86,7 +67,7 @@ void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView, targetView.slice(i, array::Range::all(), array::Range::all()); array::helpers::ArrayForEach<0>::apply( std::tie(sourceSlice, targetSlice), iterationFuctor); - } + } else { ATLAS_NOTIMPLEMENTED; } @@ -95,14 +76,14 @@ void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView, } // namespace -void ParallelTransport::do_setup(const Grid& source, const Grid& target, +void SphericalVector::do_setup(const Grid& source, const Grid& target, const Cache&) { ATLAS_NOTIMPLEMENTED; } -void ParallelTransport::do_setup(const FunctionSpace& source, +void SphericalVector::do_setup(const FunctionSpace& source, const FunctionSpace& target) { - ATLAS_TRACE("interpolation::method::ParallelTransport::do_setup"); + ATLAS_TRACE("interpolation::method::SphericalVector::do_setup"); source_ = source; target_ = target; @@ -114,20 +95,23 @@ void ParallelTransport::do_setup(const FunctionSpace& source, Interpolation(interpolationScheme_, source_, target_); setMatrix(MatrixCache(baseInterpolator)); - // Get matrix dimensions. + // Get matrix data. const auto nRows = matrix().rows(); const auto nCols = matrix().cols(); const auto nNonZeros = matrix().nonZeros(); - auto weightsReal = std::vector(nNonZeros); - auto weightsImag = std::vector(nNonZeros); + realWeights_ = + std::make_shared(nRows, nCols, nNonZeros, matrix().outer(), + matrix().inner(), matrix().data()); + + complexWeights_ = std::make_shared(nRows, nCols); + auto complexTriplets = ComplexTriplets(nNonZeros); - const auto sourceLonLats = array::make_view(source_.lonlat()); - const auto targetLonLats = array::make_view(target_.lonlat()); + sparseMatrixForEach(*realWeights_, [&](const auto& weight, auto i, auto j) { + + const auto sourceLonLats = array::make_view(source_.lonlat()); + const auto targetLonLats = array::make_view(target_.lonlat()); - // Make complex weights (would be nice if we could have a complex matrix). - sparseMatrixForEach(matrix(), - [&](auto&& weight, auto i, auto j, auto dataIdx) { const auto sourceLonLat = PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1)); const auto targetLonLat = @@ -138,26 +122,23 @@ void ParallelTransport::do_setup(const FunctionSpace& source, auto deltaAlpha = (alpha.first - alpha.second) * util::Constants::degreesToRadians(); - weightsReal[dataIdx] = {i, j, weight * std::cos(deltaAlpha)}; - weightsImag[dataIdx] = {i, j, weight * std::sin(deltaAlpha)}; - }); + auto idx = &weight - realWeights_->valuePtr(); - // Deal with slightly old fashioned Matrix interface - const auto buildMatrix = [&](auto& matrix, const auto& weights) { - auto tempMatrix = Matrix(nRows, nCols, weights); - matrix.swap(tempMatrix); - }; + complexTriplets[idx] = {i, j, std::polar(weight, deltaAlpha)}; + }); + complexWeights_->setFromTriplets(complexTriplets.begin(), + complexTriplets.end()); - buildMatrix(matrixReal_, weightsReal); - buildMatrix(matrixImag_, weightsImag); + ATLAS_ASSERT(complexWeights_->nonZeros() == matrix().nonZeros()); + ATLAS_ASSERT(realWeights_->nonZeros() == matrix().nonZeros()); } -void ParallelTransport::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; } +void SphericalVector::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; } -void ParallelTransport::do_execute(const FieldSet& sourceFieldSet, +void SphericalVector::do_execute(const FieldSet& sourceFieldSet, FieldSet& targetFieldSet, Metadata& metadata) const { - ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()"); + ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()"); const auto nFields = sourceFieldSet.size(); ATLAS_ASSERT(nFields == targetFieldSet.size()); @@ -167,9 +148,9 @@ void ParallelTransport::do_execute(const FieldSet& sourceFieldSet, } } -void ParallelTransport::do_execute(const Field& sourceField, Field& targetField, +void SphericalVector::do_execute(const Field& sourceField, Field& targetField, Metadata&) const { - ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()"); + ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()"); if (!(sourceField.variables() == 2 || sourceField.variables() == 3)) { @@ -197,7 +178,7 @@ void ParallelTransport::do_execute(const Field& sourceField, Field& targetField, } template -void ParallelTransport::interpolate_vector_field(const Field& sourceField, +void SphericalVector::interpolate_vector_field(const Field& sourceField, Field& targetField) const { if (sourceField.rank() == 2) { interpolate_vector_field(sourceField, targetField); @@ -209,32 +190,31 @@ void ParallelTransport::interpolate_vector_field(const Field& sourceField, } template -void ParallelTransport::interpolate_vector_field(const Field& sourceField, +void SphericalVector::interpolate_vector_field(const Field& sourceField, Field& targetField) const { const auto sourceView = array::make_view(sourceField); auto targetView = array::make_view(targetField); targetView.assign(0.); - // Matrix multiplication split in two to simulate complex variable - // multiplication. - matrixMultiply(matrixReal_, sourceView, targetView, - [](const auto& weight, auto&& sourceVars, auto&& targetVars) { - targetVars(0) += weight * sourceVars(0); - targetVars(1) += weight * sourceVars(1); - }); - matrixMultiply(matrixImag_, sourceView, targetView, - [](const auto& weight, auto&& sourceVars, auto&& targetVars) { - targetVars(0) -= weight * sourceVars(1); - targetVars(1) += weight * sourceVars(0); - }); + const auto horizontalComponent = [](const auto& weight, auto&& sourceVars, + auto&& targetVars) { + const auto targetVector = + weight * std::complex(sourceVars(0), sourceVars(1)); - if (sourceField.variables() == 3) { - matrixMultiply( - matrix(), sourceView, targetView, - [](const auto& weight, auto&& sourceVars, - auto&& targetVars) { targetVars(2) = weight * sourceVars(2); }); - } + targetVars(0) += targetVector.real(); + targetVars(1) += targetVector.imag(); + }; + + const auto verticalComponent = []( + const auto& weight, auto&& sourceVars, + auto&& targetVars) { targetVars(2) += weight * sourceVars(2); }; + + matrixMultiply(*complexWeights_, sourceView, targetView, horizontalComponent); + + if (sourceField.variables() == 2) return; + + matrixMultiply(*realWeights_, sourceView, targetView, verticalComponent); } } // namespace method diff --git a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.h b/src/atlas/interpolation/method/sphericalvector/SphericalVector.h similarity index 73% rename from src/atlas/interpolation/method/paralleltransport/ParallelTransport.h rename to src/atlas/interpolation/method/sphericalvector/SphericalVector.h index fbc427104..a4d3703a2 100644 --- a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.h +++ b/src/atlas/interpolation/method/sphericalvector/SphericalVector.h @@ -7,6 +7,11 @@ #pragma once +#include +#include + +#include + #include "atlas/functionspace/FunctionSpace.h" #include "atlas/interpolation/method/Method.h" #include "atlas/linalg/sparse.h" @@ -16,14 +21,19 @@ namespace atlas { namespace interpolation { namespace method { -class ParallelTransport : public Method { +using RealMatrix = Eigen::SparseMatrix; +using RealMatrixMap = Eigen::Map; +using ComplexMatrix = Eigen::SparseMatrix, Eigen::RowMajor>; +using ComplexTriplets = std::vector>>; + +class SphericalVector : public Method { public: - ParallelTransport(const Config& config): Method(config) { + SphericalVector(const Config& config): Method(config) { const auto& conf = dynamic_cast(config); interpolationScheme_ = conf.getSubConfiguration("scheme"); } - virtual ~ParallelTransport() override {} + virtual ~SphericalVector() override {} void print(std::ostream&) const override; const FunctionSpace& source() const override { return source_; } @@ -50,10 +60,8 @@ class ParallelTransport : public Method { FunctionSpace source_; FunctionSpace target_; - // Complex interpolation weights. We treat a (u, v) pair as a complex variable - // with u on the real number line and v on the imaginary number line. - Matrix matrixReal_; - Matrix matrixImag_; + std::shared_ptr realWeights_; + std::shared_ptr complexWeights_; }; diff --git a/src/tests/interpolation/CMakeLists.txt b/src/tests/interpolation/CMakeLists.txt index c680db212..dc4df9e29 100644 --- a/src/tests/interpolation/CMakeLists.txt +++ b/src/tests/interpolation/CMakeLists.txt @@ -81,9 +81,10 @@ ecbuild_add_test( TARGET atlas_test_interpolation_non_linear ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ) -ecbuild_add_test( TARGET atlas_test_interpolation_parallel_transport +ecbuild_add_test( TARGET atlas_test_interpolation_spherical_vector + CONDITION eckit_EIGEN_FOUND OMP 4 - SOURCES test_interpolation_parallel_transport.cc + SOURCES test_interpolation_spherical_vector.cc LIBS atlas ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ) diff --git a/src/tests/interpolation/test_interpolation_parallel_transport.cc b/src/tests/interpolation/test_interpolation_spherical_vector.cc similarity index 95% rename from src/tests/interpolation/test_interpolation_parallel_transport.cc rename to src/tests/interpolation/test_interpolation_spherical_vector.cc index 23967c6f8..7e09cdf02 100644 --- a/src/tests/interpolation/test_interpolation_parallel_transport.cc +++ b/src/tests/interpolation/test_interpolation_spherical_vector.cc @@ -154,13 +154,13 @@ CASE("cubed sphere vector interpolation") { const auto baseInterpScheme = util::Config("type", "cubedsphere-bilinear").set("adjoint", true); - const auto interpScheme = util::Config("type", "parallel-transport") + const auto interpScheme = util::Config("type", "spherical-vector") .set("scheme", baseInterpScheme); const auto cubedSphereConf = Config("source_grid", "CS-LFR-48") .set("source_mesh", "cubedsphere_dual") .set("target_grid", "O48") .set("target_mesh", "structured") - .set("file_id", "parallel_transport_cs") + .set("file_id", "spherical_vector_cs") .set("scheme", interpScheme); testInterpolation((cubedSphereConf)); @@ -170,13 +170,13 @@ CASE("finite element vector interpolation") { const auto baseInterpScheme = util::Config("type", "finite-element").set("adjoint", true); - const auto interpScheme = util::Config("type", "parallel-transport") + const auto interpScheme = util::Config("type", "spherical-vector") .set("scheme", baseInterpScheme); const auto cubedSphereConf = Config("source_grid", "O48") .set("source_mesh", "structured") .set("target_grid", "CS-LFR-48") .set("target_mesh", "cubedsphere_dual") - .set("file_id", "parallel_transport_fe") + .set("file_id", "spherical_vector_fe") .set("scheme", interpScheme); testInterpolation((cubedSphereConf));