diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 235ef1a4..31c9151f 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -134,9 +134,9 @@ jobs: if [[ "$MODULE" == "xoscar" ]]; then pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ - --cov-config=setup.cfg --cov-report=xml --cov=xoscar xoscar + --cov-config=setup.cfg --cov-report=xml --cov=xoscar xoscar --capture=no else - pytest -m cuda --cov-config=setup.cfg --cov-report=xml --cov=xoscar + pytest -m cuda --cov-config=setup.cfg --cov-report=xml --cov=xoscar --capture=no fi working-directory: ./python diff --git a/.gitignore b/.gitignore index 1547d4ef..a6c7e7d3 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,9 @@ CMakeFiles CMakeCache.txt *.cmake Makefile + +#config.h +cpp/collective/gloo/include/config.h + +#filestore +python/xoscar/collective/tests/collective \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index fd30701d..45f06e7d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "third_party/pybind11"] path = third_party/pybind11 url = https://github.com/pybind/pybind11.git +[submodule "third_party/gloo"] + path = third_party/gloo + url = https://github.com/facebookincubator/gloo.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fde062e..93f922a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,8 +11,32 @@ endif() set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + option(USE_LIBUV "Build tcp transport on linux" OFF) +else() + option(USE_LIBUV "Build libuv transport on others" ON) +endif() + include_directories(${CMAKE_SOURCE_DIR}) +#find python3 include dir +execute_process(COMMAND python -c "import sysconfig; print(sysconfig.get_path('include'))" + OUTPUT_VARIABLE PYTHON_INCLUDE_PATH) + +# Set include directories +include_directories(${PYTHON_INCLUDE_PATH}) add_subdirectory(third_party/fmt) add_subdirectory(third_party/pybind11) +add_subdirectory(third_party/gloo) + +# set c++11 for gloo +set_target_properties( + gloo + PROPERTIES CXX_STANDARD 11 + CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS OFF) + +# copy config.h to cpp/gloo/include +file(COPY python/${BUILD_TMP_DIR}/third_party/gloo/gloo/config.h + DESTINATION ${CMAKE_SOURCE_DIR}/cpp/collective/gloo/include) add_subdirectory(cpp) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d421cfa5..2c9740b9 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -4,13 +4,18 @@ project(XoscarCollective) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") if(NOT DEFINED LIBRARY_OUTPUT_DIRECTORY) - set(LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/python/xoscar/collective/rendezvous) + set(LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/python/xoscar/collective) endif() include_directories(${CMAKE_SOURCE_DIR}/cpp/collective/rendezvous/include) +include_directories(${CMAKE_SOURCE_DIR}/cpp/collective/gloo/include) +include_directories(../third_party/gloo) +include_directories(../third_party/fmt/include) add_subdirectory(collective/rendezvous) +add_subdirectory(collective/gloo) -pybind11_add_module(xoscar_store collective/rendezvous/src/bind_tcp_store.cpp) -target_link_libraries(xoscar_store PRIVATE StoreLib fmt::fmt) -set_target_properties(xoscar_store PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_DIRECTORY}) +pybind11_add_module(xoscar_pygloo collective/gloo/main.cc) +target_link_libraries(xoscar_pygloo PRIVATE GlooLib gloo StoreLib fmt::fmt) +set_target_properties(xoscar_pygloo PROPERTIES LIBRARY_OUTPUT_DIRECTORY + ${LIBRARY_OUTPUT_DIRECTORY}) diff --git a/cpp/collective/gloo/CMakeLists.txt b/cpp/collective/gloo/CMakeLists.txt new file mode 100644 index 00000000..4cdbe019 --- /dev/null +++ b/cpp/collective/gloo/CMakeLists.txt @@ -0,0 +1,32 @@ +cmake_minimum_required(VERSION 3.11...3.21) + +project( + XoscarGloo + VERSION 0.0.1 + LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) + +include_directories(include) +include_directories(../rendezvous/include) +include_directories(../../../third_party/pybind11/include) + +add_library( + GlooLib + include/collective.h + include/rendezvous.h + include/transport.h + include/config.h + src/allgather.cc + src/allreduce.cc + src/barrier.cc + src/broadcast.cc + src/gather.cc + src/recv.cc + src/reduce_scatter.cc + src/reduce.cc + src/rendezvous.cc + src/scatter.cc + src/send.cc + src/transport.cc + src/all_to_all.cc) diff --git a/cpp/collective/gloo/LICENSE b/cpp/collective/gloo/LICENSE new file mode 100644 index 00000000..c944a02b --- /dev/null +++ b/cpp/collective/gloo/LICENSE @@ -0,0 +1,207 @@ +Apache 2.0 License with Code from pygloo Repository + +This software includes code primarily sourced from the pygloo repository, which is governed by its own original license. +The original license for the pygloo repository, as well as the additional terms below, apply to the code in this software. + +----------------------------------------------------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/cpp/collective/gloo/include/collective.h b/cpp/collective/gloo/include/collective.h new file mode 100644 index 00000000..be494190 --- /dev/null +++ b/cpp/collective/gloo/include/collective.h @@ -0,0 +1,171 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xoscar { + +enum class ReduceOp : std::uint8_t { + SUM = 0, + PRODUCT, + MIN, + MAX, + BAND, // Bitwise AND + BOR, // Bitwise OR + BXOR, // Bitwise XOR + UNUSED, +}; + +typedef void (*ReduceFunc)(void *, const void *, const void *, size_t); + +template +ReduceFunc toFunction(const ReduceOp &r) { + switch (r) { + case ReduceOp::SUM: + return ReduceFunc(&gloo::sum); + case ReduceOp::PRODUCT: + return ReduceFunc(&gloo::product); + case ReduceOp::MIN: + return ReduceFunc(&gloo::min); + case ReduceOp::MAX: + return ReduceFunc(&gloo::max); + case ReduceOp::BAND: + throw std::runtime_error( + "Cannot use ReduceOp.BAND with non-integral dtype"); + break; + case ReduceOp::BOR: + throw std::runtime_error( + "Cannot use ReduceOp.BOR with non-integral dtype"); + break; + case ReduceOp::BXOR: + throw std::runtime_error( + "Cannot use ReduceOp.BXOR with non-integral dtype"); + break; + case ReduceOp::UNUSED: + break; + } + + throw std::runtime_error("Unhandled ReduceOp"); +} + +enum class glooDataType_t : std::uint8_t { + glooInt8 = 0, + glooUint8, + glooInt32, + glooUint32, + glooInt64, + glooUint64, + glooFloat16, + glooFloat32, + glooFloat64, +}; + +void allreduce_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + ReduceOp reduceop = ReduceOp::SUM, + gloo::AllreduceOptions::Algorithm algorithm + = gloo::AllreduceOptions::Algorithm::RING, + uint32_t tag = 0); + +void allgather_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + uint32_t tag = 0); + +void allgatherv_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + uint32_t tag = 0); + +void reduce_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + ReduceOp reduceop = xoscar::ReduceOp::SUM, + int root = 0, + uint32_t tag = 0); + +void scatter_wrapper(const std::shared_ptr &context, + std::vector sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int root = 0, + uint32_t tag = 0); + +void gather_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int root = 0, + uint32_t tag = 0); + +void send_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + size_t size, + glooDataType_t datatype, + int peer, + uint32_t tag = 0); + +void recv_wrapper(const std::shared_ptr &context, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int peer, + uint32_t tag = 0); + +void broadcast_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int root = 0, + uint32_t tag = 0); + +void reduce_scatter_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + std::vector recvElems, + glooDataType_t datatype, + ReduceOp reduceop = xoscar::ReduceOp::SUM); + +void barrier(const std::shared_ptr &context, uint32_t tag = 0); + +void all_to_all_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + uint32_t tag); +} // namespace xoscar diff --git a/cpp/collective/gloo/include/rendezvous.h b/cpp/collective/gloo/include/rendezvous.h new file mode 100644 index 00000000..a24bf203 --- /dev/null +++ b/cpp/collective/gloo/include/rendezvous.h @@ -0,0 +1,24 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include + +namespace xoscar { +namespace rendezvous { + +void def_rendezvous_module(pybind11::module &m); +} // namespace rendezvous +} // namespace xoscar diff --git a/cpp/collective/gloo/include/transport.h b/cpp/collective/gloo/include/transport.h new file mode 100644 index 00000000..629d1e22 --- /dev/null +++ b/cpp/collective/gloo/include/transport.h @@ -0,0 +1,110 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#if GLOO_HAVE_TRANSPORT_TCP + +# include +# include +# include +# include +# include +# include +# include +# include + +#endif + +#if GLOO_HAVE_TRANSPORT_UV + +# include +# include +# include +# include +# include + +#endif + +#if !GLOO_HAVE_TRANSPORT_UV +# if !GLOO_HAVE_TRANSPORT_UV +# include +# include +# include +# include +# include +# endif +#endif + +namespace xoscar { +namespace transport { +class PyDevice : public gloo::transport::Device { +public: + using gloo::transport::Device::Device; + + std::string str() const override { + PYBIND11_OVERRIDE_PURE( + std::string, // Return type + gloo::transport::Device, // Parent class + str, /* Name of function in C++ (must match Python name) */ + /* Argument(s) */); + } + + const std::string &getPCIBusID() const override { + PYBIND11_OVERRIDE_PURE( + const std::string &, /* Return type */ + gloo::transport::Device, /* Parent class */ + getPCIBusID, /* Name of function in C++ (must match Python name) */ + /* Argument(s) */); + } + + int getInterfaceSpeed() const override { + PYBIND11_OVERRIDE(int, /* Return type */ + gloo::transport::Device, // Parent class + getInterfaceSpeed, // Name of function in C++ (must + // match Python name) + /* Argument(s) */); + } + + bool hasGPUDirect() const override { + PYBIND11_OVERRIDE( + bool, /* Return type */ + gloo::transport::Device, /* Parent class */ + hasGPUDirect, /* Name of function in C++ (must match Python name) */ + /* Argument(s) */); + } + + std::shared_ptr createContext(int rank, + int size) override { + PYBIND11_OVERRIDE_PURE( + std::shared_ptr, /* Return type */ + gloo::transport::Device, /* Parent class */ + createContext, // Name of function in C++ (must match Python name) + rank, + size /* Argument(s) */); + } +}; + +void def_transport_module(pybind11::module &m); +void def_transport_tcp_module(pybind11::module &m); +void def_transport_uv_module(pybind11::module &m); +} // namespace transport +} // namespace xoscar diff --git a/cpp/collective/gloo/main.cc b/cpp/collective/gloo/main.cc new file mode 100644 index 00000000..3c26dba4 --- /dev/null +++ b/cpp/collective/gloo/main.cc @@ -0,0 +1,205 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xoscar { +bool transport_tcp_available() { return GLOO_HAVE_TRANSPORT_TCP; } + +bool transport_uv_available() { return GLOO_HAVE_TRANSPORT_UV; } +} // namespace xoscar +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MODULE(xoscar_pygloo, m) { + m.doc() = "binding gloo from c to python"; // optional module docstring + + m.def("transport_tcp_available", + &xoscar::transport_tcp_available, + "transport_tcp_available"); + + m.def("transport_uv_available", + &xoscar::transport_uv_available, + "transport_uv_available"); + pybind11::bind_vector>(m, "StringVector"); + + pybind11::enum_(m, "ReduceOp", pybind11::arithmetic()) + .value("SUM", xoscar::ReduceOp::SUM) + .value("PRODUCT", xoscar::ReduceOp::PRODUCT) + .value("MIN", xoscar::ReduceOp::MIN) + .value("MAX", xoscar::ReduceOp::MAX) + .value("BAND", xoscar::ReduceOp::BAND) + .value("BOR", xoscar::ReduceOp::BOR) + .value("BXOR", xoscar::ReduceOp::BXOR) + .value("UNUSED", xoscar::ReduceOp::UNUSED) + .export_values(); + + pybind11::enum_( + m, "AllreduceAlgorithm", pybind11::arithmetic()) + .value("SUM", + gloo::detail::AllreduceOptionsImpl::Algorithm::UNSPECIFIED) + .value("RING", gloo::detail::AllreduceOptionsImpl::Algorithm::RING) + .value("BCUBE", gloo::detail::AllreduceOptionsImpl::Algorithm::BCUBE) + .export_values(); + + pybind11::enum_( + m, "GlooDataType_t", pybind11::arithmetic()) + .value("glooInt8", xoscar::glooDataType_t::glooInt8) + .value("glooUint8", xoscar::glooDataType_t::glooUint8) + .value("glooInt32", xoscar::glooDataType_t::glooInt32) + .value("glooUint32", xoscar::glooDataType_t::glooUint32) + .value("glooInt64", xoscar::glooDataType_t::glooInt64) + .value("glooUint64", xoscar::glooDataType_t::glooUint64) + .value("glooFloat16", xoscar::glooDataType_t::glooFloat16) + .value("glooFloat32", xoscar::glooDataType_t::glooFloat32) + .value("glooFloat64", xoscar::glooDataType_t::glooFloat64) + .export_values(); + + m.def("allreduce", + &xoscar::allreduce_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("reduceop") = xoscar::ReduceOp::SUM, + pybind11::arg("algorithm") = gloo::AllreduceOptions::Algorithm::RING, + pybind11::arg("tag") = 0); + + m.def("allgather", + &xoscar::allgather_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("tag") = 0); + m.def("allgatherv", + &xoscar::allgatherv_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("tag") = 0); + + m.def("reduce", + &xoscar::reduce_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("reduceop") = xoscar::ReduceOp::SUM, + pybind11::arg("root") = 0, + pybind11::arg("tag") = 0); + + m.def("scatter", + &xoscar::scatter_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("root") = 0, + pybind11::arg("tag") = 0); + + m.def("gather", + &xoscar::gather_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("root") = 0, + pybind11::arg("tag") = 0); + + m.def("send", + &xoscar::send_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("peer") = nullptr, + pybind11::arg("tag") = 0); + m.def("recv", + &xoscar::recv_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("peer") = nullptr, + pybind11::arg("tag") = 0); + + m.def("broadcast", + &xoscar::broadcast_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("root") = 0, + pybind11::arg("tag") = 0); +#ifdef __linux__ + m.def("reduce_scatter", + &xoscar::reduce_scatter_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("recvElems") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("reduceop") = xoscar::ReduceOp::SUM); +#endif + m.def("all_to_all", + &xoscar::all_to_all_wrapper, + pybind11::arg("context") = nullptr, + pybind11::arg("sendbuf") = nullptr, + pybind11::arg("recvbuf") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("datatype") = nullptr, + pybind11::arg("tag") = 0); + + m.def("barrier", + &xoscar::barrier, + pybind11::arg("context") = nullptr, + pybind11::arg("tag") = 0); + + pybind11::class_>(m, + "Context") + .def(pybind11::init(), + pybind11::arg("rank") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("base") = 2) + .def("getDevice", &gloo::Context::getDevice) + .def_readonly("rank", &gloo::Context::rank) + .def_readonly("size", &gloo::Context::size) + .def_readwrite("base", &gloo::Context::base) + // .def("getPair", &gloo::Context::getPair) + .def("createUnboundBuffer", &gloo::Context::createUnboundBuffer) + .def("nextSlot", &gloo::Context::nextSlot) + .def("closeConnections", &gloo::Context::closeConnections) + .def("setTimeout", &gloo::Context::setTimeout) + .def("getTimeout", &gloo::Context::getTimeout); + + xoscar::transport::def_transport_module(m); + xoscar::rendezvous::def_rendezvous_module(m); +} diff --git a/cpp/collective/gloo/src/all_to_all.cc b/cpp/collective/gloo/src/all_to_all.cc new file mode 100644 index 00000000..57f3d047 --- /dev/null +++ b/cpp/collective/gloo/src/all_to_all.cc @@ -0,0 +1,79 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +namespace xoscar { + +template +void all_to_all(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + uint32_t tag) { + T *input_ptr = reinterpret_cast(sendbuf); + T *output_ptr = reinterpret_cast(recvbuf); + + // Configure AlltoallOptions struct and call alltoall function + gloo::AlltoallOptions opts_(context); + opts_.setInput(input_ptr, size); + opts_.setOutput(output_ptr, size); + opts_.setTag(tag); + + gloo::alltoall(opts_); +} + +void all_to_all_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint8: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooInt32: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint32: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooInt64: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint64: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat16: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat32: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat64: + all_to_all(context, sendbuf, recvbuf, size, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} + +} // namespace xoscar diff --git a/cpp/collective/gloo/src/allgather.cc b/cpp/collective/gloo/src/allgather.cc new file mode 100644 index 00000000..8eec4d96 --- /dev/null +++ b/cpp/collective/gloo/src/allgather.cc @@ -0,0 +1,135 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +namespace xoscar { + +template +void allgather(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + uint32_t tag) { + T *input_ptr = reinterpret_cast(sendbuf); + T *output_ptr = reinterpret_cast(recvbuf); + + // Configure AllgatherOptions struct and call allgather function + gloo::AllgatherOptions opts_(context); + opts_.setInput(input_ptr, size); + opts_.setOutput(output_ptr, size * context->size); + opts_.setTag(tag); + + gloo::allgather(opts_); +} + +void allgather_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint8: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooInt32: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint32: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooInt64: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint64: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat16: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat32: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat64: + allgather(context, sendbuf, recvbuf, size, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} + +template +void allgatherv(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + uint32_t tag) { + T *input_ptr = reinterpret_cast(sendbuf); + T *output_ptr = reinterpret_cast(recvbuf); + + // Configure AllgatherOptions struct and call allgather function + gloo::AllgatherOptions opts_(context); + opts_.setInput(input_ptr, size); + opts_.setOutput(output_ptr, size * context->size); + opts_.setTag(tag); + + gloo::allgather(opts_); +} + +void allgatherv_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint8: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooInt32: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint32: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooInt64: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooUint64: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat16: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat32: + allgather(context, sendbuf, recvbuf, size, tag); + break; + case glooDataType_t::glooFloat64: + allgather(context, sendbuf, recvbuf, size, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/allreduce.cc b/cpp/collective/gloo/src/allreduce.cc new file mode 100644 index 00000000..ba8a16c9 --- /dev/null +++ b/cpp/collective/gloo/src/allreduce.cc @@ -0,0 +1,93 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +namespace xoscar { + +template +void allreduce(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + ReduceOp reduceop, + gloo::AllreduceOptions::Algorithm algorithm, + uint32_t tag) { + std::vector input_ptr{reinterpret_cast(sendbuf)}; + std::vector output_ptr{reinterpret_cast(recvbuf)}; + + // Configure AllreduceOptions struct and call allreduce function + gloo::AllreduceOptions opts_(context); + opts_.setInputs(input_ptr, size); + opts_.setOutputs(output_ptr, size); + opts_.setAlgorithm(algorithm); + gloo::ReduceOptions::Func fn = toFunction(reduceop); + opts_.setReduceFunction(fn); + opts_.setTag(tag); + + gloo::allreduce(opts_); +} + +void allreduce_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + ReduceOp reduceop, + gloo::AllreduceOptions::Algorithm algorithm, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooUint8: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooInt32: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooUint32: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooInt64: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooUint64: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooFloat16: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooFloat32: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + case glooDataType_t::glooFloat64: + allreduce( + context, sendbuf, recvbuf, size, reduceop, algorithm, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/barrier.cc b/cpp/collective/gloo/src/barrier.cc new file mode 100644 index 00000000..1f65c63d --- /dev/null +++ b/cpp/collective/gloo/src/barrier.cc @@ -0,0 +1,27 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +namespace xoscar { + +void barrier(const std::shared_ptr &context, uint32_t tag) { + gloo::BarrierOptions opts_(context); + + opts_.setTag(tag); + + gloo::barrier(opts_); +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/broadcast.cc b/cpp/collective/gloo/src/broadcast.cc new file mode 100644 index 00000000..649f19f8 --- /dev/null +++ b/cpp/collective/gloo/src/broadcast.cc @@ -0,0 +1,84 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +namespace xoscar { + +template +void broadcast(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + int root, + uint32_t tag) { + // Configure BroadcastOptions struct and call broadcast function + gloo::BroadcastOptions opts_(context); + + if (context->rank == root) { + T *input_ptr = reinterpret_cast(sendbuf); + opts_.setInput(input_ptr, size); + } + T *output_ptr = reinterpret_cast(recvbuf); + opts_.setOutput(output_ptr, size); + + opts_.setRoot(root); + opts_.setTag(tag); + + gloo::broadcast(opts_); +} + +void broadcast_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int root, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint8: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooInt32: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint32: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooInt64: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint64: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat16: + broadcast( + context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat32: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat64: + broadcast(context, sendbuf, recvbuf, size, root, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/gather.cc b/cpp/collective/gloo/src/gather.cc new file mode 100644 index 00000000..de32d26f --- /dev/null +++ b/cpp/collective/gloo/src/gather.cc @@ -0,0 +1,82 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +namespace xoscar { + +template +void gather(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + int root, + uint32_t tag) { + // Configure GatherOptions struct + gloo::GatherOptions opts_(context); + + T *input_ptr = reinterpret_cast(sendbuf); + opts_.setInput(input_ptr, size); + + if (root == context->rank) { + T *output_ptr = reinterpret_cast(recvbuf); + opts_.setOutput(output_ptr, context->size * size); + } + opts_.setRoot(root); + opts_.setTag(tag); + + gloo::gather(opts_); +} + +void gather_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int root, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint8: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooInt32: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint32: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooInt64: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint64: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat16: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat32: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat64: + gather(context, sendbuf, recvbuf, size, root, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/recv.cc b/cpp/collective/gloo/src/recv.cc new file mode 100644 index 00000000..8e999e04 --- /dev/null +++ b/cpp/collective/gloo/src/recv.cc @@ -0,0 +1,78 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +namespace xoscar { + +template +void recv(const std::shared_ptr &context, + intptr_t recvbuf, + size_t size, + int peer, + uint32_t tag) { + if (context->rank == peer) + throw std::runtime_error( + "peer equals to current rank. Please specify other peer values."); + + auto outputBuffer = context->createUnboundBuffer( + reinterpret_cast(recvbuf), size * sizeof(T)); + + constexpr uint8_t kSendRecvSlotPrefix = 0x09; + gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag); + + outputBuffer->recv(peer, slot); + outputBuffer->waitRecv(context->getTimeout()); +} + +void recv_wrapper(const std::shared_ptr &context, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int peer, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooUint8: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooInt32: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooUint32: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooInt64: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooUint64: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooFloat16: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooFloat32: + recv(context, recvbuf, size, peer, tag); + break; + case glooDataType_t::glooFloat64: + recv(context, recvbuf, size, peer, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/reduce.cc b/cpp/collective/gloo/src/reduce.cc new file mode 100644 index 00000000..e1afe394 --- /dev/null +++ b/cpp/collective/gloo/src/reduce.cc @@ -0,0 +1,100 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +namespace xoscar { + +template +void reduce(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + ReduceOp reduceop, + int root, + uint32_t tag) { + T *input_ptr = reinterpret_cast(sendbuf); + + T *output_ptr; + if (context->rank == root) + output_ptr = reinterpret_cast(recvbuf); + else + output_ptr = new T[size]; + + // Configure reduceOptions struct + gloo::ReduceOptions opts_(context); + opts_.setInput(input_ptr, size); + opts_.setOutput(output_ptr, size); + gloo::ReduceOptions::Func fn = toFunction(reduceop); + opts_.setReduceFunction(fn); + opts_.setRoot(root); + opts_.setTag(tag); + + gloo::reduce(opts_); + + if (context->rank != root) + delete output_ptr; +} + +void reduce_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + ReduceOp reduceop, + int root, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooUint8: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooInt32: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooUint32: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooInt64: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooUint64: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooFloat16: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooFloat32: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + case glooDataType_t::glooFloat64: + reduce( + context, sendbuf, recvbuf, size, reduceop, root, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/reduce_scatter.cc b/cpp/collective/gloo/src/reduce_scatter.cc new file mode 100644 index 00000000..7f558fc2 --- /dev/null +++ b/cpp/collective/gloo/src/reduce_scatter.cc @@ -0,0 +1,129 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +namespace xoscar { + +template +const gloo::ReductionFunction *getReductionFunction(ReduceOp reduceop) { + switch (reduceop) { + case ReduceOp::SUM: + return gloo::ReductionFunction::sum; + break; + case ReduceOp::PRODUCT: + return gloo::ReductionFunction::product; + break; + case ReduceOp::MIN: + return gloo::ReductionFunction::min; + break; + case ReduceOp::MAX: + return gloo::ReductionFunction::max; + break; + case ReduceOp::BAND: + throw std::runtime_error( + "Cannot use ReduceOp.BAND with non-integral dtype"); + break; + case ReduceOp::BOR: + throw std::runtime_error( + "Cannot use ReduceOp.BOR with non-integral dtype"); + break; + case ReduceOp::BXOR: + throw std::runtime_error( + "Cannot use ReduceOp.BXOR with non-integral dtype"); + break; + case ReduceOp::UNUSED: + break; + } + throw std::runtime_error("Unhandled ReduceOp"); +} + +template +void reduce_scatter(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + std::vector recvElems, + ReduceOp reduceop) { + T *input_ptr = reinterpret_cast(sendbuf); + + std::vector inputbuf(size); + + memcpy(inputbuf.data(), input_ptr, size * sizeof(T)); + + std::vector dataPtrs{inputbuf.data()}; + + const gloo::ReductionFunction *fn = getReductionFunction(reduceop); + + gloo::ReduceScatterHalvingDoubling algorithm( + context, dataPtrs, size, recvElems, fn); + algorithm.run(); + + memcpy(reinterpret_cast(recvbuf), + inputbuf.data(), + recvElems[context->rank] * sizeof(T)); +} + +void reduce_scatter_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + intptr_t recvbuf, + size_t size, + std::vector recvElems, + glooDataType_t datatype, + ReduceOp reduceop) { + switch (datatype) { + case glooDataType_t::glooInt8: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooUint8: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooInt32: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooUint32: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooInt64: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooUint64: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooFloat16: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooFloat32: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + case glooDataType_t::glooFloat64: + reduce_scatter( + context, sendbuf, recvbuf, size, recvElems, reduceop); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/rendezvous.cc b/cpp/collective/gloo/src/rendezvous.cc new file mode 100644 index 00000000..4c6f2ba3 --- /dev/null +++ b/cpp/collective/gloo/src/rendezvous.cc @@ -0,0 +1,162 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "tcp_store.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xoscar { +namespace rendezvous { +constexpr std::chrono::milliseconds kDefaultTimeout = std::chrono::seconds(30); +void def_rendezvous_module(pybind11::module &m) { + pybind11::module rendezvous + = m.def_submodule("rendezvous", "This is a rendezvous module"); + + pybind11::class_>(rendezvous, + "Context") + .def(pybind11::init(), + pybind11::arg("rank") = nullptr, + pybind11::arg("size") = nullptr, + pybind11::arg("base") = 2) + .def("connectFullMesh", &gloo::rendezvous::Context::connectFullMesh); + + pybind11::class_>(rendezvous, + "Store") + .def("set", &gloo::rendezvous::Store::set) + .def("get", &gloo::rendezvous::Store::get); + + pybind11::class_(rendezvous, "TCPStoreOptions") + .def(pybind11::init()) + .def_readwrite("port", &TCPStoreOptions::port) + .def_readwrite("isServer", &TCPStoreOptions::isServer) + .def_readwrite("numWorkers", &TCPStoreOptions::numWorkers) + .def_readwrite("waitWorkers", &TCPStoreOptions::waitWorkers) + .def_readwrite("timeout", &TCPStoreOptions::timeout) + .def_readwrite("multiTenant", &TCPStoreOptions::multiTenant); + + pybind11::class_>( + rendezvous, // why we use pybind11::nodelete: + // https://github.com/pybind/pybind11/issues/3514 + "TCPStore") + .def(pybind11::init()) + .def("wait", + pybind11::overload_cast &>( + &TCPStore::wait)) + .def("wait", + pybind11::overload_cast &, + const std::chrono::milliseconds &>( + &TCPStore::wait)) + .def("set", &TCPStore::set) + .def("get", &TCPStore::get); + + pybind11::class_>(rendezvous, + "FileStore") + .def(pybind11::init()) + .def("set", &gloo::rendezvous::FileStore::set) + .def("get", &gloo::rendezvous::FileStore::get); + + pybind11::class_>(rendezvous, + "HashStore") + .def(pybind11::init([]() { return new gloo::rendezvous::HashStore(); })) + .def("set", &gloo::rendezvous::HashStore::set) + .def("get", &gloo::rendezvous::HashStore::get); + + pybind11::class_>( + rendezvous, "PrefixStore") + .def(pybind11::init()) + .def("set", &gloo::rendezvous::PrefixStore::set) + .def("get", &gloo::rendezvous::PrefixStore::get); + + class CustomStore : public gloo::rendezvous::Store { + public: + explicit CustomStore(const pybind11::object &real_store_py_object) + : real_store_py_object_(real_store_py_object) {} + + virtual ~CustomStore() {} + + void set(const std::string &key, + const std::vector &data) override { + pybind11::str py_key(key.data(), key.size()); + pybind11::bytes py_data(data.data(), data.size()); + auto set_func = real_store_py_object_.attr("set_tcp"); + set_func(py_key, py_data); + } + + std::vector get(const std::string &key) override { + /// Wait until key being ready. + wait({key}); + + pybind11::str py_key(key.data(), key.size()); + auto get_func = real_store_py_object_.attr("get_tcp"); + pybind11::bytes data = get_func(py_key); + std::string ret_str = data; + std::vector ret(ret_str.data(), + ret_str.data() + ret_str.size()); + return ret; + } + + void wait(const std::vector &keys) override { + wait(keys, xoscar::rendezvous::kDefaultTimeout); + } + + void wait(const std::vector &keys, + const std::chrono::milliseconds &timeout) override { + // We now ignore the timeout_ms. + + pybind11::list py_keys = pybind11::cast(keys); + auto wait_func = real_store_py_object_.attr("wait"); + wait_func(py_keys); + } + + void delKeys(const std::vector &keys) { + pybind11::list py_keys = pybind11::cast(keys); + auto del_keys_func = real_store_py_object_.attr("del_keys"); + del_keys_func(py_keys); + } + + protected: + const pybind11::object real_store_py_object_; + }; + + pybind11::class_>(rendezvous, "CustomStore") + .def(pybind11::init()) + .def("set", &CustomStore::set) + .def("get", &CustomStore::get) + .def("delKeys", &CustomStore::delKeys); +} +} // namespace rendezvous +} // namespace xoscar diff --git a/cpp/collective/gloo/src/scatter.cc b/cpp/collective/gloo/src/scatter.cc new file mode 100644 index 00000000..2a63e18a --- /dev/null +++ b/cpp/collective/gloo/src/scatter.cc @@ -0,0 +1,82 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +namespace xoscar { + +template +void scatter(const std::shared_ptr &context, + std::vector sendbuf, + intptr_t recvbuf, + size_t size, + int root, + uint32_t tag) { + std::vector input_ptr; + for (size_t i = 0; i < sendbuf.size(); ++i) + input_ptr.emplace_back(reinterpret_cast(sendbuf[i])); + + T *output_ptr = reinterpret_cast(recvbuf); + + // Configure ScatterOptions struct + gloo::ScatterOptions opts_(context); + opts_.setInputs(input_ptr, size); + opts_.setOutput(output_ptr, size); + opts_.setTag(tag); + opts_.setRoot(root); + + gloo::scatter(opts_); +} + +void scatter_wrapper(const std::shared_ptr &context, + std::vector sendbuf, + intptr_t recvbuf, + size_t size, + glooDataType_t datatype, + int root, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint8: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooInt32: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint32: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooInt64: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooUint64: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat16: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat32: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + case glooDataType_t::glooFloat64: + scatter(context, sendbuf, recvbuf, size, root, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/send.cc b/cpp/collective/gloo/src/send.cc new file mode 100644 index 00000000..412d5983 --- /dev/null +++ b/cpp/collective/gloo/src/send.cc @@ -0,0 +1,78 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +namespace xoscar { + +template +void send(const std::shared_ptr &context, + intptr_t sendbuf, + size_t size, + int peer, + uint32_t tag) { + if (context->rank == peer) + throw std::runtime_error( + "peer equals to current rank. Please specify other peer values."); + + auto inputBuffer = context->createUnboundBuffer( + reinterpret_cast(sendbuf), size * sizeof(T)); + + constexpr uint8_t kSendRecvSlotPrefix = 0x09; + gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag); + + inputBuffer->send(peer, slot); + inputBuffer->waitSend(context->getTimeout()); +} + +void send_wrapper(const std::shared_ptr &context, + intptr_t sendbuf, + size_t size, + glooDataType_t datatype, + int peer, + uint32_t tag) { + switch (datatype) { + case glooDataType_t::glooInt8: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooUint8: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooInt32: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooUint32: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooInt64: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooUint64: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooFloat16: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooFloat32: + send(context, sendbuf, size, peer, tag); + break; + case glooDataType_t::glooFloat64: + send(context, sendbuf, size, peer, tag); + break; + default: + throw std::runtime_error("Unhandled dataType"); + } +} +} // namespace xoscar diff --git a/cpp/collective/gloo/src/transport.cc b/cpp/collective/gloo/src/transport.cc new file mode 100644 index 00000000..193f8aea --- /dev/null +++ b/cpp/collective/gloo/src/transport.cc @@ -0,0 +1,117 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +namespace xoscar { +namespace transport { + +#if GLOO_HAVE_TRANSPORT_TCP +template +using overload_cast_ = pybind11::detail::overload_cast_impl; + +void def_transport_tcp_module(pybind11::module &m) { + pybind11::module tcp = m.def_submodule("tcp", "This is a tcp module"); + + tcp.def("CreateDevice", &gloo::transport::tcp::CreateDevice); + + pybind11::class_(tcp, "attr") + .def(pybind11::init<>()) + .def(pybind11::init()) + .def_readwrite("hostname", &gloo::transport::tcp::attr::hostname) + .def_readwrite("iface", &gloo::transport::tcp::attr::iface) + .def_readwrite("ai_family", &gloo::transport::tcp::attr::ai_family) + .def_readwrite("hostname", &gloo::transport::tcp::attr::hostname) + .def_readwrite("ai_socktype", &gloo::transport::tcp::attr::ai_socktype) + .def_readwrite("ai_protocol", &gloo::transport::tcp::attr::ai_protocol) + .def_readwrite("ai_addr", &gloo::transport::tcp::attr::ai_addr) + .def_readwrite("ai_addrlen", &gloo::transport::tcp::attr::ai_addrlen); + + pybind11::class_>(tcp, + "Context") + .def(pybind11::init, + int, + int>()) + // .def("createPair", &gloo::transport::tcp::Context::createPair) + .def("createUnboundBuffer", + &gloo::transport::tcp::Context::createUnboundBuffer); + + pybind11::class_, + gloo::transport::Device>(tcp, "Device") + .def(pybind11::init()); +} +#else +void def_transport_tcp_module(pybind11::module &m) { + pybind11::module tcp = m.def_submodule("tcp", "This is a tcp module"); +} +#endif + +#if GLOO_HAVE_TRANSPORT_UV +void def_transport_uv_module(pybind11::module &m) { + pybind11::module uv = m.def_submodule("uv", "This is a uv module"); + + uv.def("CreateDevice", &gloo::transport::uv::CreateDevice, "CreateDevice"); + + pybind11::class_(uv, "attr") + .def(pybind11::init<>()) + .def(pybind11::init()) + .def_readwrite("hostname", &gloo::transport::uv::attr::hostname) + .def_readwrite("iface", &gloo::transport::uv::attr::iface) + .def_readwrite("ai_family", &gloo::transport::uv::attr::ai_family) + .def_readwrite("ai_socktype", &gloo::transport::uv::attr::ai_socktype) + .def_readwrite("ai_protocol", &gloo::transport::uv::attr::ai_protocol) + .def_readwrite("ai_addr", &gloo::transport::uv::attr::ai_addr) + .def_readwrite("ai_addrlen", &gloo::transport::uv::attr::ai_addrlen); + + pybind11::class_>(uv, + "Context") + .def(pybind11:: + init, int, int>()) + .def("createUnboundBuffer", + &gloo::transport::uv::Context::createUnboundBuffer); + + pybind11::class_, + gloo::transport::Device>(uv, "Device") + .def(pybind11::init()); +} +#else +void def_transport_uv_module(pybind11::module &m) { + pybind11::module uv = m.def_submodule("uv", "This is a uv module"); +} +#endif + +void def_transport_module(pybind11::module &m) { + pybind11::module transport + = m.def_submodule("transport", "This is a transport module"); + + pybind11::class_, + xoscar::transport::PyDevice>( + transport, "Device", pybind11::module_local()) + .def("str", &gloo::transport::Device::str) + .def("getPCIBusID", &gloo::transport::Device::getPCIBusID) + .def("getInterfaceSpeed", &gloo::transport::Device::getInterfaceSpeed) + .def("hasGPUDirect", &gloo::transport::Device::hasGPUDirect) + .def("createContext", &gloo::transport::Device::createContext); + + def_transport_uv_module(transport); + def_transport_tcp_module(transport); +} +} // namespace transport +} // namespace xoscar diff --git a/cpp/collective/rendezvous/CMakeLists.txt b/cpp/collective/rendezvous/CMakeLists.txt index 2757139d..014d963e 100644 --- a/cpp/collective/rendezvous/CMakeLists.txt +++ b/cpp/collective/rendezvous/CMakeLists.txt @@ -8,7 +8,6 @@ project( set(CMAKE_CXX_STANDARD 20) include_directories(include) -include_directories(../../../third_party/fmt/include) add_library( StoreLib @@ -17,8 +16,6 @@ add_library( src/exception.cpp include/socket.h src/socket.cpp - include/store.hpp - src/store.cpp include/tcp_store.hpp src/tcp_store.cpp include/unix_sock_utils.hpp diff --git a/cpp/collective/rendezvous/include/store.hpp b/cpp/collective/rendezvous/include/store.hpp deleted file mode 100644 index 57f430c2..00000000 --- a/cpp/collective/rendezvous/include/store.hpp +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2022-2023 XProbe Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace xoscar { - -// callback function will be given arguments (optional oldValue, -// optional newValue) -using WatchKeyCallback = std::function, - std::optional)>; - -class Store { -public: - static constexpr std::chrono::milliseconds kDefaultTimeout - = std::chrono::seconds(300); - static constexpr std::chrono::milliseconds kNoTimeout - = std::chrono::milliseconds::zero(); - - Store() : timeout_(kDefaultTimeout) {} - - explicit Store(const std::chrono::milliseconds &timeout) - : timeout_(timeout) {} - - ~Store(); - - void set(const std::string &key, const std::string &value); - - virtual void set(const std::string &key, const std::vector &value) - = 0; - - std::string compareSet(const std::string &key, - const std::string ¤tValue, - const std::string &newValue); - - virtual std::vector - compareSet(const std::string &key, - const std::vector ¤tValue, - const std::vector &newValue) { - throw std::runtime_error("Not implemented."); - } - - std::string get_to_str(const std::string &key); - - virtual std::vector get(const std::string &key) = 0; - - virtual int64_t add(const std::string &key, int64_t value) = 0; - - virtual bool deleteKey(const std::string &key) = 0; - - virtual bool check(const std::vector &keys) = 0; - - virtual int64_t getNumKeys() = 0; - - virtual void wait(const std::vector &keys) = 0; - - virtual void wait(const std::vector &keys, - const std::chrono::milliseconds &timeout) - = 0; - - virtual const std::chrono::milliseconds &getTimeout() const noexcept; - - virtual void setTimeout(const std::chrono::milliseconds &timeout); - - // watchKey() takes two arguments: key and callback function. The callback - // should be run whenever the key is changed (create, update, or delete). - // The callback function takes two parameters: currentValue and newValue, - // which are optional depending on how the key is changed. These key updates - // should trigger the callback as follows: CREATE: callback(c10::nullopt, - // newValue) // null currentValue UPDATE: callback(currentValue, newValue) - // DELETE: callback(currentValue, c10::nullopt) // null newValue - virtual void watchKey(const std::string & /* unused */, - WatchKeyCallback /* unused */) { - throw std::runtime_error("watchKey only implemented for TCPStore and " - "PrefixStore that wraps TCPStore."); - } - - virtual void append(const std::string &key, - const std::vector &value); - - virtual std::vector> - multiGet(const std::vector &keys); - - virtual void multiSet(const std::vector &keys, - const std::vector> &values); - - // Returns true if this store support watchKey, append, multiGet and - // multiSet - virtual bool hasExtendedApi() const; - -protected: - std::chrono::milliseconds timeout_; -}; - -} // namespace xoscar diff --git a/cpp/collective/rendezvous/include/tcp_store.hpp b/cpp/collective/rendezvous/include/tcp_store.hpp index 95b5c205..8d72fe7b 100644 --- a/cpp/collective/rendezvous/include/tcp_store.hpp +++ b/cpp/collective/rendezvous/include/tcp_store.hpp @@ -12,13 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once - -#include "store.hpp" - +#include #include #include +#include +#include +#include #include #include +#include +#include +#include +#include namespace xoscar { namespace detail { @@ -36,6 +41,9 @@ struct SocketAddress { } // namespace detail +using WatchKeyCallback = std::function, + std::optional)>; + struct TCPStoreOptions { static constexpr std::uint16_t kDefaultPort = 29500; @@ -43,15 +51,19 @@ struct TCPStoreOptions { bool isServer = false; std::optional numWorkers = std::nullopt; bool waitWorkers = true; - std::chrono::milliseconds timeout = Store::kDefaultTimeout; + std::chrono::milliseconds timeout = std::chrono::seconds(300); // A boolean value indicating whether multiple store instances can be // initialized with the same host:port pair. bool multiTenant = false; }; -class TCPStore : public Store { +class TCPStore : public gloo::rendezvous::Store { public: + static constexpr std::chrono::milliseconds kDefaultTimeout + = std::chrono::seconds(300); + static constexpr std::chrono::milliseconds kNoTimeout + = std::chrono::milliseconds::zero(); explicit TCPStore(std::string host, const TCPStoreOptions &opts = {}); [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( @@ -64,46 +76,43 @@ class TCPStore : public Store { ~TCPStore(); - void set(const std::string &key, - const std::vector &value) override; + void setTCP(const std::string &key, const std::vector &value); - std::vector - compareSet(const std::string &key, - const std::vector &expectedValue, - const std::vector &desiredValue) override; + std::vector compareSet(const std::string &key, + const std::vector &expectedValue, + const std::vector &desiredValue); - std::vector get(const std::string &key) override; + std::vector getTCP(const std::string &key); int64_t add(const std::string &key, int64_t value) override; - bool deleteKey(const std::string &key) override; + bool deleteKey(const std::string &key); // NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe // watchKey() is a blocking operation. It will register the socket on // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will // return once it has verified the callback is registered on both background // threads. Only one thread can call watchKey() at a time. - void watchKey(const std::string &key, WatchKeyCallback callback) override; + void watchKey(const std::string &key, WatchKeyCallback callback); - bool check(const std::vector &keys) override; + bool check(const std::vector &keys); - int64_t getNumKeys() override; + int64_t getNumKeys(); void wait(const std::vector &keys) override; void wait(const std::vector &keys, const std::chrono::milliseconds &timeout) override; - void append(const std::string &key, - const std::vector &value) override; + void append(const std::string &key, const std::vector &value); std::vector> - multiGet(const std::vector &keys) override; + multiGet(const std::vector &keys); void multiSet(const std::vector &keys, - const std::vector> &values) override; + const std::vector> &values); - bool hasExtendedApi() const override; + bool hasExtendedApi() const; // Waits for all workers to join. void waitForWorkers(); @@ -114,6 +123,13 @@ class TCPStore : public Store { // Returns the port used by the TCPStore. std::uint16_t getPort() const noexcept { return addr_.port; } + void set(const std::string &key, const std::vector &data) override; + + std::vector get(const std::string &key) override; + +protected: + std::chrono::milliseconds timeout_; + private: int64_t incrementValueBy(const std::string &key, int64_t delta); diff --git a/cpp/collective/rendezvous/src/bind_tcp_store.cpp b/cpp/collective/rendezvous/src/bind_tcp_store.cpp deleted file mode 100644 index 019e0b02..00000000 --- a/cpp/collective/rendezvous/src/bind_tcp_store.cpp +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2022-2023 XProbe Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include "tcp_store.hpp" - -#include -#include -#include - -namespace py = pybind11; - -namespace xoscar { -PYBIND11_MODULE(xoscar_store, m) { - py::class_(m, "TCPStoreOptions") - .def(py::init()) - .def_readwrite("port", &TCPStoreOptions::port) - .def_readwrite("isServer", &TCPStoreOptions::isServer) - .def_readwrite("numWorkers", &TCPStoreOptions::numWorkers) - .def_readwrite("waitWorkers", &TCPStoreOptions::waitWorkers) - .def_readwrite("timeout", &TCPStoreOptions::timeout) - .def_readwrite("multiTenant", &TCPStoreOptions::multiTenant); - - py::class_(m, "Store"); - - py::class_(m, "TCPStore") - .def(py::init()) - .def("wait", - py::overload_cast &>( - &TCPStore::wait)) - .def("wait", - py::overload_cast &, - const std::chrono::milliseconds &>( - &TCPStore::wait)) - .def("set", - [](TCPStore &self, const std::string &key, py::bytes &bytes) { - const py::buffer_info info(py::buffer(bytes).request()); - const char *data = reinterpret_cast(info.ptr); - auto length = static_cast(info.size); - self.set(key, std::vector(data, data + length)); - }) - .def("get", [](TCPStore &self, const std::string &key) { - auto result = self.get(key); - const std::string str_result(result.begin(), result.end()); - return py::bytes(str_result); - }); -} -} // namespace xoscar diff --git a/cpp/collective/rendezvous/src/store.cpp b/cpp/collective/rendezvous/src/store.cpp deleted file mode 100644 index 1309d06f..00000000 --- a/cpp/collective/rendezvous/src/store.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2022-2023 XProbe Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "store.hpp" - -namespace xoscar { - -constexpr std::chrono::milliseconds Store::kDefaultTimeout; -constexpr std::chrono::milliseconds Store::kNoTimeout; - -// Define destructor symbol for abstract base class. -Store::~Store() = default; - -const std::chrono::milliseconds &Store::getTimeout() const noexcept { - return timeout_; -} - -// Set timeout function -void Store::setTimeout(const std::chrono::milliseconds &timeout) { - timeout_ = timeout; -} - -void Store::set(const std::string &key, const std::string &value) { - set(key, std::vector(value.begin(), value.end())); -} - -std::string Store::compareSet(const std::string &key, - const std::string ¤tValue, - const std::string &newValue) { - auto value = compareSet( - key, - std::vector(currentValue.begin(), currentValue.end()), - std::vector(newValue.begin(), newValue.end())); - return std::string(value.begin(), value.end()); -} - -std::string Store::get_to_str(const std::string &key) { - auto value = get(key); - return std::string(value.begin(), value.end()); -} - -void Store::append(const std::string &key, const std::vector &value) { - // This fallback depends on compareSet - std::vector expected = value; - std::vector current; - // cannot use get(key) as it might block forever if the key doesn't exist - current = compareSet(key, current, expected); - while (current != expected) { - expected = current; - expected.insert(expected.end(), value.begin(), value.end()); - current = compareSet(key, current, expected); - } -} - -std::vector> -Store::multiGet(const std::vector &keys) { - std::vector> result; - result.reserve(keys.size()); - for (auto &key : keys) { - result.emplace_back(get(key)); - } - return result; -} - -void Store::multiSet(const std::vector &keys, - const std::vector> &values) { - for (int i = 0; i < keys.size(); i++) { - set(keys[i], values[i]); - } -} - -bool Store::hasExtendedApi() const { return false; } - -} // namespace xoscar diff --git a/cpp/collective/rendezvous/src/tcp_store.cpp b/cpp/collective/rendezvous/src/tcp_store.cpp index 0af12378..2643810f 100644 --- a/cpp/collective/rendezvous/src/tcp_store.cpp +++ b/cpp/collective/rendezvous/src/tcp_store.cpp @@ -1075,8 +1075,8 @@ TCPStore::TCPStore(const std::string &masterAddr, timeout}} {} TCPStore::TCPStore(std::string host, const TCPStoreOptions &opts) - : Store{opts.timeout}, addr_{std::move(host)}, numWorkers_{ - opts.numWorkers} { + : timeout_{opts.timeout}, addr_{std::move(host)}, numWorkers_{ + opts.numWorkers} { Socket::initialize(); if (opts.isServer) { @@ -1134,7 +1134,8 @@ void TCPStore::waitForWorkers() { } } -void TCPStore::set(const std::string &key, const std::vector &data) { +void TCPStore::setTCP(const std::string &key, + const std::vector &data) { const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::SET); buffer.appendString(keyPrefix_ + key); @@ -1156,7 +1157,7 @@ TCPStore::compareSet(const std::string &key, return client_->receiveBits(); } -std::vector TCPStore::get(const std::string &key) { +std::vector TCPStore::getTCP(const std::string &key) { const std::lock_guard lock(activeOpLock_); return doGet(keyPrefix_ + key); } @@ -1313,4 +1314,16 @@ void TCPStore::multiSet(const std::vector &keys, bool TCPStore::hasExtendedApi() const { return true; } +void TCPStore::set(const std::string &key, const std::vector &data) { + std::vector dataSet(data.begin(), data.end()); + setTCP(key, dataSet); +} + +std::vector TCPStore::get(const std::string &key) { + wait({key}); + std::vector dataUint8Get = getTCP(key); + std::vector dataCharGet(dataUint8Get.begin(), dataUint8Get.end()); + return dataCharGet; +} + } // namespace xoscar diff --git a/python/setup.py b/python/setup.py index e619ed2f..b52760b4 100644 --- a/python/setup.py +++ b/python/setup.py @@ -146,7 +146,7 @@ def build_long_description(): # A CMakeExtension needs a sourcedir instead of a file list. # The name must be the _single_ output extension from the CMake build. # If you need multiple extensions, see scikit-build. -class XoscarStoreExtension(Extension): +class XoscarCmakeExtension(Extension): def __init__(self, name: str, sourcedir: str = "") -> None: super().__init__(name, sources=[]) self.sourcedir = os.fspath(Path(sourcedir).resolve()) @@ -156,7 +156,7 @@ class CMakeBuild(build_ext): def copy_extensions_to_source(self): build_py = self.get_finalized_command('build_py') for ext in self.extensions: - if not isinstance(ext, XoscarStoreExtension): + if not isinstance(ext, XoscarCmakeExtension): fullname = self.get_ext_fullname(ext.name) filename = self.get_ext_filename(fullname) modpath = fullname.split('.') @@ -179,9 +179,9 @@ def copy_extensions_to_source(self): def build_extension(self, ext): # TODO: support windows compilation is_windows = sys.platform.startswith('win') - if isinstance(ext, XoscarStoreExtension) and not is_windows: - self.build_store(ext) - elif isinstance(ext, XoscarStoreExtension) and is_windows: + if isinstance(ext, XoscarCmakeExtension) and not is_windows: + self.build_Cmake(ext) + elif isinstance(ext, XoscarCmakeExtension) and is_windows: pass else: ext._convert_pyx_sources_to_lang() @@ -196,12 +196,15 @@ def build_extension(self, ext): finally: self.compiler = _compiler - def build_store(self, ext: XoscarStoreExtension) -> None: + def build_Cmake(self, ext: XoscarCmakeExtension) -> None: # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) extdir = ext_fullpath.parent.resolve() source_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - output_directory = Path(source_dir) / "python" / "xoscar" / "collective" / "rendezvous" + output_directory_collective = Path(source_dir) / "python" / "xoscar" / "collective" + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) # Using this requires trailing slash for auto-detection & inclusion of # auxiliary "native" libs @@ -217,7 +220,8 @@ def build_store(self, ext: XoscarStoreExtension) -> None: # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code # from Python. cmake_args = [ - f"-DLIBRARY_OUTPUT_DIRECTORY={output_directory}", + f"-DBUILD_TMP_DIR={build_temp}", + f"-DLIBRARY_OUTPUT_DIRECTORY={output_directory_collective}", f"-DPYTHON_PATH={sys.executable}", f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm ] @@ -280,9 +284,7 @@ def build_store(self, ext: XoscarStoreExtension) -> None: # CMake 3.12+ only. build_args += [f"-j{self.parallel}"] - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) + subprocess.run( ["cmake", source_dir, *cmake_args], cwd=build_temp, check=True @@ -294,7 +296,7 @@ def build_store(self, ext: XoscarStoreExtension) -> None: setup_options = dict( version=versioneer.get_version(), - ext_modules=extensions + [XoscarStoreExtension("xoscar_store")], + ext_modules=extensions + [XoscarCmakeExtension("xoscar_cmake")], cmdclass={"build_ext": CMakeBuild}, long_description=build_long_description(), long_description_content_type="text/markdown", diff --git a/python/xoscar/collective/rendezvous/test/__init__.py b/python/xoscar/collective/rendezvous/test/__init__.py deleted file mode 100644 index 37f6558d..00000000 --- a/python/xoscar/collective/rendezvous/test/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2022-2023 XProbe Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/xoscar/collective/rendezvous/test/test_tcp_store.py b/python/xoscar/collective/rendezvous/test/test_tcp_store.py deleted file mode 100644 index c23df536..00000000 --- a/python/xoscar/collective/rendezvous/test/test_tcp_store.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2022-2023 XProbe Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import multiprocessing - -import pytest - -from ....tests.core import require_unix - - -@require_unix -def test_tcp_store_options(): - from .. import xoscar_store as xs - - opt = xs.TCPStoreOptions() - assert opt.numWorkers is None - assert opt.isServer is False - - opt.numWorkers = 2 - assert opt.numWorkers == 2 - - with pytest.raises(TypeError): - opt.numWorkers = [5] - - -def server(): - from .. import xoscar_store as xs - - opt = xs.TCPStoreOptions() - opt.port = 25001 - opt.numWorkers = 2 - opt.isServer = True - - store = xs.TCPStore("127.0.0.1", opt) - val = store.get("test_key") - assert val == b"test_12345" - - -def worker(): - from .. import xoscar_store as xs - - opt = xs.TCPStoreOptions() - opt.port = 25001 - opt.numWorkers = 2 - opt.isServer = False - - store = xs.TCPStore("127.0.0.1", opt) - store.set("test_key", b"test_12345") - - -@require_unix -def test_tcp_store(): - process1 = multiprocessing.Process(target=server) - process1.start() - process2 = multiprocessing.Process(target=worker) - process2.start() - - process1.join() - process2.join() diff --git a/python/xoscar/collective/rendezvous/xoscar_store.pyi b/python/xoscar/collective/rendezvous/xoscar_store.pyi deleted file mode 100644 index ecfaa103..00000000 --- a/python/xoscar/collective/rendezvous/xoscar_store.pyi +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2022-2023 XProbe Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import datetime -from typing import List, Optional - -class TCPStoreOptions: - port: int - isServer: bool - numWorkers: Optional[int] - waitWorkers: bool - timeout: datetime.timedelta - multiTenant: bool - -class TCPStore: - def __init__(self, host: str, opts: TCPStoreOptions = TCPStoreOptions()): ... - def set(self, key: str, value: bytes): ... - def get(self, key: str) -> bytes: ... - def wait(self, keys: List[str]): ... diff --git a/python/xoscar/collective/rendezvous/__init__.py b/python/xoscar/collective/tests/__init__.py similarity index 100% rename from python/xoscar/collective/rendezvous/__init__.py rename to python/xoscar/collective/tests/__init__.py diff --git a/python/xoscar/collective/tests/test_pygloo.py b/python/xoscar/collective/tests/test_pygloo.py new file mode 100644 index 00000000..5af389d0 --- /dev/null +++ b/python/xoscar/collective/tests/test_pygloo.py @@ -0,0 +1,566 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import platform +import tempfile + +import numpy as np + +from ...tests.core import require_linux, require_unix + +system_name = platform.system() + + +def worker_allgather(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(2), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros([2] + list(sendbuf.shape), dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + assert sendbuf.size * 2 == recvbuf.size + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + + xp.allgather(context, sendptr, recvptr, data_size, datatype) + + np.testing.assert_array_equal(recvbuf, np.array([sendbuf] * 2)) + + +@require_unix +def test_allgather(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_allgather, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_allgather, args=(1, temp_dir)) + process2.start() + + process1.join() + process2.join() + + +def worker_allreduce(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(2), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + algorithm = xp.AllreduceAlgorithm.RING + + xp.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) + + np.testing.assert_array_equal(recvbuf, np.array(sendbuf * 2)) + + +@require_unix +def test_allreduce(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_allreduce, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_allreduce, args=(1, temp_dir)) + process2.start() + + process1.join() + process2.join() + + +def worker_barrier(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(2), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + algorithm = xp.AllreduceAlgorithm.RING + + xp.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) + xp.barrier(context) + + np.testing.assert_array_equal(recvbuf, np.array(sendbuf * 2)) + + +@require_unix +def test_barrier(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_barrier, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_barrier, args=(1, temp_dir)) + process2.start() + + process1.join() + process2.join() + + +def worker_broadcast(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(2), fileStore) + + context.connectFullMesh(store, dev) + + if rank == 0: + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + sendptr = sendbuf.ctypes.data + else: + sendbuf = np.zeros((2, 3), dtype=np.float32) + sendptr = -1 + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + root = 0 + + xp.broadcast(context, sendptr, recvptr, data_size, datatype, root) + + np.testing.assert_array_equal( + recvbuf, np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + ) + ## example output + # (pid=36435) rank 1 sends [[0. 0. 0.] + # (pid=36435) [0. 0. 0.]], receives [[1. 2. 3.] + # (pid=36435) [1. 2. 3.]] + # (pid=36432) rank 0 sends [[1. 2. 3.] + # (pid=36432) [1. 2. 3.]], receives [[1. 2. 3.] + # (pid=36432) [1. 2. 3.]] + + +@require_unix +def test_broadcast(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_broadcast, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_broadcast, args=(1, temp_dir)) + process2.start() + + process1.join() + process2.join() + + +def worker_gather(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(3), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([rank, rank + 1], dtype=np.float32) + sendptr = sendbuf.ctypes.data + + recvbuf = np.zeros((1, 3 * 2), dtype=np.float32) + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + + xp.gather(context, sendptr, recvptr, data_size, datatype, root=0) + + if rank == 0: + np.testing.assert_array_equal( + recvbuf, np.array([[0.0, 1.0, 1.0, 2.0, 2.0, 3.0]]) + ) + ## example output + # (pid=23172) rank 2 sends [2. 3.], receives [[0. 0. 0. 0. 0. 0.]] + # (pid=23171) rank 1 sends [1. 2.], receives [[0. 0. 0. 0. 0. 0.]] + # (pid=23173) rank 0 sends [0. 1.], receives [[0. 1. 1. 2. 2. 3.]] + + +@require_unix +def test_gather(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_gather, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_gather, args=(1, temp_dir)) + process2.start() + process3 = mp.Process(target=worker_gather, args=(2, temp_dir)) + process3.start() + + process1.join() + process2.join() + process3.join() + + +def worker_reduce_scatter(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(3), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = np.array( + [i + 1 for i in range(sum([j + 1 for j in range(3)]))], dtype=np.float32 + ) + sendptr = sendbuf.ctypes.data + + recvbuf = np.zeros((rank + 1,), dtype=np.float32) + recvptr = recvbuf.ctypes.data + recvElems = [i + 1 for i in range(3)] + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + + xp.reduce_scatter(context, sendptr, recvptr, data_size, recvElems, datatype, op) + + if rank == 0: + np.testing.assert_array_equal( + recvbuf, + np.array( + [ + 3.0, + ] + ), + ) + elif rank == 1: + np.testing.assert_array_equal(recvbuf, np.array([6.0, 9.0])) + else: + np.testing.assert_array_equal(recvbuf, np.array([12.0, 15.0, 18.0])) + + +@require_linux +def test_reduce_scatter(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_reduce_scatter, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_reduce_scatter, args=(1, temp_dir)) + process2.start() + process3 = mp.Process(target=worker_reduce_scatter, args=(2, temp_dir)) + process3.start() + + process1.join() + process2.join() + process3.join() + + +def worker_reduce(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(3), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + root = 0 + + xp.reduce(context, sendptr, recvptr, data_size, datatype, op, root) + + if rank == 0: + np.testing.assert_array_equal( + recvbuf, + np.array( + [ + [ + 3.0, + 6.0, + 9.0, + ], + [3.0, 6.0, 9.0], + ] + ), + ) + + +@require_unix +def test_reduce(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_reduce, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_reduce, args=(1, temp_dir)) + process2.start() + process3 = mp.Process(target=worker_reduce, args=(2, temp_dir)) + process3.start() + + process1.join() + process2.join() + process3.join() + + +def worker_scatter(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(2), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = [np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)] * 2 + recvbuf = np.zeros((2, 3), dtype=np.float32) + sendptr = [] + for i in sendbuf: + sendptr.append(i.ctypes.data) + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf[0].size + if isinstance(sendbuf[0], np.ndarray) + else sendbuf[0].numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + root = 0 + + xp.scatter(context, sendptr, recvptr, data_size, datatype, root) + + np.testing.assert_array_equal(recvbuf, np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])) + ## example output, root is 0. + # (pid=18951) rank 1 sends [array([[1., 2., 3.], + # (pid=18951) [1., 2., 3.]], dtype=float32), array([[1., 2., 3.], + # (pid=18951) [1., 2., 3.]], dtype=float32)], receives [[1. 2. 3.] + # (pid=18951) [1. 2. 3.]] + # (pid=18952) rank 0 sends [array([[1., 2., 3.], + # (pid=18952) [1., 2., 3.]], dtype=float32), array([[1., 2., 3.], + # (pid=18952) [1., 2., 3.]], dtype=float32)], receives [[1. 2. 3.] + # (pid=18952) [1. 2. 3.]] + + +@require_unix +def test_scatter(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_scatter, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_scatter, args=(1, temp_dir)) + process2.start() + + process1.join() + process2.join() + + +def worker_send_recv(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(2), fileStore) + + context.connectFullMesh(store, dev) + + if rank == 0: + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + sendptr = sendbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + peer = 1 + xp.send(context, sendptr, data_size, datatype, peer) + + elif rank == 1: + recvbuf = np.zeros((2, 3), dtype=np.float32) + recvptr = recvbuf.ctypes.data + + data_size = ( + recvbuf.size if isinstance(recvbuf, np.ndarray) else recvbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + peer = 0 + + xp.recv(context, recvptr, data_size, datatype, peer) + np.testing.assert_array_equal( + recvbuf, np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + ) + else: + raise Exception( + "Only support 2 process to test send function and recv function" + ) + + +@require_unix +def test_send_recv(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_send_recv, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_send_recv, args=(1, temp_dir)) + process2.start() + + process1.join() + process2.join() + + +def worker_all_to_all(rank, fileStore_path): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + fileStore = xp.rendezvous.FileStore(fileStore_path) + store = xp.rendezvous.PrefixStore(str(3), fileStore) + + context.connectFullMesh(store, dev) + + sendbuf = np.zeros((6,), dtype=np.float32) + rank + recvbuf = np.zeros(sendbuf.shape, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + + xp.all_to_all(context, sendptr, recvptr, data_size, datatype) + + np.testing.assert_array_equal(recvbuf, np.array([0.0, 0.0, 1.0, 1.0, 2.0, 2.0])) + + +@require_unix +def test_all_to_all(): + with tempfile.TemporaryDirectory(prefix="collective") as temp_dir: + process1 = mp.Process(target=worker_all_to_all, args=(0, temp_dir)) + process1.start() + process2 = mp.Process(target=worker_all_to_all, args=(1, temp_dir)) + process2.start() + process3 = mp.Process(target=worker_all_to_all, args=(2, temp_dir)) + process3.start() + + process1.join() + process2.join() + process3.join() diff --git a/python/xoscar/collective/tests/test_pygloo_tcp_store.py b/python/xoscar/collective/tests/test_pygloo_tcp_store.py new file mode 100644 index 00000000..1465042d --- /dev/null +++ b/python/xoscar/collective/tests/test_pygloo_tcp_store.py @@ -0,0 +1,636 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import platform + +import numpy as np + +from ...tests.core import require_linux, require_unix + +system_name = platform.system() + + +def worker_allgather(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(2), store) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros([2] + list(sendbuf.shape), dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + assert sendbuf.size * 2 == recvbuf.size + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + + xp.allgather(context, sendptr, recvptr, data_size, datatype) + + np.testing.assert_array_equal(recvbuf, np.array([sendbuf] * 2)) + + +@require_unix +def test_allgather(): + process1 = mp.Process(target=worker_allgather, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_allgather, args=(1,)) + process2.start() + + process1.join() + process2.join() + + +def worker_allreduce(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(2), store) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + algorithm = xp.AllreduceAlgorithm.RING + + xp.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) + + np.testing.assert_array_equal(recvbuf, np.array(sendbuf * 2)) + + +@require_unix +def test_allreduce(): + process1 = mp.Process(target=worker_allreduce, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_allreduce, args=(1,)) + process2.start() + + process1.join() + process2.join() + + +def worker_barrier(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(2), store) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + algorithm = xp.AllreduceAlgorithm.RING + + xp.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) + xp.barrier(context) + + np.testing.assert_array_equal(recvbuf, sendbuf * 2) + + +@require_unix +def test_barrier(): + process1 = mp.Process(target=worker_barrier, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_barrier, args=(1,)) + process2.start() + + process1.join() + process2.join() + + +def worker_broadcast(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(2), store) + + context.connectFullMesh(store, dev) + + if rank == 0: + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + sendptr = sendbuf.ctypes.data + else: + sendbuf = np.zeros((2, 3), dtype=np.float32) + sendptr = -1 + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + root = 0 + + xp.broadcast(context, sendptr, recvptr, data_size, datatype, root) + + np.testing.assert_array_equal( + recvbuf, np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + ) + ## example output + # (pid=36435) rank 1 sends [[0. 0. 0.] + # (pid=36435) [0. 0. 0.]], receives [[1. 2. 3.] + # (pid=36435) [1. 2. 3.]] + # (pid=36432) rank 0 sends [[1. 2. 3.] + # (pid=36432) [1. 2. 3.]], receives [[1. 2. 3.] + # (pid=36432) [1. 2. 3.]] + + +@require_unix +def test_broadcast(): + process1 = mp.Process(target=worker_broadcast, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_broadcast, args=(1,)) + process2.start() + + process1.join() + process2.join() + + +def worker_gather(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 3 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(3), store) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([rank, rank + 1], dtype=np.float32) + sendptr = sendbuf.ctypes.data + + recvbuf = np.zeros((1, 3 * 2), dtype=np.float32) + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + + xp.gather(context, sendptr, recvptr, data_size, datatype, root=0) + + if rank == 0: + np.testing.assert_array_equal( + recvbuf, np.array([[0.0, 1.0, 1.0, 2.0, 2.0, 3.0]]) + ) + + ## example output + # (pid=23172) rank 2 sends [2. 3.], receives [[0. 0. 0. 0. 0. 0.]] + # (pid=23171) rank 1 sends [1. 2.], receives [[0. 0. 0. 0. 0. 0.]] + # (pid=23173) rank 0 sends [0. 1.], receives [[0. 1. 1. 2. 2. 3.]] + + +@require_unix +def test_gather(): + process1 = mp.Process(target=worker_gather, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_gather, args=(1,)) + process2.start() + process3 = mp.Process(target=worker_gather, args=(2,)) + process3.start() + + process1.join() + process2.join() + process3.join() + + +def worker_reduce_scatter(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 3 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(3), store) + + context.connectFullMesh(store, dev) + + sendbuf = np.array( + [i + 1 for i in range(sum([j + 1 for j in range(3)]))], dtype=np.float32 + ) + sendptr = sendbuf.ctypes.data + + recvbuf = np.zeros((rank + 1,), dtype=np.float32) + recvptr = recvbuf.ctypes.data + recvElems = [i + 1 for i in range(3)] + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + + xp.reduce_scatter(context, sendptr, recvptr, data_size, recvElems, datatype, op) + + # print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") + if rank == 0: + np.testing.assert_array_equal( + recvbuf, + np.array( + [ + 3.0, + ] + ), + ) + elif rank == 1: + np.testing.assert_array_equal(recvbuf, np.array([6.0, 9.0])) + else: + np.testing.assert_array_equal(recvbuf, np.array([12.0, 15.0, 18.0])) + + +@require_linux +def test_reduce_scatter(): + process1 = mp.Process(target=worker_reduce_scatter, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_reduce_scatter, args=(1,)) + process2.start() + process3 = mp.Process(target=worker_reduce_scatter, args=(2,)) + process3.start() + + process1.join() + process2.join() + process3.join() + + +def worker_reduce(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 3 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(3), store) + + context.connectFullMesh(store, dev) + + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + recvbuf = np.zeros_like(sendbuf, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + op = xp.ReduceOp.SUM + root = 0 + + xp.reduce(context, sendptr, recvptr, data_size, datatype, op, root) + + if rank == 0: + np.testing.assert_array_equal( + recvbuf, + np.array( + [ + [ + 3.0, + 6.0, + 9.0, + ], + [3.0, 6.0, 9.0], + ] + ), + ) + + +@require_unix +def test_reduce(): + process1 = mp.Process(target=worker_reduce, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_reduce, args=(1,)) + process2.start() + process3 = mp.Process(target=worker_reduce, args=(2,)) + process3.start() + + process1.join() + process2.join() + process3.join() + + +def worker_scatter(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(2), store) + + context.connectFullMesh(store, dev) + + sendbuf = [np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)] * 2 + recvbuf = np.zeros((2, 3), dtype=np.float32) + sendptr = [] + for i in sendbuf: + sendptr.append(i.ctypes.data) + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf[0].size + if isinstance(sendbuf[0], np.ndarray) + else sendbuf[0].numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + root = 0 + + xp.scatter(context, sendptr, recvptr, data_size, datatype, root) + + np.testing.assert_array_equal(recvbuf, np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])) + ## example output, root is 0. + # (pid=18951) rank 1 sends [array([[1., 2., 3.], + # (pid=18951) [1., 2., 3.]], dtype=float32), array([[1., 2., 3.], + # (pid=18951) [1., 2., 3.]], dtype=float32)], receives [[1. 2. 3.] + # (pid=18951) [1. 2. 3.]] + # (pid=18952) rank 0 sends [array([[1., 2., 3.], + # (pid=18952) [1., 2., 3.]], dtype=float32), array([[1., 2., 3.], + # (pid=18952) [1., 2., 3.]], dtype=float32)], receives [[1. 2. 3.] + # (pid=18952) [1. 2. 3.]] + + +@require_unix +def test_scatter(): + process1 = mp.Process(target=worker_scatter, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_scatter, args=(1,)) + process2.start() + + process1.join() + process2.join() + + +def worker_send_recv(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 2) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(2), store) + + context.connectFullMesh(store, dev) + + if rank == 0: + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + sendptr = sendbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + peer = 1 + xp.send(context, sendptr, data_size, datatype, peer) + + elif rank == 1: + recvbuf = np.zeros((2, 3), dtype=np.float32) + recvptr = recvbuf.ctypes.data + + data_size = ( + recvbuf.size if isinstance(recvbuf, np.ndarray) else recvbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + peer = 0 + + xp.recv(context, recvptr, data_size, datatype, peer) + np.testing.assert_array_equal( + recvbuf, np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + ) + else: + raise Exception( + "Only support 2 process to test send function and recv function" + ) + ## example output + + +@require_unix +def test_send_recv(): + process1 = mp.Process(target=worker_send_recv, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_send_recv, args=(1,)) + process2.start() + + process1.join() + process2.join() + + +def worker_all_to_all(rank): + from .. import xoscar_pygloo as xp + + context = xp.rendezvous.Context(rank, 3) + + if system_name == "Linux": + attr = xp.transport.tcp.attr("localhost") + dev = xp.transport.tcp.CreateDevice(attr) + else: + attr = xp.transport.uv.attr("localhost") + dev = xp.transport.uv.CreateDevice(attr) + opt = xp.rendezvous.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 3 + if rank == 0: + opt.isServer = True + else: + opt.isServer = False + + store = xp.rendezvous.TCPStore("127.0.0.1", opt) + store = xp.rendezvous.PrefixStore(str(2), store) + + context.connectFullMesh(store, dev) + + sendbuf = np.zeros((6,), dtype=np.float32) + rank + recvbuf = np.zeros(sendbuf.shape, dtype=np.float32) + sendptr = sendbuf.ctypes.data + recvptr = recvbuf.ctypes.data + + data_size = ( + sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size + ) + datatype = xp.GlooDataType_t.glooFloat32 + + xp.all_to_all(context, sendptr, recvptr, data_size, datatype) + + np.testing.assert_array_equal(recvbuf, np.array([0.0, 0.0, 1.0, 1.0, 2.0, 2.0])) + + +@require_unix +def test_all_to_all(): + process1 = mp.Process(target=worker_all_to_all, args=(0,)) + process1.start() + process2 = mp.Process(target=worker_all_to_all, args=(1,)) + process2.start() + process3 = mp.Process(target=worker_all_to_all, args=(2,)) + process3.start() + + process1.join() + process2.join() + process3.join() diff --git a/python/xoscar/collective/xoscar_pygloo.pyi b/python/xoscar/collective/xoscar_pygloo.pyi new file mode 100644 index 00000000..7fb52c92 --- /dev/null +++ b/python/xoscar/collective/xoscar_pygloo.pyi @@ -0,0 +1,239 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from ctypes import c_void_p +from enum import IntEnum +from typing import Callable, List, Optional + +import xoscar_pygloo + +class ReduceOp(IntEnum): + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + BAND = 4 + BOR = 5 + BXOR = 6 + UNUSED = 7 + +class GlooDataType_t(IntEnum): + glooInt8 = 0 + glooUint8 = 1 + glooInt32 = 2 + glooUint32 = 3 + glooInt64 = 4 + glooUint64 = 5 + glooFloat16 = 6 + glooFloat32 = 7 + glooFloat64 = 8 + +class AllreduceAlgorithm(IntEnum): + UNSPECIFIED = 0 + RING = 1 + BCUBE = 2 + +def transport_tcp_available() -> bool: ... +def transport_uv_available() -> bool: ... + +class Context: + rank: Optional[int] = None + size: Optional[int] = None + base: Optional[int] = 2 + def getDevice(self) -> int: ... + def createUnboundBuffer(self, ptr: c_void_p, size: int): ... + def nextSlot(self, numToSkip: int) -> int: ... + def closeConnections(self) -> None: ... + def setTimeout(self, timeout: datetime.timedelta) -> None: ... + def getTimeout(self) -> datetime.timedelta: ... + +def allreduce( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + reduceop: Optional[ReduceOp] = ReduceOp.SUM, + algorithm: Optional[AllreduceAlgorithm] = AllreduceAlgorithm.RING, + tag: int = 0, +) -> None: ... +def allgather( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + tag: Optional[int] = 0, +) -> None: ... +def all_to_all( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + tag: Optional[int] = 0, +) -> None: ... +def allgatherv( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + tag: Optional[int] = 0, +) -> None: ... +def reduce( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + reduceop: Optional[ReduceOp] = ReduceOp.SUM, + root: Optional[int] = 0, + tag: Optional[int] = 0, +) -> None: ... +def scatter( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + root: Optional[int] = 0, + tag: Optional[int] = 0, +) -> None: ... +def gather( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + root: Optional[int] = 0, + tag: Optional[int] = 0, +) -> None: ... +def send( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + peer: Optional[int] = None, + tag: Optional[int] = 0, +) -> None: ... +def recv( + context: Optional[Context] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + peer: Optional[int] = None, + tag: Optional[int] = 0, +) -> None: ... +def broadcast( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + datatype: Optional[GlooDataType_t] = None, + root: Optional[int] = 0, + tag: Optional[int] = 0, +) -> None: ... +def reduce_scatter( + context: Optional[Context] = None, + sendbuf: Optional[int] = None, + recvbuf: Optional[int] = None, + size: Optional[int] = None, + recvElems: Optional[List[int]] = None, + datatype: Optional[GlooDataType_t] = None, + reduceop: Optional[ReduceOp] = ReduceOp.SUM, +) -> None: ... +def barrier(context: Optional[Context] = None, tag: Optional[int] = 0) -> None: ... + +class rendezvous: + class Store: + def set(self, key: str, data: List[str]) -> None: ... + def get(self, key: str) -> str: ... + + class TCPStoreOptions: + port: int + isServer: bool + numWorkers: Optional[int] + waitWorkers: bool + timeout: datetime.timedelta + multiTenant: bool + + class TCPStore: + def __init__( + self, + host: str, + opts: rendezvous.TCPStoreOptions = rendezvous.TCPStoreOptions(), + ): ... + def set(self, key: str, value: bytes): ... + def get(self, key: str) -> bytes: ... + def wait(self, keys: List[str]): ... + + class FileStore(Store): + def __init__(self, path: str) -> None: ... + def set(self, key: str, data: List[str]) -> None: ... + def get(self, key: str) -> str: ... + + class HashStore(Store): + def __init__(self) -> None: ... + def set(self, key: str, data: List[str]) -> None: ... + def get(self, key: str) -> str: ... + + class PrefixStore(Store): + def __init__(self, prefix: str, store: rendezvous.Store) -> None: ... + def set(self, key: str, data: List[str]) -> None: ... + def get(self, key: str) -> str: ... + + class CustomStore(Store): + def __init__(self, real_store_py_object: object) -> None: ... + def delKeys(self, keys: List[str]) -> None: ... + def set(self, key: str, data: List[str]) -> None: ... + def get(self, key: str) -> str: ... + + class Context(xoscar_pygloo.Context): + def connectFullMesh( + self, store: rendezvous.Store, dev: transport.Device + ) -> None: ... + +class transport: + class uv: + pass + + class tcp: + class Device(transport.Device): + def __init__(self, attr: transport.tcp.attr) -> None: ... + + class Context(xoscar_pygloo.Context): + def __init__( + self, device: transport.tcp.Device, rank: int, size: int + ) -> None: ... + + class attr: + hostname: str + iface: str + ai_family: int + ai_socktype: int + ai_protocol: int + ai_addr: object + ai_addrlen: int + def __init__(self, string: Optional[str] = None) -> None: ... + + def CreateDevice(self, src: transport.tcp.attr) -> transport.Device: ... + + class Device: + def __str__(self) -> str: ... + def getPCIBusID(self) -> Callable[[], str]: ... + def getInterfaceSpeed(self) -> int: ... + def hasGPUDirect(self) -> bool: ... + def createContext(self, rank: int, size: int): ... diff --git a/python/xoscar/tests/core.py b/python/xoscar/tests/core.py index 3bcdff39..538fa7f7 100644 --- a/python/xoscar/tests/core.py +++ b/python/xoscar/tests/core.py @@ -17,7 +17,7 @@ import pytest -from ..utils import is_windows, lazy_import +from ..utils import is_linux, is_windows, lazy_import cupy = lazy_import("cupy") cudf = lazy_import("cudf") @@ -53,6 +53,14 @@ def require_unix(func): return func +def require_linux(func): + if pytest: + func = pytest.mark.linux(func) + + func = pytest.mark.skipif(not is_linux(), reason="only linux is supported")(func) + return func + + DICT_NOT_EMPTY = type("DICT_NOT_EMPTY", (object,), {}) # is check works for deepcopy diff --git a/python/xoscar/utils.py b/python/xoscar/utils.py index 3a5f3a54..04e19ec0 100644 --- a/python/xoscar/utils.py +++ b/python/xoscar/utils.py @@ -458,3 +458,7 @@ def is_cuda_buffer(cuda_buffer: Union["_cupy.ndarray", "_rmm.DeviceBuffer"]) -> def is_windows(): return sys.platform.startswith("win") + + +def is_linux(): + return sys.platform.startswith("linux") diff --git a/third_party/gloo b/third_party/gloo new file mode 160000 index 00000000..c6f3a5bc --- /dev/null +++ b/third_party/gloo @@ -0,0 +1 @@ +Subproject commit c6f3a5bcf568dafc9a8ae482e8cc900633dd6db1