diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index fd234bd2d744..994071f7ec18 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -112,6 +112,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/loop.o \ + $(PKGROOT)/src/collective/topo.o \ $(PKGROOT)/src/collective/socket.o \ $(PKGROOT)/src/common/charconv.o \ $(PKGROOT)/src/common/column_matrix.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 924fbb6010c3..cb8ba1528510 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -112,6 +112,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/loop.o \ + $(PKGROOT)/src/collective/topo.o \ $(PKGROOT)/src/collective/socket.o \ $(PKGROOT)/src/common/charconv.o \ $(PKGROOT)/src/common/column_matrix.o \ diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index a025edddd409..c368189919f5 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -3,6 +3,7 @@ */ #pragma once +#include // for max #include // errno, EINTR, EBADF #include // HOST_NAME_MAX #include // std::size_t @@ -539,7 +540,8 @@ class TCPSocket { /** * @brief Listen to incoming requests. Should be called after bind. */ - [[nodiscard]] Result Listen(std::int32_t backlog = 16) { + [[nodiscard]] Result Listen(std::int32_t backlog) { + backlog = std::max(backlog, 16); // Don't be too small. if (listen(handle_, backlog) != 0) { return system::FailWithCode("Failed to listen."); } diff --git a/src/collective/allgather.h b/src/collective/allgather.h index ca44c3916cc3..133d4fb84b66 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -13,6 +13,7 @@ #include "../common/type.h" // for EraseType #include "comm.h" // for Comm, Channel #include "comm_group.h" // for CommGroup +#include "topo.h" // for BootstrapNext, BootstrapPrev #include "xgboost/collective/result.h" // for Result #include "xgboost/linalg.h" // for MakeVec #include "xgboost/span.h" // for Span diff --git a/src/collective/broadcast.cc b/src/collective/broadcast.cc index e1ef60f86847..13a6eb75d268 100644 --- a/src/collective/broadcast.cc +++ b/src/collective/broadcast.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "broadcast.h" @@ -7,45 +7,13 @@ #include // for int32_t, int8_t #include // for move -#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32 + #include "comm.h" // for Comm #include "xgboost/collective/result.h" // for Result #include "xgboost/span.h" // for Span +#include "topo.h" // for ParentRank namespace xgboost::collective::cpu_impl { -namespace { -std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) { - std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff... - RBitField32 maskbits{common::Span{&mask, 1}}; - RBitField32 rankbits{ - common::Span{reinterpret_cast(&shifted_rank), 1}}; - // prepare for counting trailing zeros. - for (std::int32_t i = 0; i < depth + 1; ++i) { - if (rankbits.Check(i)) { - maskbits.Set(i); - } else { - maskbits.Clear(i); - } - } - - CHECK_NE(mask, 0); - auto k = TrailingZeroBits(mask); - auto shifted_parent = shifted_rank - (1 << k); - return shifted_parent; -} - -// Shift the root node to rank 0 -std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) { - auto shifted_rank = (rank + world - root) % world; - return shifted_rank; -} -// shift back to the original rank -std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) { - auto orig = (rank + root) % world; - return orig; -} -} // namespace - Result Broadcast(Comm const& comm, common::Span data, std::int32_t root) { // Binomial tree broadcast // * Wiki @@ -56,28 +24,47 @@ Result Broadcast(Comm const& comm, common::Span data, std::int32_t auto rank = comm.Rank(); auto world = comm.World(); - // shift root to rank 0 - auto shifted_rank = ShiftLeft(rank, world, root); + // Send data to the root to preserve the topology. Alternative is to shift the rank, but + // it requires a all-to-all connection. + // + // Most of the use of broadcasting in XGBoost are short messages, this should be + // fine. Otherwise, we can implement a linear pipeline broadcast. + if (root != 0) { + auto rc = Success() << [&] { + return (rank == 0) ? comm.Chan(root)->RecvAll(data) : Success(); + } << [&] { + return (rank == root) ? comm.Chan(0)->SendAll(data) : Success(); + } << [&] { + return comm.Block(); + }; + if (!rc.OK()) { + return Fail("Broadcast failed to send data to root.", std::move(rc)); + } + root = 0; + } + std::int32_t depth = std::ceil(std::log2(static_cast(world))) - 1; - if (shifted_rank != 0) { // not root - auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root); - auto rc = Success() << [&] { return comm.Chan(parent)->RecvAll(data); } - << [&] { return comm.Chan(parent)->Block(); }; + if (rank != 0) { // not root + auto parent = ParentRank(rank, depth); + auto rc = Success() << [&] { + return comm.Chan(parent)->RecvAll(data); + } << [&] { + return comm.Chan(parent)->Block(); + }; if (!rc.OK()) { - return Fail("broadcast failed.", std::move(rc)); + return Fail("Broadcast failed to send data to parent.", std::move(rc)); } } for (std::int32_t i = depth; i >= 0; --i) { CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative - if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) { - auto sft_peer = shifted_rank + (1 << i); - auto peer = ShiftRight(sft_peer, world, root); + if (rank % (1 << (i + 1)) == 0 && rank + (1 << i) < world) { + auto peer = rank + (1 << i); CHECK_NE(peer, root); auto rc = comm.Chan(peer)->SendAll(data); if (!rc.OK()) { - return rc; + return Fail("Failed to seed to " + std::to_string(peer), std::move(rc)); } } } diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 32631442b88f..e64ed276f6d4 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -16,6 +16,7 @@ #endif // !defined(XGBOOST_USE_NCCL) #include "allgather.h" // for RingAllgather #include "protocol.h" // for kMagic +#include "topo.h" // for BootstrapNext #include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE #include "xgboost/collective/socket.h" // for TCPSocket #include "xgboost/json.h" // for Json, Object @@ -58,6 +59,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st this->Rank(), this->World()); } +// Connect ring and tree neighbors [[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport, proto::PeerInfo ninfo, std::chrono::seconds timeout, std::int32_t retry, @@ -80,10 +82,10 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st return prev->NonBlocking(true); }; if (!rc.OK()) { - return rc; + return Fail("Bootstrap failed to recv from ring prev.", std::move(rc)); } - // exchange host name and port + // Exchange host name and port std::vector buffer(HOST_NAME_MAX * comm.World(), 0); auto s_buffer = common::Span{buffer.data(), buffer.size()}; auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX); @@ -107,7 +109,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st rc = std::move(rc) << [&] { return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch); - } << [&] { return block(); }; + } << [&] { + return block(); + }; if (!rc.OK()) { return Fail("Failed to get host names from peers.", std::move(rc)); } @@ -118,7 +122,9 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st auto s_ports = common::Span{reinterpret_cast(peers_port.data()), peers_port.size() * sizeof(ninfo.port)}; return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch); - } << [&] { return block(); }; + } << [&] { + return block(); + }; if (!rc.OK()) { return Fail("Failed to get the port from peers.", std::move(rc)); } @@ -138,55 +144,94 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st std::vector>& workers = *out_workers; workers.resize(comm.World()); - - for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) { - auto const& peer = peers[r]; - auto worker = std::make_shared(); - rc = std::move(rc) - << [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); } - << [&] { return worker->RecvTimeout(timeout); }; - if (!rc.OK()) { - return rc; - } - - auto rank = comm.Rank(); - std::size_t n_bytes{0}; - auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes); - if (!rc.OK()) { - return rc; - } else if (n_bytes != sizeof(comm.Rank())) { - return Fail("Failed to send rank.", std::move(rc)); + workers[BootstrapNext(comm.Rank(), comm.World())] = next; + if (BootstrapNext(comm.Rank(), comm.World()) == BootstrapPrev(comm.Rank(), comm.World())) { + if (comm.Rank() == 0) { + if (comm.World() == 2) { + workers[BootstrapNext(comm.Rank(), comm.World())] = prev; + } else { + CHECK_EQ(comm.World(), 1); + } } - workers[r] = std::move(worker); + } else { + workers[BootstrapPrev(comm.Rank(), comm.World())] = prev; } - for (std::int32_t r = 0; r < comm.Rank(); ++r) { - auto peer = std::make_shared(); - rc = std::move(rc) << [&] { + /** + * Construct tree. + */ + // All workers connect to rank 0 so that we can always use rank 0 as broadcast root. + if (comm.Rank() == 0) { + for (std::int32_t i = 0; i < comm.World() - 3; ++i) { + auto worker = std::make_shared(); SockAddress addr; - return listener->Accept(peer.get(), &addr); - } << [&] { - return peer->RecvTimeout(timeout); - }; - if (!rc.OK()) { - return rc; + rc = listener->Accept(worker.get(), &addr); + if (!rc.OK()) { + return Fail("Failed to accept for rank 0.", std::move(rc)); + } + std::int32_t r{-1}; + std::size_t n_bytes{0}; + rc = worker->RecvAll(&r, sizeof(r), &n_bytes); + if (!rc.OK()) { + return Fail("Failed to recv rank.", std::move(rc)); + } + if (n_bytes != sizeof(r)) { + return Fail("Failed to recv rank due to size.", std::move(rc)); + } + workers[r] = worker; } - std::int32_t rank{-1}; - std::size_t n_bytes{0}; - auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes); - if (!rc.OK()) { - return rc; - } else if (n_bytes != sizeof(comm.Rank())) { - return Fail("Failed to recv rank."); + } else { + if (!workers[0]) { + auto worker = std::make_shared(); + rc = std::move(rc) << [&] { + return Connect(peers[0].host, peers[0].port, retry, timeout, worker.get()); + } << [&] { + auto rank = comm.Rank(); + std::size_t n_bytes = 0; + auto rc = worker->SendAll(&rank, sizeof(rank), &n_bytes); + if (n_bytes != sizeof(rank)) { + return Fail("Failed to send rank due to size.", std::move(rc)); + } + return rc; + }; + if (!rc.OK()) { + return Fail("Failed to connect to root.", std::move(rc)); + } + workers[0] = worker; + } + } + // Binomial tree connect + std::int32_t const kDepth = std::ceil(std::log2(static_cast(comm.World()))) - 1; + if (comm.Rank() != 0) { + auto prank = ParentRank(comm.Rank(), kDepth); + if (!workers[prank]) { // Skip if it's part of the ring. + auto parent = std::make_shared(); + SockAddress addr; + rc = listener->Accept(parent.get(), &addr); + if (!rc.OK()) { + return Fail("Failed to recv connection from tree parent.", std::move(rc)); + } + workers[prank] = parent; } - workers[rank] = std::move(peer); } - for (std::int32_t r = 0; r < comm.World(); ++r) { - if (r == comm.Rank()) { - continue; + for (std::int32_t i = kDepth; i >= 0; --i) { + if (comm.Rank() % (1 << (i + 1)) == 0 && comm.Rank() + (1 << i) < comm.World()) { + auto peer = comm.Rank() + (1 << i); + if (workers[peer]) { // skip if it's part of the ring. + continue; + } + auto worker = std::make_shared(); + rc = std::move(rc) << [&] { + return Connect(peers[peer].host, peers[peer].port, retry, timeout, worker.get()); + } << [&] { + return worker->RecvTimeout(timeout); + }; + if (!rc.OK()) { + return Fail("Failed to connect to tree neighbor", std::move(rc)); + } + workers[peer] = worker; } - CHECK(workers[r]); } return Success(); @@ -230,6 +275,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { [[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry, std::string task_id) { + common::Timer t; + t.Start(); + TCPSocket tracker; std::int32_t world{-1}; auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(), @@ -243,11 +291,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // Start command TCPSocket listener = TCPSocket::Create(tracker.Domain()); std::int32_t lport{0}; - rc = std::move(rc) << [&] { - return listener.BindHost(&lport); - } << [&] { - return listener.Listen(); - }; + rc = listener.BindHost(&lport); if (!rc.OK()) { return rc; } @@ -259,8 +303,8 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { rc = std::move(rc) << [&] { return error_sock->BindHost(&eport); } << [&] { - return error_sock->Listen(); - }; + return error_sock->Listen(4); + };; if (!rc.OK()) { return rc; } @@ -304,8 +348,15 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { error_worker_.detach(); proto::Start start; - rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); } - << [&] { return start.WorkerRecv(&tracker, &world); }; + rc = std::move(rc) << [&] { + return start.WorkerSend(lport, &tracker, eport); + } << [&] { + return start.WorkerRecv(&tracker, &world); + } << [&] { + return listener.Listen(world); + } << [&] { + return start.WorkerFinish(&tracker); + }; if (!rc.OK()) { return rc; } @@ -347,6 +398,8 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { this->channels_.emplace_back(std::make_shared(*this, w)); } + t.Stop(); + LOG(DEBUG) << "Bootstrap took:" << t.ElapsedSeconds() << " secs."; LOG(CONSOLE) << InitLog(task_id_, rank_); return rc; } diff --git a/src/collective/comm.h b/src/collective/comm.h index 72fec2e816e9..d194b17abb54 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -20,21 +20,10 @@ namespace xgboost::collective { -inline constexpr std::int64_t DefaultTimeoutSec() { return 60 * 30; } // 30min -inline constexpr std::int32_t DefaultRetry() { return 3; } +constexpr std::int64_t DefaultTimeoutSec() { return 60 * 30; } // 30min +constexpr std::int32_t DefaultRetry() { return 3; } -// indexing into the ring -inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) { - auto nrank = (r + world + 1) % world; - return nrank; -} - -inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) { - auto nrank = (r + world - 1) % world; - return nrank; -} - -inline StringView DefaultNcclName() { return "libnccl.so.2"; } +constexpr StringView DefaultNcclName() { return "libnccl.so.2"; } class Channel; class Coll; diff --git a/src/collective/protocol.h b/src/collective/protocol.h index 2222594033f3..d87fbeace1a0 100644 --- a/src/collective/protocol.h +++ b/src/collective/protocol.h @@ -130,6 +130,17 @@ class Start { } return Success(); } + // Ensure the worker has started to listen before bootstrapping the coll group. + [[nodiscard]] Result WorkerFinish(TCPSocket* tracker) { + Json jcmd{Object{}}; + jcmd["done"] = true; + auto scmd = Json::Dump(jcmd); + auto n_bytes = tracker->Send(scmd); + if (n_bytes != scmd.size()) { + return Fail("Failed to send init command from worker."); + } + return Success(); + } [[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const { std::string scmd; auto rc = tracker->Recv(&scmd); @@ -158,6 +169,14 @@ class Start { *eport = get(jcmd["error_port"]); return TrackerSend(world, p_sock); } + + [[nodiscard]] Result TrackerFinish(Json jcmd) { + auto it = get(jcmd).find("done"); + if (IsA(it->second) && get(it->second)) { + return Success(); + } + return Fail("Failed to start."); + } }; // Protocol for communicating with the tracker for printing message. diff --git a/src/collective/topo.cc b/src/collective/topo.cc new file mode 100644 index 000000000000..385e2abec74e --- /dev/null +++ b/src/collective/topo.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2023-2024, XGBoost Contributors + */ +#include "topo.h" + +#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32 +namespace xgboost::collective { +std::int32_t ParentRank(std::int32_t rank, std::int32_t depth) { + std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff... + RBitField32 maskbits{common::Span{&mask, 1}}; + RBitField32 rankbits{common::Span{reinterpret_cast(&rank), 1}}; + // prepare for counting trailing zeros. + for (std::int32_t i = 0; i < depth + 1; ++i) { + if (rankbits.Check(i)) { + maskbits.Set(i); + } else { + maskbits.Clear(i); + } + } + + CHECK_NE(mask, 0); + auto k = TrailingZeroBits(mask); + auto parent = rank - (1 << k); + return parent; +} +} // namespace xgboost::collective diff --git a/src/collective/topo.h b/src/collective/topo.h new file mode 100644 index 000000000000..c78896848e21 --- /dev/null +++ b/src/collective/topo.h @@ -0,0 +1,19 @@ +/** + * Copyright 2023-2024, XGBoost Contributors + */ +#pragma once +#include // for int32_t + +namespace xgboost::collective { +inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) { + auto nrank = (r + world + 1) % world; + return nrank; +} + +inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) { + auto nrank = (r + world - 1) % world; + return nrank; +} + +std::int32_t ParentRank(std::int32_t rank, std::int32_t depth); +} // namespace xgboost::collective diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index b1081fe8e789..480ffff4e5a5 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -2,6 +2,8 @@ * Copyright 2023-2024, XGBoost Contributors */ +#include "tracker.h" + #if defined(__unix__) || defined(__APPLE__) #include // gethostbyname #include // socket, AF_INET6, AF_INET, connect, getsockname @@ -23,10 +25,10 @@ #include // for move, forward #include "../common/json_utils.h" -#include "../common/threading_utils.h" // for NameThread -#include "comm.h" -#include "protocol.h" // for kMagic, PeerInfo -#include "tracker.h" +#include "../common/threading_utils.h" // for NameThread +#include "../common/timer.h" // for Timer +#include "protocol.h" // for kMagic, PeerInfo +#include "topo.h" // for BootstrapNext #include "xgboost/collective/poll_utils.h" // for PollHelper #include "xgboost/collective/result.h" // for Result, Fail, Success #include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ... @@ -89,7 +91,15 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA } << [&] { if (cmd_ == proto::CMD::kStart) { proto::Start start; - return start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_); + std::string cmd1; + return Success() << [&] { + return start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_); + } << [&] { + return sock_.Recv(&cmd1); + } << [&] { + auto jcmd1 = Json::Load(StringView{cmd1}); + return start.TrackerFinish(jcmd1); + }; } else if (cmd_ == proto::CMD::kPrint) { proto::Print print; return print.TrackerHandle(jcmd, &msg_); @@ -123,7 +133,8 @@ RabitTracker::RabitTracker(Json const& config) : Tracker{config} { listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6); return listener_.Bind(host_, &this->port_); } << [&] { - return listener_.Listen(); + CHECK_GT(this->n_workers_, 0); + return listener_.Listen(this->n_workers_); }; SafeColl(rc); } @@ -155,6 +166,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { for (auto const& w : workers) { worker_error_handles_.emplace_back(w.Host(), w.ErrorPort()); } + LOG(CONSOLE) << "[tracker]: Bootstrap " << workers.size() << " workers."; return Success(); } diff --git a/src/learner.cc b/src/learner.cc index 542bf1dc6279..67f07fe91e08 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -23,13 +23,10 @@ #include // for numeric_limits #include // for allocator, unique_ptr, shared_ptr, operator== #include // for mutex, lock_guard -#include // for set #include // for operator<<, basic_ostream, basic_ostream::opera... #include // for stack #include // for basic_string, char_traits, operator<, string #include // for errc -#include // for get -#include // for operator!=, unordered_map #include // for pair, as_const, move, swap #include // for vector diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc index 1b1d73428be1..6bc6e4461af4 100644 --- a/tests/cpp/collective/test_broadcast.cc +++ b/tests/cpp/collective/test_broadcast.cc @@ -40,11 +40,16 @@ class BroadcastTest : public SocketTest {}; } // namespace TEST_F(BroadcastTest, Basic) { - std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency()); - TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, - std::int32_t r) { - Worker worker{host, port, timeout, n_workers, r}; - worker.Run(); - }); + auto test_with = [](std::int32_t n_workers) { + TestDistributed(n_workers, [=](std::string host, std::int32_t port, + std::chrono::seconds timeout, std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.Run(); + }); + }; + for (std::uint32_t n_workers = 1u; n_workers < 4u; ++n_workers) { + n_workers = std::min(n_workers, std::thread::hardware_concurrency()); + test_with(n_workers); + } } } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc index 622b350aaae8..577174908d5f 100644 --- a/tests/cpp/collective/test_loop.cc +++ b/tests/cpp/collective/test_loop.cc @@ -32,7 +32,7 @@ class LoopTest : public ::testing::Test { auto rc = Success() << [&] { return pair_.first.BindHost(&port); } << [&] { - return pair_.first.Listen(); + return pair_.first.Listen(16); }; SafeColl(rc); diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index 8e455d100f0d..9c29a442650f 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -25,7 +25,7 @@ TEST_F(SocketTest, Basic) { auto rc = Success() << [&] { return server.BindHost(&port); } << [&] { - return server.Listen(); + return server.Listen(16); }; SafeColl(rc); diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 4f6dfc1ff6cc..c434d5855e10 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -49,7 +49,7 @@ class WorkerForTest { void LimitSockBuf(std::int32_t n_bytes) { for (std::int32_t i = 0; i < comm_.World(); ++i) { - if (i != comm_.Rank()) { + if (comm_.Chan(i)->Socket()) { ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking()); SafeColl(comm_.Chan(i)->Socket()->SetBufSize(n_bytes)); SafeColl(comm_.Chan(i)->Socket()->SetNoDelay());