Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Avoid all-to-all connection. #10840

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/topo.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/in_memory_handler.o \
$(PKGROOT)/src/collective/loop.o \
$(PKGROOT)/src/collective/topo.o \
$(PKGROOT)/src/collective/socket.o \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
Expand Down
4 changes: 3 additions & 1 deletion include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
#pragma once

#include <algorithm> // for max
#include <cerrno> // errno, EINTR, EBADF
#include <climits> // HOST_NAME_MAX
#include <cstddef> // std::size_t
Expand Down Expand Up @@ -539,7 +540,8 @@ class TCPSocket {
/**
* @brief Listen to incoming requests. Should be called after bind.
*/
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
[[nodiscard]] Result Listen(std::int32_t backlog) {
backlog = std::max(backlog, 16); // Don't be too small.
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
Expand Down
1 change: 1 addition & 0 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "comm_group.h" // for CommGroup
#include "topo.h" // for BootstrapNext, BootstrapPrev
#include "xgboost/collective/result.h" // for Result
#include "xgboost/linalg.h" // for MakeVec
#include "xgboost/span.h" // for Span
Expand Down
79 changes: 33 additions & 46 deletions src/collective/broadcast.cc
Original file line number Diff line number Diff line change
@@ -1,51 +1,19 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "broadcast.h"

#include <cmath> // for ceil, log2
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move

#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32

#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
#include "topo.h" // for ParentRank

namespace xgboost::collective::cpu_impl {
namespace {
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
RBitField32 rankbits{
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
// prepare for counting trailing zeros.
for (std::int32_t i = 0; i < depth + 1; ++i) {
if (rankbits.Check(i)) {
maskbits.Set(i);
} else {
maskbits.Clear(i);
}
}

CHECK_NE(mask, 0);
auto k = TrailingZeroBits(mask);
auto shifted_parent = shifted_rank - (1 << k);
return shifted_parent;
}

// Shift the root node to rank 0
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto shifted_rank = (rank + world - root) % world;
return shifted_rank;
}
// shift back to the original rank
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto orig = (rank + root) % world;
return orig;
}
} // namespace

Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
// Binomial tree broadcast
// * Wiki
Expand All @@ -56,28 +24,47 @@ Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t
auto rank = comm.Rank();
auto world = comm.World();

// shift root to rank 0
auto shifted_rank = ShiftLeft(rank, world, root);
// Send data to the root to preserve the topology. Alternative is to shift the rank, but
// it requires a all-to-all connection.
//
// Most of the use of broadcasting in XGBoost are short messages, this should be
// fine. Otherwise, we can implement a linear pipeline broadcast.
if (root != 0) {
auto rc = Success() << [&] {
return (rank == 0) ? comm.Chan(root)->RecvAll(data) : Success();
} << [&] {
return (rank == root) ? comm.Chan(0)->SendAll(data) : Success();
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return Fail("Broadcast failed to send data to root.", std::move(rc));
}
root = 0;
}

std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;

if (shifted_rank != 0) { // not root
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); }
<< [&] { return comm.Chan(parent)->Block(); };
if (rank != 0) { // not root
auto parent = ParentRank(rank, depth);
auto rc = Success() << [&] {
return comm.Chan(parent)->RecvAll(data);
} << [&] {
return comm.Chan(parent)->Block();
};
if (!rc.OK()) {
return Fail("broadcast failed.", std::move(rc));
return Fail("Broadcast failed to send data to parent.", std::move(rc));
}
}

for (std::int32_t i = depth; i >= 0; --i) {
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
auto sft_peer = shifted_rank + (1 << i);
auto peer = ShiftRight(sft_peer, world, root);
if (rank % (1 << (i + 1)) == 0 && rank + (1 << i) < world) {
auto peer = rank + (1 << i);
CHECK_NE(peer, root);
auto rc = comm.Chan(peer)->SendAll(data);
if (!rc.OK()) {
return rc;
return Fail("Failed to seed to " + std::to_string(peer), std::move(rc));
}
}
}
Expand Down
Loading
Loading