Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Projection base implementation derivatives performance/encapsulation … #185

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/atlas/mesh/Mesh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
*/

#include "atlas/mesh/Mesh.h"
#include "atlas/mesh/Nodes.h"

#include "atlas/grid/Grid.h"
#include "atlas/grid/Partitioner.h"
#include "atlas/mesh/Nodes.h"
#include "atlas/meshgenerator/MeshGenerator.h"
#include "atlas/parallel/mpi/mpi.h"
#include "atlas/runtime/Exception.h"

namespace atlas {

Expand All @@ -23,13 +25,13 @@ Mesh::Mesh(): Handle(new Implementation()) {}

Mesh::Mesh(const Grid& grid, const eckit::Configuration& config):
Handle([&]() {
if(config.has("mpi_comm")) {
if (config.has("mpi_comm")) {
mpi::push(config.getString("mpi_comm"));
}
util::Config cfg = grid.meshgenerator()|util::Config(config);
auto meshgenerator = MeshGenerator{grid.meshgenerator()|config};
auto mesh = meshgenerator.generate(grid, grid::Partitioner(grid.partitioner()|config));
if(config.has("mpi_comm")) {
auto cfg = grid.meshgenerator() | util::Config(config);
auto meshgenerator = MeshGenerator{grid.meshgenerator() | config};
auto mesh = meshgenerator.generate(grid, grid::Partitioner(grid.partitioner() | config));
if (config.has("mpi_comm")) {
mpi::pop();
}
mesh.get()->attach();
Expand All @@ -40,13 +42,13 @@ Mesh::Mesh(const Grid& grid, const eckit::Configuration& config):

Mesh::Mesh(const Grid& grid, const grid::Partitioner& partitioner, const eckit::Configuration& config):
Handle([&]() {
std::string mpi_comm = partitioner.mpi_comm();
if(config.has("mpi_comm")) {
auto mpi_comm = partitioner.mpi_comm();
if (config.has("mpi_comm")) {
mpi_comm = config.getString("mpi_comm");
ATLAS_ASSERT(mpi_comm == partitioner.mpi_comm());
}
mpi::Scope mpi_scope(mpi_comm);
auto meshgenerator = MeshGenerator{grid.meshgenerator()|config};
auto meshgenerator = MeshGenerator{grid.meshgenerator() | config};
auto mesh = meshgenerator.generate(grid, partitioner);
mesh.get()->attach();
return mesh.get();
Expand All @@ -57,7 +59,9 @@ Mesh::Mesh(const Grid& grid, const grid::Partitioner& partitioner, const eckit::

Mesh::Mesh(eckit::Stream& stream): Handle(new Implementation(stream)) {}

Mesh::operator bool() const { return get()->nodes().size() > 0; }
Mesh::operator bool() const {
return get()->nodes().size() > 0;
}

//----------------------------------------------------------------------------------------------------------------------

Expand Down
14 changes: 14 additions & 0 deletions src/atlas/projection/detail/MercatorProjection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

#include <cmath>
#include <functional>
#include <limits>
#include <sstream>

#include "eckit/config/Parametrisation.h"
#include "eckit/types/FloatCompare.h"
#include "eckit/utils/Hash.h"

#include "atlas/projection/detail/MercatorProjection.h"
Expand Down Expand Up @@ -96,6 +98,18 @@ void MercatorProjectionT<Rotation>::lonlat2xy(double crd[]) const {
rotation_.unrotate(crd);

// then project
if (eckit::types::is_approximately_equal<double>(crd[LAT], 90., 1e-3)) {
wdeconinck marked this conversation as resolved.
Show resolved Hide resolved
crd[XX] = false_easting_;
crd[YY] = std::numeric_limits<double>::infinity();
return;
}

if (eckit::types::is_approximately_equal<double>(crd[LAT], -90., 1e-3)) {
crd[XX] = false_easting_;
crd[YY] = -std::numeric_limits<double>::infinity();
return;
}

crd[XX] = k_radius_ * (D2R(normalise_mercator_(crd[LON]) - lon0_));
crd[YY] = k_radius_ * 0.5 * std::log(t(crd[LAT]));
crd[XX] += false_easting_;
Expand Down
76 changes: 30 additions & 46 deletions src/atlas/projection/detail/ProjectionImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,23 @@
* nor does it submit to any jurisdiction.
*/

#include "atlas/projection/detail/ProjectionImpl.h"

#include <algorithm>
#include <cstddef>
#include <iostream>
#include <cstring>
#include <memory>
#include <vector>
#include <utility>

#include "eckit/types/FloatCompare.h"
#include "eckit/utils/Hash.h"
#include "eckit/utils/MD5.h"

#include "atlas/projection/detail/ProjectionImpl.h"

#include "atlas/projection/detail/ProjectionFactory.h"
#include "atlas/runtime/Exception.h"
#include "atlas/util/Config.h"

namespace atlas {
namespace projection {
namespace detail {
namespace atlas::projection::detail {

// --------------------------------------------------------------------------------------------------------------------

Expand All @@ -53,39 +51,27 @@ struct DerivateBuilder : public ProjectionImpl::DerivateFactory {

struct DerivateForwards final : ProjectionImpl::Derivate {
using Derivate::Derivate;
PointLonLat d(PointXY P) const override {
PointLonLat A(xy2lonlat(P));
PointLonLat B(xy2lonlat(PointXY::add(P, H_)));
return PointLonLat::div(PointLonLat::sub(B, A), normH_);
}
PointLonLat d(PointXY P) const override { return (xy2lonlat(P + H_) - xy2lonlat(P)) * invnH_; }
};

struct DerivateBackwards final : ProjectionImpl::Derivate {
using Derivate::Derivate;
PointLonLat d(PointXY P) const override {
PointLonLat A(xy2lonlat(PointXY::sub(P, H_)));
PointLonLat B(xy2lonlat(P));
return PointLonLat::div(PointLonLat::sub(B, A), normH_);
}
PointLonLat d(PointXY P) const override { return (xy2lonlat(P) - xy2lonlat(P - H_)) * invnH_; }
};

struct DerivateCentral final : ProjectionImpl::Derivate {
DerivateCentral(const ProjectionImpl& p, PointXY A, PointXY B, double h, double refLongitude):
Derivate(p, A, B, h, refLongitude), H2_{PointXY::mul(H_, 0.5)} {}
Derivate(p, A, B, h, refLongitude), H2_{H_ * 0.5} {}
const PointXY H2_;
PointLonLat d(PointXY P) const override {
PointLonLat A(xy2lonlat(PointXY::sub(P, H2_)));
PointLonLat B(xy2lonlat(PointXY::add(P, H2_)));
return PointLonLat::div(PointLonLat::sub(B, A), normH_);
}
PointLonLat d(PointXY P) const override { return (xy2lonlat(P + H2_) - xy2lonlat(P - H2_)) * invnH_; }
};

} // namespace

ProjectionImpl::Derivate::Derivate(const ProjectionImpl& p, PointXY A, PointXY B, double h, double refLongitude):
projection_(p),
H_{PointXY::mul(PointXY::normalize(PointXY::sub(B, A)), h)},
normH_(PointXY::norm(H_)),
invnH_(1. / PointXY::norm(H_)),
wdeconinck marked this conversation as resolved.
Show resolved Hide resolved
refLongitude_(refLongitude) {}

ProjectionImpl::Derivate::~Derivate() = default;
Expand All @@ -107,7 +93,7 @@ ProjectionImpl::Derivate* ProjectionImpl::DerivateFactory::build(const std::stri
return new DerivateDegenerate(p, A, B, h, refLongitude);
}

auto factory = get(type);
auto* factory = get(type);
return factory->make(p, A, B, h);
}

Expand Down Expand Up @@ -184,15 +170,15 @@ ProjectionImpl::Normalise::Normalise(const eckit::Parametrisation& p) {
provided = true;
}
if (provided) {
normalise_.reset(new util::NormaliseLongitude(values_[0], values_[1]));
normalise_ = std::make_unique<util::NormaliseLongitude>(values_[0], values_[1]);
}
}

ProjectionImpl::Normalise::Normalise(double west) {
values_.resize(2);
values_[0] = west;
values_[1] = values_[0] + 360.;
normalise_.reset(new util::NormaliseLongitude(values_[0], values_[1]));
normalise_ = std::make_unique<util::NormaliseLongitude>(values_[0], values_[1]);
}


Expand Down Expand Up @@ -246,8 +232,8 @@ RectangularLonLatDomain ProjectionImpl::lonlatBoundingBox(const Domain& domain)
ATLAS_ASSERT(rect);

// use central longitude as absolute reference (keep points within +-180 longitude range)
const auto centre = lonlat({(rect.xmin() + rect.xmax()) / 2., (rect.ymin() + rect.ymax()) / 2.});

const PointXY centre_xy{(rect.xmin() + rect.xmax()) / 2., (rect.ymin() + rect.ymax()) / 2.};
const auto centre_lon = lonlat(centre_xy).lon();

const std::string derivative = "central";
constexpr double h_deg = 0.5e-6; // precision to microdegrees
Expand All @@ -256,28 +242,31 @@ RectangularLonLatDomain ProjectionImpl::lonlatBoundingBox(const Domain& domain)

const double h = units() == "degrees" ? h_deg : h_meters;


// 1. determine box from projected corners

const std::vector<PointXY> corners{
{rect.xmin(), rect.ymax()}, {rect.xmax(), rect.ymax()}, {rect.xmax(), rect.ymin()}, {rect.xmin(), rect.ymin()}};
const std::pair<PointXY, PointXY> segments[] = {{{rect.xmin(), rect.ymax()}, {rect.xmax(), rect.ymax()}},
{{rect.xmax(), rect.ymax()}, {rect.xmax(), rect.ymin()}},
{{rect.xmax(), rect.ymin()}, {rect.xmin(), rect.ymin()}},
{{rect.xmin(), rect.ymin()}, {rect.xmin(), rect.ymax()}}};

BoundLonLat bounds;
for (auto& p : corners) {
for (const auto [p, dummy] : segments) {
auto q = lonlat(p);
longitude_in_range(centre.lon(), q.lon());
longitude_in_range(centre_lon, q.lon());
bounds.extend(q, PointLonLat{h_deg, h_deg});
}


// 2. locate latitude extrema by checking if poles are included (in the un-projected frame) and if not, find extrema
// not at the corners by refining iteratively

for (size_t i = 0; i < corners.size(); ++i) {
if (!bounds.includesNorthPole() || !bounds.includesSouthPole()) {
PointXY A = corners[i];
PointXY B = corners[(i + 1) % corners.size()];
bounds.includesNorthPole(bounds.includesNorthPole() || rect.contains(xy(PointLonLat{0, 90 - h_deg})));
bounds.includesSouthPole(bounds.includesSouthPole() || rect.contains(xy(PointLonLat{0, -90 + h_deg})));

std::unique_ptr<Derivate> derivate(DerivateFactory::build(derivative, *this, A, B, h, centre.lon()));
for (auto [A, B] : segments) {
if (!bounds.includesNorthPole() || !bounds.includesSouthPole()) {
std::unique_ptr<Derivate> derivate(DerivateFactory::build(derivative, *this, A, B, h, centre_lon));
double dAdy = derivate->d(A).lat();
double dBdy = derivate->d(B).lat();

Expand Down Expand Up @@ -307,12 +296,9 @@ RectangularLonLatDomain ProjectionImpl::lonlatBoundingBox(const Domain& domain)

// 3. locate longitude extrema not at the corners by refining iteratively

for (size_t i = 0; i < corners.size(); ++i) {
for (auto [A, B] : segments) {
if (!bounds.crossesDateLine()) {
PointXY A = corners[i];
PointXY B = corners[(i + 1) % corners.size()];

std::unique_ptr<Derivate> derivate(DerivateFactory::build(derivative, *this, A, B, h, centre.lon()));
std::unique_ptr<Derivate> derivate(DerivateFactory::build(derivative, *this, A, B, h, centre_lon));
double dAdx = derivate->d(A).lon();
double dBdx = derivate->d(B).lon();

Expand Down Expand Up @@ -410,6 +396,4 @@ void atlas__Projection__lonlat2xy(const ProjectionImpl* This, const double lon,

} // extern "C"

} // namespace detail
} // namespace projection
} // namespace atlas
} // namespace atlas::projection::detail
21 changes: 6 additions & 15 deletions src/atlas/projection/detail/ProjectionImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@

#pragma once

#include <initializer_list>
#include <iostream>
#include <memory>
#include <string>

#include "atlas/projection/Jacobian.h"
#include "atlas/runtime/Exception.h"
#include "atlas/util/Factory.h"
#include "atlas/util/NormaliseLongitude.h"
#include "atlas/util/Object.h"
Expand All @@ -36,9 +34,7 @@ class Config;
}
} // namespace atlas

namespace atlas {
namespace projection {
namespace detail {
namespace atlas::projection::detail {

class ProjectionImpl : public util::Object {
public:
Expand All @@ -49,8 +45,7 @@ class ProjectionImpl : public util::Object {
static const ProjectionImpl* create(const eckit::Parametrisation& p);
static const ProjectionImpl* create(const std::string& type, const eckit::Parametrisation& p);

ProjectionImpl() = default;
virtual ~ProjectionImpl() = default; // destructor should be virtual
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no need for a virtual destructor?

ProjectionImpl() = default;

virtual std::string type() const = 0;

Expand Down Expand Up @@ -123,7 +118,7 @@ class ProjectionImpl : public util::Object {
protected:
const ProjectionImpl& projection_;
const PointXY H_;
const double normH_;
const double invnH_;
const double refLongitude_;
PointLonLat xy2lonlat(const PointXY& p) const;
};
Expand Down Expand Up @@ -190,10 +185,8 @@ class NotRotated {
static std::string classNamePrefix() { return ""; } // deliberately empty
static std::string typePrefix() { return ""; } // deliberately empty

void rotate(double*) const { /* do nothing */
}
void unrotate(double*) const { /* do nothing */
}
void rotate(double*) const { /* do nothing */ }
void unrotate(double*) const { /* do nothing */ }

bool rotated() const { return false; }

Expand All @@ -211,6 +204,4 @@ void atlas__Projection__xy2lonlat(const ProjectionImpl* This, const double x, co
void atlas__Projection__lonlat2xy(const ProjectionImpl* This, const double lon, const double lat, double& x, double& y);
}

} // namespace detail
} // namespace projection
} // namespace atlas
} // namespace atlas::projection::detail
Loading