Skip to content

Commit

Permalink
sparse maybe working
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Nov 4, 2024
1 parent 207c2dc commit ac01d92
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 30 deletions.
8 changes: 4 additions & 4 deletions src/bvals/comms/boundary_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,16 @@ TaskStatus SendBoundBufs(std::shared_ptr<MeshData<Real>> &md) {
Kokkos::fence();
#endif

// Send the combined buffers
pmesh->pcombined_buffers->PackAndSend(md.get(), bound_type);

for (int ibuf = 0; ibuf < cache.buf_vec.size(); ++ibuf) {
auto &buf = *cache.buf_vec[ibuf];
if (sending_nonzero_flags_h(ibuf) || !Globals::sparse_config.enabled)
buf.SendLocal();
else
buf.SendNull();
buf.SendNullLocal();
}

// Send the combined buffers
pmesh->pcombined_buffers->PackAndSend(md.get(), bound_type);

return TaskStatus::complete;
}
Expand Down
84 changes: 62 additions & 22 deletions src/bvals/comms/combined_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ namespace parthenon {
void CombinedBuffersRankPartition::AllocateCombinedBuffer() {
int send_rank = sender ? Globals::my_rank : other_rank;
int recv_rank = sender ? other_rank : Globals::my_rank;
combined_comm_buffer = CommBuffer<buf_t>(partition, send_rank, recv_rank, comm_);
combined_comm_buffer = CommBuffer<buf_t>(2 * partition, send_rank, recv_rank, comm_);
combined_comm_buffer.ConstructBuffer("combined send buffer",
current_size); // Actually allocate the thing
current_size + 1); // Actually allocate the thing
sparse_status_buffer = CommBuffer<std::vector<int>>(2 * partition + 1, send_rank, recv_rank, comm_);
sparse_status_buffer.ConstructBuffer(current_size + 1);
//PARTHENON_REQUIRE(current_size > 0, "Are we bigger than zero?");
// Point the BndId objects to the combined buffer
for (auto uid : all_vars) {
for (auto &[bnd_id, pvbbuf] : combined_info_buf.at(uid)) {
Expand Down Expand Up @@ -65,7 +68,10 @@ ParArray1D<BndId> &CombinedBuffersRankPartition::GetBndIdsOnDevice(const std::se
for (auto uid : var_set) {
for (auto &[bnd_id, pvbbuf] : combined_info_buf.at(uid)) {
auto &bid_h = bnd_ids_host[idx];
const bool alloc = pvbbuf->IsActive();
auto buf_state = pvbbuf->GetState();
PARTHENON_REQUIRE(buf_state != BufferState::stale, "Trying to work with a stale buffer.");

const bool alloc = (buf_state == BufferState::sending) || (buf_state == BufferState::received);
// Test if this boundary has changed
if (!bid_h.SameBVChannel(bnd_id) ||
(bid_h.buf_allocated != alloc) ||
Expand Down Expand Up @@ -96,18 +102,35 @@ void CombinedBuffersRankPartition::PackAndSend(const std::set<Uid_t> &vars) {
Kokkos::TeamPolicy<>(parthenon::DevExecSpace(), bids.size(), Kokkos::AUTO),
KOKKOS_LAMBDA(parthenon::team_mbr_t team_member) {
const int b = team_member.league_rank();
const int buf_size = bids[b].size();
Real *com_buf = &(bids[b].combined_buf(bids[b].start_idx()));
Real *buf = &(bids[b].buf(0));
Kokkos::parallel_for(Kokkos::TeamThreadRange<>(team_member, buf_size),
[&](const int idx) { com_buf[idx] = buf[idx]; });
if (bids[b].buf_allocated) {
const int buf_size = bids[b].size();
Real *com_buf = &(bids[b].combined_buf(bids[b].start_idx()));
Real *buf = &(bids[b].buf(0));
Kokkos::parallel_for(Kokkos::TeamThreadRange<>(team_member, buf_size),
[&](const int idx) { com_buf[idx] = buf[idx]; });
}
});
#ifdef MPI_PARALLEL
Kokkos::fence();
#endif
combined_comm_buffer.Send();

// Send the sparse null info as well
if (bids.size() != sparse_status_buffer.buffer().size()) {
sparse_status_buffer.ConstructBuffer(bids.size());
}

const auto &var_set = vars.size() == 0 ? all_vars : vars;
auto &stat = sparse_status_buffer.buffer();
int idx{0};
for (auto uid : var_set) {
for (auto &[bnd_id, pvbbuf] : combined_info_buf.at(uid)) {
stat[idx] = (pvbbuf->GetState() == BufferState::sending);
++idx;
}
}
sparse_status_buffer.Send();

// Information in these send buffers is no longer required
for (auto uid : var_set) {
for (auto &[bnd_id, pvbbuf] : combined_info_buf.at(uid)) {
Expand All @@ -121,19 +144,39 @@ bool CombinedBuffersRankPartition::TryReceiveAndUnpack(mpi_message_t *message,
const std::set<Uid_t> &vars) {
const auto &var_set = vars.size() == 0 ? all_vars : vars;
// Make sure the var-boundary buffers are available to write to
int nbuf{0};
for (auto uid : var_set) {
for (auto &[bnd_id, pvbbuf] : combined_info_buf.at(uid)) {
if (pvbbuf->GetState() != BufferState::stale) return false;
nbuf++;
}
}

if (nbuf != sparse_status_buffer.buffer().size()) {
sparse_status_buffer.ConstructBuffer(nbuf);
}
auto received_sparse = sparse_status_buffer.TryReceive();
auto received = combined_comm_buffer.TryReceive(message);
if (!received) return false;
if (!received || !received_sparse) return false;

// TODO(LFR): Update this to allocate based on second received message
// Allocate and free buffers as required
int idx{0};
auto &stat = sparse_status_buffer.buffer();
for (auto uid : var_set) {
for (auto &[bnd_id, pvbbuf] : combined_info_buf.at(uid)) {
if (!pvbbuf->IsActive()) pvbbuf->Allocate();
if (pvbbuf->IsActive()) {
if (stat[idx] == 0)
pvbbuf->Free();
} else {
if (stat[idx] == 1)
pvbbuf->Allocate();
}
if (stat[idx]) {
pvbbuf->SetReceived();
} else {
pvbbuf->SetReceivedNull();
}
idx++;
}
}

Expand All @@ -143,19 +186,16 @@ bool CombinedBuffersRankPartition::TryReceiveAndUnpack(mpi_message_t *message,
Kokkos::TeamPolicy<>(parthenon::DevExecSpace(), bids.size(), Kokkos::AUTO),
KOKKOS_LAMBDA(parthenon::team_mbr_t team_member) {
const int b = team_member.league_rank();
const int buf_size = bids[b].size();
Real *com_buf = &(bids[b].combined_buf(bids[b].start_idx()));
Real *buf = &(bids[b].buf(0));
Kokkos::parallel_for(Kokkos::TeamThreadRange<>(team_member, buf_size),
[&](const int idx) { buf[idx] = com_buf[idx]; });
if (bids[b].buf_allocated) {
const int buf_size = bids[b].size();
Real *com_buf = &(bids[b].combined_buf(bids[b].start_idx()));
Real *buf = &(bids[b].buf(0));
Kokkos::parallel_for(Kokkos::TeamThreadRange<>(team_member, buf_size),
[&](const int idx) { buf[idx] = com_buf[idx]; });
}
});
combined_comm_buffer.Stale();

for (auto uid : var_set) {
for (auto &[bnd_id, pvbbuf] : combined_info_buf.at(uid)) {
pvbbuf->SetReceived();
}
}
sparse_status_buffer.Stale();

return true;
}
Expand Down
3 changes: 2 additions & 1 deletion src/bvals/comms/combined_buffers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ struct CombinedBuffersRankPartition {
ParArray1D<BndId> bnd_ids_device;
ParArray1D<BndId>::host_mirror_type bnd_ids_host;
CommBuffer<buf_t> combined_comm_buffer;
CommBuffer<std::vector<int>> sparse_status_buffer;
int current_size;

CombinedBuffersRankPartition(bool sender, int partition, int other_rank,
Expand All @@ -72,7 +73,7 @@ struct CombinedBuffersRankPartition {

void AllocateCombinedBuffer();

bool IsAvailableForWrite() { return combined_comm_buffer.IsAvailableForWrite(); }
bool IsAvailableForWrite() { return sparse_status_buffer.IsAvailableForWrite() && combined_comm_buffer.IsAvailableForWrite(); }

ParArray1D<BndId> &GetBndIdsOnDevice(const std::set<Uid_t> &vars);

Expand Down
20 changes: 17 additions & 3 deletions src/utils/communication_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class CommBuffer {
void Send() noexcept;
void SendLocal() noexcept;
void SendNull() noexcept;
void SendNullLocal() noexcept;

bool IsAvailableForWrite();

Expand All @@ -138,6 +139,12 @@ class CommBuffer {
"This doesn't make sense for a non-receiver.");
*state_ = BufferState::received;
}
void SetReceivedNull() noexcept {
PARTHENON_REQUIRE(*comm_type_ == BuffCommType::receiver ||
*comm_type_ == BuffCommType::sparse_receiver,
"This doesn't make sense for a non-receiver.");
*state_ = BufferState::received_null;
}
bool IsSafeToDelete() {
if (*comm_type_ == BuffCommType::sparse_receiver ||
*comm_type_ == BuffCommType::receiver) {
Expand Down Expand Up @@ -263,10 +270,17 @@ void CommBuffer<T>::SendLocal() noexcept {
PARTHENON_DEBUG_REQUIRE(*state_ == BufferState::stale,
"Trying to send from buffer that hasn't been staled.");
*state_ = BufferState::sending;
if (*comm_type_ == BuffCommType::sender) {
// This buffer has been sent in some other way
*state_ = BufferState::stale;
if (*comm_type_ == BuffCommType::receiver) {
// This is an error
PARTHENON_FAIL("Trying to send from a receiver");
}
}

template <class T>
void CommBuffer<T>::SendNullLocal() noexcept {
PARTHENON_DEBUG_REQUIRE(*state_ == BufferState::stale,
"Trying to send from buffer that hasn't been staled.");
*state_ = BufferState::sending_null;
if (*comm_type_ == BuffCommType::receiver) {
// This is an error
PARTHENON_FAIL("Trying to send from a receiver");
Expand Down

0 comments on commit ac01d92

Please sign in to comment.