Skip to content

Commit

Permalink
Renamed class to SphericalVector. Added Eigen3 matrices.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Nov 28, 2023
1 parent d5c9ef5 commit 088403e
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 92 deletions.
11 changes: 9 additions & 2 deletions src/atlas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,6 @@ functionspace/detail/CubedSphereStructure.cc
# for cubedsphere matching mesh partitioner
interpolation/method/cubedsphere/CellFinder.cc
interpolation/method/cubedsphere/CellFinder.h
interpolation/method/paralleltransport/ParallelTransport.h
interpolation/method/paralleltransport/ParallelTransport.cc
interpolation/Vector2D.cc
interpolation/Vector2D.h
interpolation/Vector3D.cc
Expand Down Expand Up @@ -671,6 +669,15 @@ interpolation/nonlinear/NonLinear.cc
interpolation/nonlinear/NonLinear.h
)

if (eckit_EIGEN_FOUND)

list (APPEND atlas_interpolation_srcs
interpolation/method/sphericalvector/SphericalVector.h
interpolation/method/sphericalvector/SphericalVector.cc
)

endif()

list( APPEND atlas_linalg_srcs
linalg/Indexing.h
linalg/Introspection.h
Expand Down
4 changes: 2 additions & 2 deletions src/atlas/interpolation/method/MethodFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

// for static linking
#include "cubedsphere/CubedSphereBilinear.h"
#include "paralleltransport/ParallelTransport.h"
#include "sphericalvector/SphericalVector.h"
#include "knn/GridBoxAverage.h"
#include "knn/GridBoxMaximum.h"
#include "knn/KNearestNeighbours.h"
Expand Down Expand Up @@ -48,7 +48,7 @@ void force_link() {
MethodBuilder<method::GridBoxAverage>();
MethodBuilder<method::GridBoxMaximum>();
MethodBuilder<method::CubedSphereBilinear>();
MethodBuilder<method::ParallelTransport>();
MethodBuilder<method::SphericalVector>();
}
} link;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "atlas/interpolation/Cache.h"
#include "atlas/interpolation/Interpolation.h"
#include "atlas/interpolation/method/MethodFactory.h"
#include "atlas/interpolation/method/paralleltransport/ParallelTransport.h"
#include "atlas/interpolation/method/sphericalvector/SphericalVector.h"
#include "atlas/linalg/sparse.h"
#include "atlas/option/Options.h"
#include "atlas/parallel/omp/omp.h"
Expand All @@ -32,33 +32,14 @@ namespace interpolation {
namespace method {

namespace {
MethodBuilder<ParallelTransport> __builder("parallel-transport");
MethodBuilder<SphericalVector> __builder("spherical-vector");

template <typename MatrixT, typename Functor>
void sparseMatrixForEach(MatrixT&& matrix, const Functor& functor) {

const auto nRows = matrix.rows();
const auto nCols = matrix.cols();
const auto rowIndices = matrix.outer();
const auto colIndices = matrix.inner();
auto valData = matrix.data();

atlas_omp_parallel_for(auto i = size_t{}; i < nRows; ++i) {
for (auto dataIdx = rowIndices[i]; dataIdx < rowIndices[i + 1]; ++dataIdx) {
const auto j = size_t(colIndices[dataIdx]);
auto&& value = valData[dataIdx];

if constexpr(
std::is_invocable_v<Functor, decltype(value), size_t, size_t>) {
functor(value, i, j);
}
else if constexpr(std::is_invocable_v<Functor, decltype(value), size_t,
size_t, size_t>) {
functor(value, i, j, dataIdx);
}
else {
ATLAS_NOTIMPLEMENTED;
}
void sparseMatrixForEach(const MatrixT& matrix, const Functor& functor) {

atlas_omp_parallel_for (auto k = 0; k < matrix.outerSize(); ++k) {
for (auto it = typename MatrixT::InnerIterator(matrix, k); it; ++it) {
functor(it.value(), it.row(), it.col());
}
}
}
Expand All @@ -70,12 +51,12 @@ void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView,

sparseMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) {

constexpr auto rank = std::decay_t<decltype(sourceView)>::rank();
constexpr auto rank = std::decay_t<SourceView>::rank();
if constexpr(rank == 2) {
const auto sourceSlice = sourceView.slice(j, array::Range::all());
auto targetSlice = targetView.slice(i, array::Range::all());
mappingFunctor(weight, sourceSlice, targetSlice);
}
}
else if constexpr(rank == 3) {
const auto iterationFuctor = [&](auto&& sourceVars, auto&& targetVars) {
mappingFunctor(weight, sourceVars, targetVars);
Expand All @@ -86,7 +67,7 @@ void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView,
targetView.slice(i, array::Range::all(), array::Range::all());
array::helpers::ArrayForEach<0>::apply(
std::tie(sourceSlice, targetSlice), iterationFuctor);
}
}
else {
ATLAS_NOTIMPLEMENTED;
}
Expand All @@ -95,14 +76,14 @@ void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView,

} // namespace

void ParallelTransport::do_setup(const Grid& source, const Grid& target,
void SphericalVector::do_setup(const Grid& source, const Grid& target,
const Cache&) {
ATLAS_NOTIMPLEMENTED;
}

void ParallelTransport::do_setup(const FunctionSpace& source,
void SphericalVector::do_setup(const FunctionSpace& source,
const FunctionSpace& target) {
ATLAS_TRACE("interpolation::method::ParallelTransport::do_setup");
ATLAS_TRACE("interpolation::method::SphericalVector::do_setup");
source_ = source;
target_ = target;

Expand All @@ -114,20 +95,23 @@ void ParallelTransport::do_setup(const FunctionSpace& source,
Interpolation(interpolationScheme_, source_, target_);
setMatrix(MatrixCache(baseInterpolator));

// Get matrix dimensions.
// Get matrix data.
const auto nRows = matrix().rows();
const auto nCols = matrix().cols();
const auto nNonZeros = matrix().nonZeros();

auto weightsReal = std::vector<eckit::linalg::Triplet>(nNonZeros);
auto weightsImag = std::vector<eckit::linalg::Triplet>(nNonZeros);
realWeights_ =
std::make_shared<RealMatrixMap>(nRows, nCols, nNonZeros, matrix().outer(),
matrix().inner(), matrix().data());

complexWeights_ = std::make_shared<ComplexMatrix>(nRows, nCols);
auto complexTriplets = ComplexTriplets(nNonZeros);

const auto sourceLonLats = array::make_view<double, 2>(source_.lonlat());
const auto targetLonLats = array::make_view<double, 2>(target_.lonlat());
sparseMatrixForEach(*realWeights_, [&](const auto& weight, auto i, auto j) {

const auto sourceLonLats = array::make_view<double, 2>(source_.lonlat());
const auto targetLonLats = array::make_view<double, 2>(target_.lonlat());

// Make complex weights (would be nice if we could have a complex matrix).
sparseMatrixForEach(matrix(),
[&](auto&& weight, auto i, auto j, auto dataIdx) {
const auto sourceLonLat =
PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1));
const auto targetLonLat =
Expand All @@ -138,26 +122,23 @@ void ParallelTransport::do_setup(const FunctionSpace& source,
auto deltaAlpha =
(alpha.first - alpha.second) * util::Constants::degreesToRadians();

weightsReal[dataIdx] = {i, j, weight * std::cos(deltaAlpha)};
weightsImag[dataIdx] = {i, j, weight * std::sin(deltaAlpha)};
});
auto idx = &weight - realWeights_->valuePtr();

// Deal with slightly old fashioned Matrix interface
const auto buildMatrix = [&](auto& matrix, const auto& weights) {
auto tempMatrix = Matrix(nRows, nCols, weights);
matrix.swap(tempMatrix);
};
complexTriplets[idx] = {i, j, std::polar(weight, deltaAlpha)};
});
complexWeights_->setFromTriplets(complexTriplets.begin(),
complexTriplets.end());

buildMatrix(matrixReal_, weightsReal);
buildMatrix(matrixImag_, weightsImag);
ATLAS_ASSERT(complexWeights_->nonZeros() == matrix().nonZeros());
ATLAS_ASSERT(realWeights_->nonZeros() == matrix().nonZeros());
}

void ParallelTransport::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; }
void SphericalVector::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; }

void ParallelTransport::do_execute(const FieldSet& sourceFieldSet,
void SphericalVector::do_execute(const FieldSet& sourceFieldSet,
FieldSet& targetFieldSet,
Metadata& metadata) const {
ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()");
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()");

const auto nFields = sourceFieldSet.size();
ATLAS_ASSERT(nFields == targetFieldSet.size());
Expand All @@ -167,9 +148,9 @@ void ParallelTransport::do_execute(const FieldSet& sourceFieldSet,
}
}

void ParallelTransport::do_execute(const Field& sourceField, Field& targetField,
void SphericalVector::do_execute(const Field& sourceField, Field& targetField,
Metadata&) const {
ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()");
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()");

if (!(sourceField.variables() == 2 || sourceField.variables() == 3)) {

Expand Down Expand Up @@ -197,7 +178,7 @@ void ParallelTransport::do_execute(const Field& sourceField, Field& targetField,
}

template <typename Value>
void ParallelTransport::interpolate_vector_field(const Field& sourceField,
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField) const {
if (sourceField.rank() == 2) {
interpolate_vector_field<Value, 2>(sourceField, targetField);
Expand All @@ -209,32 +190,31 @@ void ParallelTransport::interpolate_vector_field(const Field& sourceField,
}

template <typename Value, int Rank>
void ParallelTransport::interpolate_vector_field(const Field& sourceField,
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField) const {

const auto sourceView = array::make_view<Value, Rank>(sourceField);
auto targetView = array::make_view<Value, Rank>(targetField);
targetView.assign(0.);

// Matrix multiplication split in two to simulate complex variable
// multiplication.
matrixMultiply(matrixReal_, sourceView, targetView,
[](const auto& weight, auto&& sourceVars, auto&& targetVars) {
targetVars(0) += weight * sourceVars(0);
targetVars(1) += weight * sourceVars(1);
});
matrixMultiply(matrixImag_, sourceView, targetView,
[](const auto& weight, auto&& sourceVars, auto&& targetVars) {
targetVars(0) -= weight * sourceVars(1);
targetVars(1) += weight * sourceVars(0);
});
const auto horizontalComponent = [](const auto& weight, auto&& sourceVars,
auto&& targetVars) {
const auto targetVector =
weight * std::complex<double>(sourceVars(0), sourceVars(1));

if (sourceField.variables() == 3) {
matrixMultiply(
matrix(), sourceView, targetView,
[](const auto& weight, auto&& sourceVars,
auto&& targetVars) { targetVars(2) = weight * sourceVars(2); });
}
targetVars(0) += targetVector.real();
targetVars(1) += targetVector.imag();
};

const auto verticalComponent = [](
const auto& weight, auto&& sourceVars,
auto&& targetVars) { targetVars(2) += weight * sourceVars(2); };

matrixMultiply(*complexWeights_, sourceView, targetView, horizontalComponent);

if (sourceField.variables() == 2) return;

matrixMultiply(*realWeights_, sourceView, targetView, verticalComponent);
}

} // namespace method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

#pragma once

#include <complex>
#include <memory>

#include <Eigen/Sparse>

#include "atlas/functionspace/FunctionSpace.h"
#include "atlas/interpolation/method/Method.h"
#include "atlas/linalg/sparse.h"
Expand All @@ -16,14 +21,19 @@ namespace atlas {
namespace interpolation {
namespace method {

class ParallelTransport : public Method {
using RealMatrix = Eigen::SparseMatrix<double, Eigen::RowMajor>;
using RealMatrixMap = Eigen::Map<const RealMatrix>;
using ComplexMatrix = Eigen::SparseMatrix<std::complex<double>, Eigen::RowMajor>;
using ComplexTriplets = std::vector<Eigen::Triplet<std::complex<double>>>;

class SphericalVector : public Method {
public:
ParallelTransport(const Config& config): Method(config) {
SphericalVector(const Config& config): Method(config) {
const auto& conf = dynamic_cast<const eckit::LocalConfiguration&>(config);
interpolationScheme_ = conf.getSubConfiguration("scheme");

}
virtual ~ParallelTransport() override {}
virtual ~SphericalVector() override {}

void print(std::ostream&) const override;
const FunctionSpace& source() const override { return source_; }
Expand All @@ -50,10 +60,8 @@ class ParallelTransport : public Method {
FunctionSpace source_;
FunctionSpace target_;

// Complex interpolation weights. We treat a (u, v) pair as a complex variable
// with u on the real number line and v on the imaginary number line.
Matrix matrixReal_;
Matrix matrixImag_;
std::shared_ptr<RealMatrixMap> realWeights_;
std::shared_ptr<ComplexMatrix> complexWeights_;

};

Expand Down
5 changes: 3 additions & 2 deletions src/tests/interpolation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ ecbuild_add_test( TARGET atlas_test_interpolation_non_linear
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_interpolation_parallel_transport
ecbuild_add_test( TARGET atlas_test_interpolation_spherical_vector
CONDITION eckit_EIGEN_FOUND
OMP 4
SOURCES test_interpolation_parallel_transport.cc
SOURCES test_interpolation_spherical_vector.cc
LIBS atlas
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ CASE("cubed sphere vector interpolation") {

const auto baseInterpScheme =
util::Config("type", "cubedsphere-bilinear").set("adjoint", true);
const auto interpScheme = util::Config("type", "parallel-transport")
const auto interpScheme = util::Config("type", "spherical-vector")
.set("scheme", baseInterpScheme);
const auto cubedSphereConf = Config("source_grid", "CS-LFR-48")
.set("source_mesh", "cubedsphere_dual")
.set("target_grid", "O48")
.set("target_mesh", "structured")
.set("file_id", "parallel_transport_cs")
.set("file_id", "spherical_vector_cs")
.set("scheme", interpScheme);

testInterpolation((cubedSphereConf));
Expand All @@ -170,13 +170,13 @@ CASE("finite element vector interpolation") {

const auto baseInterpScheme =
util::Config("type", "finite-element").set("adjoint", true);
const auto interpScheme = util::Config("type", "parallel-transport")
const auto interpScheme = util::Config("type", "spherical-vector")
.set("scheme", baseInterpScheme);
const auto cubedSphereConf = Config("source_grid", "O48")
.set("source_mesh", "structured")
.set("target_grid", "CS-LFR-48")
.set("target_mesh", "cubedsphere_dual")
.set("file_id", "parallel_transport_fe")
.set("file_id", "spherical_vector_fe")
.set("scheme", interpScheme);

testInterpolation((cubedSphereConf));
Expand Down

0 comments on commit 088403e

Please sign in to comment.