From 2295e3380c52b08f4554bb54ef24cee65bc6fcac Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 27 Jun 2023 15:40:46 +0800 Subject: [PATCH 01/19] init --- .clang-format | 21 + .cmake-format.yaml | 73 + .gitignore | 7 + .gitmodules | 6 + .pre-commit-config.yaml | 28 +- CMakeLists.txt | 16 + CPPLINT.cfg | 1 + cpp/CMakeLists.txt | 13 + cpp/collective/rendezvous/CMakeLists.txt | 25 + cpp/collective/rendezvous/include/error.h | 68 + cpp/collective/rendezvous/include/exception.h | 51 + cpp/collective/rendezvous/include/socket.h | 108 ++ cpp/collective/rendezvous/include/store.hpp | 119 ++ .../rendezvous/include/tcp_store.hpp | 136 ++ .../rendezvous/include/unix_sock_utils.hpp | 39 + cpp/collective/rendezvous/include/utils.hpp | 205 +++ .../rendezvous/src/bind_tcp_store.cpp | 59 + cpp/collective/rendezvous/src/exception.cpp | 22 + cpp/collective/rendezvous/src/socket.cpp | 919 ++++++++++++ cpp/collective/rendezvous/src/store.cpp | 84 ++ cpp/collective/rendezvous/src/tcp_store.cpp | 1321 +++++++++++++++++ python/setup.py | 44 + python/xoscar/collective/__init__.py | 13 + .../xoscar/collective/rendezvous/__init__.py | 13 + .../collective/rendezvous/test/__init__.py | 13 + .../rendezvous/test/test_tcp_store.py | 61 + .../collective/rendezvous/xoscar_store.pyi | 29 + third_party/fmt | 1 + third_party/pybind11 | 1 + 29 files changed, 3495 insertions(+), 1 deletion(-) create mode 100644 .clang-format create mode 100644 .cmake-format.yaml create mode 100644 .gitmodules create mode 100644 CMakeLists.txt create mode 100644 CPPLINT.cfg create mode 100644 cpp/CMakeLists.txt create mode 100644 cpp/collective/rendezvous/CMakeLists.txt create mode 100644 cpp/collective/rendezvous/include/error.h create mode 100644 cpp/collective/rendezvous/include/exception.h create mode 100644 cpp/collective/rendezvous/include/socket.h create mode 100644 cpp/collective/rendezvous/include/store.hpp create mode 100644 cpp/collective/rendezvous/include/tcp_store.hpp create mode 100644 cpp/collective/rendezvous/include/unix_sock_utils.hpp create mode 100644 cpp/collective/rendezvous/include/utils.hpp create mode 100644 cpp/collective/rendezvous/src/bind_tcp_store.cpp create mode 100644 cpp/collective/rendezvous/src/exception.cpp create mode 100644 cpp/collective/rendezvous/src/socket.cpp create mode 100644 cpp/collective/rendezvous/src/store.cpp create mode 100644 cpp/collective/rendezvous/src/tcp_store.cpp create mode 100644 python/xoscar/collective/__init__.py create mode 100644 python/xoscar/collective/rendezvous/__init__.py create mode 100644 python/xoscar/collective/rendezvous/test/__init__.py create mode 100644 python/xoscar/collective/rendezvous/test/test_tcp_store.py create mode 100644 python/xoscar/collective/rendezvous/xoscar_store.pyi create mode 160000 third_party/fmt create mode 160000 third_party/pybind11 diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..e4d4dec4 --- /dev/null +++ b/.clang-format @@ -0,0 +1,21 @@ +# See all possible options and defaults with: +# clang-format --style=llvm --dump-config +BasedOnStyle: LLVM +AccessModifierOffset: -4 +AllowShortLambdasOnASingleLine: Inline +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakBeforeBinaryOperators: All +BreakConstructorInitializers: BeforeColon +ColumnLimit: 80 +SpacesBeforeTrailingComments: 2 +IncludeBlocks: Regroup +IndentCaseLabels: true +IndentPPDirectives: AfterHash +IndentWidth: 4 +Language: Cpp +SpaceAfterCStyleCast: true +Standard: c++20 +StatementMacros: ['PyObject_HEAD'] +TabWidth: 4 diff --git a/.cmake-format.yaml b/.cmake-format.yaml new file mode 100644 index 00000000..fd097e58 --- /dev/null +++ b/.cmake-format.yaml @@ -0,0 +1,73 @@ +parse: + additional_commands: + pybind11_add_module: + flags: + - THIN_LTO + - MODULE + - SHARED + - NO_EXTRAS + - EXCLUDE_FROM_ALL + - SYSTEM + +format: + line_width: 99 + tab_size: 2 + + # If an argument group contains more than this many sub-groups + # (parg or kwarg groups) then force it to a vertical layout. + max_subgroups_hwrap: 2 + + # If a positional argument group contains more than this many + # arguments, then force it to a vertical layout. + max_pargs_hwrap: 6 + + # If a cmdline positional group consumes more than this many + # lines without nesting, then invalidate the layout (and nest) + max_rows_cmdline: 2 + separate_ctrl_name_with_space: false + separate_fn_name_with_space: false + dangle_parens: false + + # If the trailing parenthesis must be 'dangled' on its on + # 'line, then align it to this reference: `prefix`: the start' + # 'of the statement, `prefix-indent`: the start of the' + # 'statement, plus one indentation level, `child`: align to' + # the column of the arguments + dangle_align: prefix + # If the statement spelling length (including space and + # parenthesis) is smaller than this amount, then force reject + # nested layouts. + min_prefix_chars: 4 + + # If the statement spelling length (including space and + # parenthesis) is larger than the tab width by more than this + # amount, then force reject un-nested layouts. + max_prefix_chars: 10 + + # If a candidate layout is wrapped horizontally but it exceeds + # this many lines, then reject the layout. + max_lines_hwrap: 2 + + line_ending: unix + + # Format command names consistently as 'lower' or 'upper' case + command_case: canonical + + # Format keywords consistently as 'lower' or 'upper' case + # unchanged is valid too + keyword_case: 'upper' + + # A list of command names which should always be wrapped + always_wrap: [] + + # If true, the argument lists which are known to be sortable + # will be sorted lexicographically + enable_sort: true + + # If true, the parsers may infer whether or not an argument + # list is sortable (without annotation). + autosort: false + +# Causes a few issues - can be solved later, possibly. +markup: + enable_markup: false \ No newline at end of file diff --git a/.gitignore b/.gitignore index 34bbb8c3..1547d4ef 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,10 @@ dmypy.json # cython compiled files python/xoscar/**/*.c* + +# cmake +cmake-* +CMakeFiles +CMakeCache.txt +*.cmake +Makefile diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..fd30701d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "third_party/fmt"] + path = third_party/fmt + url = https://github.com/fmtlib/fmt.git +[submodule "third_party/pybind11"] + path = third_party/pybind11 + url = https://github.com/pybind/pybind11.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c9b9f1a3..29d27bfb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,32 +1,58 @@ -files: python/xoscar repos: - repo: https://github.com/psf/black rev: 23.1.0 hooks: - id: black + files: python/xoscar - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - id: end-of-file-fixer + files: python/xoscar - id: trailing-whitespace + files: python/xoscar - repo: https://github.com/PyCQA/flake8 rev: 6.0.0 hooks: - id: flake8 args: [--config, python/setup.cfg] + files: python/xoscar - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort args: [--sp, python/setup.cfg] + files: python/xoscar - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.3.0 hooks: - id: mypy additional_dependencies: [tokenize-rt==3.2.0] args: [--ignore-missing-imports, --follow-imports, skip] + files: python/xoscar - repo: https://github.com/codespell-project/codespell rev: v2.2.2 hooks: - id: codespell args: [ --config, python/setup.cfg] + files: python/xoscar + + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: "v15.0.7" + hooks: + - id: clang-format + files: cpp + + - repo: https://github.com/cheshirekow/cmake-format-precommit + rev: "v0.6.13" + hooks: + - id: cmake-format + additional_dependencies: [ pyyaml ] + types: [ file ] + files: (\.cmake|CMakeLists.txt)(.in)?$ + + - repo: https://github.com/pocc/pre-commit-hooks + rev: v1.3.5 + hooks: + - id: cpplint + files: cpp diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..d6accd7c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.11...3.21) + +project(XoscarCollective) + +if(NOT ${PYTHON_EXECUTABLE}) + find_package(Python COMPONENTS Interpreter Development) +endif() + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") + +include_directories(${CMAKE_SOURCE_DIR}) + +add_subdirectory(third_party/fmt) +add_subdirectory(third_party/pybind11) +add_subdirectory(cpp) diff --git a/CPPLINT.cfg b/CPPLINT.cfg new file mode 100644 index 00000000..dd15eb40 --- /dev/null +++ b/CPPLINT.cfg @@ -0,0 +1 @@ +filter=-build/c++11,-build/include_subdir,-build/include_order,-build/include_what_you_use,-readability/todo,-readability/nolint,-runtime/int,-runtime/references,-whitespace/indent \ No newline at end of file diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt new file mode 100644 index 00000000..569ce1fd --- /dev/null +++ b/cpp/CMakeLists.txt @@ -0,0 +1,13 @@ +cmake_minimum_required(VERSION 3.11...3.21) + +project(XoscarCollective) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") + +include_directories(${CMAKE_SOURCE_DIR}/cpp/collective/rendezvous/include) + +add_subdirectory(collective/rendezvous) + +pybind11_add_module(xoscar_store collective/rendezvous/src/bind_tcp_store.cpp) +target_link_libraries(xoscar_store PRIVATE StoreLib fmt::fmt) +set_target_properties(xoscar_store PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_DIRECTORY}) diff --git a/cpp/collective/rendezvous/CMakeLists.txt b/cpp/collective/rendezvous/CMakeLists.txt new file mode 100644 index 00000000..2757139d --- /dev/null +++ b/cpp/collective/rendezvous/CMakeLists.txt @@ -0,0 +1,25 @@ +cmake_minimum_required(VERSION 3.11...3.21) + +project( + XoscarRendezvous + VERSION 0.0.1 + LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) + +include_directories(include) +include_directories(../../../third_party/fmt/include) + +add_library( + StoreLib + include/error.h + include/exception.h + src/exception.cpp + include/socket.h + src/socket.cpp + include/store.hpp + src/store.cpp + include/tcp_store.hpp + src/tcp_store.cpp + include/unix_sock_utils.hpp + include/utils.hpp) diff --git a/cpp/collective/rendezvous/include/error.h b/cpp/collective/rendezvous/include/error.h new file mode 100644 index 00000000..b29ffb14 --- /dev/null +++ b/cpp/collective/rendezvous/include/error.h @@ -0,0 +1,68 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "third_party/fmt/include/fmt/format.h" + +#include +#include + +namespace fmt { + +template <> +struct formatter { + constexpr decltype(auto) parse(format_parse_context &ctx) const { + return ctx.begin(); + } + + template + decltype(auto) format(const std::error_category &cat, + FormatContext &ctx) const { + if (std::strcmp(cat.name(), "generic") == 0) { + return format_to(ctx.out(), "errno"); + } else { + return format_to(ctx.out(), "{} error", cat.name()); + } + } +}; + +template <> +struct formatter { + constexpr decltype(auto) parse(format_parse_context &ctx) const { + return ctx.begin(); + } + + template + decltype(auto) format(const std::error_code &err, + FormatContext &ctx) const { + return format_to(ctx.out(), + "({}: {} - {})", + err.category(), + err.value(), + err.message()); + } +}; + +} // namespace fmt + +namespace xoscar { +namespace detail { + +inline std::error_code lastError() noexcept { + return std::error_code{errno, std::generic_category()}; +} + +} // namespace detail +} // namespace xoscar diff --git a/cpp/collective/rendezvous/include/exception.h b/cpp/collective/rendezvous/include/exception.h new file mode 100644 index 00000000..0c8bc01d --- /dev/null +++ b/cpp/collective/rendezvous/include/exception.h @@ -0,0 +1,51 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace xoscar { + +class XoscarError : public std::runtime_error { +public: + using std::runtime_error::runtime_error; + + XoscarError(const XoscarError &) = default; + + XoscarError &operator=(const XoscarError &) = default; + + XoscarError(XoscarError &&) = default; + + XoscarError &operator=(XoscarError &&) = default; + + ~XoscarError() override; +}; + +class TimeoutError : public XoscarError { +public: + using XoscarError::XoscarError; + + TimeoutError(const TimeoutError &) = default; + + TimeoutError &operator=(const TimeoutError &) = default; + + TimeoutError(TimeoutError &&) = default; + + TimeoutError &operator=(TimeoutError &&) = default; + + ~TimeoutError() override; +}; + +} // namespace xoscar diff --git a/cpp/collective/rendezvous/include/socket.h b/cpp/collective/rendezvous/include/socket.h new file mode 100644 index 00000000..e39c1d81 --- /dev/null +++ b/cpp/collective/rendezvous/include/socket.h @@ -0,0 +1,108 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "exception.h" + +#include +#include +#include +#include +#include + +namespace xoscar { +namespace detail { + +class SocketOptions { +public: + SocketOptions &prefer_ipv6(bool value) noexcept { + prefer_ipv6_ = value; + + return *this; + } + + bool prefer_ipv6() const noexcept { return prefer_ipv6_; } + + SocketOptions &connect_timeout(std::chrono::seconds value) noexcept { + connect_timeout_ = value; + + return *this; + } + + std::chrono::seconds connect_timeout() const noexcept { + return connect_timeout_; + } + +private: + bool prefer_ipv6_ = true; + std::chrono::seconds connect_timeout_{30}; +}; + +class SocketImpl; + +class Socket { +public: + // This function initializes the underlying socket library and must be + // called before any other socket function. + static void initialize(); + + static Socket listen(std::uint16_t port, const SocketOptions &opts = {}); + + static Socket connect(const std::string &host, + std::uint16_t port, + const SocketOptions &opts = {}); + + Socket() noexcept = default; + + Socket(const Socket &other) = delete; + + Socket &operator=(const Socket &other) = delete; + + Socket(Socket &&other) noexcept; + + Socket &operator=(Socket &&other) noexcept; + + ~Socket(); + + Socket accept() const; + + int handle() const noexcept; + + std::uint16_t port() const; + +private: + explicit Socket(std::unique_ptr &&impl) noexcept; + + std::unique_ptr impl_; +}; + +} // namespace detail + +class SocketError : public XoscarError { +public: + using XoscarError::XoscarError; + + SocketError(const SocketError &) = default; + + SocketError &operator=(const SocketError &) = default; + + SocketError(SocketError &&) = default; + + SocketError &operator=(SocketError &&) = default; + + ~SocketError() override; +}; + +} // namespace xoscar diff --git a/cpp/collective/rendezvous/include/store.hpp b/cpp/collective/rendezvous/include/store.hpp new file mode 100644 index 00000000..be4e943b --- /dev/null +++ b/cpp/collective/rendezvous/include/store.hpp @@ -0,0 +1,119 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xoscar { + +// callback function will be given arguments (optional oldValue, +// optional newValue) +using WatchKeyCallback = std::function, + std::optional)>; + +class Store { +public: + static constexpr std::chrono::milliseconds kDefaultTimeout + = std::chrono::seconds(300); + static constexpr std::chrono::milliseconds kNoTimeout + = std::chrono::milliseconds::zero(); + + Store() : timeout_(kDefaultTimeout) {} + + explicit Store(const std::chrono::milliseconds &timeout) + : timeout_(timeout) {} + + ~Store(); + + void set(const std::string &key, const std::string &value); + + virtual void set(const std::string &key, const std::vector &value) + = 0; + + std::string compareSet(const std::string &key, + const std::string ¤tValue, + const std::string &newValue); + + virtual std::vector + compareSet(const std::string &key, + const std::vector ¤tValue, + const std::vector &newValue) { + // TORCH_INTERNAL_ASSERT(false, "Not implemented."); + throw std::runtime_error("Not implemented."); + } + + std::string get_to_str(const std::string &key); + + virtual std::vector get(const std::string &key) = 0; + + virtual int64_t add(const std::string &key, int64_t value) = 0; + + virtual bool deleteKey(const std::string &key) = 0; + + virtual bool check(const std::vector &keys) = 0; + + virtual int64_t getNumKeys() = 0; + + virtual void wait(const std::vector &keys) = 0; + + virtual void wait(const std::vector &keys, + const std::chrono::milliseconds &timeout) + = 0; + + virtual const std::chrono::milliseconds &getTimeout() const noexcept; + + virtual void setTimeout(const std::chrono::milliseconds &timeout); + + // watchKey() takes two arguments: key and callback function. The callback + // should be run whenever the key is changed (create, update, or delete). + // The callback function takes two parameters: currentValue and newValue, + // which are optional depending on how the key is changed. These key updates + // should trigger the callback as follows: CREATE: callback(c10::nullopt, + // newValue) // null currentValue UPDATE: callback(currentValue, newValue) + // DELETE: callback(currentValue, c10::nullopt) // null newValue + virtual void watchKey(const std::string & /* unused */, + WatchKeyCallback /* unused */) { + // TORCH_CHECK( + // false, + // "watchKey only implemented for TCPStore and PrefixStore that + // wraps TCPStore."); + throw std::runtime_error("watchKey only implemented for TCPStore and " + "PrefixStore that wraps TCPStore."); + } + + virtual void append(const std::string &key, + const std::vector &value); + + virtual std::vector> + multiGet(const std::vector &keys); + + virtual void multiSet(const std::vector &keys, + const std::vector> &values); + + // Returns true if this store support watchKey, append, multiGet and + // multiSet + virtual bool hasExtendedApi() const; + +protected: + std::chrono::milliseconds timeout_; +}; + +} // namespace xoscar diff --git a/cpp/collective/rendezvous/include/tcp_store.hpp b/cpp/collective/rendezvous/include/tcp_store.hpp new file mode 100644 index 00000000..95b5c205 --- /dev/null +++ b/cpp/collective/rendezvous/include/tcp_store.hpp @@ -0,0 +1,136 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include "store.hpp" + +#include +#include +#include +#include + +namespace xoscar { +namespace detail { + +class TCPServer; + +class TCPClient; + +class TCPCallbackClient; + +struct SocketAddress { + std::string host{}; + std::uint16_t port{}; +}; + +} // namespace detail + +struct TCPStoreOptions { + static constexpr std::uint16_t kDefaultPort = 29500; + + std::uint16_t port = kDefaultPort; + bool isServer = false; + std::optional numWorkers = std::nullopt; + bool waitWorkers = true; + std::chrono::milliseconds timeout = Store::kDefaultTimeout; + + // A boolean value indicating whether multiple store instances can be + // initialized with the same host:port pair. + bool multiTenant = false; +}; + +class TCPStore : public Store { +public: + explicit TCPStore(std::string host, const TCPStoreOptions &opts = {}); + + [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( + const std::string &masterAddr, + std::uint16_t masterPort, + std::optional numWorkers = std::nullopt, + bool isServer = false, + const std::chrono::milliseconds &timeout = kDefaultTimeout, + bool waitWorkers = true); + + ~TCPStore(); + + void set(const std::string &key, + const std::vector &value) override; + + std::vector + compareSet(const std::string &key, + const std::vector &expectedValue, + const std::vector &desiredValue) override; + + std::vector get(const std::string &key) override; + + int64_t add(const std::string &key, int64_t value) override; + + bool deleteKey(const std::string &key) override; + + // NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe + // watchKey() is a blocking operation. It will register the socket on + // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will + // return once it has verified the callback is registered on both background + // threads. Only one thread can call watchKey() at a time. + void watchKey(const std::string &key, WatchKeyCallback callback) override; + + bool check(const std::vector &keys) override; + + int64_t getNumKeys() override; + + void wait(const std::vector &keys) override; + + void wait(const std::vector &keys, + const std::chrono::milliseconds &timeout) override; + + void append(const std::string &key, + const std::vector &value) override; + + std::vector> + multiGet(const std::vector &keys) override; + + void multiSet(const std::vector &keys, + const std::vector> &values) override; + + bool hasExtendedApi() const override; + + // Waits for all workers to join. + void waitForWorkers(); + + // Returns the hostname used by the TCPStore. + const std::string &getHost() const noexcept { return addr_.host; } + + // Returns the port used by the TCPStore. + std::uint16_t getPort() const noexcept { return addr_.port; } + +private: + int64_t incrementValueBy(const std::string &key, int64_t delta); + + std::vector doGet(const std::string &key); + + void doWait(std::vector keys, + std::chrono::milliseconds timeout); + + detail::SocketAddress addr_; + std::shared_ptr server_; + std::unique_ptr client_; + std::unique_ptr callbackClient_; + std::optional numWorkers_; + + const std::string initKey_ = "init/"; + const std::string keyPrefix_ = "/"; + std::mutex activeOpLock_; +}; + +} // namespace xoscar diff --git a/cpp/collective/rendezvous/include/unix_sock_utils.hpp b/cpp/collective/rendezvous/include/unix_sock_utils.hpp new file mode 100644 index 00000000..de991b6a --- /dev/null +++ b/cpp/collective/rendezvous/include/unix_sock_utils.hpp @@ -0,0 +1,39 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include "utils.hpp" + +#include +#include + +namespace xoscar::tcputil { + +#define CONNECT_SOCKET_OFFSET 2 + +inline int poll(struct pollfd *fds, unsigned long nfds, int timeout) { + return ::poll(fds, nfds, timeout); +} + +inline void +addPollfd(std::vector &fds, int socket, short events) { + fds.push_back({.fd = socket, .events = events}); +} + +inline struct ::pollfd getPollfd(int socket, short events) { + struct ::pollfd res = {.fd = socket, .events = events}; + return res; +} + +} // namespace xoscar::tcputil diff --git a/cpp/collective/rendezvous/include/utils.hpp b/cpp/collective/rendezvous/include/utils.hpp new file mode 100644 index 00000000..4468058a --- /dev/null +++ b/cpp/collective/rendezvous/include/utils.hpp @@ -0,0 +1,205 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xoscar { + +using RankType = uint32_t; +using SizeType = uint64_t; + +// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets +// `errno` to `EINVAL` in child process on some macos +// (https://stackoverflow.com/a/20295079), and thus `errno` should really only +// be inspected if an error occurred. +// +// `success_cond` is an expression used to check if an error has happend. So for +// `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output +// is stored in variable `__output` and may be used in `success_cond`. +#ifdef _WIN32 +# define SYSCHECK(expr, success_cond) \ + while (true) { \ + auto __output = (expr); \ + auto errno_local = WSAGetLastError(); \ + (void) __output; \ + if (!(success_cond)) { \ + if (errno == EINTR) { \ + continue; \ + } else if (errno_local == WSAETIMEDOUT \ + || errno_local == WSAEWOULDBLOCK) { \ + TORCH_CHECK(false, "Socket Timeout"); \ + } else { \ + throw std::system_error(errno_local, \ + std::system_category()); \ + } \ + } else { \ + break; \ + } \ + } +#else +# define SYSCHECK(expr, success_cond) \ + while (true) { \ + auto __output = (expr); \ + (void) __output; \ + if (!(success_cond)) { \ + if (errno == EINTR) { \ + continue; \ + } else if (errno == EAGAIN || errno == EWOULDBLOCK) { \ + throw std::runtime_error("Socket Timeout"); \ + } else { \ + throw std::system_error(errno, std::system_category()); \ + } \ + } else { \ + break; \ + } \ + } +#endif + +// Most functions indicate error by returning `-1`. This is a helper macro for +// this common case with `SYSCHECK`. +// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 +#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) + +namespace tcputil { +// Send and receive +template +void sendBytes(int socket, + const T *buffer, + size_t length, + bool moreData = false) { + size_t bytesToSend = sizeof(T) * length; + if (bytesToSend == 0) { + return; + } + + auto bytes = reinterpret_cast(buffer); + uint8_t *currentBytes = const_cast(bytes); + + int flags = 0; + +#ifdef MSG_MORE + if (moreData) { // there is more data to send + flags |= MSG_MORE; + } +#endif + +// Ignore SIGPIPE as the send() return value is always checked for error +#ifdef MSG_NOSIGNAL + flags |= MSG_NOSIGNAL; +#endif + + while (bytesToSend > 0) { + ssize_t bytesSent; + SYSCHECK_ERR_RETURN_NEG1( + bytesSent + = ::send(socket, (const char *) currentBytes, bytesToSend, flags)) + if (bytesSent == 0) { + throw std::system_error(ECONNRESET, std::system_category()); + } + + bytesToSend -= bytesSent; + currentBytes += bytesSent; + } +} + +template +void recvBytes(int socket, T *buffer, size_t length) { + size_t bytesToReceive = sizeof(T) * length; + if (bytesToReceive == 0) { + return; + } + + auto bytes = reinterpret_cast(buffer); + uint8_t *currentBytes = bytes; + + while (bytesToReceive > 0) { + ssize_t bytesReceived; + SYSCHECK_ERR_RETURN_NEG1(bytesReceived + = recv(socket, + reinterpret_cast(currentBytes), + bytesToReceive, + 0)) + if (bytesReceived == 0) { + throw std::system_error(ECONNRESET, std::system_category()); + } + + bytesToReceive -= bytesReceived; + currentBytes += bytesReceived; + } +} + +// send a vector's length and data +template +void sendVector(int socket, const std::vector &vec, bool moreData = false) { + SizeType size = vec.size(); + sendBytes(socket, &size, 1, true); + sendBytes(socket, vec.data(), size, moreData); +} + +// receive a vector as sent in sendVector +template +std::vector recvVector(int socket) { + SizeType valueSize; + recvBytes(socket, &valueSize, 1); + std::vector value(valueSize); + recvBytes(socket, value.data(), value.size()); + return value; +} + +// this is only for convenience when sending rvalues +template +void sendValue(int socket, const T &value, bool moreData = false) { + sendBytes(socket, &value, 1, moreData); +} + +template +T recvValue(int socket) { + T value; + recvBytes(socket, &value, 1); + return value; +} + +// send a string's length and data +inline void +sendString(int socket, const std::string &str, bool moreData = false) { + SizeType size = str.size(); + sendBytes(socket, &size, 1, true); + sendBytes(socket, str.data(), size, moreData); +} + +// receive a string as sent in sendString +inline std::string recvString(int socket) { + SizeType valueSize; + recvBytes(socket, &valueSize, 1); + std::vector value(valueSize); + recvBytes(socket, value.data(), value.size()); + return std::string(value.data(), value.size()); +} +} // namespace tcputil +} // namespace xoscar diff --git a/cpp/collective/rendezvous/src/bind_tcp_store.cpp b/cpp/collective/rendezvous/src/bind_tcp_store.cpp new file mode 100644 index 00000000..019e0b02 --- /dev/null +++ b/cpp/collective/rendezvous/src/bind_tcp_store.cpp @@ -0,0 +1,59 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include "tcp_store.hpp" + +#include +#include +#include + +namespace py = pybind11; + +namespace xoscar { +PYBIND11_MODULE(xoscar_store, m) { + py::class_(m, "TCPStoreOptions") + .def(py::init()) + .def_readwrite("port", &TCPStoreOptions::port) + .def_readwrite("isServer", &TCPStoreOptions::isServer) + .def_readwrite("numWorkers", &TCPStoreOptions::numWorkers) + .def_readwrite("waitWorkers", &TCPStoreOptions::waitWorkers) + .def_readwrite("timeout", &TCPStoreOptions::timeout) + .def_readwrite("multiTenant", &TCPStoreOptions::multiTenant); + + py::class_(m, "Store"); + + py::class_(m, "TCPStore") + .def(py::init()) + .def("wait", + py::overload_cast &>( + &TCPStore::wait)) + .def("wait", + py::overload_cast &, + const std::chrono::milliseconds &>( + &TCPStore::wait)) + .def("set", + [](TCPStore &self, const std::string &key, py::bytes &bytes) { + const py::buffer_info info(py::buffer(bytes).request()); + const char *data = reinterpret_cast(info.ptr); + auto length = static_cast(info.size); + self.set(key, std::vector(data, data + length)); + }) + .def("get", [](TCPStore &self, const std::string &key) { + auto result = self.get(key); + const std::string str_result(result.begin(), result.end()); + return py::bytes(str_result); + }); +} +} // namespace xoscar diff --git a/cpp/collective/rendezvous/src/exception.cpp b/cpp/collective/rendezvous/src/exception.cpp new file mode 100644 index 00000000..4e690f70 --- /dev/null +++ b/cpp/collective/rendezvous/src/exception.cpp @@ -0,0 +1,22 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "exception.h" + +namespace xoscar { + +XoscarError::~XoscarError() = default; + +TimeoutError::~TimeoutError() = default; + +} // namespace xoscar diff --git a/cpp/collective/rendezvous/src/socket.cpp b/cpp/collective/rendezvous/src/socket.cpp new file mode 100644 index 00000000..244adcfa --- /dev/null +++ b/cpp/collective/rendezvous/src/socket.cpp @@ -0,0 +1,919 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "socket.h" + +#include "error.h" +#include "exception.h" +#include "fmt/chrono.h" + +#include +#include +#include + +#ifdef _WIN32 +# include +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif + +namespace xoscar { +namespace detail { +namespace { +#ifdef _WIN32 + +// Since Winsock uses the name `WSAPoll` instead of `poll`, we alias it here +// to avoid #ifdefs in the source code. +const auto pollFd = ::WSAPoll; + +// Winsock's `getsockopt()` and `setsockopt()` functions expect option values to +// be passed as `char*` instead of `void*`. We wrap them here to avoid redundant +// casts in the source code. +int getSocketOption( + SOCKET s, int level, int optname, void *optval, int *optlen) { + return ::getsockopt(s, level, optname, static_cast(optval), optlen); +} + +int setSocketOption( + SOCKET s, int level, int optname, const void *optval, int optlen) { + return ::setsockopt( + s, level, optname, static_cast(optval), optlen); +} + +// Winsock has its own error codes which differ from Berkeley's. Fortunately the +// C++ Standard Library on Windows can map them to standard error codes. +inline std::error_code getSocketError() noexcept { + return std::error_code{::WSAGetLastError(), std::system_category()}; +} + +inline void setSocketError(int val) noexcept { ::WSASetLastError(val); } + +#else + +const auto pollFd = ::poll; + +const auto getSocketOption = ::getsockopt; +const auto setSocketOption = ::setsockopt; + +inline std::error_code getSocketError() noexcept { return lastError(); } + +inline void setSocketError(int val) noexcept { errno = val; } + +#endif + +// Suspends the current thread for the specified duration. +void delay(std::chrono::seconds d) { +#ifdef _WIN32 + std::this_thread::sleep_for(d); +#else + ::timespec req{}; + req.tv_sec = d.count(); + + // The C++ Standard does not specify whether `sleep_for()` should be signal- + // aware; therefore, we use the `nanosleep()` syscall. + if (::nanosleep(&req, nullptr) != 0) { + std::error_code err = getSocketError(); + // We don't care about error conditions other than EINTR since a failure + // here is not critical. + if (err == std::errc::interrupted) { + throw std::system_error{err}; + } + } +#endif +} + +class SocketListenOp; +class SocketConnectOp; +} // namespace + +class SocketImpl { + friend class SocketListenOp; + friend class SocketConnectOp; + +public: +#ifdef _WIN32 + using Handle = SOCKET; +#else + using Handle = int; +#endif + +#ifdef _WIN32 + static constexpr Handle invalid_socket = INVALID_SOCKET; +#else + static constexpr Handle invalid_socket = -1; +#endif + + explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {} + + SocketImpl(const SocketImpl &other) = delete; + + SocketImpl &operator=(const SocketImpl &other) = delete; + + SocketImpl(SocketImpl &&other) noexcept = delete; + + SocketImpl &operator=(SocketImpl &&other) noexcept = delete; + + ~SocketImpl(); + + std::unique_ptr accept() const; + + void closeOnExec() noexcept; + + void enableNonBlocking(); + + void disableNonBlocking(); + + bool enableNoDelay() noexcept; + + bool enableDualStack() noexcept; + +#ifndef _WIN32 + bool enableAddressReuse() noexcept; +#endif + +#ifdef _WIN32 + bool enableExclusiveAddressUse() noexcept; +#endif + + std::uint16_t getPort() const; + + Handle handle() const noexcept { return hnd_; } + +private: + bool setSocketFlag(int level, int optname, bool value) noexcept; + + Handle hnd_; +}; +} // namespace detail +} // namespace xoscar + +// +// libfmt formatters for `addrinfo` and `Socket` +// +namespace fmt { + +template <> +struct formatter<::addrinfo> { + constexpr decltype(auto) parse(format_parse_context &ctx) const { + return ctx.begin(); + } + + template + decltype(auto) format(const ::addrinfo &addr, FormatContext &ctx) const { + char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT + + int r = ::getnameinfo(addr.ai_addr, + addr.ai_addrlen, + host, + NI_MAXHOST, + port, + NI_MAXSERV, + NI_NUMERICSERV); + if (r != 0) { + return format_to(ctx.out(), "?UNKNOWN?"); + } + + if (addr.ai_addr->sa_family == AF_INET) { + return format_to(ctx.out(), "{}:{}", host, port); + } else { + return format_to(ctx.out(), "[{}]:{}", host, port); + } + } +}; + +template <> +struct formatter { + constexpr decltype(auto) parse(format_parse_context &ctx) const { + return ctx.begin(); + } + + template + decltype(auto) format(const xoscar::detail::SocketImpl &socket, + FormatContext &ctx) const { + ::sockaddr_storage addr_s{}; + + auto addr_ptr = reinterpret_cast<::sockaddr *>(&addr_s); + + ::socklen_t addr_len = sizeof(addr_s); + + if (::getsockname(socket.handle(), addr_ptr, &addr_len) != 0) { + return format_to(ctx.out(), "?UNKNOWN?"); + } + + ::addrinfo addr{}; + addr.ai_addr = addr_ptr; + addr.ai_addrlen = addr_len; + + return format_to(ctx.out(), "{}", addr); + } +}; + +} // namespace fmt + +namespace xoscar { +namespace detail { + +SocketImpl::~SocketImpl() { +#ifdef _WIN32 + ::closesocket(hnd_); +#else + ::close(hnd_); +#endif +} + +std::unique_ptr SocketImpl::accept() const { + ::sockaddr_storage addr_s{}; + + auto addr_ptr = reinterpret_cast<::sockaddr *>(&addr_s); + + ::socklen_t addr_len = sizeof(addr_s); + + Handle hnd = ::accept(hnd_, addr_ptr, &addr_len); + if (hnd == invalid_socket) { + std::error_code err = getSocketError(); + if (err == std::errc::interrupted) { + throw std::system_error{err}; + } + + std::string msg{}; + if (err == std::errc::invalid_argument) { + msg = fmt::format( + "The server socket on {} is not listening for connections.", + *this); + } else { + msg = fmt::format( + "The server socket on {} has failed to accept a connection {}.", + *this, + err); + } + + // xoscar_ERROR(msg); + + throw SocketError{msg}; + } + + ::addrinfo addr{}; + addr.ai_addr = addr_ptr; + addr.ai_addrlen = addr_len; + + // xoscar_DEBUG( + // "The server socket on {} has accepted a connection from {}.", + // *this, + // addr); + + auto impl = std::make_unique(hnd); + + // Make sure that we do not "leak" our file descriptors to child processes. + impl->closeOnExec(); + + if (!impl->enableNoDelay()) { + // xoscar_WARNING( + // "The no-delay option cannot be enabled for the client socket + // on + // {}.", addr); + } + + return impl; +} + +void SocketImpl::closeOnExec() noexcept { +#ifndef _WIN32 + ::fcntl(hnd_, F_SETFD, FD_CLOEXEC); +#endif +} + +void SocketImpl::enableNonBlocking() { +#ifdef _WIN32 + unsigned long value = 1; + if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) { + return; + } +#else + int flg = ::fcntl(hnd_, F_GETFL); + if (flg != -1) { + if (::fcntl(hnd_, F_SETFL, flg | O_NONBLOCK) == 0) { + return; + } + } +#endif + throw SocketError{"The socket cannot be switched to non-blocking mode."}; +} + +// TODO: Remove once we migrate everything to non-blocking mode. +void SocketImpl::disableNonBlocking() { +#ifdef _WIN32 + unsigned long value = 0; + if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) { + return; + } +#else + int flg = ::fcntl(hnd_, F_GETFL); + if (flg != -1) { + if (::fcntl(hnd_, F_SETFL, flg & ~O_NONBLOCK) == 0) { + return; + } + } +#endif + throw SocketError{"The socket cannot be switched to blocking mode."}; +} + +bool SocketImpl::enableNoDelay() noexcept { + return setSocketFlag(IPPROTO_TCP, TCP_NODELAY, true); +} + +bool SocketImpl::enableDualStack() noexcept { + return setSocketFlag(IPPROTO_IPV6, IPV6_V6ONLY, false); +} + +#ifndef _WIN32 +bool SocketImpl::enableAddressReuse() noexcept { + return setSocketFlag(SOL_SOCKET, SO_REUSEADDR, true); +} +#endif + +#ifdef _WIN32 +bool SocketImpl::enableExclusiveAddressUse() noexcept { + return setSocketFlag(SOL_SOCKET, SO_EXCLUSIVEADDRUSE, true); +} +#endif + +std::uint16_t SocketImpl::getPort() const { + ::sockaddr_storage addr_s{}; + + ::socklen_t addr_len = sizeof(addr_s); + + if (::getsockname(hnd_, reinterpret_cast<::sockaddr *>(&addr_s), &addr_len) + != 0) { + throw SocketError{"The port number of the socket cannot be retrieved."}; + } + + if (addr_s.ss_family == AF_INET) { + return ntohs(reinterpret_cast<::sockaddr_in *>(&addr_s)->sin_port); + } else { + return ntohs(reinterpret_cast<::sockaddr_in6 *>(&addr_s)->sin6_port); + } +} + +bool SocketImpl::setSocketFlag(int level, int optname, bool value) noexcept { +#ifdef _WIN32 + auto buf = value ? TRUE : FALSE; +#else + auto buf = value ? 1 : 0; +#endif + return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0; +} + +namespace { + +struct addrinfo_delete { + void operator()(::addrinfo *addr) const noexcept { ::freeaddrinfo(addr); } +}; + +using addrinfo_ptr = std::unique_ptr<::addrinfo, addrinfo_delete>; + +class SocketListenOp { +public: + SocketListenOp(std::uint16_t port, const SocketOptions &opts); + + std::unique_ptr run(); + +private: + bool tryListen(int family); + + bool tryListen(const ::addrinfo &addr); + + template + void recordError(fmt::string_view format, Args &&...args) { + auto msg = fmt::vformat(format, fmt::make_format_args(args...)); + + // xoscar_WARNING(msg); + + errors_.emplace_back(std::move(msg)); + } + + std::string port_; + const SocketOptions *opts_; + std::vector errors_{}; + std::unique_ptr socket_{}; +}; + +SocketListenOp::SocketListenOp(std::uint16_t port, const SocketOptions &opts) + : port_{fmt::to_string(port)}, opts_{&opts} {} + +std::unique_ptr SocketListenOp::run() { + if (opts_->prefer_ipv6()) { + // xoscar_DEBUG("The server socket will attempt to listen on an IPv6 + // address."); + if (tryListen(AF_INET6)) { + return std::move(socket_); + } + + // xoscar_DEBUG("The server socket will attempt to listen on an IPv4 + // address."); + if (tryListen(AF_INET)) { + return std::move(socket_); + } + } else { + // xoscar_DEBUG( + // "The server socket will attempt to listen on an IPv4 or IPv6 + // address."); + if (tryListen(AF_UNSPEC)) { + return std::move(socket_); + } + } + + constexpr auto *msg = "The server socket has failed to listen on any local " + "network address."; + + // xoscar_ERROR(msg); + + throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))}; +} + +bool SocketListenOp::tryListen(int family) { + ::addrinfo hints{}, *naked_result = nullptr; + + hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV; + hints.ai_family = family; + hints.ai_socktype = SOCK_STREAM; + + int r = ::getaddrinfo(nullptr, port_.c_str(), &hints, &naked_result); + if (r != 0) { + const char *gai_err = ::gai_strerror(r); + + recordError( + "The local {}network addresses cannot be retrieved (gai error: " + "{} - {}).", + family == AF_INET ? "IPv4 " + : family == AF_INET6 ? "IPv6 " + : "", + r, + gai_err); + + return false; + } + + addrinfo_ptr result{naked_result}; + + for (::addrinfo *addr = naked_result; addr != nullptr; + addr = addr->ai_next) { + // xoscar_DEBUG("The server socket is attempting to listen on {}.", + // *addr); + if (tryListen(*addr)) { + return true; + } + } + + return false; +} + +bool SocketListenOp::tryListen(const ::addrinfo &addr) { + SocketImpl::Handle hnd + = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); + if (hnd == SocketImpl::invalid_socket) { + recordError("The server socket cannot be initialized on {} {}.", + addr, + getSocketError()); + + return false; + } + + socket_ = std::make_unique(hnd); + +#ifndef _WIN32 + if (!socket_->enableAddressReuse()) { + // xoscar_WARNING( + // "The address reuse option cannot be enabled for the server + // socket on {}.", addr); + } +#endif + +#ifdef _WIN32 + // The SO_REUSEADDR flag has a significantly different behavior on Windows + // compared to Unix-like systems. It allows two or more processes to share + // the same port simultaneously, which is totally unsafe. + // + // Here we follow the recommendation of Microsoft and use the non-standard + // SO_EXCLUSIVEADDRUSE flag instead. + if (!socket_->enableExclusiveAddressUse()) { + xoscar_WARNING( + "The exclusive address use option cannot be enabled for the " + "server socket on {}.", + addr); + } +#endif + + // Not all operating systems support dual-stack sockets by default. Since we + // wish to use our IPv6 socket for IPv4 communication as well, we explicitly + // ask the system to enable it. + if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) { + // xoscar_WARNING( + // "The server socket does not support IPv4 communication on + // {}.", addr); + } + + if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) { + recordError("The server socket has failed to bind to {} {}.", + addr, + getSocketError()); + + return false; + } + + // NOLINTNEXTLINE(bugprone-argument-comment) + if (::listen(socket_->handle(), /*backlog=*/2048) != 0) { + recordError("The server socket has failed to listen on {} {}.", + addr, + getSocketError()); + + return false; + } + + socket_->closeOnExec(); + + // xoscar_INFO("The server socket has started to listen on {}.", addr); + + return true; +} + +class SocketConnectOp { + using Clock = std::chrono::steady_clock; + using Duration = std::chrono::steady_clock::duration; + using TimePoint = std::chrono::time_point; + + static const std::chrono::seconds delay_duration_; + + enum class ConnectResult { Success, Error, Retry }; + +public: + SocketConnectOp(const std::string &host, + std::uint16_t port, + const SocketOptions &opts); + + std::unique_ptr run(); + +private: + bool tryConnect(int family); + + ConnectResult tryConnect(const ::addrinfo &addr); + + ConnectResult tryConnectCore(const ::addrinfo &addr); + + [[noreturn]] void throwTimeoutError() const; + + template + void recordError(fmt::string_view format, Args &&...args) { + auto msg = fmt::vformat(format, fmt::make_format_args(args...)); + + // xoscar_WARNING(msg); + + errors_.emplace_back(std::move(msg)); + } + + const char *host_; + std::string port_; + const SocketOptions *opts_; + TimePoint deadline_{}; + std::vector errors_{}; + std::unique_ptr socket_{}; +}; + +const std::chrono::seconds SocketConnectOp::delay_duration_{1}; + +SocketConnectOp::SocketConnectOp(const std::string &host, + std::uint16_t port, + const SocketOptions &opts) + : host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {} + +std::unique_ptr SocketConnectOp::run() { + if (opts_->prefer_ipv6()) { + // xoscar_DEBUG( + // "The client socket will attempt to connect to an IPv6 address + // of + // ({}, {}).", host_, port_); + + if (tryConnect(AF_INET6)) { + return std::move(socket_); + } + + // xoscar_DEBUG( + // "The client socket will attempt to connect to an IPv4 address + // of + // ({}, {}).", host_, port_); + + if (tryConnect(AF_INET)) { + return std::move(socket_); + } + } else { + // xoscar_DEBUG( + // "The client socket will attempt to connect to an IPv4 or IPv6 + // address of ({}, {}).", host_, port_); + + if (tryConnect(AF_UNSPEC)) { + return std::move(socket_); + } + } + + auto msg = fmt::format("The client socket has failed to connect to any " + "network address of ({}, {}).", + host_, + port_); + + // xoscar_ERROR(msg); + + throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))}; +} + +bool SocketConnectOp::tryConnect(int family) { + ::addrinfo hints{}; + hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV; + hints.ai_family = family; + hints.ai_socktype = SOCK_STREAM; + + deadline_ = Clock::now() + opts_->connect_timeout(); + + std::size_t retry_attempt = 1; + + bool retry; // NOLINT(cppcoreguidelines-init-variables) + do { + retry = false; + + errors_.clear(); + + ::addrinfo *naked_result = nullptr; + // patternlint-disable cpp-dns-deps + int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result); + if (r != 0) { + const char *gai_err = ::gai_strerror(r); + + recordError( + "The {}network addresses of ({}, {}) cannot be retrieved " + "(gai error: {} - {}).", + family == AF_INET ? "IPv4 " + : family == AF_INET6 ? "IPv6 " + : "", + host_, + port_, + r, + gai_err); + retry = true; + } else { + addrinfo_ptr result{naked_result}; + + for (::addrinfo *addr = naked_result; addr != nullptr; + addr = addr->ai_next) { + // xoscar_TRACE("The client socket is attempting to + // connect to + // {}.", *addr); + + ConnectResult cr = tryConnect(*addr); + if (cr == ConnectResult::Success) { + return true; + } + + if (cr == ConnectResult::Retry) { + retry = true; + } + } + } + + if (retry) { + if (Clock::now() < deadline_ - delay_duration_) { + // Prevent our log output to be too noisy, warn only every 30 + // seconds. + if (retry_attempt == 30) { + // xoscar_INFO( + // "No socket on ({}, {}) is listening yet, + // will retry.", host_, port_); + + retry_attempt = 0; + } + + // Wait one second to avoid choking the server. + delay(delay_duration_); + + retry_attempt++; + } else { + throwTimeoutError(); + } + } + } while (retry); + + return false; +} + +SocketConnectOp::ConnectResult +SocketConnectOp::tryConnect(const ::addrinfo &addr) { + if (Clock::now() >= deadline_) { + throwTimeoutError(); + } + + SocketImpl::Handle hnd + = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); + if (hnd == SocketImpl::invalid_socket) { + recordError( + "The client socket cannot be initialized to connect to {} {}.", + addr, + getSocketError()); + + return ConnectResult::Error; + } + + socket_ = std::make_unique(hnd); + + socket_->enableNonBlocking(); + + ConnectResult cr = tryConnectCore(addr); + if (cr == ConnectResult::Error) { + std::error_code err = getSocketError(); + if (err == std::errc::interrupted) { + throw std::system_error{err}; + } + + // Retry if the server is not yet listening or if its backlog is + // exhausted. + if (err == std::errc::connection_refused + || err == std::errc::connection_reset) { + // xoscar_TRACE( + // "The server socket on {} is not yet listening {}, will + // retry.", addr, err); + + return ConnectResult::Retry; + } else { + recordError( + "The client socket has failed to connect to {} {}.", addr, err); + + return ConnectResult::Error; + } + } + + socket_->closeOnExec(); + + // TODO: Remove once we fully migrate to non-blocking mode. + socket_->disableNonBlocking(); + + // xoscar_INFO("The client socket has connected to {} on {}.", addr, + // *socket_); + + if (!socket_->enableNoDelay()) { + // xoscar_WARNING( + // "The no-delay option cannot be enabled for the client socket + // on + // {}.", *socket_); + } + + return ConnectResult::Success; +} + +SocketConnectOp::ConnectResult +SocketConnectOp::tryConnectCore(const ::addrinfo &addr) { + int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen); + if (r == 0) { + return ConnectResult::Success; + } + + std::error_code err = getSocketError(); + if (err == std::errc::already_connected) { + return ConnectResult::Success; + } + + if (err != std::errc::operation_in_progress + && err != std::errc::operation_would_block) { + return ConnectResult::Error; + } + + Duration remaining = deadline_ - Clock::now(); + if (remaining <= Duration::zero()) { + throwTimeoutError(); + } + + ::pollfd pfd{}; + pfd.fd = socket_->handle(); + pfd.events = POLLOUT; + + auto ms = std::chrono::duration_cast(remaining); + + r = pollFd(&pfd, 1, static_cast(ms.count())); + if (r == 0) { + throwTimeoutError(); + } + if (r == -1) { + return ConnectResult::Error; + } + + int err_code = 0; + + ::socklen_t err_len = sizeof(int); + + r = getSocketOption( + socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len); + if (r != 0) { + return ConnectResult::Error; + } + + if (err_code != 0) { + setSocketError(err_code); + + return ConnectResult::Error; + } else { + return ConnectResult::Success; + } +} + +void SocketConnectOp::throwTimeoutError() const { + auto msg = fmt::format("The client socket has timed out after {} while " + "trying to connect to ({}, {}).", + opts_->connect_timeout(), + host_, + port_); + + // xoscar_ERROR(msg); + + throw TimeoutError{msg}; +} + +} // namespace + +void Socket::initialize() { +#ifdef _WIN32 + static c10::once_flag init_flag{}; + + // All processes that call socket functions on Windows must first initialize + // the Winsock library. + c10::call_once(init_flag, []() { + WSADATA data{}; + if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) { + throw SocketError{"The initialization of Winsock has failed."}; + } + }); +#endif +} + +Socket Socket::listen(std::uint16_t port, const SocketOptions &opts) { + SocketListenOp op{port, opts}; + + return Socket{op.run()}; +} + +Socket Socket::connect(const std::string &host, + std::uint16_t port, + const SocketOptions &opts) { + SocketConnectOp op{host, port, opts}; + + return Socket{op.run()}; +} + +Socket::Socket(Socket &&other) noexcept = default; + +Socket &Socket::operator=(Socket &&other) noexcept = default; + +Socket::~Socket() = default; + +Socket Socket::accept() const { + if (impl_) { + return Socket{impl_->accept()}; + } + + throw SocketError{"The socket is not initialized."}; +} + +int Socket::handle() const noexcept { + if (impl_) { + return impl_->handle(); + } + return SocketImpl::invalid_socket; +} + +std::uint16_t Socket::port() const { + if (impl_) { + return impl_->getPort(); + } + return 0; +} + +Socket::Socket(std::unique_ptr &&impl) noexcept + : impl_{std::move(impl)} {} + +} // namespace detail + +SocketError::~SocketError() = default; + +} // namespace xoscar diff --git a/cpp/collective/rendezvous/src/store.cpp b/cpp/collective/rendezvous/src/store.cpp new file mode 100644 index 00000000..1309d06f --- /dev/null +++ b/cpp/collective/rendezvous/src/store.cpp @@ -0,0 +1,84 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "store.hpp" + +namespace xoscar { + +constexpr std::chrono::milliseconds Store::kDefaultTimeout; +constexpr std::chrono::milliseconds Store::kNoTimeout; + +// Define destructor symbol for abstract base class. +Store::~Store() = default; + +const std::chrono::milliseconds &Store::getTimeout() const noexcept { + return timeout_; +} + +// Set timeout function +void Store::setTimeout(const std::chrono::milliseconds &timeout) { + timeout_ = timeout; +} + +void Store::set(const std::string &key, const std::string &value) { + set(key, std::vector(value.begin(), value.end())); +} + +std::string Store::compareSet(const std::string &key, + const std::string ¤tValue, + const std::string &newValue) { + auto value = compareSet( + key, + std::vector(currentValue.begin(), currentValue.end()), + std::vector(newValue.begin(), newValue.end())); + return std::string(value.begin(), value.end()); +} + +std::string Store::get_to_str(const std::string &key) { + auto value = get(key); + return std::string(value.begin(), value.end()); +} + +void Store::append(const std::string &key, const std::vector &value) { + // This fallback depends on compareSet + std::vector expected = value; + std::vector current; + // cannot use get(key) as it might block forever if the key doesn't exist + current = compareSet(key, current, expected); + while (current != expected) { + expected = current; + expected.insert(expected.end(), value.begin(), value.end()); + current = compareSet(key, current, expected); + } +} + +std::vector> +Store::multiGet(const std::vector &keys) { + std::vector> result; + result.reserve(keys.size()); + for (auto &key : keys) { + result.emplace_back(get(key)); + } + return result; +} + +void Store::multiSet(const std::vector &keys, + const std::vector> &values) { + for (int i = 0; i < keys.size(); i++) { + set(keys[i], values[i]); + } +} + +bool Store::hasExtendedApi() const { return false; } + +} // namespace xoscar diff --git a/cpp/collective/rendezvous/src/tcp_store.cpp b/cpp/collective/rendezvous/src/tcp_store.cpp new file mode 100644 index 00000000..d32ff13c --- /dev/null +++ b/cpp/collective/rendezvous/src/tcp_store.cpp @@ -0,0 +1,1321 @@ +/* Copyright 2022-2023 XProbe Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "tcp_store.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +# include +# include +#else +# include +# include +#endif + +#ifdef _WIN32 +# include +#else +# include "unix_sock_utils.hpp" +#endif + +#include "socket.h" + +namespace xoscar { +namespace detail { +namespace { + +// Abstract base class to handle thread state for TCPStoreMasterDaemon and +// TCPStoreWorkerDaemon. Contains the windows/unix implementations to signal a +// shutdown sequence for the thread +class BackgroundThread { +public: + explicit BackgroundThread(Socket &&storeListenSocket); + + virtual ~BackgroundThread() = 0; + +protected: + void dispose(); + + Socket storeListenSocket_; + std::thread daemonThread_{}; + std::vector sockets_{}; +#ifdef _WIN32 + const std::chrono::milliseconds checkTimeout_ + = std::chrono::milliseconds{10}; + HANDLE ghStopEvent_{}; +#else + std::array controlPipeFd_{{-1, -1}}; +#endif + +private: + // Initialization for shutdown signal + void initStopSignal(); + // Triggers the shutdown signal + void stop(); + // Joins the thread + void join(); + // Clean up the shutdown signal + void closeStopSignal(); +}; + +// Background thread parent class methods +BackgroundThread::BackgroundThread(Socket &&storeListenSocket) + : storeListenSocket_{std::move(storeListenSocket)} { + // Signal instance destruction to the daemon thread. + initStopSignal(); +} + +BackgroundThread::~BackgroundThread() = default; + +// WARNING: +// Since we rely on the subclass for the daemon thread clean-up, we cannot +// destruct our member variables in the destructor. The subclass must call +// dispose() in its own destructor. +void BackgroundThread::dispose() { + // Stop the run + stop(); + // Join the thread + join(); + // Close unclosed sockets + sockets_.clear(); + // Now close the rest control pipe + closeStopSignal(); +} + +void BackgroundThread::join() { daemonThread_.join(); } + +#ifdef _WIN32 +void BackgroundThread::initStopSignal() { + ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); + if (ghStopEvent_ == NULL) { + TORCH_CHECK(false, + "Failed to create the control pipe to start the " + "BackgroundThread run"); + } +} + +void BackgroundThread::closeStopSignal() { CloseHandle(ghStopEvent_); } + +void BackgroundThread::stop() { SetEvent(ghStopEvent_); } +#else +void BackgroundThread::initStopSignal() { + if (pipe(controlPipeFd_.data()) == -1) { + // TORCH_CHECK( + // false, + // "Failed to create the control pipe to start the " + // "BackgroundThread run"); + throw std::runtime_error("Failed to create the control pipe to start " + "the BackgroundThread run"); + } +} + +void BackgroundThread::closeStopSignal() { + for (int fd : controlPipeFd_) { + if (fd != -1) { + ::close(fd); + } + } +} + +void BackgroundThread::stop() { + if (controlPipeFd_[1] != -1) { + ::write(controlPipeFd_[1], "\0", 1); + // close the write end of the pipe + ::close(controlPipeFd_[1]); + controlPipeFd_[1] = -1; + } +} +#endif + +enum class QueryType : uint8_t { + SET, + COMPARE_SET, + GET, + ADD, + CHECK, + WAIT, + GETNUMKEYS, + WATCH_KEY, + DELETE_KEY, + APPEND, + MULTI_GET, + MULTI_SET, +}; + +enum class CheckResponseType : uint8_t { READY, NOT_READY }; + +enum class WaitResponseType : uint8_t { STOP_WAITING }; + +enum class WatchResponseType : uint8_t { + KEY_UPDATED, + KEY_CREATED, + KEY_DELETED, + KEY_CALLBACK_REGISTERED, + KEY_APPENDED, +}; + +// Separate thread that is only launched on master +class TCPStoreMasterDaemon : public BackgroundThread { +public: + explicit TCPStoreMasterDaemon(Socket &&storeListenSocket); + + ~TCPStoreMasterDaemon() override; + +private: + void run(); + void queryFds(std::vector &fds); + void query(int socket); + + // The master runs on a single thread so only + // one handler can be executed at a time + void setHandler(int socket); + void compareSetHandler(int socket); + void addHandler(int socket); + void getHandler(int socket) const; + void checkHandler(int socket) const; + void getNumKeysHandler(int socket) const; + void deleteHandler(int socket); + void waitHandler(int socket); + void watchHandler(int socket); + void appendHandler(int socket); + void multiGetHandler(int socket); + void multiSetHandler(int socket); + + bool checkKeys(const std::vector &keys) const; + // Helper function to alerts waiting workers, used in setHandler, getHandler + void wakeupWaitingClients(const std::string &key); + // Helper function used when the key is changed + // used in setHandler, addHandler, getHandler, deleteHandler + void sendKeyUpdatesToClients(const std::string &key, + const enum WatchResponseType &type, + const std::vector &oldData, + const std::vector &newData); + void doSet(const std::string &key, const std::vector &newData); + + std::unordered_map> tcpStore_; + // From key -> the list of sockets waiting on the key + std::unordered_map> waitingSockets_; + // From socket -> number of keys awaited + std::unordered_map keysAwaited_; + // From key -> the list of sockets watching the key + std::unordered_map> watchedSockets_; +}; + +// Simply start the daemon thread +TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket &&storeListenSocket) + : BackgroundThread{std::move(storeListenSocket)} { + daemonThread_ = std::thread{&TCPStoreMasterDaemon::run, this}; +} + +TCPStoreMasterDaemon::~TCPStoreMasterDaemon() { dispose(); } + +void TCPStoreMasterDaemon::queryFds(std::vector &fds) { + // Skipping the fds[0] and fds[1], + // fds[0] is master's listening socket + // fds[1] is control pipe's reading fd, it is not for Windows platform + for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) { + if (fds[fdIdx].revents == 0) { + continue; + } + + // Now query the socket that has the event + try { + query(fds[fdIdx].fd); + } catch (...) { + // There was an error when processing query. Probably an exception + // occurred in recv/send what would indicate that socket on the + // other side has been closed. If the closing was due to normal + // exit, then the store should continue executing. Otherwise, if it + // was different exception, other connections will get an exception + // once they try to use the store. We will go ahead and close this + // connection whenever we hit an exception here. + + // Remove all the tracking state of the close FD + for (auto it = waitingSockets_.begin(); + it != waitingSockets_.end();) { + for (auto vecIt = it->second.begin(); + vecIt != it->second.end();) { + if (*vecIt == fds[fdIdx].fd) { + vecIt = it->second.erase(vecIt); + } else { + ++vecIt; + } + } + if (it->second.empty()) { + it = waitingSockets_.erase(it); + } else { + ++it; + } + } + for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) { + if (it->first == fds[fdIdx].fd) { + it = keysAwaited_.erase(it); + } else { + ++it; + } + } + fds.erase(fds.begin() + fdIdx); + sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET); + --fdIdx; + continue; + } + } +} + +// query communicates with the worker. The format +// of the query is as follows: +// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... +// or, in the case of wait +// type of query | number of args | size of arg1 | arg1 | ... +void TCPStoreMasterDaemon::query(int socket) { + QueryType qt; + tcputil::recvBytes(socket, &qt, 1); + if (qt == QueryType::SET) { + setHandler(socket); + + } else if (qt == QueryType::COMPARE_SET) { + compareSetHandler(socket); + + } else if (qt == QueryType::ADD) { + addHandler(socket); + + } else if (qt == QueryType::GET) { + getHandler(socket); + + } else if (qt == QueryType::CHECK) { + checkHandler(socket); + + } else if (qt == QueryType::WAIT) { + waitHandler(socket); + + } else if (qt == QueryType::GETNUMKEYS) { + getNumKeysHandler(socket); + + } else if (qt == QueryType::DELETE_KEY) { + deleteHandler(socket); + + } else if (qt == QueryType::WATCH_KEY) { + watchHandler(socket); + } else if (qt == QueryType::APPEND) { + appendHandler(socket); + } else if (qt == QueryType::MULTI_GET) { + multiGetHandler(socket); + } else if (qt == QueryType::MULTI_SET) { + multiSetHandler(socket); + } else { + // TORCH_CHECK(false, "Unexpected query type"); + throw std::runtime_error("Unexpected query type"); + } +} + +void TCPStoreMasterDaemon::wakeupWaitingClients(const std::string &key) { + auto socketsToWait = waitingSockets_.find(key); + if (socketsToWait != waitingSockets_.end()) { + for (int socket : socketsToWait->second) { + if (--keysAwaited_[socket] == 0) { + tcputil::sendValue( + socket, WaitResponseType::STOP_WAITING); + } + } + waitingSockets_.erase(socketsToWait); + } +} + +void TCPStoreMasterDaemon::sendKeyUpdatesToClients( + const std::string &key, + const enum WatchResponseType &type, + const std::vector &oldData, + const std::vector &newData) { + for (int socket : watchedSockets_[key]) { + tcputil::sendValue(socket, type); + tcputil::sendString(socket, key, true); + tcputil::sendVector(socket, oldData); + tcputil::sendVector(socket, newData); + } +} + +void TCPStoreMasterDaemon::doSet(const std::string &key, + const std::vector &newData) { + std::vector oldData; + bool newKey = true; + auto it = tcpStore_.find(key); + if (it != tcpStore_.end()) { + oldData = it->second; + newKey = false; + } + tcpStore_[key] = newData; + // On "set", wake up all clients that have been waiting + wakeupWaitingClients(key); + // Send key update to all watching clients + newKey ? sendKeyUpdatesToClients( + key, WatchResponseType::KEY_CREATED, oldData, newData) + : sendKeyUpdatesToClients( + key, WatchResponseType::KEY_UPDATED, oldData, newData); +} + +void TCPStoreMasterDaemon::setHandler(int socket) { + std::string key = tcputil::recvString(socket); + std::vector newData = tcputil::recvVector(socket); + doSet(key, newData); +} + +void TCPStoreMasterDaemon::compareSetHandler(int socket) { + std::string key = tcputil::recvString(socket); + std::vector currentValue = tcputil::recvVector(socket); + std::vector newValue = tcputil::recvVector(socket); + + auto pos = tcpStore_.find(key); + if (pos == tcpStore_.end()) { + if (currentValue.empty()) { + tcpStore_[key] = newValue; + + // Send key update to all watching clients + sendKeyUpdatesToClients( + key, WatchResponseType::KEY_CREATED, currentValue, newValue); + tcputil::sendVector(socket, newValue); + } else { + // TODO: This code path is not ideal as we are "lying" to the caller + // in case the key does not exist. We should come up with a working + // solution. + tcputil::sendVector(socket, currentValue); + } + } else { + if (pos->second == currentValue) { + pos->second = std::move(newValue); + + // Send key update to all watching clients + sendKeyUpdatesToClients( + key, WatchResponseType::KEY_UPDATED, currentValue, pos->second); + } + tcputil::sendVector(socket, pos->second); + } +} + +void TCPStoreMasterDaemon::addHandler(int socket) { + std::string key = tcputil::recvString(socket); + int64_t addVal = tcputil::recvValue(socket); + + bool newKey = true; + std::vector oldData; + auto it = tcpStore_.find(key); + if (it != tcpStore_.end()) { + oldData = it->second; + auto buf = reinterpret_cast(it->second.data()); + auto len = it->second.size(); + addVal += std::stoll(std::string(buf, len)); + newKey = false; + } + auto addValStr = std::to_string(addVal); + std::vector newData + = std::vector(addValStr.begin(), addValStr.end()); + tcpStore_[key] = newData; + // Now send the new value + tcputil::sendValue(socket, addVal); + // On "add", wake up all clients that have been waiting + wakeupWaitingClients(key); + // Send key update to all watching clients + newKey ? sendKeyUpdatesToClients( + key, WatchResponseType::KEY_CREATED, oldData, newData) + : sendKeyUpdatesToClients( + key, WatchResponseType::KEY_UPDATED, oldData, newData); +} + +void TCPStoreMasterDaemon::getHandler(int socket) const { + std::string key = tcputil::recvString(socket); + auto data = tcpStore_.at(key); + tcputil::sendVector(socket, data); +} + +void TCPStoreMasterDaemon::getNumKeysHandler(int socket) const { + tcputil::sendValue(socket, tcpStore_.size()); +} + +void TCPStoreMasterDaemon::deleteHandler(int socket) { + std::string key = tcputil::recvString(socket); + auto it = tcpStore_.find(key); + if (it != tcpStore_.end()) { + std::vector oldData = it->second; + // Send key update to all watching clients + std::vector newData; + sendKeyUpdatesToClients( + key, WatchResponseType::KEY_DELETED, oldData, newData); + } + auto numDeleted = tcpStore_.erase(key); + tcputil::sendValue(socket, numDeleted); +} + +void TCPStoreMasterDaemon::checkHandler(int socket) const { + SizeType nargs = 0; + tcputil::recvBytes(socket, &nargs, 1); + std::vector keys(nargs); + // for (const auto i : c10::irange(nargs)) { + // keys[i] = tcputil::recvString(socket); + // } + for (auto &key : keys) { + key = tcputil::recvString(socket); + } + // Now we have received all the keys + if (checkKeys(keys)) { + tcputil::sendValue(socket, CheckResponseType::READY); + } else { + tcputil::sendValue(socket, + CheckResponseType::NOT_READY); + } +} + +void TCPStoreMasterDaemon::waitHandler(int socket) { + SizeType nargs = 0; + tcputil::recvBytes(socket, &nargs, 1); + std::vector keys(nargs); + // for (const auto i : c10::irange(nargs)) { + // keys[i] = tcputil::recvString(socket); + // } + for (auto &key : keys) { + key = tcputil::recvString(socket); + } + if (checkKeys(keys)) { + tcputil::sendValue(socket, + WaitResponseType::STOP_WAITING); + } else { + int numKeysToAwait = 0; + for (auto &key : keys) { + // Only count keys that have not already been set + if (tcpStore_.find(key) == tcpStore_.end()) { + waitingSockets_[key].push_back(socket); + numKeysToAwait++; + } + } + keysAwaited_[socket] = numKeysToAwait; + } +} + +void TCPStoreMasterDaemon::watchHandler(int socket) { + std::string key = tcputil::recvString(socket); + + // Record the socket to respond to when the key is updated + watchedSockets_[key].push_back(socket); + + // Send update to TCPStoreWorkerDaemon on client + tcputil::sendValue( + socket, WatchResponseType::KEY_CALLBACK_REGISTERED); +} + +void TCPStoreMasterDaemon::appendHandler(int socket) { + std::string key = tcputil::recvString(socket); + std::vector newData = tcputil::recvVector(socket); + bool newKey = true; + auto it = tcpStore_.find(key); + if (it != tcpStore_.end()) { + it->second.insert(it->second.end(), newData.begin(), newData.end()); + newKey = false; + } else { + tcpStore_[key] = newData; + } + // we should not have clients waiting if we're appending, so it's all fine + wakeupWaitingClients(key); + // Send key update to all watching clients + std::vector oldData; + newKey ? sendKeyUpdatesToClients( + key, WatchResponseType::KEY_CREATED, oldData, newData) + : sendKeyUpdatesToClients( + key, WatchResponseType::KEY_APPENDED, oldData, newData); +} + +void TCPStoreMasterDaemon::multiGetHandler(int socket) { + SizeType nargs = 0; + tcputil::recvBytes(socket, &nargs, 1); + for (int i = 0; i < nargs; i++) { + auto key = tcputil::recvString(socket); + auto &data = tcpStore_.at(key); + tcputil::sendVector(socket, data, i < (nargs - 1)); + } +} + +void TCPStoreMasterDaemon::multiSetHandler(int socket) { + SizeType nargs = 0; + tcputil::recvBytes(socket, &nargs, 1); + for (int i = 0; i < nargs; i++) { + // (void)_; // Suppress unused variable warning + auto key = tcputil::recvString(socket); + auto value = tcputil::recvVector(socket); + doSet(key, value); + } +} + +bool TCPStoreMasterDaemon::checkKeys( + const std::vector &keys) const { + return std::all_of(keys.begin(), keys.end(), [this](const std::string &s) { + return tcpStore_.count(s) > 0; + }); +} + +#ifdef _WIN32 +void TCPStoreMasterDaemon::run() { + std::vector fds; + tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); + + // receive the queries + bool finished = false; + while (!finished) { + for (const auto i : c10::irange(sockets_.size())) { + fds[i].revents = 0; + } + + int res; + SYSCHECK_ERR_RETURN_NEG1( + res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count())) + if (res == 0) { + auto rv = WaitForSingleObject(ghStopEvent_, 0); + if (rv != WAIT_TIMEOUT) { + finished = true; + break; + } + continue; + } + + // TCPStore's listening socket has an event and it should now be able to + // accept new connections. + if (fds[0].revents != 0) { + if (!(fds[0].revents & POLLIN)) { + throw std::system_error( + ECONNABORTED, + std::system_category(), + "Unexpected poll revent on the master's listening socket: " + + std::to_string(fds[0].revents)); + } + Socket socket = storeListenSocket_.accept(); + int rawSocket = socket.handle(); + sockets_.emplace_back(std::move(socket)); + tcputil::addPollfd(fds, rawSocket, POLLIN); + } + queryFds(fds); + } +} +#else +void TCPStoreMasterDaemon::run() { + std::vector fds; + tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); + // Although we haven't found any documentation or literature describing + // this, we've seen cases that, under certain circumstances, the read end of + // the pipe won't receive POLLHUP when the write end is closed. However, + // under the same circumstances, writing to the pipe will guarantee POLLIN + // to be received on the read end. + // + // For more reliable termination, the main thread will write a byte to the + // pipe before closing it, and the background thread will poll for both + // POLLIN and POLLHUP. + tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP); + + // receive the queries + bool finished = false; + while (!finished) { + // for (const auto i : c10::irange(sockets_.size())) { + // fds[i].revents = 0; + // } + for (auto &fd : fds) { + fd.revents = 0; + } + + SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); + + // TCPStore's listening socket has an event and it should now be able to + // accept new connections. + if (fds[0].revents != 0) { + if (fds[0].revents ^ POLLIN) { + throw std::system_error( + ECONNABORTED, + std::system_category(), + "Unexpected poll revent on the master's listening socket: " + + std::to_string(fds[0].revents)); + } + Socket socket = storeListenSocket_.accept(); + int rawSocket = socket.handle(); + sockets_.emplace_back(std::move(socket)); + tcputil::addPollfd(fds, rawSocket, POLLIN); + } + + // The pipe receives an event which tells us to shutdown the daemon + if (fds[1].revents != 0) { + // The main thread will write a byte to the pipe then close it + // before joining the background thread + if (fds[1].revents & ~(POLLIN | POLLHUP)) { + throw std::system_error( + ECONNABORTED, + std::system_category(), + "Unexpected poll revent on the control pipe's reading fd: " + + std::to_string(fds[1].revents)); + } + finished = true; + break; + } + queryFds(fds); + } +} +#endif + +// Separate thread that is launched on all instances (including master) +// Right now only handles callbacks registered from watchKey() +class TCPStoreWorkerDaemon : public BackgroundThread { +public: + explicit TCPStoreWorkerDaemon(Socket &&listenSocket); + ~TCPStoreWorkerDaemon() override; + // Set the callback to run key change + void setCallback(std::string key, WatchKeyCallback cb); + void waitForCallbackRegistration() { + // Block until callback has been registered successfully + std::unique_lock callbackRegistrationLock( + callbackRegistrationMutex_); + callbackRegisteredCV_.wait(callbackRegistrationLock, + [&] { return callbackRegisteredData_; }); + + // Reset payload for next callback + callbackRegisteredData_ = false; + } + void setCallbackRegistered() { + { + std::unique_lock callbackRegistrationLock( + callbackRegistrationMutex_); + callbackRegisteredData_ = true; + } + callbackRegisteredCV_.notify_one(); + } + +private: + void run(); + void callbackHandler(int socket); + // List of callbacks map each watched key + std::unordered_map keyToCallbacks_{}; + std::mutex keyToCallbacksMutex_{}; + std::mutex callbackRegistrationMutex_{}; + std::condition_variable callbackRegisteredCV_{}; + bool callbackRegisteredData_ = false; +}; + +// TCPStoreListener class methods +TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(Socket &&listenSocket) + : BackgroundThread{std::move(listenSocket)} { + daemonThread_ = std::thread{&TCPStoreWorkerDaemon::run, this}; +} + +TCPStoreWorkerDaemon::~TCPStoreWorkerDaemon() { dispose(); } + +void TCPStoreWorkerDaemon::setCallback(std::string key, + WatchKeyCallback callback) { + const std::lock_guard lock(keyToCallbacksMutex_); + keyToCallbacks_[key] = callback; +} + +// Runs all the callbacks that the worker has registered +void TCPStoreWorkerDaemon::callbackHandler(int socket) { + auto watchResponse = tcputil::recvValue(socket); + if (watchResponse == WatchResponseType::KEY_CALLBACK_REGISTERED) { + // Notify the waiting "watchKey" operation to return + setCallbackRegistered(); + return; + } + std::string key = tcputil::recvString(socket); + std::vector currentValueVec = tcputil::recvVector(socket); + std::vector newValueVec = tcputil::recvVector(socket); + std::optional currentValue; + if (watchResponse == WatchResponseType::KEY_CREATED) { + assert(currentValueVec.empty()); + currentValue = std::nullopt; + } else { + currentValue + = std::string(currentValueVec.begin(), currentValueVec.end()); + } + std::optional newValue; + if (watchResponse == WatchResponseType::KEY_DELETED) { + assert(newValueVec.empty()); + newValue = std::nullopt; + } else { + newValue = std::string(newValueVec.begin(), newValueVec.end()); + } + const std::lock_guard lock(keyToCallbacksMutex_); + keyToCallbacks_.at(key)(currentValue, newValue); +} + +#ifdef _WIN32 +void TCPStoreWorkerDaemon::run() { + std::vector fds; + tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); + + while (true) { + // Check control and exit early if triggered + int res; + SYSCHECK_ERR_RETURN_NEG1( + res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count())) + if (res == 0) { + auto rvPoll = WaitForSingleObject(ghStopEvent_, 0); + if (rvPoll != WAIT_TIMEOUT) { + break; + } + continue; + } + + // if connection is closed gracefully by master, peeked data will return + // 0 + char data; + int ret = recv(fds[0].fd, &data, 1, MSG_PEEK); + if (ret == 0) { + auto rvData = WaitForSingleObject(ghStopEvent_, 0); + if (rvData != WAIT_TIMEOUT) { + break; + } + continue; + } + + // valid request, perform callback logic + callbackHandler(fds[0].fd); + } +} +#else +void TCPStoreWorkerDaemon::run() { + std::vector fds; + // Although we haven't found any documentation or literature describing + // this, we've seen cases that, under certain circumstances, the read end of + // the pipe won't receive POLLHUP when the write end is closed. However, + // under the same circumstances, writing to the pipe will guarantee POLLIN + // to be received on the read end. + // + // For more reliable termination, the main thread will write a byte to the + // pipe before closing it, and the background thread will poll for both + // POLLIN and POLLHUP. + tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP); + tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); + + while (true) { + SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); + + // Check control and exit early if triggered + // The pipe receives an event which tells us to shutdown the listener + // thread + if (fds[0].revents != 0) { + // The main thread will write a byte to the pipe then close it + // before joining the background thread + if (fds[0].revents & ~(POLLIN | POLLHUP)) { + throw std::system_error( + ECONNABORTED, + std::system_category(), + "Unexpected poll revent on the control pipe's reading fd: " + + std::to_string(fds[0].revents)); + } + break; + } + + // if connection is closed gracefully by master, peeked data will return + // 0 + char data = 0; + int ret = recv(fds[1].fd, &data, 1, MSG_PEEK); + if (ret == 0) { + continue; + } + + // valid request, perform callback logic + callbackHandler(fds[1].fd); + } +} +#endif + +} // namespace + +// Manages the lifecycle of a server daemon. +class TCPServer { +public: + static std::shared_ptr start(const TCPStoreOptions &opts); + + std::uint16_t port() const noexcept { return port_; } + + explicit TCPServer(std::uint16_t port, + std::unique_ptr &&daemon) + : port_{port}, daemon_{std::move(daemon)} {} + +private: + std::uint16_t port_; + std::unique_ptr daemon_; + + // We store weak references to all TCPServers for which the caller requested + // multi-tenancy. + static std::unordered_map> + cachedServers_; + + static std::mutex cache_mutex_; +}; + +std::unordered_map> + TCPServer::cachedServers_{}; + +std::mutex TCPServer::cache_mutex_{}; + +std::shared_ptr TCPServer::start(const TCPStoreOptions &opts) { + auto startCore = [&opts]() { + Socket socket = Socket::listen(opts.port); + + std::uint16_t port = socket.port(); + + auto daemon = std::make_unique(std::move(socket)); + + return std::make_shared(port, std::move(daemon)); + }; + + std::shared_ptr server{}; + + if (opts.multiTenant) { + std::lock_guard guard{cache_mutex_}; + + // If the caller is okay with a multi-tenant store, first check if we + // already have a TCPServer running on the specified port. + if (opts.port > 0) { + auto pos = cachedServers_.find(opts.port); + if (pos != cachedServers_.end()) { + server = pos->second.lock(); + if (server != nullptr) { + return server; + } + + // Looks like the TCPStore has been disposed, make sure that we + // release the control block. + cachedServers_.erase(pos); + } + } + + server = startCore(); + + cachedServers_.emplace(server->port(), server); + } else { + server = startCore(); + } + + return server; +} + +class TCPClient { +public: + static std::unique_ptr connect(const SocketAddress &addr, + const TCPStoreOptions &opts); + + void sendRaw(uint8_t *data, size_t lenght) { + tcputil::sendBytes(socket_.handle(), data, lenght); + } + + std::vector receiveBits() { + return tcputil::recvVector(socket_.handle()); + } + + template + T receiveValue() { + return tcputil::recvValue(socket_.handle()); + } + + void setTimeout(std::chrono::milliseconds value); + + explicit TCPClient(Socket &&socket) : socket_{std::move(socket)} {} + +private: + Socket socket_; +}; + +std::unique_ptr TCPClient::connect(const SocketAddress &addr, + const TCPStoreOptions &opts) { + auto timeout + = std::chrono::duration_cast(opts.timeout); + Socket socket = Socket::connect( + addr.host, addr.port, SocketOptions{}.connect_timeout(timeout)); + + return std::make_unique(std::move(socket)); +} + +void TCPClient::setTimeout(std::chrono::milliseconds value) { + if (value == std::chrono::milliseconds::zero()) { + return; + } + +#ifdef _WIN32 + struct timeval timeoutTV + = {static_cast(value.count() / 1000), + static_cast((value.count() % 1000) * 1000)}; +#else + struct timeval timeoutTV = { + .tv_sec = value.count() / 1000, + .tv_usec = static_cast((value.count() % 1000) * 1000), + }; +#endif + SYSCHECK_ERR_RETURN_NEG1(::setsockopt(socket_.handle(), + SOL_SOCKET, + SO_RCVTIMEO, + reinterpret_cast(&timeoutTV), + sizeof(timeoutTV))); +} + +class TCPCallbackClient { +public: + static std::unique_ptr + connect(const SocketAddress &addr, const TCPStoreOptions &opts); + + void setCallback(const std::string &key, WatchKeyCallback callback); + + explicit TCPCallbackClient(int rawSocket, + std::unique_ptr &&daemon) + : rawSocket_{rawSocket}, daemon_{std::move(daemon)} {} + +private: + int rawSocket_; + std::unique_ptr daemon_; + std::mutex mutex_; +}; + +std::unique_ptr +TCPCallbackClient::connect(const SocketAddress &addr, + const TCPStoreOptions &opts) { + auto timeout + = std::chrono::duration_cast(opts.timeout); + Socket socket = Socket::connect( + addr.host, addr.port, SocketOptions{}.connect_timeout(timeout)); + + int rawSocket = socket.handle(); + + auto daemon = std::make_unique(std::move(socket)); + + return std::make_unique(rawSocket, std::move(daemon)); +} + +void TCPCallbackClient::setCallback(const std::string &key, + WatchKeyCallback callback) { + std::lock_guard guard{mutex_}; + + daemon_->setCallback(key, callback); + + tcputil::sendValue(rawSocket_, QueryType::WATCH_KEY); + + tcputil::sendString(rawSocket_, key); + + daemon_->waitForCallbackRegistration(); +} + +class SendBuffer { + // ethernet mtu 1500 - 40 (ip v6 header) - 20 (tcp header) + const size_t FLUSH_WATERMARK = 1440; + std::vector buffer; + detail::TCPClient &client; + + void maybeFlush() { + if (buffer.size() >= FLUSH_WATERMARK) { + flush(); + } + } + +public: + SendBuffer(detail::TCPClient &client, detail::QueryType cmd) + : client(client) { + buffer.reserve(32); // enough for most commands + buffer.push_back((uint8_t) cmd); + } + + void appendString(const std::string &str) { + appendValue(str.size()); + buffer.insert(buffer.end(), str.begin(), str.end()); + maybeFlush(); + } + + void appendBytes(const std::vector &vec) { + appendValue(vec.size()); + buffer.insert(buffer.end(), vec.begin(), vec.end()); + maybeFlush(); + } + + template + void appendValue(T value) { + uint8_t *begin = reinterpret_cast(&value); + buffer.insert(buffer.end(), begin, begin + sizeof(T)); + maybeFlush(); + } + + void flush() { + if (buffer.size() > 0) { + client.sendRaw(buffer.data(), buffer.size()); + buffer.clear(); + } + } +}; + +} // namespace detail + +using detail::Socket; + +// TCPStore class methods +TCPStore::TCPStore(const std::string &masterAddr, + std::uint16_t masterPort, + std::optional numWorkers, + bool isServer, + const std::chrono::milliseconds &timeout, + bool waitWorkers) + : TCPStore{masterAddr, + TCPStoreOptions{masterPort, + isServer, + numWorkers + ? std::optional(*numWorkers) + : std::nullopt, + waitWorkers, + timeout}} {} + +TCPStore::TCPStore(std::string host, const TCPStoreOptions &opts) + : Store{opts.timeout}, addr_{std::move(host)}, numWorkers_{ + opts.numWorkers} { + Socket::initialize(); + + if (opts.isServer) { + server_ = detail::TCPServer::start(opts); + + addr_.port = server_->port(); + } else { + addr_.port = opts.port; + } + + client_ = detail::TCPClient::connect(addr_, opts); + + if (opts.waitWorkers) { + waitForWorkers(); + } + + callbackClient_ = detail::TCPCallbackClient::connect(addr_, opts); +} + +TCPStore::~TCPStore() = default; + +void TCPStore::waitForWorkers() { + if (numWorkers_ == std::nullopt) { + return; + } + + incrementValueBy(initKey_, 1); + + // Let server block until all workers have completed, this ensures that + // the server daemon thread is always running until the very end + if (server_) { + const auto start = std::chrono::steady_clock::now(); + while (true) { + // TODO: Any chance to make this cleaner? + std::vector value = doGet(initKey_); + auto buf = reinterpret_cast(value.data()); + auto len = value.size(); + int numWorkersCompleted = std::stoi(std::string(buf, len)); + if (numWorkersCompleted >= static_cast(*numWorkers_)) { + break; + } + const auto elapsed + = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (timeout_ != kNoTimeout && elapsed > timeout_) { + break; + } + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } +} + +void TCPStore::set(const std::string &key, const std::vector &data) { + const std::lock_guard lock(activeOpLock_); + detail::SendBuffer buffer(*client_, detail::QueryType::SET); + buffer.appendString(keyPrefix_ + key); + buffer.appendBytes(data); + buffer.flush(); +} + +std::vector +TCPStore::compareSet(const std::string &key, + const std::vector &expectedValue, + const std::vector &desiredValue) { + const std::lock_guard lock(activeOpLock_); + detail::SendBuffer buffer(*client_, detail::QueryType::COMPARE_SET); + buffer.appendString(keyPrefix_ + key); + buffer.appendBytes(expectedValue); + buffer.appendBytes(desiredValue); + buffer.flush(); + + return client_->receiveBits(); +} + +std::vector TCPStore::get(const std::string &key) { + const std::lock_guard lock(activeOpLock_); + return doGet(keyPrefix_ + key); +} + +std::vector TCPStore::doGet(const std::string &key) { + std::vector keys; + keys.push_back(key); + doWait(keys, timeout_); + detail::SendBuffer buffer(*client_, detail::QueryType::GET); + buffer.appendString(key); + buffer.flush(); + + return client_->receiveBits(); +} + +int64_t TCPStore::add(const std::string &key, int64_t value) { + const std::lock_guard lock(activeOpLock_); + return incrementValueBy(keyPrefix_ + key, value); +} + +bool TCPStore::deleteKey(const std::string &key) { + const std::lock_guard lock(activeOpLock_); + detail::SendBuffer buffer(*client_, detail::QueryType::DELETE_KEY); + buffer.appendString(keyPrefix_ + key); + buffer.flush(); + + auto numDeleted = client_->receiveValue(); + return numDeleted == 1; +} + +void TCPStore::watchKey(const std::string &key, WatchKeyCallback callback) { + const std::lock_guard lock(activeOpLock_); + callbackClient_->setCallback(keyPrefix_ + key, callback); +} + +int64_t TCPStore::incrementValueBy(const std::string &key, int64_t delta) { + detail::SendBuffer buff(*client_, detail::QueryType::ADD); + buff.appendString(key); + buff.appendValue(delta); + buff.flush(); + + return client_->receiveValue(); +} + +int64_t TCPStore::getNumKeys() { + const std::lock_guard lock(activeOpLock_); + detail::SendBuffer buffer(*client_, detail::QueryType::GETNUMKEYS); + buffer.flush(); + + return client_->receiveValue(); +} + +bool TCPStore::check(const std::vector &keys) { + const std::lock_guard lock(activeOpLock_); + detail::SendBuffer buffer(*client_, detail::QueryType::CHECK); + buffer.appendValue(keys.size()); + + for (const std::string &key : keys) { + buffer.appendString(keyPrefix_ + key); + } + buffer.flush(); + + auto response = client_->receiveValue(); + if (response == detail::CheckResponseType::READY) { + return true; + } + if (response == detail::CheckResponseType::NOT_READY) { + return false; + } + // TORCH_CHECK(false, "ready or not_ready response expected"); + throw std::runtime_error("ready or not_ready response expected"); +} + +void TCPStore::wait(const std::vector &keys) { + wait(keys, timeout_); +} + +void TCPStore::wait(const std::vector &keys, + const std::chrono::milliseconds &timeout) { + const std::lock_guard lock(activeOpLock_); + std::vector prefixedKeys{}; + prefixedKeys.reserve(keys.size()); + for (const std::string &key : keys) { + prefixedKeys.emplace_back(keyPrefix_ + key); + } + + doWait(prefixedKeys, timeout); +} + +void TCPStore::doWait(std::vector keys, + std::chrono::milliseconds timeout) { + // TODO: Should we revert to the original timeout at the end of the call? + client_->setTimeout(timeout); + + detail::SendBuffer buffer(*client_, detail::QueryType::WAIT); + buffer.appendValue(keys.size()); + for (const std::string &key : keys) { + buffer.appendString(key); + } + buffer.flush(); + + auto response = client_->receiveValue(); + if (response != detail::WaitResponseType::STOP_WAITING) { + // TORCH_CHECK(false, "Stop_waiting response is expected"); + throw std::runtime_error("Stop_waiting response is expected"); + } +} + +void TCPStore::append(const std::string &key, + const std::vector &data) { + const std::lock_guard lock(activeOpLock_); + detail::SendBuffer buffer(*client_, detail::QueryType::APPEND); + buffer.appendString(keyPrefix_ + key); + buffer.appendBytes(data); + buffer.flush(); +} + +std::vector> +TCPStore::multiGet(const std::vector &keys) { + const std::lock_guard lock(activeOpLock_); + std::vector prefixedKeys; + prefixedKeys.reserve(keys.size()); + for (const std::string &key : keys) { + prefixedKeys.emplace_back(keyPrefix_ + key); + } + doWait(prefixedKeys, timeout_); + + detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_GET); + buffer.appendValue(keys.size()); + for (auto &key : prefixedKeys) { + buffer.appendString(key); + } + buffer.flush(); + + std::vector> result; + result.reserve(keys.size()); + for (size_t i = 0; i < keys.size(); ++i) { + result.emplace_back(client_->receiveBits()); + } + return result; +} + +void TCPStore::multiSet(const std::vector &keys, + const std::vector> &values) { + // TORCH_CHECK( + // keys.size() == values.size(), + // "multiSet keys and values vectors must be of same size"); + assert(keys.size() == values.size()); + const std::lock_guard lock(activeOpLock_); + + detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_SET); + buffer.appendValue(keys.size()); + for (int i = 0; i < keys.size(); i++) { + buffer.appendString(keyPrefix_ + keys[i]); + buffer.appendBytes(values[i]); + } + buffer.flush(); +} + +bool TCPStore::hasExtendedApi() const { return true; } + +} // namespace xoscar diff --git a/python/setup.py b/python/setup.py index 8de0939f..ef03f739 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,13 +14,17 @@ import os import platform +import re +import subprocess import sys +from pathlib import Path from sysconfig import get_config_vars import numpy as np from Cython.Build import cythonize from pkg_resources import parse_version from setuptools import Extension, setup + try: import distutils.ccompiler @@ -116,6 +120,45 @@ def build_long_description(): return f.read() +def build_cpp(): + source_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + debug = int(os.environ.get("DEBUG", 0)) + cfg = "Debug" if debug else "Release" + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + output_directory = Path(source_dir) / "python" / "xoscar" / "collective" / "rendezvous" + cmake_args = [ + f"-DLIBRARY_OUTPUT_DIRECTORY={output_directory}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + + build_args = [] + # Adding CMake arguments set as environment variable + # (needed e.g. to build for ARM OSx on conda-forge) + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if sys.platform.startswith("darwin"): + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] + + build_temp = Path(source_dir) / "build" + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", source_dir, *cmake_args], cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True + ) + + setup_options = dict( version=versioneer.get_version(), ext_modules=extensions, @@ -123,3 +166,4 @@ def build_long_description(): long_description_content_type="text/markdown", ) setup(**setup_options) +build_cpp() diff --git a/python/xoscar/collective/__init__.py b/python/xoscar/collective/__init__.py new file mode 100644 index 00000000..37f6558d --- /dev/null +++ b/python/xoscar/collective/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/xoscar/collective/rendezvous/__init__.py b/python/xoscar/collective/rendezvous/__init__.py new file mode 100644 index 00000000..37f6558d --- /dev/null +++ b/python/xoscar/collective/rendezvous/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/xoscar/collective/rendezvous/test/__init__.py b/python/xoscar/collective/rendezvous/test/__init__.py new file mode 100644 index 00000000..37f6558d --- /dev/null +++ b/python/xoscar/collective/rendezvous/test/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/xoscar/collective/rendezvous/test/test_tcp_store.py b/python/xoscar/collective/rendezvous/test/test_tcp_store.py new file mode 100644 index 00000000..595d7df9 --- /dev/null +++ b/python/xoscar/collective/rendezvous/test/test_tcp_store.py @@ -0,0 +1,61 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing + +import pytest + +from .. import xoscar_store as xs + + +def test_tcp_store_options(): + opt = xs.TCPStoreOptions() + assert opt.numWorkers is None + assert opt.isServer is False + + opt.numWorkers = 2 + assert opt.numWorkers == 2 + + with pytest.raises(TypeError): + opt.numWorkers = [5] + + +def server(): + opt = xs.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + opt.isServer = True + + store = xs.TCPStore("127.0.0.1", opt) + val = store.get("test_key") + assert val == b"test_12345" + + +def worker(): + opt = xs.TCPStoreOptions() + opt.port = 25001 + opt.numWorkers = 2 + opt.isServer = False + + store = xs.TCPStore("127.0.0.1", opt) + store.set("test_key", b"test_12345") + + +def test_tcp_store(): + process1 = multiprocessing.Process(target=server) + process1.start() + process2 = multiprocessing.Process(target=worker) + process2.start() + + process1.join() + process2.join() diff --git a/python/xoscar/collective/rendezvous/xoscar_store.pyi b/python/xoscar/collective/rendezvous/xoscar_store.pyi new file mode 100644 index 00000000..ecfaa103 --- /dev/null +++ b/python/xoscar/collective/rendezvous/xoscar_store.pyi @@ -0,0 +1,29 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +from typing import List, Optional + +class TCPStoreOptions: + port: int + isServer: bool + numWorkers: Optional[int] + waitWorkers: bool + timeout: datetime.timedelta + multiTenant: bool + +class TCPStore: + def __init__(self, host: str, opts: TCPStoreOptions = TCPStoreOptions()): ... + def set(self, key: str, value: bytes): ... + def get(self, key: str) -> bytes: ... + def wait(self, keys: List[str]): ... diff --git a/third_party/fmt b/third_party/fmt new file mode 160000 index 00000000..13156e54 --- /dev/null +++ b/third_party/fmt @@ -0,0 +1 @@ +Subproject commit 13156e54bf91e44641ce3aac041d31f9a15a8042 diff --git a/third_party/pybind11 b/third_party/pybind11 new file mode 160000 index 00000000..e10da79b --- /dev/null +++ b/third_party/pybind11 @@ -0,0 +1 @@ +Subproject commit e10da79b6ee2554be364ef14df1c988f94df02ea From 923e88f447ff30fd6d6e7ef470cb5d9910c6c480 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 27 Jun 2023 17:17:19 +0800 Subject: [PATCH 02/19] fix setup.py --- python/setup.py | 206 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 167 insertions(+), 39 deletions(-) diff --git a/python/setup.py b/python/setup.py index ef03f739..cb9e1f2e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -17,6 +17,7 @@ import re import subprocess import sys +from distutils.file_util import copy_file from pathlib import Path from sysconfig import get_config_vars @@ -24,6 +25,9 @@ from Cython.Build import cythonize from pkg_resources import parse_version from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext +from setuptools.extension import Library +from distutils.command.build_ext import build_ext as _du_build_ext try: import distutils.ccompiler @@ -35,6 +39,15 @@ except ImportError: pass +try: + # Attempt to use Cython for building extensions, if available + from Cython.Distutils.build_ext import build_ext as _build_ext + # Additionally, assert that the compiler module will load + # also. Ref #1229. + __import__('Cython.Compiler.Main') +except ImportError: + _build_ext = _du_build_ext + # From https://github.com/pandas-dev/pandas/pull/24274: # For mac, ensure extensions are built for macos 10.9 when compiling on a # 10.9 system or above, overriding distuitls behaviour which is to target @@ -120,50 +133,165 @@ def build_long_description(): return f.read() -def build_cpp(): - source_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - debug = int(os.environ.get("DEBUG", 0)) - cfg = "Debug" if debug else "Release" - - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. - output_directory = Path(source_dir) / "python" / "xoscar" / "collective" / "rendezvous" - cmake_args = [ - f"-DLIBRARY_OUTPUT_DIRECTORY={output_directory}", - f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm - ] - - build_args = [] - # Adding CMake arguments set as environment variable - # (needed e.g. to build for ARM OSx on conda-forge) - if "CMAKE_ARGS" in os.environ: - cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] - - if sys.platform.startswith("darwin"): - # Cross-compile support for macOS - respect ARCHFLAGS if set - archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) - if archs: - cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] - - build_temp = Path(source_dir) / "build" - if not build_temp.exists(): - build_temp.mkdir(parents=True) - - subprocess.run( - ["cmake", source_dir, *cmake_args], cwd=build_temp, check=True - ) - subprocess.run( - ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True - ) +# Convert distutils Windows platform specifiers to CMake -A arguments +PLAT_TO_CMAKE = { + "win32": "Win32", + "win-amd64": "x64", + "win-arm32": "ARM", + "win-arm64": "ARM64", +} + + +# A CMakeExtension needs a sourcedir instead of a file list. +# The name must be the _single_ output extension from the CMake build. +# If you need multiple extensions, see scikit-build. +class XoscarStoreExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[]) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class CMakeBuild(build_ext): + def copy_extensions_to_source(self): + build_py = self.get_finalized_command('build_py') + for ext in self.extensions: + if not isinstance(ext, XoscarStoreExtension): + fullname = self.get_ext_fullname(ext.name) + filename = self.get_ext_filename(fullname) + modpath = fullname.split('.') + package = '.'.join(modpath[:-1]) + package_dir = build_py.get_package_dir(package) + dest_filename = os.path.join(package_dir, + os.path.basename(filename)) + src_filename = os.path.join(self.build_lib, filename) + + # Always copy, even if source is older than destination, to ensure + # that the right extensions for the current Python/platform are + # used. + copy_file( + src_filename, dest_filename, verbose=self.verbose, + dry_run=self.dry_run + ) + if ext._needs_stub: + self.write_stub(package_dir or os.curdir, ext, True) + + def build_extension(self, ext): + if isinstance(ext, XoscarStoreExtension): + self.build_store(ext) + else: + ext._convert_pyx_sources_to_lang() + _compiler = self.compiler + try: + if isinstance(ext, Library): + self.compiler = self.shlib_compiler + _build_ext.build_extension(self, ext) + if ext._needs_stub: + build_lib = self.get_finalized_command('build_py').build_lib + self.write_stub(build_lib, ext) + finally: + self.compiler = _compiler + + def build_store(self, ext: XoscarStoreExtension) -> None: + # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) + extdir = ext_fullpath.parent.resolve() + source_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + output_directory = Path(source_dir) / "python" / "xoscar" / "collective" / "rendezvous" + + # Using this requires trailing slash for auto-detection & inclusion of + # auxiliary "native" libs + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we need to check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DLIBRARY_OUTPUT_DIRECTORY={output_directory}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + build_args = [] + # Adding CMake arguments set as environment variable + # (needed e.g. to build for ARM OSx on conda-forge) + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if self.compiler.compiler_type != "msvc": + # Using Ninja-build since it a) is available as a wheel and b) + # multithreads automatically. MSVC would require all variables be + # exported for Ninja to pick it up, which is a little tricky to do. + # Users can override the generator with CMAKE_GENERATOR in CMake + # 3.15+. + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja + + ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + else: + # Single config generators are handled "normally" + single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) + + # CMake allows an arch-in-generator style for backward compatibility + contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) + + # Specify the arch if using MSVC generator, but only if it doesn't + # contain a backward-compatibility arch spec already in the + # generator name. + if not single_config and not contains_arch: + cmake_args += ["-A", PLAT_TO_CMAKE[self.plat_name]] + + # Multi-config generators have a different way to specify configs + if not single_config: + cmake_args += [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}" + ] + build_args += ["--config", cfg] + + if sys.platform.startswith("darwin"): + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", source_dir, *cmake_args], cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True + ) setup_options = dict( version=versioneer.get_version(), - ext_modules=extensions, + ext_modules=extensions + [XoscarStoreExtension("xoscar_store")], + cmdclass={"build_ext": CMakeBuild}, long_description=build_long_description(), long_description_content_type="text/markdown", ) setup(**setup_options) -build_cpp() From 564c17574d6a89cb2dcaa1994f7229006d07b963 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 27 Jun 2023 17:17:44 +0800 Subject: [PATCH 03/19] fix setup.py --- python/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index cb9e1f2e..38261236 100644 --- a/python/setup.py +++ b/python/setup.py @@ -17,6 +17,7 @@ import re import subprocess import sys +from distutils.command.build_ext import build_ext as _du_build_ext from distutils.file_util import copy_file from pathlib import Path from sysconfig import get_config_vars @@ -27,7 +28,6 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext from setuptools.extension import Library -from distutils.command.build_ext import build_ext as _du_build_ext try: import distutils.ccompiler @@ -42,6 +42,7 @@ try: # Attempt to use Cython for building extensions, if available from Cython.Distutils.build_ext import build_ext as _build_ext + # Additionally, assert that the compiler module will load # also. Ref #1229. __import__('Cython.Compiler.Main') From 96464ed20851eb2776c5932db52bc315cd98e2f9 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 27 Jun 2023 18:41:42 +0800 Subject: [PATCH 04/19] not support UT on windows --- .github/workflows/python.yaml | 6 ++++++ python/setup.py | 6 +++++- .../collective/rendezvous/test/test_tcp_store.py | 3 +++ python/xoscar/tests/core.py | 10 +++++++++- python/xoscar/utils.py | 4 ++++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 7382a007..88cdc705 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -85,6 +85,12 @@ jobs: python-version: ${{ matrix.python-version }} activate-environment: ${{ env.CONDA_ENV }} + # Fix "version `GLIBCXX_3.4.30' not found (required by xoscar_store.cpython-311-x86_64-linux-gnu.so)" issue in Python 3.11 + - name: Install libstdcxx-ng for Python 3.11 + if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version == '3.11') }} + run: | + conda install -c conda-forge libstdcxx-ng + - name: Install dependencies env: MODULE: ${{ matrix.module }} diff --git a/python/setup.py b/python/setup.py index 38261236..0663f5ef 100644 --- a/python/setup.py +++ b/python/setup.py @@ -177,8 +177,12 @@ def copy_extensions_to_source(self): self.write_stub(package_dir or os.curdir, ext, True) def build_extension(self, ext): - if isinstance(ext, XoscarStoreExtension): + # TODO: support windows compilation + is_windows = sys.platform.startswith('win') + if isinstance(ext, XoscarStoreExtension) and not is_windows: self.build_store(ext) + elif isinstance(ext, XoscarStoreExtension) and is_windows: + pass else: ext._convert_pyx_sources_to_lang() _compiler = self.compiler diff --git a/python/xoscar/collective/rendezvous/test/test_tcp_store.py b/python/xoscar/collective/rendezvous/test/test_tcp_store.py index 595d7df9..01c119d3 100644 --- a/python/xoscar/collective/rendezvous/test/test_tcp_store.py +++ b/python/xoscar/collective/rendezvous/test/test_tcp_store.py @@ -15,9 +15,11 @@ import pytest +from ....tests.core import require_unix from .. import xoscar_store as xs +@require_unix def test_tcp_store_options(): opt = xs.TCPStoreOptions() assert opt.numWorkers is None @@ -51,6 +53,7 @@ def worker(): store.set("test_key", b"test_12345") +@require_unix def test_tcp_store(): process1 = multiprocessing.Process(target=server) process1.start() diff --git a/python/xoscar/tests/core.py b/python/xoscar/tests/core.py index e8337007..3bcdff39 100644 --- a/python/xoscar/tests/core.py +++ b/python/xoscar/tests/core.py @@ -17,7 +17,7 @@ import pytest -from ..utils import lazy_import +from ..utils import is_windows, lazy_import cupy = lazy_import("cupy") cudf = lazy_import("cudf") @@ -45,6 +45,14 @@ def require_ucx(func): return func +def require_unix(func): + if pytest: + func = pytest.mark.unix(func) + + func = pytest.mark.skipif(is_windows(), reason="only unix is supported")(func) + return func + + DICT_NOT_EMPTY = type("DICT_NOT_EMPTY", (object,), {}) # is check works for deepcopy diff --git a/python/xoscar/utils.py b/python/xoscar/utils.py index a92ba582..3a5f3a54 100644 --- a/python/xoscar/utils.py +++ b/python/xoscar/utils.py @@ -454,3 +454,7 @@ def retry_call(*args, **kwargs): def is_cuda_buffer(cuda_buffer: Union["_cupy.ndarray", "_rmm.DeviceBuffer"]) -> bool: # type: ignore return hasattr(cuda_buffer, "__cuda_array_interface__") + + +def is_windows(): + return sys.platform.startswith("win") From 16c89f66049d83703dd3f68c384fa26594961190 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 27 Jun 2023 18:56:06 +0800 Subject: [PATCH 05/19] fix import error on windows --- python/xoscar/collective/rendezvous/test/test_tcp_store.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/xoscar/collective/rendezvous/test/test_tcp_store.py b/python/xoscar/collective/rendezvous/test/test_tcp_store.py index 01c119d3..ab588c36 100644 --- a/python/xoscar/collective/rendezvous/test/test_tcp_store.py +++ b/python/xoscar/collective/rendezvous/test/test_tcp_store.py @@ -16,7 +16,11 @@ import pytest from ....tests.core import require_unix -from .. import xoscar_store as xs + +try: + from .. import xoscar_store as xs +except ImportError: # windows case + xs = None # type: ignore @require_unix From cb4f15228f6147cb90757b938d6960d979bca3a2 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 27 Jun 2023 18:56:41 +0800 Subject: [PATCH 06/19] fix import error on windows --- .../collective/rendezvous/test/test_tcp_store.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/xoscar/collective/rendezvous/test/test_tcp_store.py b/python/xoscar/collective/rendezvous/test/test_tcp_store.py index ab588c36..c23df536 100644 --- a/python/xoscar/collective/rendezvous/test/test_tcp_store.py +++ b/python/xoscar/collective/rendezvous/test/test_tcp_store.py @@ -17,14 +17,11 @@ from ....tests.core import require_unix -try: - from .. import xoscar_store as xs -except ImportError: # windows case - xs = None # type: ignore - @require_unix def test_tcp_store_options(): + from .. import xoscar_store as xs + opt = xs.TCPStoreOptions() assert opt.numWorkers is None assert opt.isServer is False @@ -37,6 +34,8 @@ def test_tcp_store_options(): def server(): + from .. import xoscar_store as xs + opt = xs.TCPStoreOptions() opt.port = 25001 opt.numWorkers = 2 @@ -48,6 +47,8 @@ def server(): def worker(): + from .. import xoscar_store as xs + opt = xs.TCPStoreOptions() opt.port = 25001 opt.numWorkers = 2 From 139684c06428e5e870309ab1dde94062f2d882d0 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 27 Jun 2023 19:00:22 +0800 Subject: [PATCH 07/19] fix read the docs --- .readthedocs.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index b162ef80..980922b1 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,6 +8,8 @@ build: os: ubuntu-20.04 tools: python: "3.9" + apt_packages: + - cmake python: install: From 9cd6baf95ccd745ab7ea429162511710aa8b1d81 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 12:02:22 +0800 Subject: [PATCH 08/19] github ci to check files --- .github/workflows/python.yaml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 88cdc705..27bfc16b 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -49,6 +49,21 @@ jobs: - name: codespell run: pip install codespell && cd python && codespell xoscar + - name: clang-format + uses: jidicula/clang-format-action@v4.11.0 + with: + clang-format-version: '14' + fallback-style: 'LLVM' + check-path: 'cpp' + - name: cmake-format + uses: PuneetMatharu/cmake-format-lint-action@v1.0.0 + with: + args: --config-files .cmake-format.yaml --check + - name: cpplint + run: | + pip install cpplint + cpplint --recursive cpp + build_test_job: runs-on: ${{ matrix.os }} needs: lint From 156b45b6dca5d9099fa453b2a3c5ef7c41db9b10 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 12:07:43 +0800 Subject: [PATCH 09/19] fix cmake-format --- .github/workflows/python.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 27bfc16b..116c9371 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -55,14 +55,14 @@ jobs: clang-format-version: '14' fallback-style: 'LLVM' check-path: 'cpp' + - name: cpplint + run: | + pip install cpplint pyyaml + cpplint --recursive cpp - name: cmake-format uses: PuneetMatharu/cmake-format-lint-action@v1.0.0 with: args: --config-files .cmake-format.yaml --check - - name: cpplint - run: | - pip install cpplint - cpplint --recursive cpp build_test_job: runs-on: ${{ matrix.os }} From 1f82c0f1eb082b89fc3a7f2aeacebf6bfc414e7f Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 12:19:46 +0800 Subject: [PATCH 10/19] fix cmake-format --- .github/workflows/python.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 116c9371..e0026bd7 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -57,10 +57,10 @@ jobs: check-path: 'cpp' - name: cpplint run: | - pip install cpplint pyyaml + pip install cpplint cpplint --recursive cpp - name: cmake-format - uses: PuneetMatharu/cmake-format-lint-action@v1.0.0 + uses: PuneetMatharu/cmake-format-lint-action@v1.0.2 with: args: --config-files .cmake-format.yaml --check From 460d6ffd3eb6f71563cb43fa2feaeb92fcf64bdd Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 12:34:52 +0800 Subject: [PATCH 11/19] add clang format ignore --- .clang-format-ignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .clang-format-ignore diff --git a/.clang-format-ignore b/.clang-format-ignore new file mode 100644 index 00000000..28e05f62 --- /dev/null +++ b/.clang-format-ignore @@ -0,0 +1 @@ +third_party/ \ No newline at end of file From 6845c3fb48116b5d74be18d29e2e61e6005d6bab Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 13:02:03 +0800 Subject: [PATCH 12/19] fix cmake format check --- .clang-format-ignore | 1 - .github/workflows/python.yaml | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) delete mode 100644 .clang-format-ignore diff --git a/.clang-format-ignore b/.clang-format-ignore deleted file mode 100644 index 28e05f62..00000000 --- a/.clang-format-ignore +++ /dev/null @@ -1 +0,0 @@ -third_party/ \ No newline at end of file diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index e0026bd7..1c7cc8d3 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -60,9 +60,9 @@ jobs: pip install cpplint cpplint --recursive cpp - name: cmake-format - uses: PuneetMatharu/cmake-format-lint-action@v1.0.2 - with: - args: --config-files .cmake-format.yaml --check + run: | + pip install cmakelang + find . -name "CMakeLists.txt" -not -path "*third_party/*" | xargs cmake-format -c .cmake-format.yaml --check build_test_job: runs-on: ${{ matrix.os }} From 860e3ed2a6eb16ebdcfca9a74445036ce97f9c59 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 13:03:48 +0800 Subject: [PATCH 13/19] fix cmake format check --- .github/workflows/python.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 1c7cc8d3..28e4d4a2 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -61,7 +61,7 @@ jobs: cpplint --recursive cpp - name: cmake-format run: | - pip install cmakelang + pip install cmakelang[YAML] find . -name "CMakeLists.txt" -not -path "*third_party/*" | xargs cmake-format -c .cmake-format.yaml --check build_test_job: From 3135a5115ab778a4fedc43e262153ec1028ad5c6 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 13:15:27 +0800 Subject: [PATCH 14/19] remove torch --- cpp/collective/rendezvous/include/store.hpp | 5 ----- cpp/collective/rendezvous/src/tcp_store.cpp | 13 ++----------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/cpp/collective/rendezvous/include/store.hpp b/cpp/collective/rendezvous/include/store.hpp index be4e943b..57f430c2 100644 --- a/cpp/collective/rendezvous/include/store.hpp +++ b/cpp/collective/rendezvous/include/store.hpp @@ -56,7 +56,6 @@ class Store { compareSet(const std::string &key, const std::vector ¤tValue, const std::vector &newValue) { - // TORCH_INTERNAL_ASSERT(false, "Not implemented."); throw std::runtime_error("Not implemented."); } @@ -91,10 +90,6 @@ class Store { // DELETE: callback(currentValue, c10::nullopt) // null newValue virtual void watchKey(const std::string & /* unused */, WatchKeyCallback /* unused */) { - // TORCH_CHECK( - // false, - // "watchKey only implemented for TCPStore and PrefixStore that - // wraps TCPStore."); throw std::runtime_error("watchKey only implemented for TCPStore and " "PrefixStore that wraps TCPStore."); } diff --git a/cpp/collective/rendezvous/src/tcp_store.cpp b/cpp/collective/rendezvous/src/tcp_store.cpp index d32ff13c..0fcb0c9a 100644 --- a/cpp/collective/rendezvous/src/tcp_store.cpp +++ b/cpp/collective/rendezvous/src/tcp_store.cpp @@ -24,6 +24,7 @@ limitations under the License. */ #include #include +// TODO: Currently not support windows #ifdef _WIN32 # include # include @@ -33,7 +34,7 @@ limitations under the License. */ #endif #ifdef _WIN32 -# include +// # include #else # include "unix_sock_utils.hpp" #endif @@ -120,10 +121,6 @@ void BackgroundThread::stop() { SetEvent(ghStopEvent_); } #else void BackgroundThread::initStopSignal() { if (pipe(controlPipeFd_.data()) == -1) { - // TORCH_CHECK( - // false, - // "Failed to create the control pipe to start the " - // "BackgroundThread run"); throw std::runtime_error("Failed to create the control pipe to start " "the BackgroundThread run"); } @@ -323,7 +320,6 @@ void TCPStoreMasterDaemon::query(int socket) { } else if (qt == QueryType::MULTI_SET) { multiSetHandler(socket); } else { - // TORCH_CHECK(false, "Unexpected query type"); throw std::runtime_error("Unexpected query type"); } } @@ -1226,7 +1222,6 @@ bool TCPStore::check(const std::vector &keys) { if (response == detail::CheckResponseType::NOT_READY) { return false; } - // TORCH_CHECK(false, "ready or not_ready response expected"); throw std::runtime_error("ready or not_ready response expected"); } @@ -1260,7 +1255,6 @@ void TCPStore::doWait(std::vector keys, auto response = client_->receiveValue(); if (response != detail::WaitResponseType::STOP_WAITING) { - // TORCH_CHECK(false, "Stop_waiting response is expected"); throw std::runtime_error("Stop_waiting response is expected"); } } @@ -1301,9 +1295,6 @@ TCPStore::multiGet(const std::vector &keys) { void TCPStore::multiSet(const std::vector &keys, const std::vector> &values) { - // TORCH_CHECK( - // keys.size() == values.size(), - // "multiSet keys and values vectors must be of same size"); assert(keys.size() == values.size()); const std::lock_guard lock(activeOpLock_); From 4a7482a70ea1e1e46fa0a765fd11b8c42781ca78 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Wed, 28 Jun 2023 13:20:04 +0800 Subject: [PATCH 15/19] fix include --- cpp/collective/rendezvous/include/error.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/collective/rendezvous/include/error.h b/cpp/collective/rendezvous/include/error.h index b29ffb14..17645dc4 100644 --- a/cpp/collective/rendezvous/include/error.h +++ b/cpp/collective/rendezvous/include/error.h @@ -14,9 +14,8 @@ limitations under the License. */ #pragma once -#include "third_party/fmt/include/fmt/format.h" - #include +#include #include namespace fmt { From 66be2d3988778644037347c0668e2711d7c72853 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Mon, 3 Jul 2023 17:50:39 +0800 Subject: [PATCH 16/19] add license --- cpp/collective/rendezvous/LICENSE | 83 +++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 cpp/collective/rendezvous/LICENSE diff --git a/cpp/collective/rendezvous/LICENSE b/cpp/collective/rendezvous/LICENSE new file mode 100644 index 00000000..52ef620b --- /dev/null +++ b/cpp/collective/rendezvous/LICENSE @@ -0,0 +1,83 @@ +Apache 2.0 License with Code from PyTorch Repository + +This software includes code primarily sourced from the PyTorch repository, which is governed by its own original license. +The original license for the PyTorch repository, as well as the additional terms below, apply to the code in this software. + +----------------------------------------------------------------------------- +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file From e1231a9865b579f44b76fac839052def9505140b Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 4 Jul 2023 14:52:20 +0800 Subject: [PATCH 17/19] fix compile --- CMakeLists.txt | 4 +++- python/setup.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d6accd7c..6fde062e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,10 @@ cmake_minimum_required(VERSION 3.11...3.21) project(XoscarCollective) -if(NOT ${PYTHON_EXECUTABLE}) +if(NOT DEFINED PYTHON_PATH) find_package(Python COMPONENTS Interpreter Development) +else() + set(PYTHON_EXECUTABLE ${PYTHON_PATH}) endif() set(CMAKE_CXX_STANDARD 20) diff --git a/python/setup.py b/python/setup.py index 0663f5ef..e619ed2f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -218,7 +218,7 @@ def build_store(self, ext: XoscarStoreExtension) -> None: # from Python. cmake_args = [ f"-DLIBRARY_OUTPUT_DIRECTORY={output_directory}", - f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DPYTHON_PATH={sys.executable}", f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm ] build_args = [] From 34a6d69ea5c39e551fa3fb59481c31936698e8e1 Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 4 Jul 2023 15:14:03 +0800 Subject: [PATCH 18/19] for local development --- cpp/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 569ce1fd..d421cfa5 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -3,6 +3,9 @@ cmake_minimum_required(VERSION 3.11...3.21) project(XoscarCollective) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +if(NOT DEFINED LIBRARY_OUTPUT_DIRECTORY) + set(LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/python/xoscar/collective/rendezvous) +endif() include_directories(${CMAKE_SOURCE_DIR}/cpp/collective/rendezvous/include) From 77c2fef18574eae77ad2faebd9e2885649289cea Mon Sep 17 00:00:00 2001 From: ChengjieLi Date: Tue, 4 Jul 2023 16:35:24 +0800 Subject: [PATCH 19/19] add comments --- cpp/collective/rendezvous/src/tcp_store.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/collective/rendezvous/src/tcp_store.cpp b/cpp/collective/rendezvous/src/tcp_store.cpp index 0fcb0c9a..0af12378 100644 --- a/cpp/collective/rendezvous/src/tcp_store.cpp +++ b/cpp/collective/rendezvous/src/tcp_store.cpp @@ -1096,6 +1096,10 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions &opts) callbackClient_ = detail::TCPCallbackClient::connect(addr_, opts); } +// The destructor does nothing means that +// the background thread does not stop working when the ``tcp store`` object is +// recycled, and it will continue to occupy the port. Care is required when +// using it. So here is the implementation inherited from ``Torch``. TCPStore::~TCPStore() = default; void TCPStore::waitForWorkers() {