From 66d8e5f9c44a2b5c25afbb2a3a6ce4d260f8baae Mon Sep 17 00:00:00 2001 From: Willem Deconinck Date: Thu, 21 Sep 2023 15:42:01 +0200 Subject: [PATCH] Implement PointCloud::gather/scatter --- src/atlas/functionspace/PointCloud.cc | 161 +++++++++++++++++- src/atlas/functionspace/PointCloud.h | 18 +- .../test_functionspace_splitcomm.cc | 33 ++++ .../test_pointcloud_halo_creation.cc | 39 ++++- 4 files changed, 248 insertions(+), 3 deletions(-) diff --git a/src/atlas/functionspace/PointCloud.cc b/src/atlas/functionspace/PointCloud.cc index 04642a29f..4083d60d1 100644 --- a/src/atlas/functionspace/PointCloud.cc +++ b/src/atlas/functionspace/PointCloud.cc @@ -45,6 +45,32 @@ namespace functionspace { namespace detail { +namespace { + +template +array::LocalView make_leveled_view(Field& field) { + using namespace array; + if (field.levels()) { + if (field.variables()) { + return make_view(field).slice(Range::all(), Range::all(), Range::all()); + } + else { + return make_view(field).slice(Range::all(), Range::all(), Range::dummy()); + } + } + else { + if (field.variables()) { + return make_view(field).slice(Range::all(), Range::dummy(), Range::all()); + } + else { + return make_view(field).slice(Range::all(), Range::dummy(), Range::dummy()); + } + } +} + +} // namespace + + static std::string get_mpi_comm(const eckit::Configuration& config) { if(config.has("mpi_comm")) { return config.getString("mpi_comm"); @@ -84,6 +110,7 @@ PointCloud::PointCloud(const Field& lonlat, const eckit::Configuration& config): PointCloud::PointCloud(const Field& lonlat, const Field& ghost, const eckit::Configuration& config): lonlat_(lonlat), ghost_(ghost) { mpi_comm_ = get_mpi_comm(config); setupHaloExchange(); + setupGatherScatter(); } PointCloud::PointCloud(const FieldSet& flds, const eckit::Configuration& config): lonlat_(flds["lonlat"]) { @@ -102,6 +129,7 @@ PointCloud::PointCloud(const FieldSet& flds, const eckit::Configuration& config) } if( ghost_ && remote_index_ && partition_ ) { setupHaloExchange(); + setupGatherScatter(); } } @@ -243,6 +271,7 @@ PointCloud::PointCloud(const Grid& grid, const grid::Partitioner& _partitioner, } setupHaloExchange(); + setupGatherScatter(); } @@ -255,9 +284,20 @@ Field PointCloud::ghost() const { } array::ArrayShape PointCloud::config_shape(const eckit::Configuration& config) const { + idx_t _size = size(); + bool global(false); + if (config.get("global", global)) { + if (global) { + idx_t owner(0); + config.get("owner", owner); + idx_t rank = mpi::comm(mpi_comm()).rank(); + _size = (rank == owner ? size_global_ : 0); + } + } + array::ArrayShape shape; - shape.emplace_back(size()); + shape.emplace_back(_size); idx_t levels(levels_); config.get("levels", levels); @@ -302,6 +342,110 @@ const parallel::HaloExchange& PointCloud::halo_exchange() const { return *halo_exchange_; } +void PointCloud::gather(const FieldSet& local_fieldset, FieldSet& global_fieldset) const { + ATLAS_ASSERT(local_fieldset.size() == global_fieldset.size()); + + for (idx_t f = 0; f < local_fieldset.size(); ++f) { + const Field& loc = local_fieldset[f]; + Field& glb = global_fieldset[f]; + const idx_t nb_fields = 1; + idx_t root(0); + glb.metadata().get("owner", root); + + if (loc.datatype() == array::DataType::kind()) { + parallel::Field loc_field(make_leveled_view(loc)); + parallel::Field glb_field(make_leveled_view(glb)); + gather().gather(&loc_field, &glb_field, nb_fields, root); + } + else if (loc.datatype() == array::DataType::kind()) { + parallel::Field loc_field(make_leveled_view(loc)); + parallel::Field glb_field(make_leveled_view(glb)); + gather().gather(&loc_field, &glb_field, nb_fields, root); + } + else if (loc.datatype() == array::DataType::kind()) { + parallel::Field loc_field(make_leveled_view(loc)); + parallel::Field glb_field(make_leveled_view(glb)); + gather().gather(&loc_field, &glb_field, nb_fields, root); + } + else if (loc.datatype() == array::DataType::kind()) { + parallel::Field loc_field(make_leveled_view(loc)); + parallel::Field glb_field(make_leveled_view(glb)); + gather().gather(&loc_field, &glb_field, nb_fields, root); + } + else { + throw_Exception("datatype not supported", Here()); + } + } +} + +void PointCloud::gather(const Field& local, Field& global) const { + FieldSet local_fields; + FieldSet global_fields; + local_fields.add(local); + global_fields.add(global); + gather(local_fields, global_fields); +} +const parallel::GatherScatter& PointCloud::gather() const { + ATLAS_ASSERT(gather_scatter_); + return *gather_scatter_; +} +const parallel::GatherScatter& PointCloud::scatter() const { + ATLAS_ASSERT(gather_scatter_); + return *gather_scatter_; +} + +void PointCloud::scatter(const FieldSet& global_fieldset, FieldSet& local_fieldset) const { + ATLAS_ASSERT(local_fieldset.size() == global_fieldset.size()); + + for (idx_t f = 0; f < local_fieldset.size(); ++f) { + const Field& glb = global_fieldset[f]; + Field& loc = local_fieldset[f]; + const idx_t nb_fields = 1; + idx_t root(0); + glb.metadata().get("owner", root); + + if (loc.datatype() == array::DataType::kind()) { + parallel::Field glb_field(make_leveled_view(glb)); + parallel::Field loc_field(make_leveled_view(loc)); + scatter().scatter(&glb_field, &loc_field, nb_fields, root); + } + else if (loc.datatype() == array::DataType::kind()) { + parallel::Field glb_field(make_leveled_view(glb)); + parallel::Field loc_field(make_leveled_view(loc)); + scatter().scatter(&glb_field, &loc_field, nb_fields, root); + } + else if (loc.datatype() == array::DataType::kind()) { + parallel::Field glb_field(make_leveled_view(glb)); + parallel::Field loc_field(make_leveled_view(loc)); + scatter().scatter(&glb_field, &loc_field, nb_fields, root); + } + else if (loc.datatype() == array::DataType::kind()) { + parallel::Field glb_field(make_leveled_view(glb)); + parallel::Field loc_field(make_leveled_view(loc)); + scatter().scatter(&glb_field, &loc_field, nb_fields, root); + } + else { + throw_Exception("datatype not supported", Here()); + } + + auto name = loc.name(); + glb.metadata().broadcast(loc.metadata(), root); + loc.metadata().set("global", false); + if( !name.empty() ) { + loc.metadata().set("name", name); + } + } +} + +void PointCloud::scatter(const Field& global, Field& local) const { + FieldSet global_fields; + FieldSet local_fields; + global_fields.add(global); + local_fields.add(local); + scatter(global_fields, local_fields); +} + + void PointCloud::set_field_metadata(const eckit::Configuration& config, Field& field) const { field.set_functionspace(this); @@ -849,6 +993,21 @@ void PointCloud::setupHaloExchange() { ghost_.size()); } +void PointCloud::setupGatherScatter() { + if (ghost_ and partition_ and remote_index_ and global_index_) { + gather_scatter_.reset(new parallel::GatherScatter()); + gather_scatter_->setup(mpi_comm_, + array::make_view(partition_).data(), + array::make_view(remote_index_).data(), + REMOTE_IDX_BASE, + array::make_view(global_index_).data(), + array::make_view(ghost_).data(), + ghost_.size()); + size_global_ = gather_scatter_->glb_dof(); + } +} + + void PointCloud::adjointHaloExchange(const FieldSet& fieldset, bool on_device) const { if (halo_exchange_) { for (idx_t f = 0; f < fieldset.size(); ++f) { diff --git a/src/atlas/functionspace/PointCloud.h b/src/atlas/functionspace/PointCloud.h index f603a3146..a58207554 100644 --- a/src/atlas/functionspace/PointCloud.h +++ b/src/atlas/functionspace/PointCloud.h @@ -20,6 +20,7 @@ #include "atlas/functionspace/FunctionSpace.h" #include "atlas/functionspace/detail/FunctionSpaceImpl.h" #include "atlas/parallel/HaloExchange.h" +#include "atlas/parallel/GatherScatter.h" #include "atlas/runtime/Exception.h" #include "atlas/util/Config.h" #include "atlas/util/Point.h" @@ -28,6 +29,7 @@ namespace atlas { namespace parallel { class HaloExchange; +class GatherScatter; } // namespace parallel } // namespace atlas @@ -77,6 +79,15 @@ class PointCloud : public functionspace::FunctionSpaceImpl { const parallel::HaloExchange& halo_exchange() const; + void gather(const FieldSet&, FieldSet&) const override; + void gather(const Field&, Field&) const override; + const parallel::GatherScatter& gather() const override; + + void scatter(const FieldSet&, FieldSet&) const override; + void scatter(const Field&, Field&) const override; + const parallel::GatherScatter& scatter() const override; + + template class IteratorT { public: @@ -165,14 +176,19 @@ class PointCloud : public functionspace::FunctionSpaceImpl { Field global_index_; Field partition_; idx_t size_owned_; + idx_t size_global_{0}; idx_t max_glb_idx_{0}; - std::unique_ptr halo_exchange_; + idx_t levels_{0}; idx_t part_{0}; idx_t nb_partitions_{1}; std::string mpi_comm_; + mutable std::unique_ptr halo_exchange_; + mutable std::unique_ptr gather_scatter_; + void setupHaloExchange(); + void setupGatherScatter(); }; diff --git a/src/tests/functionspace/test_functionspace_splitcomm.cc b/src/tests/functionspace/test_functionspace_splitcomm.cc index 68a10b88f..56ed9dae3 100644 --- a/src/tests/functionspace/test_functionspace_splitcomm.cc +++ b/src/tests/functionspace/test_functionspace_splitcomm.cc @@ -16,6 +16,7 @@ #include "atlas/functionspace/NodeColumns.h" #include "atlas/functionspace/StructuredColumns.h" #include "atlas/functionspace/BlockStructuredColumns.h" +#include "atlas/functionspace/PointCloud.h" #include "atlas/grid/Partitioner.h" #include "atlas/field/for_each.h" @@ -173,6 +174,38 @@ CASE("test FunctionSpace BlockStructuredColumns") { //----------------------------------------------------------------------------- +CASE("test FunctionSpace PointCloud") { + Fixture fixture; + + auto fs = functionspace::PointCloud(grid(),util::Config("halo_radius",400*1000)|option::mpi_split_comm()); + EXPECT_EQUAL(fs.part(),mpi::comm("split").rank()); + EXPECT_EQUAL(fs.nb_parts(),mpi::comm("split").size()); + + auto field = fs.createField(); + field_init(field); + + // HaloExchange + field.haloExchange(); + // TODO CHECK + + // Gather + auto fieldg = fs.createField(atlas::option::global()); + fs.gather(field,fieldg); + + if (fieldg.size()) { + idx_t g{0}; + field::for_each_value(fieldg,[&](double x) { + EXPECT_EQ(++g,x); + }); + } + + // Checksum + // auto checksum = fs.checksum(field); + // EXPECT_EQ(checksum, expected_checksum()); +} + +//----------------------------------------------------------------------------- + CASE("test FunctionSpace StructuredColumns with MatchingPartitioner") { Fixture fixture; diff --git a/src/tests/functionspace/test_pointcloud_halo_creation.cc b/src/tests/functionspace/test_pointcloud_halo_creation.cc index 6228c2a34..2ee1593a1 100644 --- a/src/tests/functionspace/test_pointcloud_halo_creation.cc +++ b/src/tests/functionspace/test_pointcloud_halo_creation.cc @@ -217,12 +217,27 @@ auto ghost = array::make_view(pointcloud.ghost()); auto view = array::make_view(field); + +auto fieldg_init = pointcloud.createField(option::name("fieldg_init")|option::global()); + +if (mpi::rank() == 0) { + auto viewg = array::make_view(fieldg_init); + gidx_t g=0; + for (auto& p: grid.lonlat()) { + double lat = p.lat() * M_PI/180.; + viewg(g) = std::cos(4.*lat); + g++; + } +} + +pointcloud.scatter(fieldg_init,field); + size_t count_ghost{0}; for (idx_t i=0; i(option::name("field")|option::global()); +if( mpi::rank() == 0 ) { + EXPECT_EQ(fieldg.size(), grid.size()); +} +else { + EXPECT_EQ(fieldg.size(), 0); +} + +pointcloud.gather(field, fieldg); + +if (mpi::rank() == 0) { + auto viewg = array::make_view(fieldg); + gidx_t g=0; + for (auto& p: grid.lonlat()) { + double lat = p.lat() * M_PI/180.; + EXPECT_EQ( viewg(g), std::cos(4.*lat)); + g++; + } +} + }