Skip to content

Commit

Permalink
use separate comm
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Oct 23, 2024
1 parent e60b82c commit 4d49ce5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
42 changes: 24 additions & 18 deletions src/bvals/comms/combined_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@

namespace parthenon {

CombinedBuffersRank::CombinedBuffersRank(int o_rank, BoundaryType b_type, bool send)
: other_rank(o_rank), b_type(b_type), sender(send), buffers_built(false) {
CombinedBuffersRank::CombinedBuffersRank(int o_rank, BoundaryType b_type, bool send,
mpi_comm_t comm)
: other_rank(o_rank), b_type(b_type), sender(send), buffers_built(false),
comm_(comm) {

int tag = 1234 + static_cast<int>(GetAssociatedSender(b_type));
if (sender) {
Expand Down Expand Up @@ -72,7 +74,7 @@ bool CombinedBuffersRank::TryReceiveBufInfo(Mesh *pmesh) {
const int nbuf = mess_buf[idx++];
const int total_size = mess_buf[idx++];
combined_buffers[partition] =
CommBuffer<buf_t>(913 + partition, other_rank, Globals::my_rank, comm_);
CommBuffer<buf_t>(partition, other_rank, Globals::my_rank, comm_);
combined_buffers[partition].ConstructBuffer("combined recv buffer", total_size);
auto &cr_info = combined_info[partition];
auto &bufs = buffers[partition];
Expand Down Expand Up @@ -140,7 +142,7 @@ void CombinedBuffersRank::ResolveSendBuffersAndSendInfo(Mesh *pmesh) {
// Allocate the combined buffers
for (auto &[partition, size] : current_size) {
combined_buffers[partition] =
CommBuffer<buf_t>(913 + partition, Globals::my_rank, other_rank, comm_);
CommBuffer<buf_t>(partition, Globals::my_rank, other_rank, comm_);
combined_buffers[partition].ConstructBuffer("combined send buffer", size);
}

Expand Down Expand Up @@ -284,8 +286,11 @@ void CombinedBuffers::AddSendBuffer(int partition, MeshBlock *pmb,
const std::shared_ptr<Variable<Real>> &var,
BoundaryType b_type) {
if (combined_send_buffers.count({nb.rank, b_type}) == 0)
combined_send_buffers[{nb.rank, b_type}] = CombinedBuffersRank(nb.rank, b_type, true);
combined_send_buffers[{nb.rank, b_type}].AddSendBuffer(partition, pmb, nb, var, b_type);
combined_send_buffers.emplace(
std::make_pair(std::make_pair(nb.rank, b_type),
CombinedBuffersRank(nb.rank, b_type, true, comm_)));
combined_send_buffers.at({nb.rank, b_type})
.AddSendBuffer(partition, pmb, nb, var, b_type);
}

void CombinedBuffers::AddRecvBuffer(MeshBlock *pmb, const NeighborBlock &nb,
Expand All @@ -295,8 +300,9 @@ void CombinedBuffers::AddRecvBuffer(MeshBlock *pmb, const NeighborBlock &nb,
// know that it's existence implies that we need to receive a message from the
// neighbor block rank eventually telling us the details
if (combined_recv_buffers.count({nb.rank, b_type}) == 0)
combined_recv_buffers[{nb.rank, b_type}] =
CombinedBuffersRank(nb.rank, b_type, false);
combined_recv_buffers.emplace(
std::make_pair(std::make_pair(nb.rank, b_type),
CombinedBuffersRank(nb.rank, b_type, false, comm_)));
}

void CombinedBuffers::ResolveAndSendSendBuffers(Mesh *pmesh) {
Expand Down Expand Up @@ -325,7 +331,7 @@ bool CombinedBuffers::IsAvailableForWrite(int partition, BoundaryType b_type) {
for (int rank = 0; rank < Globals::nranks; ++rank) {
if (combined_send_buffers.count({rank, b_type})) {
available = available &&
combined_send_buffers[{rank, b_type}].IsAvailableForWrite(partition);
combined_send_buffers.at({rank, b_type}).IsAvailableForWrite(partition);
}
}
return available;
Expand All @@ -334,7 +340,7 @@ bool CombinedBuffers::IsAvailableForWrite(int partition, BoundaryType b_type) {
void CombinedBuffers::PackAndSend(int partition, BoundaryType b_type) {
for (int rank = 0; rank < Globals::nranks; ++rank) {
if (combined_send_buffers.count({rank, b_type})) {
combined_send_buffers[{rank, b_type}].PackAndSend(partition);
combined_send_buffers.at({rank, b_type}).PackAndSend(partition);
}
}
}
Expand All @@ -343,15 +349,15 @@ void CombinedBuffers::RepointSendBuffers(Mesh *pmesh, int partition,
BoundaryType b_type) {
for (int rank = 0; rank < Globals::nranks; ++rank) {
if (combined_send_buffers.count({rank, b_type}))
combined_send_buffers[{rank, b_type}].RepointBuffers(pmesh, partition);
combined_send_buffers.at({rank, b_type}).RepointBuffers(pmesh, partition);
}
}

void CombinedBuffers::RepointRecvBuffers(Mesh *pmesh, int partition,
BoundaryType b_type) {
for (int rank = 0; rank < Globals::nranks; ++rank) {
if (combined_recv_buffers.count({rank, b_type}))
combined_recv_buffers[{rank, b_type}].RepointBuffers(pmesh, partition);
combined_recv_buffers.at({rank, b_type}).RepointBuffers(pmesh, partition);
}
}

Expand All @@ -361,22 +367,22 @@ void CombinedBuffers::TryReceiveAny(Mesh *pmesh, BoundaryType b_type) {
int flag;
do {
// TODO(LFR): Switch to a different communicator for each BoundaryType
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, &status);
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, comm_, &flag, &status);
if (flag) {
const int rank = status.MPI_SOURCE;
const int partition = status.MPI_TAG - 913;
const int partition = status.MPI_TAG;
bool finished =
combined_recv_buffers[{rank, b_type}].TryReceiveAndUnpack(pmesh, partition);
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition);
if (!finished) processing_messages.insert({rank, partition});
}
} while (flag);

// Process in flight messages
std::set<std::pair<int, int>> finished_messages;
std::vector<std::pair<int, int>> finished_messages;
for (auto &[rank, partition] : processing_messages) {
bool finished =
combined_recv_buffers[{rank, b_type}].TryReceiveAndUnpack(pmesh, partition);
if (finished) finished_messages.insert({rank, partition});
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition);
if (finished) finished_messages.push_back({rank, partition});
}

for (auto &m : finished_messages)
Expand Down
24 changes: 17 additions & 7 deletions src/bvals/comms/combined_buffers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,10 @@ struct CombinedBuffersRank {
using com_buf_t = CommBuffer<std::vector<int>>;
com_buf_t message;

#ifdef MPI_PARALLEL
mpi_comm_t comm_{MPI_COMM_WORLD};
#else
mpi_comm_t comm_{0};
#endif
mpi_comm_t comm_;

bool sender{true};
CombinedBuffersRank() = default;
CombinedBuffersRank(int o_rank, BoundaryType b_type, bool send);
CombinedBuffersRank(int o_rank, BoundaryType b_type, bool send, mpi_comm_t comm);

void AddSendBuffer(int partition, MeshBlock *pmb, const NeighborBlock &nb,
const std::shared_ptr<Variable<Real>> &var, BoundaryType b_type);
Expand Down Expand Up @@ -95,6 +90,21 @@ struct CombinedBuffers {

std::set<std::pair<int, int>> processing_messages;

mpi_comm_t comm_;
CombinedBuffers() {
#ifdef MPI_PARALLEL
PARTHENON_MPI_CHECK(MPI_Comm_dup(MPI_COMM_WORLD, &comm_));
#else
comm_ = 0;
#endif
}

~CombinedBuffers() {
#ifdef MPI_PARALLEL
PARTHENON_MPI_CHECK(MPI_Comm_free(&comm_));
#endif
}

void clear() {
// TODO(LFR): Need to be careful here that the asynchronous send buffers are finished
combined_send_buffers.clear();
Expand Down

0 comments on commit 4d49ce5

Please sign in to comment.