Skip to content

Commit

Permalink
Use mpi_comm in Redistribute
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Sep 12, 2023
1 parent 0697ed1 commit 49a35f7
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 53 deletions.
51 changes: 32 additions & 19 deletions src/atlas/redistribution/detail/RedistributeGeneric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Field getGhostField(const FunctionSpace& functionspace) {
}
if (functionspace::EdgeColumns(functionspace) || functionspace::CellColumns(functionspace)) {
// TODO: Move something like this into the functionspace::EdgeColumns and functionspace::CellColumns
auto& comm = mpi::comm(functionspace.mpi_comm());

// Get mesh elements.
const auto& elems = functionspace::EdgeColumns(functionspace)
Expand All @@ -58,7 +59,7 @@ Field getGhostField(const FunctionSpace& functionspace) {
auto partition = array::make_view<int, 1>(elems.partition());

// Set ghost field.
const auto thisPart = static_cast<int>(mpi::comm().rank());
const auto thisPart = static_cast<int>(comm.rank());
for (idx_t i = 0; i < ghost.shape(0); ++i) {
ghost(i) = partition(i) != thisPart || remote_index(i) != i;
}
Expand Down Expand Up @@ -147,19 +148,20 @@ std::vector<uidx_t> getUidVal(const std::vector<IdxUid>& uidVec) {
}

// Communicate UID values, return receive buffer and displacements.
std::pair<std::vector<uidx_t>, std::vector<int>> communicateUid(const std::vector<uidx_t>& sendBuffer) {
auto counts = std::vector<int>(mpi::comm().size());
mpi::comm().allGather(static_cast<int>(sendBuffer.size()), counts.begin(), counts.end());
std::pair<std::vector<uidx_t>, std::vector<int>> communicateUid(const std::string& mpi_comm, const std::vector<uidx_t>& sendBuffer) {
auto& comm = mpi::comm(mpi_comm);
auto counts = std::vector<int>(comm.size());
comm.allGather(static_cast<int>(sendBuffer.size()), counts.begin(), counts.end());

auto disps = std::vector<int>{};
disps.reserve(mpi::comm().size() + 1);
disps.reserve(comm.size() + 1);
disps.push_back(0);
std::partial_sum(counts.begin(), counts.end(), std::back_inserter(disps));


auto recvBuffer = std::vector<uidx_t>(static_cast<size_t>(disps.back()));

mpi::comm().allGatherv(sendBuffer.begin(), sendBuffer.end(), recvBuffer.begin(), counts.data(), disps.data());
comm.allGatherv(sendBuffer.begin(), sendBuffer.end(), recvBuffer.begin(), counts.data(), disps.data());

return std::make_pair(recvBuffer, disps);
}
Expand All @@ -175,18 +177,22 @@ bool operator<(const uidx_t& lhs, const IdxUid& rhs) {

// Find the intersection between local and global UIDs, then return local
// indices of incections and PE dispacements in vector.
std::pair<std::vector<idx_t>, std::vector<int>> getUidIntersection(const std::vector<IdxUid>& localUids,
std::pair<std::vector<idx_t>, std::vector<int>> getUidIntersection(const std::string& mpi_comm,
const std::vector<IdxUid>& localUids,
const std::vector<uidx_t>& globalUids,
const std::vector<int>& globalDisps) {
auto uidIntersection = std::vector<IdxUid>{};
uidIntersection.reserve(localUids.size());

auto& comm = mpi::comm(mpi_comm);
auto mpi_size = comm.size();

auto disps = std::vector<int>{};
disps.reserve(mpi::comm().size() + 1);
disps.reserve(mpi_size + 1);
disps.push_back(0);

// Loop over all PE and find UID intersection.
for (size_t i = 0; i < mpi::comm().size(); ++i) {
for (size_t i = 0; i < mpi_size; ++i) {
// Get displaced iterators.
auto globalUidsBegin = globalUids.begin() + globalDisps[i];
auto globalUidsEnd = globalUids.begin() + globalDisps[i + 1];
Expand Down Expand Up @@ -255,21 +261,25 @@ struct ForEach<Rank, Rank> {
} // namespace

void RedistributeGeneric::do_setup() {
ATLAS_ASSERT( source().mpi_comm() == target().mpi_comm() );

mpi_comm_ = source().mpi_comm();

// get a unique ID (UID) for each owned member of functionspace.
const auto sourceUidVec = getUidVec(source());
const auto targetUidVec = getUidVec(target());

// Communicate UID vectors to all PEs.
auto sourceGlobalUids = std::vector<uidx_t>{};
auto sourceGlobalDisps = std::vector<int>{};
std::tie(sourceGlobalUids, sourceGlobalDisps) = communicateUid(getUidVal(sourceUidVec));
std::tie(sourceGlobalUids, sourceGlobalDisps) = communicateUid(mpi_comm_, getUidVal(sourceUidVec));
auto targetGlobalUids = std::vector<uidx_t>{};
auto targetGlobalDisps = std::vector<int>{};
std::tie(targetGlobalUids, targetGlobalDisps) = communicateUid(getUidVal(targetUidVec));
std::tie(targetGlobalUids, targetGlobalDisps) = communicateUid(mpi_comm_, getUidVal(targetUidVec));

// Get intersection of local UIDs and Global UIDs.
std::tie(sourceLocalIdx_, sourceDisps_) = getUidIntersection(sourceUidVec, targetGlobalUids, targetGlobalDisps);
std::tie(targetLocalIdx_, targetDisps_) = getUidIntersection(targetUidVec, sourceGlobalUids, sourceGlobalDisps);
std::tie(sourceLocalIdx_, sourceDisps_) = getUidIntersection(mpi_comm_, sourceUidVec, targetGlobalUids, targetGlobalDisps);
std::tie(targetLocalIdx_, targetDisps_) = getUidIntersection(mpi_comm_, targetUidVec, sourceGlobalUids, sourceGlobalDisps);
}

void RedistributeGeneric::execute(const Field& sourceField, Field& targetField) const {
Expand Down Expand Up @@ -369,6 +379,9 @@ void RedistributeGeneric::do_execute(const Field& sourceField, Field& targetFiel
auto sourceView = array::make_view<Value, Rank>(sourceField);
auto targetView = array::make_view<Value, Rank>(targetField);

const auto& comm = mpi::comm(mpi_comm_);
auto mpi_size = comm.size();

// Get number of elems per column.
int elemsPerCol = 1;
for (int i = 1; i < Rank; ++i) {
Expand All @@ -377,18 +390,18 @@ void RedistributeGeneric::do_execute(const Field& sourceField, Field& targetFiel

// Set send displacement and counts vectors.
auto sendDisps = std::vector<int>{};
sendDisps.reserve(mpi::comm().size() + 1);
sendDisps.reserve(mpi_size + 1);
auto sendCounts = std::vector<int>{};
sendCounts.reserve(mpi::comm().size());
sendCounts.reserve(mpi_size);
std::transform(sourceDisps_.begin(), sourceDisps_.end(), std::back_inserter(sendDisps),
[&](const int& disp) { return disp * elemsPerCol; });
std::adjacent_difference(sendDisps.begin() + 1, sendDisps.end(), std::back_inserter(sendCounts));

// Set recv displacement and counts vectors.
auto recvDisps = std::vector<int>{};
recvDisps.reserve(mpi::comm().size() + 1);
recvDisps.reserve(mpi_size + 1);
auto recvCounts = std::vector<int>{};
recvCounts.reserve(mpi::comm().size());
recvCounts.reserve(mpi_size);
std::transform(targetDisps_.begin(), targetDisps_.end(), std::back_inserter(recvDisps),
[&](const int& disp) { return disp * elemsPerCol; });
std::adjacent_difference(recvDisps.begin() + 1, recvDisps.end(), std::back_inserter(recvCounts));
Expand All @@ -403,8 +416,8 @@ void RedistributeGeneric::do_execute(const Field& sourceField, Field& targetFiel
ForEach<Rank>::apply(sourceLocalIdx_, sourceView, [&](const Value& elem) { *sendBufferIt++ = elem; });

// Perform MPI communication.
mpi::comm().allToAllv(sendBuffer.data(), sendCounts.data(), sendDisps.data(), recvBuffer.data(), recvCounts.data(),
recvDisps.data());
comm.allToAllv(sendBuffer.data(), sendCounts.data(), sendDisps.data(), recvBuffer.data(), recvCounts.data(),
recvDisps.data());

// Copy recvBuffer to targetField.
ForEach<Rank>::apply(targetLocalIdx_, targetView, [&](Value& elem) { elem = *recvBufferIt++; });
Expand Down
4 changes: 4 additions & 0 deletions src/atlas/redistribution/detail/RedistributeGeneric.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#pragma once

#include <string>

#include "atlas/redistribution/detail/RedistributionImpl.h"

namespace atlas {
Expand Down Expand Up @@ -46,6 +48,8 @@ class RedistributeGeneric : public RedistributionImpl {

// Partial sum of number of columns to receive from each PE.
std::vector<int> targetDisps_{};

std::string mpi_comm_;
};

} // namespace detail
Expand Down
20 changes: 14 additions & 6 deletions src/atlas/redistribution/detail/RedistributeStructuredColumns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ void RedistributeStructuredColumns::do_setup() {
// Check levels match.
ATLAS_ASSERT(source_.levels() == target_.levels());

// Check that communicators match.
ATLAS_ASSERT(source_.mpi_comm() == target_.mpi_comm());
mpi_comm_ = source_.mpi_comm();


// Get source and target range of this function space.
const auto sourceRange = StructuredIndexRange(source_);
Expand Down Expand Up @@ -225,8 +229,8 @@ void RedistributeStructuredColumns::do_execute(const Field& sourceField, Field&
forEachIndex(sendIntersections_, sendFunctor);

// Communicate.
mpi::comm().allToAllv(sendBuffer.data(), sendCounts_.data(), sendDisplacements_.data(), recvBuffer.data(),
recvCounts_.data(), recvDisplacements_.data());
mpi::comm(mpi_comm_).allToAllv(sendBuffer.data(), sendCounts_.data(), sendDisplacements_.data(), recvBuffer.data(),
recvCounts_.data(), recvDisplacements_.data());

// Read data from buffer.
forEachIndex(recvIntersections_, recvFunctor);
Expand All @@ -246,19 +250,23 @@ StructuredIndexRange::StructuredIndexRange(const functionspace::StructuredColumn
iBeginEnd_.push_back(std::make_pair(structuredColumns.i_begin(j), structuredColumns.i_end(j)));
}

mpi_comm_ = structuredColumns.mpi_comm();

return;
}

// Get index ranges from all PEs.
StructuredIndexRangeVector StructuredIndexRange::getStructuredIndexRanges() const {
auto& comm = mpi::comm(mpi_comm());

// Get MPI communicator size.
const auto mpiSize = static_cast<size_t>(atlas::mpi::comm().size());
const auto mpiSize = static_cast<size_t>(comm.size());

// Set recv buffer for j range.
auto jRecvBuffer = idxPairVector(mpiSize);

// Perform all gather.
atlas::mpi::comm().allGather(jBeginEnd_, jRecvBuffer.begin(), jRecvBuffer.end());
comm.allGather(jBeginEnd_, jRecvBuffer.begin(), jRecvBuffer.end());

// Set i receive counts.
auto iRecvCounts = transformVector<int>(
Expand All @@ -272,8 +280,8 @@ StructuredIndexRangeVector StructuredIndexRange::getStructuredIndexRanges() cons
auto irecvBuffer = idxPairVector(static_cast<size_t>(iRecvDisplacements.back() + iRecvCounts.back()));

// Perform all gather.
atlas::mpi::comm().allGatherv(iBeginEnd_.cbegin(), iBeginEnd_.cend(), irecvBuffer.begin(), iRecvCounts.data(),
iRecvDisplacements.data());
comm.allGatherv(iBeginEnd_.cbegin(), iBeginEnd_.cend(), irecvBuffer.begin(), iRecvCounts.data(),
iRecvDisplacements.data());

// Make vector of indexRange structs.
auto indexRanges = StructuredIndexRangeVector{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class RedistributeStructuredColumns : public RedistributionImpl {
std::vector<int> sendDisplacements_{};
std::vector<int> recvCounts_{};
std::vector<int> recvDisplacements_{};

std::string mpi_comm_;
};

/// \brief Helper class for function space intersections.
Expand All @@ -116,12 +118,16 @@ class StructuredIndexRange {
template <typename functorType>
void forEach(const functorType&) const;

const std::string& mpi_comm() const { return mpi_comm_; }

private:
// Begin and end of j range.
idxPair jBeginEnd_{};

// Begin and end of i range for each j.
idxPairVector iBeginEnd_{};

std::string mpi_comm_;
};

} // namespace detail
Expand Down
Loading

0 comments on commit 49a35f7

Please sign in to comment.