From 40b9a926b9df375a6ca5ce51ea93e4d849b85517 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 14 Sep 2023 04:43:44 -0700 Subject: [PATCH] [Disco] Pipe-based Multi-processing Session (#15727) This PR introduces `ProcessSession`, a new session implementation based on multi-processing. `ProcessSession` shares exactly the same communication protocol with `ThreadedSession`, but all workers except for worker 0 are launched in a separate process than thread. Workers communicate with the controller via pipe provided by the OS, rather than SPSC message queue between threads. In our implementation, Python's `subproces.popen` is used to create subprocesses, and the Python executable, or more specifically, `sys.executable` calls into `tvm.exec.disco_worker` as the entrypoint. Besides the launching logic that is only executed once in the very beginning, the rest of the implementation resides in a C++-only environment, including reads/writes to pipe file descriptors, serialization and deserialization of messages, worker interpretation of each message, etc. Detailed engineering elements included in this PR: - Refactors the MinRPC-based communication protocol out to be shared by `ProcessSession` and `ThreadedSession` as `protocol.h`; - Refactors a controller-side worker thread into `DiscoWorkerThread`, which is shared by both session implementation to launch worker-0; - Added two instructions `kDebugGetFromRemote` and `kDebugSetRegister`, which are used to communicate with workers other than worker-0 in debug mode; - Introduces multi-processing infra including: `tvm.exec.disco_worker` serving as the entrypoint that launches workers, and `tvm/runtime/disco/process_pool.py` that exposes APIs to launch worker processes. `tvm.exec.disco_worker` calls into a global function `runtime.disco.WorkerProcess` that executes the worker main loop in pure C++; - Introduces `src/support/process_id.h` that provides cross-platform pid and tid printing utilities; - Refactors Disco's NCCL integration that get rids of initialized-once global NCCL context, and switches to broadcasting `ncclUniqueId` from controller to all workers, and then create NCCL communicators in each worker thread/process accordingly. This is a thread/process-agnostic way of using NCCL. --- include/tvm/runtime/disco/session.h | 50 ++++- python/tvm/exec/disco_worker.py | 51 +++++ python/tvm/runtime/disco/__init__.py | 9 +- python/tvm/runtime/disco/process_pool.py | 180 ++++++++++++++++ python/tvm/runtime/disco/session.py | 27 ++- python/tvm/testing/__init__.py | 3 +- python/tvm/testing/disco.py | 53 +++++ src/runtime/disco/bcast_session.cc | 7 + src/runtime/disco/bcast_session.h | 2 + src/runtime/disco/builtin.cc | 6 +- src/runtime/disco/nccl/nccl.cc | 91 +++----- src/runtime/disco/nccl/utils.h | 2 + src/runtime/disco/process_session.cc | 213 +++++++++++++++++++ src/runtime/disco/protocol.h | 254 +++++++++++++++++++++++ src/runtime/disco/session.cc | 11 +- src/runtime/disco/threaded_session.cc | 128 +++--------- src/runtime/disco/worker.cc | 55 ++++- src/runtime/disco/worker.h | 43 ++++ src/support/process_id.h | 67 ++++++ tests/python/disco/test_nccl.py | 92 ++++---- tests/python/disco/test_session.py | 95 ++++----- 21 files changed, 1149 insertions(+), 290 deletions(-) create mode 100644 python/tvm/exec/disco_worker.py create mode 100644 python/tvm/runtime/disco/process_pool.py create mode 100644 python/tvm/testing/disco.py create mode 100644 src/runtime/disco/process_session.cc create mode 100644 src/runtime/disco/protocol.h create mode 100644 src/support/process_id.h diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index e28fb7144c..984ea026d8 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -72,6 +72,7 @@ #ifndef TVM_RUNTIME_DISCO_SESSION_H_ #define TVM_RUNTIME_DISCO_SESSION_H_ +#include #include #include @@ -92,6 +93,8 @@ enum class DiscoAction : int32_t { kSyncWorker = 4, kCopyFromWorker0 = 5, kCopyToWorker0 = 6, + kDebugGetFromRemote = 7, + kDebugSetRegister = 8, }; /*! \brief Converts the enum class `DiscoAction` to string */ @@ -111,6 +114,10 @@ inline std::string DiscoAction2String(DiscoAction action) { return "kCopyFromWorker0"; case DiscoAction::kCopyToWorker0: return "kCopyToWorker0"; + case DiscoAction::kDebugGetFromRemote: + return "kDebugGetFromRemote"; + case DiscoAction::kDebugSetRegister: + return "kDebugSetRegister"; } LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast(action); } @@ -136,7 +143,7 @@ class DRefObj : public Object { * \param worker_id The id of the worker to be copied to. * \param source The NDArray to be copied. */ - void DebugCopyFrom(int worker_id, NDArray source); + inline void DebugCopyFrom(int worker_id, TVMArgValue source); static constexpr const char* _type_key = "runtime.disco.DRef"; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef; @@ -213,6 +220,12 @@ class SessionObj : public Object { virtual void SyncWorker(int worker_id) = 0; /*! \brief Signal all the workers to shutdown */ virtual void Shutdown() = 0; + /*! + * \brief Initialize the data plane between workers. + * \param ccl The name of the communication backend, e.g., nccl, rccl, mpi. + * \param device_ids The device ids of the workers. + */ + virtual void InitCCL(String ccl, ShapeTuple device_ids) = 0; /*! * \brief Get the value of a register from a remote worker. * \param reg_id The id of the register to be fetched. @@ -220,13 +233,19 @@ class SessionObj : public Object { * \return The value of the register. */ virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0; - - static constexpr const char* _type_key = "runtime.disco.Session"; - TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object); + /*! + * \brief Set the value of a register on a remote worker. + * \param reg_id The id of the register to be set. + * \param value The value to be set. + * \param worker_id The id of the worker to be set. + */ + virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0; struct FFI; friend struct SessionObj::FFI; friend class DRefObj; + static constexpr const char* _type_key = "runtime.disco.Session"; + TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object); protected: /*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */ @@ -239,8 +258,22 @@ class SessionObj : public Object { */ class Session : public ObjectRef { public: - /*! \brief Create a session backed by a thread pool of workers */ - static Session ThreadedSession(int num_workers); + /*! + * \brief Create a session backed by a thread pool of workers + * \param num_workers The number of workers. + */ + TVM_DLL static Session ThreadedSession(int num_workers); + /*! + * \brief Create a session backed by pipe-based multiprocessing + * \param num_workers The number of workers. + * \param process_pool_creator The name of a global function that takes `num_workers` as an input, + * and returns a PackedFunc, which takes an integer `worker_id` as the input and returns None. + * When `worker-id` is 0, it shuts down the process pool; Otherwise, it retursn a tuple + * (read_fd, writefd) used to communicate with the corresponding worker. + * \note Worker-0 is always co-located with the controler as a separate thread, and therefore + * worker-0 does not exist in the process pool. + */ + TVM_DLL static Session ProcessSession(int num_workers, String process_pool_creator); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); }; @@ -250,6 +283,7 @@ class Session : public ObjectRef { */ class DiscoChannel { public: + virtual ~DiscoChannel() = default; /*! \brief Send a packed sequence to the receiver */ virtual void Send(const TVMArgs& args) = 0; /*! \brief Receive a packed sequence from worker */ @@ -272,6 +306,10 @@ TVMRetValue DRefObj::DebugGetFromRemote(int worker_id) { return Downcast(this->session)->DebugGetFromRemote(this->reg_id, worker_id); } +void DRefObj::DebugCopyFrom(int worker_id, TVMArgValue value) { + return Downcast(this->session)->DebugSetRegister(this->reg_id, value, worker_id); +} + template DRef SessionObj::CallPacked(const DRef& func, Args&&... args) { constexpr int offset = 3; diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py new file mode 100644 index 0000000000..9faa5742ae --- /dev/null +++ b/python/tvm/exec/disco_worker.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Internal DiscoWorker for Disco ProcessSession.""" +import os +import sys + +from tvm import runtime as _ # pylint: disable=unused-import +from tvm._ffi import get_global_func +from tvm.testing import disco as _ # pylint: disable=unused-import + + +def main(): + """Main worker function""" + if len(sys.argv) != 5: + print("Usage: ") + return + worker_id = int(sys.argv[1]) + num_workers = int(sys.argv[2]) + if sys.platform == "win32": + import msvcrt # pylint: disable=import-outside-toplevel,import-error + + reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) + writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) + else: + reader = int(sys.argv[3]) + writer = int(sys.argv[4]) + + worker_func = get_global_func("runtime.disco.WorkerProcess") + worker_func(worker_id, num_workers, reader, writer) + + +if __name__ == "__main__": + try: + main() + except (KeyboardInterrupt, IOError): + pass diff --git a/python/tvm/runtime/disco/__init__.py b/python/tvm/runtime/disco/__init__.py index 57c0548e2e..856e69bc35 100644 --- a/python/tvm/runtime/disco/__init__.py +++ b/python/tvm/runtime/disco/__init__.py @@ -15,4 +15,11 @@ # specific language governing permissions and limitations # under the License. """TVM distributed runtime API.""" -from .session import DModule, DPackedFunc, DRef, Session, ThreadedSession +from .session import ( + DModule, + DPackedFunc, + DRef, + ProcessSession, + Session, + ThreadedSession, +) diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py new file mode 100644 index 0000000000..44348577f7 --- /dev/null +++ b/python/tvm/runtime/disco/process_pool.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Pipe worker for multi-processing.""" +import os +import subprocess +import sys + +import psutil + +from tvm._ffi import register_func +from tvm.runtime import ShapeTuple + + +class DiscoPopenWorker: + """A subprocess worker via Popen. + + PopenWorker provides a low-level + API to interact with a separate process via Popen. + + Parameters + ---------- + worker_id : int + The worker id of the current worker. + + num_workers : int + The total number of workers. + + stdout: Union[None, int, IO[Any]] + The standard output streams handler specified for the popen process. + + stderr: Union[None, int, IO[Any]] + The standard error streams handler specified for the popen process. + """ + + def __init__(self, worker_id: int, num_workers: int, stdout=None, stderr=None): + self.worker_id = worker_id + self.num_workers = num_workers + self._proc = None + self._stdout = stdout + self._stderr = stderr + + def __del__(self): + try: + self.kill() + except ImportError: + pass + + def kill(self): + """Kill the current running process and cleanup. + + Note + ---- + The worker can start a new process when send is called again. + """ + if self._proc is not None: + # kill all child processes recursively + try: + _kill_child_processes(self._proc.pid) + except TypeError: + pass + try: + self._proc.kill() + except OSError: + pass + + # Join the child process to avoid zombie processes + self.join(timeout=1.0) + self._proc = None + + def join(self, timeout=None): + """Join the current process worker before it terminates. + + Parameters + ---------- + timeout: Optional[number] + Timeout value, block at most timeout seconds if it + is a positive number. + """ + if self._proc: + try: + self._proc.wait(timeout) + except subprocess.TimeoutExpired: + pass + + def start(self): + """Start a new subprocess if nothing is available""" + if self._proc is not None: + return None, None + + # connect subprocess with a pair of pipes + main_read, worker_write = os.pipe() + worker_read, main_write = os.pipe() + + cmd = [ + sys.executable, + "-m", + "tvm.exec.disco_worker", + str(self.worker_id), + str(self.num_workers), + ] + if sys.platform == "win32": + import msvcrt # pylint: disable=import-error,import-outside-toplevel + + worker_read_handle = msvcrt.get_osfhandle(worker_read) + worker_write_handle = msvcrt.get_osfhandle(worker_write) + os.set_handle_inheritable(worker_read_handle, True) + os.set_handle_inheritable(worker_write_handle, True) + cmd += [str(worker_read_handle), str(worker_write_handle)] + self._proc = subprocess.Popen( + cmd, + close_fds=False, + stdout=self._stdout, + stderr=self._stderr, + ) + else: + cmd += [str(worker_read), str(worker_write)] + self._proc = subprocess.Popen( # pylint: disable=consider-using-with + cmd, + pass_fds=(worker_read, worker_write), + stdout=self._stdout, + stderr=self._stderr, + ) + + # close worker side of the pipe + os.close(worker_read) + os.close(worker_write) + return main_read, main_write + + +def _kill_child_processes(pid): + """Kill all child processes recursively for a given pid. + + Parameters + ---------- + pid : int + The given parameter id. + """ + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + except psutil.NoSuchProcess: + return + + for process in children: + try: + process.kill() + except psutil.NoSuchProcess: + pass + + +@register_func("runtime.disco.create_process_pool") +def _create_process_pool(num_workers: int): + """Create a process pool where the workers' are are [1, num_workers).""" + pool = [DiscoPopenWorker(i, num_workers) for i in range(1, num_workers)] + + def result_func(worker_id: int): + nonlocal pool + if worker_id != 0: + read_fd, write_fd = pool[worker_id - 1].start() + return ShapeTuple([read_fd, write_fd]) + print("Shutting down the process pool") + del pool + return None + + return result_func diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index eab5a5268d..d05561c2d1 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -27,7 +27,7 @@ from ..ndarray import NDArray from ..ndarray import array as _as_NDArray from ..object import Object -from . import _ffi_api +from . import _ffi_api, process_pool # pylint: disable=unused-import @register_object("runtime.disco.DRef") @@ -250,22 +250,21 @@ def load_vm_module( func = self._get_cached_method("runtime.disco.load_vm_module") return DModule(func(path, device)) - def init_ccl(self, api: str, *args): + def init_ccl(self, ccl: str, *device_ids): """Initialize the underlying communication collective library. Parameters ---------- - api : str + ccl : str The name of the communication collective library. Currently supported libraries are: - nccl - rccl - mpi - *args : various types - The arguments to be passed to the initialization function of the communication + *device_ids : int + The device IDs to be used by the underlying communication library. """ - assert api in ("nccl", "rccl"), f"Unsupported CCL backend: {api}" - func = self.get_global_func(f"runtime.disco.{api}.init_ccl") - func(*args) + assert ccl in ("nccl", "rccl"), f"Unsupported CCL backend: {ccl}" + return _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef: """Broadcast an array from worker-0 to all other workers. @@ -343,6 +342,18 @@ def __init__(self, num_workers: int) -> None: ) +@register_object("runtime.disco.ProcessSession") +class ProcessSession(Session): + """A Disco session backed by pipe-based multi-processing.""" + + def __init__(self, num_workers: int) -> None: + self.__init_handle_by_constructor__( + _ffi_api.SessionProcess, # type: ignore # pylint: disable=no-member + num_workers, + "runtime.disco.create_process_pool", + ) + + REDUCE_OPS = { "sum": 0, "prod": 1, diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index 3e5f838a27..9aa1a31933 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -17,8 +17,7 @@ # pylint: disable=redefined-builtin, wildcard-import """Utility Python functions for TVM testing""" - -from . import auto_scheduler, autotvm +from . import auto_scheduler, autotvm, disco from ._ffi_api import ( ErrorTest, FrontendTestModule, diff --git a/python/tvm/testing/disco.py b/python/tvm/testing/disco.py new file mode 100644 index 0000000000..c13e83b7c4 --- /dev/null +++ b/python/tvm/testing/disco.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, missing-function-docstring, missing-class-docstring +"""Common utilities for testing disco""" +from tvm._ffi import register_func +from tvm.runtime import NDArray, ShapeTuple, String +from tvm.runtime.ndarray import array + + +@register_func("tests.disco.add_one") +def add_one(x: int) -> int: # pylint: disable=invalid-name + return x + 1 + + +@register_func("tests.disco.add_one_float", override=True) +def add_one_float(x: float): # pylint: disable=invalid-name + return x + 0.5 + + +@register_func("tests.disco.add_one_ndarray", override=True) +def add_one_ndarray(x: NDArray) -> NDArray: # pylint: disable=invalid-name + return array(x.numpy() + 1) + + +@register_func("tests.disco.str", override=True) +def str_func(x: str): # pylint: disable=invalid-name + return x + "_suffix" + + +@register_func("tests.disco.str_obj", override=True) +def str_obj_func(x: String): # pylint: disable=invalid-name + assert isinstance(x, String) + return String(x + "_suffix") + + +@register_func("tests.disco.shape_tuple", override=True) +def shape_tuple_func(x: ShapeTuple): # pylint: disable=invalid-name + assert isinstance(x, ShapeTuple) + return ShapeTuple(list(x) + [4, 5]) diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 0625c1157e..9b553c319c 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -68,6 +68,13 @@ void BcastSessionObj::Shutdown() { BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kShutDown, 0); } +void BcastSessionObj::InitCCL(String ccl, ShapeTuple device_ids) { + const auto* pf = runtime::Registry::Get("runtime.disco." + ccl + ".init_ccl"); + CHECK(pf) << "ValueError: Cannot initialize CCL `" << ccl + << "`, because cannot find function: runtime.disco." << ccl << ".init_ccl"; + (*pf)(GetRef(this), device_ids); +} + void BcastSessionObj::SyncWorker(int worker_id) { BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kSyncWorker, worker_id); TVMArgs args = this->RecvReplyPacked(worker_id); diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h index 0221207b96..d064b30f5a 100644 --- a/src/runtime/disco/bcast_session.h +++ b/src/runtime/disco/bcast_session.h @@ -42,7 +42,9 @@ class BcastSessionObj : public SessionObj { void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) override; void SyncWorker(int worker_id) override; void Shutdown() override; + void InitCCL(String ccl, ShapeTuple device_ids) override; TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0; + void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) override = 0; protected: /*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */ diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 64e3fd4b28..06408c723a 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -100,7 +100,11 @@ void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; } -void SyncWorker() { GetCCLFunc("sync_worker")(); } +void SyncWorker() { + if (DiscoWorker::ThreadLocal()->ccl != "") { + GetCCLFunc("sync_worker")(); + } +} TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray); diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 0212923cef..e404e3c2bb 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -19,13 +19,16 @@ #include #include #include +#include #include #include +#include #include #include #include +#include "../../../support/process_id.h" #include "../../cuda/cuda_common.h" #include "./utils.h" @@ -33,48 +36,6 @@ namespace tvm { namespace runtime { namespace nccl { -struct NCCLGlobalContext { - std::vector communicators; - - static NCCLGlobalContext* Get() { - static NCCLGlobalContext ctx; - return &ctx; - } - - void Initialize(const std::vector& device_ids) { - { - std::ostringstream os; - bool is_first = true; - for (int device_id : device_ids) { - if (!is_first) { - os << ","; - } else { - is_first = false; - } - os << device_id; - } - LOG(INFO) << "Initializing NCCL with devices: " << os.str() << "."; - } - // TODO(@junrushao): support more flexible communicator pattern for generic SPMD usecases - DiscoWorker* worker = DiscoWorker::ThreadLocal(); - int num_workers = worker->num_workers; - CHECK_EQ(device_ids.size(), num_workers) - << "ValueError: There are " << num_workers << " worker(s), but " << device_ids.size() - << " device id(s) are provided."; - ncclUniqueId id; - NCCL_CALL(ncclGetUniqueId(&id)); - NCCL_CALL(ncclGroupStart()); - for (int worker_id = 0; worker_id < num_workers; ++worker_id) { - int device_id = device_ids[worker_id]; - ncclComm_t comm; - CUDA_CALL(cudaSetDevice(device_id)); - NCCL_CALL(ncclCommInitRank(&comm, num_workers, id, worker_id)); - this->communicators.push_back(comm); - } - NCCL_CALL(ncclGroupEnd()); - } -}; - struct NCCLThreadLocalContext { DiscoWorker* worker; int device_id; @@ -92,23 +53,38 @@ struct NCCLThreadLocalContext { } }; -void InitCCL(const std::vector& device_ids) { - // Set up global context only once - static std::once_flag flag; - std::call_once(flag, [&]() { NCCLGlobalContext::Get()->Initialize(device_ids); }); - // Set up thread-local context for each thread - DiscoWorker* worker = DiscoWorker::ThreadLocal(); +void InitCCL(Session sess, ShapeTuple device_ids) { + DRef func = sess->GetGlobalFunc("runtime.disco.nccl.init_ccl_per_worker"); + LOG(INFO) << "Initializing NCCL with devices: " << device_ids; + ncclUniqueId id; + TVMByteArray array; + NCCL_CALL(ncclGetUniqueId(&id)); + array.data = id.internal; + array.size = NCCL_UNIQUE_ID_BYTES; + sess->CallPacked(func, device_ids, array); +} + +void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) { NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get(); + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + ICHECK(worker != nullptr); + CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES) + << "ValueError: The length of unique_id must be " << NCCL_UNIQUE_ID_BYTES << ", but got " + << unique_id_bytes.size() << "."; + // Step up local context of NCCL int device_id = device_ids[worker->worker_id]; CUDA_CALL(cudaSetDevice(device_id)); + CUDA_CALL(cudaStreamCreate(&ctx->stream)); Device device{DLDeviceType::kDLCUDA, device_id}; + DeviceAPI::Get(device)->SetStream(device, ctx->stream); worker->default_device = device; worker->ccl = "nccl"; ctx->worker = worker; ctx->device_id = device_id; - ctx->comm = NCCLGlobalContext::Get()->communicators[worker->worker_id]; - CUDA_CALL(cudaStreamCreate(&ctx->stream)); - DeviceAPI::Get(device)->SetStream(device, ctx->stream); + // Initialize the communicator + ncclUniqueId id; + std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES); + NCCL_CALL(ncclCommInitRank(&ctx->comm, worker->num_workers, id, worker->worker_id)); } void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) { @@ -158,7 +134,7 @@ void ScatterFromWorker0(Optional send, NDArray recv) { } } else { if (send.defined()) { - LOG(WARNING) << "ValueError: buffer `send` must be None when worker_id != 0. However, got " + LOG(WARNING) << "Buffer `send` must be None when worker_id != 0, but got " "send = " << send.get() << ". This will be ignored."; } @@ -222,17 +198,12 @@ void RecvFromWorker0(NDArray buffer) { void SyncWorker() { NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get(); + ICHECK(ctx->worker != nullptr); CUDA_CALL(cudaStreamSynchronize(ctx->stream)); } -TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl") - .set_body([](TVMArgs args, TVMRetValue* rv) -> void { - std::vector device_ids; - for (int i = 0; i < args.num_args; ++i) { - device_ids.push_back(args[i].operator int()); - } - InitCCL(device_ids); - }); +TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl").set_body_typed(InitCCL); +TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl_per_worker").set_body_typed(InitCCLPerWorker); TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce") .set_body_typed([](NDArray send, int kind, NDArray recv) { CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; diff --git a/src/runtime/disco/nccl/utils.h b/src/runtime/disco/nccl/utils.h index 4e5fb8cd74..7f40365136 100644 --- a/src/runtime/disco/nccl/utils.h +++ b/src/runtime/disco/nccl/utils.h @@ -69,6 +69,7 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { return ncclBfloat16; } LOG(FATAL) << "ValueError: Unsupported data type " << dtype; + throw; } inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) { @@ -85,6 +86,7 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) { return ncclAvg; } LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast(kind); + throw; } } // namespace nccl diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc new file mode 100644 index 0000000000..8ddfdce812 --- /dev/null +++ b/src/runtime/disco/process_session.cc @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include +#include +#include +#include + +#include "../../support/pipe.h" +#include "../minrpc/rpc_reference.h" +#include "./bcast_session.h" +#include "./protocol.h" +#include "./worker.h" +#include "tvm/runtime/c_runtime_api.h" + +namespace tvm { +namespace runtime { + +class DiscoPipeMessageQueue : private ::tvm::support::Pipe, + private DiscoProtocol { + public: + explicit DiscoPipeMessageQueue(int64_t handle) : ::tvm::support::Pipe(handle) {} + + ~DiscoPipeMessageQueue() = default; + + void Send(const TVMArgs& args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this); + } + + TVMArgs Recv() { + { + this->RecycleAll(); + uint64_t packet_nbytes = 0; + RPCCode code = RPCCode::kReturn; + this->Read(&packet_nbytes); + this->Read(&code); + } + TVMValue* values = nullptr; + int* type_codes = nullptr; + int num_args = 0; + RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); + return TVMArgs(values, type_codes, num_args); + } + + using dmlc::Stream::Read; + using dmlc::Stream::ReadArray; + using dmlc::Stream::Write; + using dmlc::Stream::WriteArray; + friend struct RPCReference; + friend struct DiscoProtocol; +}; + +class DiscoProcessChannel final : public DiscoChannel { + public: + DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t worker_to_controler_fd) + : controler_to_worker_(controler_to_worker_fd), + worker_to_controler_(worker_to_controler_fd) {} + + DiscoProcessChannel(DiscoProcessChannel&& other) = delete; + DiscoProcessChannel(const DiscoProcessChannel& other) = delete; + + void Send(const TVMArgs& args) { controler_to_worker_.Send(args); } + TVMArgs Recv() { return controler_to_worker_.Recv(); } + void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); } + TVMArgs RecvReply() { return worker_to_controler_.Recv(); } + + DiscoPipeMessageQueue controler_to_worker_; + DiscoPipeMessageQueue worker_to_controler_; +}; + +class ProcessSessionObj final : public BcastSessionObj { + public: + explicit ProcessSessionObj(int num_workers, PackedFunc process_pool) + : process_pool_(process_pool), + worker_0_(std::make_unique(0, num_workers, &worker_zero_data_)) { + std::vector read_fds; + std::vector write_fds; + read_fds.reserve(num_workers - 1); + write_fds.reserve(num_workers - 1); + for (int i = 1; i < num_workers; ++i) { + ShapeTuple fds = process_pool(i); + CHECK_EQ(fds.size(), 2) << "ValueError: process_pool(" << i << ") should return a tuple of " + << "size 2, but got a tuple of size " << fds.size() << "."; + read_fds.push_back(fds[0]); + write_fds.push_back(fds[1]); + } + for (int i = 0; i < num_workers - 1; ++i) { + workers_.emplace_back(std::make_unique(write_fds[i], read_fds[i])); + } + } + + void Kill() { + if (this->worker_0_ != nullptr) { + this->Shutdown(); + this->worker_0_.reset(); + this->workers_.clear(); + this->process_pool_(0); + } + } + + ~ProcessSessionObj() { Kill(); } + + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { + if (worker_id == 0) { + this->SyncWorker(worker_id); + return worker_0_->worker->register_file.at(reg_id); + } + { + TVMValue values[3]; + int type_codes[3]; + PackArgs(values, type_codes, static_cast(DiscoAction::kDebugGetFromRemote), reg_id, + worker_id); + workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 3)); + } + TVMArgs args = this->RecvReplyPacked(worker_id); + ICHECK_EQ(args.size(), 2); + ICHECK(static_cast(args[0].operator int()) == DiscoAction::kDebugGetFromRemote); + TVMRetValue result; + result = args[1]; + return result; + } + + void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) { + if (worker_id == 0) { + this->SyncWorker(worker_id); + worker_0_->worker->SetRegister(reg_id, value); + return; + } + ObjectRef wrapped{nullptr}; + if (value.type_code() == kTVMNDArrayHandle || value.type_code() == kTVMObjectHandle) { + wrapped = DiscoDebugObject::Wrap(value); + TVMValue tvm_value; + int type_code = kTVMObjectHandle; + tvm_value.v_handle = const_cast(wrapped.get()); + value = TVMArgValue(tvm_value, type_code); + } + { + TVMValue values[4]; + int type_codes[4]; + PackArgs(values, type_codes, static_cast(DiscoAction::kDebugSetRegister), reg_id, + worker_id, value); + workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 4)); + } + TVMRetValue result; + TVMArgs args = this->RecvReplyPacked(worker_id); + ICHECK_EQ(args.size(), 1); + ICHECK(static_cast(args[0].operator int()) == DiscoAction::kDebugSetRegister); + } + + void BroadcastPacked(const TVMArgs& args) final { + worker_0_->channel->Send(args); + for (std::unique_ptr& channel : workers_) { + channel->Send(args); + } + } + + TVMArgs RecvReplyPacked(int worker_id) final { + if (worker_id == 0) { + return worker_0_->channel->RecvReply(); + } + return this->workers_.at(worker_id - 1)->RecvReply(); + } + + PackedFunc process_pool_; + std::unique_ptr worker_0_; + std::vector> workers_; + + static constexpr const char* _type_key = "runtime.disco.ProcessSession"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProcessSessionObj, SessionObj); +}; + +TVM_REGISTER_OBJECT_TYPE(DiscoDebugObject); +TVM_REGISTER_OBJECT_TYPE(ProcessSessionObj); + +Session Session::ProcessSession(int num_workers, String process_pool_creator) { + const PackedFunc* pf = Registry::Get(process_pool_creator); + CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator + << " in the registry. Please check if it is registered."; + PackedFunc process_pool = (*pf)(num_workers); + auto n = make_object(num_workers, process_pool); + return Session(n); +} + +void WorkerProcess(int worker_id, int num_workers, int64_t read_fd, int64_t write_fd) { + DiscoProcessChannel channel(read_fd, write_fd); + DiscoWorker worker(worker_id, num_workers, nullptr, &channel); + worker.MainLoop(); +} + +TVM_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession); +TVM_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h new file mode 100644 index 0000000000..50a6b091af --- /dev/null +++ b/src/runtime/disco/protocol.h @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_DISCO_PROTOCOL_H_ +#define TVM_RUNTIME_DISCO_PROTOCOL_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../support/arena.h" +#include "../../support/base64.h" +#include "../minrpc/rpc_reference.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief The communication protocol used by Disco message channel. + * \tparam SubClassType The subclass type that inherits this protocol. + */ +template +struct DiscoProtocol { + protected: + /*! \brief Virtual destructor */ + virtual ~DiscoProtocol() = default; + + /*! \brief Recycle all the memory used in the arena */ + inline void RecycleAll() { + this->object_arena_.clear(); + this->arena_.RecycleAll(); + } + + /*! \brief Get the length of the object being serialized. Used by RPCReference. */ + inline uint64_t GetObjectBytes(Object* obj); + + /*! \brief Write the object to stream. Used by RPCReference. */ + inline void WriteObject(Object* obj); + + /*! \brief Read the object from stream. Used by RPCReference. */ + inline void ReadObject(int* tcode, TVMValue* value); + + /*! \brief Callback method used when starting a new message. Used by RPCReference. */ + void MessageStart(uint64_t packet_nbytes) {} + + /*! \brief Callback method used when a new message is complete. Used by RPCReference. */ + void MessageDone() {} + + /*! \brief Callback method when an error occurs in (de)-serialization. Used by RPCReference. */ + void ThrowError(RPCServerStatus status) { + LOG(FATAL) << "InternalError: Unexpected error in RPC: " << RPCServerStatusToString(status); + } + + /*!\ brief Arena used by RPCReference to allocate POD memory */ + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + support::Arena arena_; + std::vector object_arena_; + friend struct RPCReference; +}; + +/*! + * \brief The debug extension of the communication protocol that allows serialization and + * deserialization of NDArrays and reflection-capable TVM objects. + */ +struct DiscoDebugObject : public Object { + public: + /*! \brief The data to be serialized */ + TVMRetValue data; + + /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug extension. */ + static ObjectRef Wrap(const TVMRetValue& data) { + ObjectPtr n = make_object(); + n->data = data; + return ObjectRef(n); + } + + /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug extension. */ + static ObjectRef Wrap(const TVMArgValue& data) { + TVMRetValue rv; + rv = data; + return Wrap(std::move(rv)); + } + + /*! \brief Serialize the debug object to string */ + inline std::string SaveToStr() const; + /*! \brief Deserialize the debug object from string */ + static inline ObjectPtr LoadFromStr(std::string json_str); + /*! \brief Get the size of the debug object in bytes */ + inline uint64_t GetObjectBytes() const { return sizeof(uint64_t) + this->SaveToStr().size(); } + + static constexpr const char* _type_key = "runtime.disco.DiscoDebugObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(DiscoDebugObject, SessionObj); +}; + +template +inline uint64_t DiscoProtocol::GetObjectBytes(Object* obj) { + if (obj->IsInstance()) { + return sizeof(uint32_t) + sizeof(int64_t); + } else if (obj->IsInstance()) { + uint64_t size = static_cast(obj)->size; + return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char); + } else if (obj->IsInstance()) { + uint64_t ndim = static_cast(obj)->size; + return sizeof(uint32_t) + sizeof(uint64_t) + ndim * sizeof(ShapeTupleObj::index_type); + } else if (obj->IsInstance()) { + return sizeof(uint32_t) + static_cast(obj)->GetObjectBytes(); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } +} +template +inline void DiscoProtocol::WriteObject(Object* obj) { + SubClassType* self = static_cast(this); + if (obj->IsInstance()) { + int64_t reg_id = static_cast(obj)->reg_id; + self->template Write(TypeIndex::kRuntimeDiscoDRef); + self->template Write(reg_id); + } else if (obj->IsInstance()) { + StringObj* str = static_cast(obj); + self->template Write(TypeIndex::kRuntimeString); + self->template Write(str->size); + self->template WriteArray(str->data, str->size); + } else if (obj->IsInstance()) { + ShapeTupleObj* shape = static_cast(obj); + self->template Write(TypeIndex::kRuntimeShapeTuple); + self->template Write(shape->size); + self->template WriteArray(shape->data, shape->size); + } else if (obj->IsInstance()) { + self->template Write(TypeIndex::kRoot); + std::string str = static_cast(obj)->SaveToStr(); + self->template Write(str.size()); + self->template WriteArray(str.data(), str.size()); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } +} + +template +inline void DiscoProtocol::ReadObject(int* tcode, TVMValue* value) { + SubClassType* self = static_cast(this); + ObjectRef result{nullptr}; + uint32_t type_index; + self->template Read(&type_index); + if (type_index == TypeIndex::kRuntimeDiscoDRef) { + ObjectPtr dref = make_object(); + self->template Read(&dref->reg_id); + dref->session = Session{nullptr}; + result = ObjectRef(std::move(dref)); + } else if (type_index == TypeIndex::kRuntimeString) { + uint64_t size = 0; + self->template Read(&size); + std::string data(size, '\0'); + self->template ReadArray(data.data(), size); + result = String(std::move(data)); + } else if (type_index == TypeIndex::kRuntimeShapeTuple) { + uint64_t ndim = 0; + self->template Read(&ndim); + std::vector data(ndim); + self->template ReadArray(data.data(), ndim); + result = ShapeTuple(std::move(data)); + } else if (type_index == TypeIndex::kRoot) { + uint64_t size = 0; + self->template Read(&size); + std::string data(size, '\0'); + self->template ReadArray(data.data(), size); + result = DiscoDebugObject::LoadFromStr(std::move(data))->data; + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; + } + TVMArgsSetter(value, tcode)(0, result); + object_arena_.push_back(result); +} + +inline std::string DiscoDebugObject::SaveToStr() const { + if (this->data.type_code() == kTVMObjectHandle) { + ObjectRef obj = this->data; + const PackedFunc* f = runtime::Registry::Get("node.SaveJSON"); + CHECK(f) << "ValueError: Cannot serialize object in non-debugging mode: " << obj->GetTypeKey(); + std::string result = (*f)(obj); + result.push_back('0'); + return result; + } else if (this->data.type_code() == kTVMNDArrayHandle) { + NDArray array = this->data; + std::string result; + { + dmlc::MemoryStringStream mstrm(&result); + support::Base64OutStream b64strm(&mstrm); + runtime::SaveDLTensor(&b64strm, array.operator->()); + b64strm.Finish(); + } + result.push_back('1'); + return result; + } + LOG(FATAL) << "ValueError: Cannot serialize the following type code in non-debugging mode: " + << this->data.type_code() << "(" << ArgTypeCode2Str(this->data.type_code()); +} + +inline ObjectPtr DiscoDebugObject::LoadFromStr(std::string json_str) { + ICHECK(!json_str.empty()); + char control_bit = json_str.back(); + json_str.pop_back(); + ObjectPtr result = make_object(); + if (control_bit == '0') { + const PackedFunc* f = runtime::Registry::Get("node.LoadJSON"); + CHECK(f) << "ValueError: Cannot deserialize object in non-debugging mode"; + result->data = (*f)(json_str); + } else if (control_bit == '1') { + dmlc::MemoryStringStream mstrm(&json_str); + support::Base64InStream b64strm(&mstrm); + b64strm.InitPosition(); + runtime::NDArray array; + ICHECK(array.Load(&b64strm)); + result->data = std::move(array); + } else { + LOG(FATAL) << "ValueError: Unsupported control bit: " << control_bit + << ". Full string: " << json_str; + } + return result; +} + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_DISCO_PROTOCOL_H_ diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index e22b6c6d26..2cc027151a 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -31,15 +31,6 @@ struct SessionObj::FFI { } }; -void DRefObj::DebugCopyFrom(int worker_id, NDArray source) { - TVMRetValue target_array = this->DebugGetFromRemote(worker_id); - CHECK(target_array.type_code() == kTVMNDArrayHandle) - << "ValueError: The DRef on the remote is not an NDArray, instead, its type code is: " - << ArgTypeCode2Str(target_array.type_code()); - NDArray target = target_array.operator NDArray(); - target.CopyFrom(source); -} - TVM_REGISTER_OBJECT_TYPE(DRefObj); TVM_REGISTER_OBJECT_TYPE(SessionObj); TVM_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession); @@ -58,6 +49,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0") .set_body_method(&SessionObj::CopyToWorker0); TVM_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker") .set_body_method(&SessionObj::SyncWorker); +TVM_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") // + .set_body_method(&SessionObj::InitCCL); TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked").set_body([](TVMArgs args, TVMRetValue* rv) { Session self = args[0]; *rv = SessionObj::FFI::CallWithPacked( diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index cb84918d2d..349601fd03 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -27,16 +27,17 @@ #include #include -#include "../../support/arena.h" #include "../../support/ring_buffer.h" #include "../minrpc/rpc_reference.h" #include "./bcast_session.h" +#include "./protocol.h" #include "./worker.h" namespace tvm { namespace runtime { -class DiscoThreadedMessageQueue : public dmlc::Stream { +class DiscoThreadedMessageQueue : private dmlc::Stream, + private DiscoProtocol { public: void Send(const TVMArgs& args) { RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this); @@ -67,10 +68,7 @@ class DiscoThreadedMessageQueue : public dmlc::Stream { condition_.wait(lock, [this] { return msg_cnt_.load() > 0; }); --msg_cnt_; } - { - this->arena_.RecycleAll(); - this->object_arena_.clear(); - } + this->RecycleAll(); uint64_t packet_nbytes = 0; RPCCode code = RPCCode::kReturn; this->Read(&packet_nbytes); @@ -84,18 +82,6 @@ class DiscoThreadedMessageQueue : public dmlc::Stream { this->ring_buffer_.Reserve(n); } - void MessageDone() {} - - void ThrowError(RPCServerStatus status) { - LOG(FATAL) << "InternalError: Unexpected error in RPC: " << RPCServerStatusToString(status); - } - - template - T* ArenaAlloc(int count) { - static_assert(std::is_pod::value, "need to be trival"); - return arena_.template allocate_(count); - } - size_t Read(void* data, size_t size) final { std::lock_guard lock(mutex_); ring_buffer_.Read(data, size); @@ -107,85 +93,17 @@ class DiscoThreadedMessageQueue : public dmlc::Stream { ring_buffer_.Write(data, size); } - uint64_t GetObjectBytes(Object* obj) { - if (obj->IsInstance()) { - return sizeof(uint32_t) + sizeof(int64_t); - } else if (obj->IsInstance()) { - uint64_t size = static_cast(obj)->size; - return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char); - } else if (obj->IsInstance()) { - uint64_t ndim = static_cast(obj)->size; - return sizeof(uint32_t) + sizeof(uint64_t) + ndim * sizeof(ShapeTupleObj::index_type); - } else { - LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; - } - } - - void WriteObject(Object* obj) { - if (obj->IsInstance()) { - int64_t reg_id = static_cast(obj)->reg_id; - this->Write(TypeIndex::kRuntimeDiscoDRef); - this->Write(reg_id); - } else if (obj->IsInstance()) { - StringObj* str = static_cast(obj); - this->Write(TypeIndex::kRuntimeString); - this->Write(str->size); - this->WriteArray(str->data, str->size); - } else if (obj->IsInstance()) { - ShapeTupleObj* shape = static_cast(obj); - this->Write(TypeIndex::kRuntimeShapeTuple); - this->Write(shape->size); - this->WriteArray(shape->data, shape->size); - } else { - LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; - } - } - - void ReadObject(int* tcode, TVMValue* value) { - ObjectRef result{nullptr}; - uint32_t type_index; - this->Read(&type_index); - if (type_index == TypeIndex::kRuntimeDiscoDRef) { - ObjectPtr dref = make_object(); - this->Read(&dref->reg_id); - dref->session = Session{nullptr}; - result = ObjectRef(std::move(dref)); - } else if (type_index == TypeIndex::kRuntimeString) { - uint64_t size = 0; - this->Read(&size); - std::string data(size, '\0'); - this->ReadArray(data.data(), size); - result = String(std::move(data)); - } else if (type_index == TypeIndex::kRuntimeShapeTuple) { - uint64_t ndim = 0; - this->Read(&ndim); - std::vector data(ndim); - this->ReadArray(data.data(), ndim); - result = ShapeTuple(std::move(data)); - } else { - LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; - } - *tcode = kTVMObjectHandle; - value->v_handle = const_cast(result.get()); - object_arena_.push_back(result); - } - using dmlc::Stream::Read; using dmlc::Stream::ReadArray; using dmlc::Stream::Write; using dmlc::Stream::WriteArray; friend struct RPCReference; + friend struct DiscoProtocol; std::mutex mutex_; std::atomic msg_cnt_{0}; std::condition_variable condition_; - support::RingBuffer ring_buffer_; - support::Arena arena_; - std::vector object_arena_; }; class DiscoThreadChannel final : public DiscoChannel { @@ -199,44 +117,52 @@ class DiscoThreadChannel final : public DiscoChannel { DiscoThreadedMessageQueue worker_to_controler_; }; +DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers, + WorkerZeroData* worker_zero_data_) + : channel(std::make_unique()), + worker( + std::make_unique(worker_id, num_workers, worker_zero_data_, channel.get())), + thread(std::make_unique([worker = this->worker.get()] { worker->MainLoop(); })) { +} + class ThreadedSessionObj final : public BcastSessionObj { public: explicit ThreadedSessionObj(int num_workers) { for (int i = 0; i < num_workers; ++i) { - std::unique_ptr channel = std::make_unique(); WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr; - workers_.emplace_back(std::make_unique(i, num_workers, data, channel.get())); - channels_.emplace_back(std::move(channel)); - worker_threads_.emplace_back([worker = workers_.back().get()] { worker->MainLoop(); }); + workers_.emplace_back(i, num_workers, data); } } ~ThreadedSessionObj() { this->Shutdown(); - for (std::thread& worker : this->worker_threads_) { - worker.join(); - } + workers_.clear(); } TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { this->SyncWorker(worker_id); - return this->workers_.at(worker_id)->register_file.at(reg_id); + return this->workers_.at(worker_id).worker->register_file.at(reg_id); + } + + void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) { + this->SyncWorker(worker_id); + this->workers_.at(worker_id).worker->SetRegister(reg_id, value); } void BroadcastPacked(const TVMArgs& args) final { - for (const std::unique_ptr& channel : this->channels_) { - channel->Send(args); + for (const DiscoWorkerThread& worker : this->workers_) { + worker.channel->Send(args); } } - TVMArgs RecvReplyPacked(int worker_id) final { return channels_[worker_id]->RecvReply(); } + TVMArgs RecvReplyPacked(int worker_id) final { + return this->workers_.at(worker_id).channel->RecvReply(); + } static constexpr const char* _type_key = "runtime.disco.ThreadedSession"; TVM_DECLARE_FINAL_OBJECT_INFO(ThreadedSessionObj, SessionObj); - std::vector> channels_; - std::vector> workers_; - std::vector worker_threads_; + std::vector workers_; }; TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj); diff --git a/src/runtime/disco/worker.cc b/src/runtime/disco/worker.cc index 63e814a7e2..3100985f18 100644 --- a/src/runtime/disco/worker.cc +++ b/src/runtime/disco/worker.cc @@ -19,11 +19,15 @@ #include "./worker.h" #include +#include +#include #include #include +#include "../../support/process_id.h" #include "./builtin.h" +#include "./protocol.h" namespace tvm { namespace runtime { @@ -43,11 +47,23 @@ DiscoWorker* DiscoWorker::ThreadLocal() { return ret; } +void DiscoWorker::SetRegister(int reg_id, TVMArgValue value) { + ICHECK(0 <= reg_id && reg_id < static_cast(register_file.size())); + TVMRetValue& rv = register_file.at(reg_id); + if (rv.type_code() == kTVMNDArrayHandle && value.type_code() == kTVMNDArrayHandle) { + NDArray dst = rv; + NDArray src = value; + dst.CopyFrom(src); + } else { + rv = value; + } +} + struct DiscoWorker::Impl { static void MainLoop(DiscoWorker* self) { ThreadLocalDiscoWorker::Get()->worker = self; - LOG(INFO) << "[Thread " << std::this_thread::get_id() << "] Worker #" << self->worker_id - << " Launched"; + LOG(INFO) << "[Worker #" << self->worker_id << "] " << support::GetProcessIdAndThreadIdHeader() + << " started"; while (true) { TVMArgs args = self->channel->Recv(); DiscoAction action = static_cast(args[0].operator int()); @@ -84,6 +100,17 @@ struct DiscoWorker::Impl { SyncWorker(self, reg_id); break; } + case DiscoAction::kDebugGetFromRemote: { + int worker_id = args[2]; + DebugGetFromRemote(self, reg_id, worker_id); + break; + } + case DiscoAction::kDebugSetRegister: { + int worker_id = args[2]; + TVMArgValue value = args[3]; + DebugSetRegister(self, reg_id, worker_id, value); + break; + } } } } @@ -131,6 +158,30 @@ struct DiscoWorker::Impl { } } + static void DebugGetFromRemote(DiscoWorker* self, int reg_id, int worker_id) { + if (worker_id == self->worker_id) { + TVMRetValue rv = GetReg(self, reg_id); + if (rv.type_code() == kTVMNDArrayHandle || rv.type_code() == kTVMObjectHandle) { + rv = DiscoDebugObject::Wrap(rv); + } + TVMValue values[2]; + int type_codes[2]; + PackArgs(values, type_codes, static_cast(DiscoAction::kDebugGetFromRemote), rv); + self->channel->Reply(TVMArgs(values, type_codes, 2)); + } + } + + static void DebugSetRegister(DiscoWorker* self, int reg_id, int worker_id, TVMArgValue value) { + if (worker_id == self->worker_id) { + ::tvm::runtime::SyncWorker(); + self->SetRegister(reg_id, value); + TVMValue values[1]; + int type_codes[1]; + PackArgs(values, type_codes, static_cast(DiscoAction::kDebugSetRegister)); + self->channel->Reply(TVMArgs(values, type_codes, 1)); + } + } + static void CallPacked(DiscoWorker* self, int64_t ret_reg_id, PackedFunc func, const TVMArgs& args) { TVMValue* values = const_cast(args.values); diff --git a/src/runtime/disco/worker.h b/src/runtime/disco/worker.h index f10382b068..e948fa1668 100644 --- a/src/runtime/disco/worker.h +++ b/src/runtime/disco/worker.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -81,6 +82,8 @@ class DiscoWorker { void MainLoop(); /*! \brief Get the worker instance on the current thread */ static DiscoWorker* ThreadLocal(); + /*! \brief Set the specific register to a specific value */ + void SetRegister(int reg_id, TVMArgValue value); /*! \brief The id of the worker.*/ int worker_id; @@ -108,6 +111,46 @@ class DiscoWorker { friend struct DiscoWorker::Impl; }; +/*! + * \brief A worker thread in Disco, which upon creation, launches a new thread to run the + * DiscoWorker. + * \sa DiscoWorker + */ +class DiscoWorkerThread { + public: + /*! + * \brief Construct a worker thread. + * \param worker_id The id of the worker. + * \param num_workers The total number of workers. + * \param worker_zero_data_ The data shared between worker-0 and the controler. It's a nullptr if + * the worker is not worker-0. + */ + explicit DiscoWorkerThread(int worker_id, int num_workers, WorkerZeroData* worker_zero_data_); + + /*! \brief Move constructor. */ + explicit DiscoWorkerThread(DiscoWorkerThread&& other) + : channel(std::move(other.channel)), + worker(std::move(other.worker)), + thread(std::move(other.thread)) {} + + /*! \brief Copy constructor is disabled */ + DiscoWorkerThread(const DiscoWorkerThread& other) = delete; + + /*! \brief Destructor that joins the thread before destruction */ + ~DiscoWorkerThread() { + if (this->thread != nullptr) { + this->thread->join(); + } + } + + /*! \brief The communication channel between the controler and the worker */ + std::unique_ptr channel; + /*! \brief The worker whose internal state is visible to the controler */ + std::unique_ptr worker; + /*! \brief The thread that runs the worker's main loop. */ + std::unique_ptr thread; +}; + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_DISCO_WORKER_H_ diff --git a/src/support/process_id.h b/src/support/process_id.h new file mode 100644 index 0000000000..8462ae0dd2 --- /dev/null +++ b/src/support/process_id.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file pipe.h + * \brief Platform independent pipe, used for IPC. + */ +#ifndef TVM_SUPPORT_PROCESS_ID_H_ +#define TVM_SUPPORT_PROCESS_ID_H_ + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +namespace tvm { +namespace support { + +/*! \brief Returns the PID of the current process as an 64-bit signed integer. */ +inline int64_t GetProcessId() { + int64_t result; +#ifdef _WIN32 + DWORD pid = GetCurrentProcessId(); + result = static_cast(pid); +#else + pid_t pid = getpid(); + result = static_cast(pid); +#endif + return result; +} + +/*! \brief Returns the PID and TIR of the current process/thread as a formatted string */ +inline std::string GetProcessIdAndThreadIdHeader() { + std::ostringstream os; + os << "[PID " << GetProcessId() << " TID 0x" << std::setw(16) << std::setfill('0') << std::hex + << std::this_thread::get_id() << "]"; + return os.str(); +} + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_PROCESS_ID_H_ diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py index 6507af5699..f0f949ab80 100644 --- a/tests/python/disco/test_nccl.py +++ b/tests/python/disco/test_nccl.py @@ -19,30 +19,33 @@ import tempfile import numpy as np +import pytest import tvm -import tvm.testing from tvm import dlight as dl from tvm import relax as rx from tvm.runtime import disco as di from tvm.runtime.relax_vm import VirtualMachine from tvm.script import relax as R +_all_session_kinds = [di.ThreadedSession, di.ProcessSession] -def test_init(): - devices = [0, 1] - sess = di.ThreadedSession(num_workers=len(devices)) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_init(session_kind): + devices = [0, 1] + sess = session_kind(num_workers=len(devices)) sess.init_ccl("nccl", *devices) -def test_allreduce(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_allreduce(session_kind): devices = [0, 1] + sess = session_kind(num_workers=len(devices)) + sess.init_ccl("nccl", *devices) + array_1 = np.arange(12, dtype="float32").reshape(3, 4) array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) - - sess = di.ThreadedSession(num_workers=len(devices)) - sess.init_ccl("nccl", *devices) d_array = sess.empty((3, 4), "float32") d_array.debug_copy_from(0, array_1) d_array.debug_copy_from(1, array_2) @@ -60,12 +63,13 @@ def test_allreduce(): np.testing.assert_equal(result, expected) -def test_broadcast_from_worker0(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_broadcast_from_worker0(session_kind): devices = [0, 1] - array = np.arange(12, dtype="float32").reshape(3, 4) - - sess = di.ThreadedSession(num_workers=len(devices)) + sess = session_kind(num_workers=len(devices)) sess.init_ccl("nccl", *devices) + + array = np.arange(12, dtype="float32").reshape(3, 4) d_array = sess.empty((3, 4), "float32") d_array.debug_copy_from(0, array) dst_array = sess.empty((3, 4), "float32") @@ -74,12 +78,13 @@ def test_broadcast_from_worker0(): np.testing.assert_equal(result, array) -def test_scatter(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_scatter(session_kind): devices = [0, 1] - array = np.arange(36, dtype="float32").reshape(3, 4, 3) - - sess = di.ThreadedSession(num_workers=len(devices)) + sess = session_kind(num_workers=len(devices)) sess.init_ccl("nccl", *devices) + + array = np.arange(36, dtype="float32").reshape(3, 4, 3) d_src = sess.empty((3, 4, 3), "float32") d_dst = sess.empty((3, 3, 2), "float32") @@ -96,29 +101,29 @@ def test_scatter(): ) -# def test_gather(): -# num_workers = 2 -# devices = [1, 2] -# array = np.arange(36, dtype="float32") - -# sess = di.ThreadedSession(num_workers=num_workers) -# sess.init_ccl("nccl", *devices) -# d_src = sess.empty((3, 3, 2), "float32") -# d_dst = sess.empty((3, 4, 3), "float32") - -# d_src.debug_copy_from(0, array[:18]) -# d_src.debug_copy_from(1, array[18:]) - -# sess.gather_to_worker0(d_src, d_dst) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_gather(session_kind): + devices = [1, 2] + sess = session_kind(num_workers=len(devices)) + sess.init_ccl("nccl", *devices) -# np.testing.assert_equal( -# d_dst.debug_get_from_remote(0).numpy(), -# array.reshape(3, 4, 3), -# ) + array = np.arange(36, dtype="float32") + d_src = sess.empty((3, 3, 2), "float32") + d_dst = sess.empty((3, 4, 3), "float32") + d_src.debug_copy_from(0, array[:18]) + d_src.debug_copy_from(1, array[18:]) + sess.gather_to_worker0(d_src, d_dst) + np.testing.assert_equal( + d_dst.debug_get_from_remote(0).numpy(), + array.reshape(3, 4, 3), + ) -def test_mlp(): # pylint: disable=too-many-locals +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_mlp(session_kind): # pylint: disable=too-many-locals devices = [0, 1] + sess = session_kind(num_workers=len(devices)) + sess.init_ccl("nccl", *devices) # pylint: disable=invalid-name @tvm.script.ir_module @@ -193,8 +198,6 @@ def relax_build(mod, target): path = tmpdir + "/test.so" relax_build(ShardedMLP, target).export_library(path) - sess = di.ThreadedSession(num_workers=len(devices)) - sess.init_ccl("nccl", *devices) mod = sess.load_vm_module(path) d_X = sess.empty((128, 128), "float32") @@ -215,8 +218,11 @@ def relax_build(mod, target): np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4) -def test_attention(): # pylint: disable=too-many-locals,too-many-statements +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_attention(session_kind): # pylint: disable=too-many-locals,too-many-statements devices = [0, 1] + sess = session_kind(num_workers=len(devices)) + sess.init_ccl("nccl", *devices) # pylint: disable=invalid-name @tvm.script.ir_module @@ -343,8 +349,6 @@ def relax_build(mod, target): path = tmpdir + "/test.so" relax_build(ShardedAttention, target).export_library(path) - sess = di.ThreadedSession(num_workers=len(devices)) - sess.init_ccl("nccl", *devices) mod = sess.load_vm_module(path) d_X = sess.empty((1, 10, 128), "float32") @@ -372,4 +376,10 @@ def relax_build(mod, target): if __name__ == "__main__": - tvm.testing.main() + test_init(di.ProcessSession) + test_allreduce(di.ProcessSession) + test_broadcast_from_worker0(di.ProcessSession) + test_scatter(di.ProcessSession) + test_gather(di.ProcessSession) + test_mlp(di.ProcessSession) + test_attention(di.ProcessSession) diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index a2c0906f22..40dcb04911 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -19,15 +19,16 @@ import tempfile import numpy as np +import pytest import tvm from tvm import relax as rx -from tvm._ffi import register_func from tvm.runtime import ShapeTuple, String from tvm.runtime import disco as di from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T +from tvm.testing import disco as _ def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): @@ -44,29 +45,23 @@ def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype): return host_array.numpy() -def test_int(): - num_workers = 4 +_all_session_kinds = [di.ThreadedSession, di.ProcessSession] - @register_func("tests.disco.add_one", override=True) - def add_one(x: int) -> int: # pylint: disable=invalid-name - return x + 1 - sess = di.ThreadedSession(num_workers=num_workers) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_int(session_kind): # pylint: disable=invalid-name + num_workers = 4 + sess = session_kind(num_workers=num_workers) func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one") result: di.DRef = func(1) - for i in range(num_workers): assert result.debug_get_from_remote(i) == 2 -def test_float(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_float(session_kind): num_workers = 4 - - @register_func("tests.disco.add_one_float", override=True) - def add_one(x: float): # pylint: disable=invalid-name - return x + 0.5 - - sess = di.ThreadedSession(num_workers=num_workers) + sess = session_kind(num_workers=num_workers) func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one_float") result: di.DRef = func(1.5) @@ -74,32 +69,23 @@ def add_one(x: float): # pylint: disable=invalid-name assert result.debug_get_from_remote(i) == 2.0 -def test_ndarray(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_ndarray(session_kind): num_workers = 4 - - @register_func("tests.disco.add_one_ndarray", override=True) - def add_one(x: tvm.runtime.NDArray) -> tvm.runtime.NDArray: # pylint: disable=invalid-name - return tvm.nd.array(x.numpy() + 1) - + sess = session_kind(num_workers=num_workers) device = tvm.cpu(0) x_np = np.arange(6).astype("float32").reshape([2, 3]) y_np = np.arange(6).astype("float32").reshape([2, 3]) + 1 - - sess = di.ThreadedSession(num_workers=num_workers) x_disc = _numpy_to_worker_0(sess, x_np, device=device) y_disc = sess.get_global_func("tests.disco.add_one_ndarray")(x_disc) y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) np.testing.assert_equal(y_nd, y_np) -def test_string(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_string(session_kind): num_workers = 4 - - @register_func("tests.disco.str", override=True) - def my_str_func(x: str): # pylint: disable=invalid-name - return x + "_suffix" - - sess = di.ThreadedSession(num_workers=num_workers) + sess = session_kind(num_workers=num_workers) func: di.DPackedFunc = sess.get_global_func("tests.disco.str") result: di.DRef = func("hello") @@ -107,15 +93,10 @@ def my_str_func(x: str): # pylint: disable=invalid-name assert result.debug_get_from_remote(i) == "hello_suffix" -def test_string_obj(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_string_obj(session_kind): num_workers = 4 - - @register_func("tests.disco.str_obj", override=True) - def my_str_func(x: String): # pylint: disable=invalid-name - assert isinstance(x, String) - return String(x + "_suffix") - - sess = di.ThreadedSession(num_workers=num_workers) + sess = session_kind(num_workers=num_workers) func: di.DPackedFunc = sess.get_global_func("tests.disco.str_obj") result: di.DRef = func(String("hello")) @@ -125,26 +106,22 @@ def my_str_func(x: String): # pylint: disable=invalid-name assert value == "hello_suffix" -def test_shape_tuple(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_shape_tuple(session_kind): num_workers = 4 - - @register_func("tests.disco.shape_tuple", override=True) - def my_str_func(x: ShapeTuple): # pylint: disable=invalid-name - assert isinstance(x, ShapeTuple) - return ShapeTuple(list(x) + [4, 5]) - - sess = di.ThreadedSession(num_workers=num_workers) + sess = session_kind(num_workers=num_workers) func: di.DPackedFunc = sess.get_global_func("tests.disco.shape_tuple") result: di.DRef = func(ShapeTuple([1, 2, 3])) - for i in range(num_workers): value = result.debug_get_from_remote(i) assert isinstance(value, ShapeTuple) assert list(value) == [1, 2, 3, 4, 5] -def test_vm_module(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_vm_module(session_kind): num_workers = 4 + sess = session_kind(num_workers=num_workers) # pylint: disable=invalid-name @I.ir_module @@ -172,7 +149,6 @@ def main(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor((16, 8), dtype="floa y_np = x_np.transpose() rx.build(TestMod, target="llvm").export_library(path) - sess = di.ThreadedSession(num_workers=num_workers) mod = sess.load_vm_module(path, device=device) x_disc = _numpy_to_worker_0(sess, x_np, device=device) @@ -181,8 +157,10 @@ def main(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor((16, 8), dtype="floa np.testing.assert_equal(y_nd, y_np) -def test_vm_multi_func(): +@pytest.mark.parametrize("session_kind", _all_session_kinds) +def test_vm_multi_func(session_kind): num_workers = 4 + sess = session_kind(num_workers=num_workers) # pylint: disable=invalid-name @I.ir_module @@ -231,7 +209,6 @@ def transpose_2( y_np = x_np.transpose() rx.build(TestMod, target="llvm").export_library(path) - sess = di.ThreadedSession(num_workers=num_workers) mod = sess.load_vm_module(path, device=device) x_disc = _numpy_to_worker_0(sess, x_np, device=device) @@ -244,11 +221,11 @@ def transpose_2( if __name__ == "__main__": - test_int() - test_float() - test_string() - test_string_obj() - test_shape_tuple() - test_ndarray() - test_vm_module() - test_vm_multi_func() + test_int(di.ProcessSession) + test_float(di.ProcessSession) + test_string(di.ProcessSession) + test_string_obj(di.ProcessSession) + test_shape_tuple(di.ProcessSession) + test_ndarray(di.ProcessSession) + test_vm_module(di.ProcessSession) + test_vm_multi_func(di.ProcessSession)