Skip to content

Commit

Permalink
Implement PointCloud::gather/scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Sep 22, 2023
1 parent 13ad416 commit 66d8e5f
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 3 deletions.
161 changes: 160 additions & 1 deletion src/atlas/functionspace/PointCloud.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,32 @@ namespace functionspace {

namespace detail {

namespace {

template <typename T, typename Field>
array::LocalView<T, 3> make_leveled_view(Field& field) {
using namespace array;
if (field.levels()) {
if (field.variables()) {
return make_view<T, 3>(field).slice(Range::all(), Range::all(), Range::all());
}
else {
return make_view<T, 2>(field).slice(Range::all(), Range::all(), Range::dummy());
}
}
else {
if (field.variables()) {
return make_view<T, 2>(field).slice(Range::all(), Range::dummy(), Range::all());
}
else {
return make_view<T, 1>(field).slice(Range::all(), Range::dummy(), Range::dummy());
}
}
}

} // namespace


static std::string get_mpi_comm(const eckit::Configuration& config) {
if(config.has("mpi_comm")) {
return config.getString("mpi_comm");
Expand Down Expand Up @@ -84,6 +110,7 @@ PointCloud::PointCloud(const Field& lonlat, const eckit::Configuration& config):
PointCloud::PointCloud(const Field& lonlat, const Field& ghost, const eckit::Configuration& config): lonlat_(lonlat), ghost_(ghost) {
mpi_comm_ = get_mpi_comm(config);
setupHaloExchange();
setupGatherScatter();
}

PointCloud::PointCloud(const FieldSet& flds, const eckit::Configuration& config): lonlat_(flds["lonlat"]) {
Expand All @@ -102,6 +129,7 @@ PointCloud::PointCloud(const FieldSet& flds, const eckit::Configuration& config)
}
if( ghost_ && remote_index_ && partition_ ) {
setupHaloExchange();
setupGatherScatter();
}
}

Expand Down Expand Up @@ -243,6 +271,7 @@ PointCloud::PointCloud(const Grid& grid, const grid::Partitioner& _partitioner,
}

setupHaloExchange();
setupGatherScatter();
}


Expand All @@ -255,9 +284,20 @@ Field PointCloud::ghost() const {
}

array::ArrayShape PointCloud::config_shape(const eckit::Configuration& config) const {
idx_t _size = size();
bool global(false);
if (config.get("global", global)) {
if (global) {
idx_t owner(0);
config.get("owner", owner);
idx_t rank = mpi::comm(mpi_comm()).rank();
_size = (rank == owner ? size_global_ : 0);
}
}

array::ArrayShape shape;

shape.emplace_back(size());
shape.emplace_back(_size);

idx_t levels(levels_);
config.get("levels", levels);
Expand Down Expand Up @@ -302,6 +342,110 @@ const parallel::HaloExchange& PointCloud::halo_exchange() const {
return *halo_exchange_;
}

void PointCloud::gather(const FieldSet& local_fieldset, FieldSet& global_fieldset) const {
ATLAS_ASSERT(local_fieldset.size() == global_fieldset.size());

for (idx_t f = 0; f < local_fieldset.size(); ++f) {
const Field& loc = local_fieldset[f];
Field& glb = global_fieldset[f];
const idx_t nb_fields = 1;
idx_t root(0);
glb.metadata().get("owner", root);

if (loc.datatype() == array::DataType::kind<int>()) {
parallel::Field<int const> loc_field(make_leveled_view<const int>(loc));
parallel::Field<int> glb_field(make_leveled_view<int>(glb));
gather().gather(&loc_field, &glb_field, nb_fields, root);
}
else if (loc.datatype() == array::DataType::kind<long>()) {
parallel::Field<long const> loc_field(make_leveled_view<const long>(loc));
parallel::Field<long> glb_field(make_leveled_view<long>(glb));
gather().gather(&loc_field, &glb_field, nb_fields, root);
}
else if (loc.datatype() == array::DataType::kind<float>()) {
parallel::Field<float const> loc_field(make_leveled_view<const float>(loc));
parallel::Field<float> glb_field(make_leveled_view<float>(glb));
gather().gather(&loc_field, &glb_field, nb_fields, root);
}
else if (loc.datatype() == array::DataType::kind<double>()) {
parallel::Field<double const> loc_field(make_leveled_view<const double>(loc));
parallel::Field<double> glb_field(make_leveled_view<double>(glb));
gather().gather(&loc_field, &glb_field, nb_fields, root);
}
else {
throw_Exception("datatype not supported", Here());
}
}
}

void PointCloud::gather(const Field& local, Field& global) const {
FieldSet local_fields;
FieldSet global_fields;
local_fields.add(local);
global_fields.add(global);
gather(local_fields, global_fields);
}
const parallel::GatherScatter& PointCloud::gather() const {
ATLAS_ASSERT(gather_scatter_);
return *gather_scatter_;
}
const parallel::GatherScatter& PointCloud::scatter() const {
ATLAS_ASSERT(gather_scatter_);
return *gather_scatter_;
}

void PointCloud::scatter(const FieldSet& global_fieldset, FieldSet& local_fieldset) const {
ATLAS_ASSERT(local_fieldset.size() == global_fieldset.size());

for (idx_t f = 0; f < local_fieldset.size(); ++f) {
const Field& glb = global_fieldset[f];
Field& loc = local_fieldset[f];
const idx_t nb_fields = 1;
idx_t root(0);
glb.metadata().get("owner", root);

if (loc.datatype() == array::DataType::kind<int>()) {
parallel::Field<int const> glb_field(make_leveled_view<const int>(glb));
parallel::Field<int> loc_field(make_leveled_view<int>(loc));
scatter().scatter(&glb_field, &loc_field, nb_fields, root);
}
else if (loc.datatype() == array::DataType::kind<long>()) {
parallel::Field<long const> glb_field(make_leveled_view<const long>(glb));
parallel::Field<long> loc_field(make_leveled_view<long>(loc));
scatter().scatter(&glb_field, &loc_field, nb_fields, root);
}
else if (loc.datatype() == array::DataType::kind<float>()) {
parallel::Field<float const> glb_field(make_leveled_view<const float>(glb));
parallel::Field<float> loc_field(make_leveled_view<float>(loc));
scatter().scatter(&glb_field, &loc_field, nb_fields, root);
}
else if (loc.datatype() == array::DataType::kind<double>()) {
parallel::Field<double const> glb_field(make_leveled_view<const double>(glb));
parallel::Field<double> loc_field(make_leveled_view<double>(loc));
scatter().scatter(&glb_field, &loc_field, nb_fields, root);
}
else {
throw_Exception("datatype not supported", Here());
}

auto name = loc.name();
glb.metadata().broadcast(loc.metadata(), root);
loc.metadata().set("global", false);
if( !name.empty() ) {
loc.metadata().set("name", name);
}
}
}

void PointCloud::scatter(const Field& global, Field& local) const {
FieldSet global_fields;
FieldSet local_fields;
global_fields.add(global);
local_fields.add(local);
scatter(global_fields, local_fields);
}


void PointCloud::set_field_metadata(const eckit::Configuration& config, Field& field) const {
field.set_functionspace(this);

Expand Down Expand Up @@ -849,6 +993,21 @@ void PointCloud::setupHaloExchange() {
ghost_.size());
}

void PointCloud::setupGatherScatter() {
if (ghost_ and partition_ and remote_index_ and global_index_) {
gather_scatter_.reset(new parallel::GatherScatter());
gather_scatter_->setup(mpi_comm_,
array::make_view<int, 1>(partition_).data(),
array::make_view<idx_t, 1>(remote_index_).data(),
REMOTE_IDX_BASE,
array::make_view<gidx_t, 1>(global_index_).data(),
array::make_view<idx_t, 1>(ghost_).data(),
ghost_.size());
size_global_ = gather_scatter_->glb_dof();
}
}


void PointCloud::adjointHaloExchange(const FieldSet& fieldset, bool on_device) const {
if (halo_exchange_) {
for (idx_t f = 0; f < fieldset.size(); ++f) {
Expand Down
18 changes: 17 additions & 1 deletion src/atlas/functionspace/PointCloud.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "atlas/functionspace/FunctionSpace.h"
#include "atlas/functionspace/detail/FunctionSpaceImpl.h"
#include "atlas/parallel/HaloExchange.h"
#include "atlas/parallel/GatherScatter.h"
#include "atlas/runtime/Exception.h"
#include "atlas/util/Config.h"
#include "atlas/util/Point.h"
Expand All @@ -28,6 +29,7 @@
namespace atlas {
namespace parallel {
class HaloExchange;
class GatherScatter;
} // namespace parallel
} // namespace atlas

Expand Down Expand Up @@ -77,6 +79,15 @@ class PointCloud : public functionspace::FunctionSpaceImpl {

const parallel::HaloExchange& halo_exchange() const;

void gather(const FieldSet&, FieldSet&) const override;
void gather(const Field&, Field&) const override;
const parallel::GatherScatter& gather() const override;

void scatter(const FieldSet&, FieldSet&) const override;
void scatter(const Field&, Field&) const override;
const parallel::GatherScatter& scatter() const override;


template <typename Point>
class IteratorT {
public:
Expand Down Expand Up @@ -165,14 +176,19 @@ class PointCloud : public functionspace::FunctionSpaceImpl {
Field global_index_;
Field partition_;
idx_t size_owned_;
idx_t size_global_{0};
idx_t max_glb_idx_{0};
std::unique_ptr<parallel::HaloExchange> halo_exchange_;

idx_t levels_{0};
idx_t part_{0};
idx_t nb_partitions_{1};
std::string mpi_comm_;

mutable std::unique_ptr<parallel::HaloExchange> halo_exchange_;
mutable std::unique_ptr<parallel::GatherScatter> gather_scatter_;

void setupHaloExchange();
void setupGatherScatter();

};

Expand Down
33 changes: 33 additions & 0 deletions src/tests/functionspace/test_functionspace_splitcomm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "atlas/functionspace/NodeColumns.h"
#include "atlas/functionspace/StructuredColumns.h"
#include "atlas/functionspace/BlockStructuredColumns.h"
#include "atlas/functionspace/PointCloud.h"
#include "atlas/grid/Partitioner.h"
#include "atlas/field/for_each.h"

Expand Down Expand Up @@ -173,6 +174,38 @@ CASE("test FunctionSpace BlockStructuredColumns") {

//-----------------------------------------------------------------------------

CASE("test FunctionSpace PointCloud") {
Fixture fixture;

auto fs = functionspace::PointCloud(grid(),util::Config("halo_radius",400*1000)|option::mpi_split_comm());
EXPECT_EQUAL(fs.part(),mpi::comm("split").rank());
EXPECT_EQUAL(fs.nb_parts(),mpi::comm("split").size());

auto field = fs.createField<double>();
field_init(field);

// HaloExchange
field.haloExchange();
// TODO CHECK

// Gather
auto fieldg = fs.createField<double>(atlas::option::global());
fs.gather(field,fieldg);

if (fieldg.size()) {
idx_t g{0};
field::for_each_value(fieldg,[&](double x) {
EXPECT_EQ(++g,x);
});
}

// Checksum
// auto checksum = fs.checksum(field);
// EXPECT_EQ(checksum, expected_checksum());
}

//-----------------------------------------------------------------------------

CASE("test FunctionSpace StructuredColumns with MatchingPartitioner") {
Fixture fixture;

Expand Down
39 changes: 38 additions & 1 deletion src/tests/functionspace/test_pointcloud_halo_creation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,27 @@ auto ghost = array::make_view<int,1>(pointcloud.ghost());

auto view = array::make_view<double,1>(field);


auto fieldg_init = pointcloud.createField<double>(option::name("fieldg_init")|option::global());

if (mpi::rank() == 0) {
auto viewg = array::make_view<double,1>(fieldg_init);
gidx_t g=0;
for (auto& p: grid.lonlat()) {
double lat = p.lat() * M_PI/180.;
viewg(g) = std::cos(4.*lat);
g++;
}
}

pointcloud.scatter(fieldg_init,field);

size_t count_ghost{0};
for (idx_t i=0; i<pointcloud.size(); ++i) {
if( not ghost(i) ) {
++count_ghost;
double lat = lonlat(i,1) * M_PI/180.;
view(i) = std::cos(4.*lat);
EXPECT_EQ(view(i), std::cos(4.*lat));
}
else {
view(i) = 0.;
Expand Down Expand Up @@ -258,6 +273,28 @@ for (idx_t i=0; i<pointcloud.size(); ++i) {
EXPECT_EQ( view(i), std::cos(4.*lat));
}



auto fieldg = pointcloud.createField<double>(option::name("field")|option::global());
if( mpi::rank() == 0 ) {
EXPECT_EQ(fieldg.size(), grid.size());
}
else {
EXPECT_EQ(fieldg.size(), 0);
}

pointcloud.gather(field, fieldg);

if (mpi::rank() == 0) {
auto viewg = array::make_view<double,1>(fieldg);
gidx_t g=0;
for (auto& p: grid.lonlat()) {
double lat = p.lat() * M_PI/180.;
EXPECT_EQ( viewg(g), std::cos(4.*lat));
g++;
}
}

}


Expand Down

0 comments on commit 66d8e5f

Please sign in to comment.