Skip to content

Commit

Permalink
[Disco] Pipe-based Multi-processing Session (#15727)
Browse files Browse the repository at this point in the history
This PR introduces `ProcessSession`, a new session implementation based
on multi-processing.

`ProcessSession` shares exactly the same communication protocol with
`ThreadedSession`, but all workers except for worker 0 are launched in a
separate process than thread. Workers communicate with the controller
via pipe provided by the OS, rather than SPSC message queue between
threads.

In our implementation, Python's `subproces.popen` is used to create
subprocesses, and the Python executable, or more specifically,
`sys.executable` calls into `tvm.exec.disco_worker` as the entrypoint.
Besides the launching logic that is only executed once in the very
beginning, the rest of the implementation resides in a C++-only
environment, including reads/writes to pipe file descriptors,
serialization and deserialization of messages, worker interpretation of
each message, etc.

Detailed engineering elements included in this PR:
- Refactors the MinRPC-based communication protocol out to be shared by
  `ProcessSession` and `ThreadedSession` as `protocol.h`;
- Refactors a controller-side worker thread into `DiscoWorkerThread`,
  which is shared by both session implementation to launch worker-0;
- Added two instructions `kDebugGetFromRemote` and `kDebugSetRegister`,
  which are used to communicate with workers other than worker-0 in
  debug mode;
- Introduces multi-processing infra including: `tvm.exec.disco_worker`
  serving as the entrypoint that launches workers, and
  `tvm/runtime/disco/process_pool.py` that exposes APIs to launch worker
  processes. `tvm.exec.disco_worker` calls into a global function
  `runtime.disco.WorkerProcess` that executes the worker main loop in
  pure C++;
- Introduces `src/support/process_id.h` that provides cross-platform pid
  and tid printing utilities;
- Refactors Disco's NCCL integration that get rids of initialized-once
  global NCCL context, and switches to broadcasting `ncclUniqueId` from
  controller to all workers, and then create NCCL communicators in each
  worker thread/process accordingly. This is a thread/process-agnostic
  way of using NCCL.
  • Loading branch information
junrushao authored Sep 14, 2023
1 parent e0e6c1d commit 40b9a92
Show file tree
Hide file tree
Showing 21 changed files with 1,149 additions and 290 deletions.
50 changes: 44 additions & 6 deletions include/tvm/runtime/disco/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
#ifndef TVM_RUNTIME_DISCO_SESSION_H_
#define TVM_RUNTIME_DISCO_SESSION_H_

#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

Expand All @@ -92,6 +93,8 @@ enum class DiscoAction : int32_t {
kSyncWorker = 4,
kCopyFromWorker0 = 5,
kCopyToWorker0 = 6,
kDebugGetFromRemote = 7,
kDebugSetRegister = 8,
};

/*! \brief Converts the enum class `DiscoAction` to string */
Expand All @@ -111,6 +114,10 @@ inline std::string DiscoAction2String(DiscoAction action) {
return "kCopyFromWorker0";
case DiscoAction::kCopyToWorker0:
return "kCopyToWorker0";
case DiscoAction::kDebugGetFromRemote:
return "kDebugGetFromRemote";
case DiscoAction::kDebugSetRegister:
return "kDebugSetRegister";
}
LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast<int>(action);
}
Expand All @@ -136,7 +143,7 @@ class DRefObj : public Object {
* \param worker_id The id of the worker to be copied to.
* \param source The NDArray to be copied.
*/
void DebugCopyFrom(int worker_id, NDArray source);
inline void DebugCopyFrom(int worker_id, TVMArgValue source);

static constexpr const char* _type_key = "runtime.disco.DRef";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef;
Expand Down Expand Up @@ -213,20 +220,32 @@ class SessionObj : public Object {
virtual void SyncWorker(int worker_id) = 0;
/*! \brief Signal all the workers to shutdown */
virtual void Shutdown() = 0;
/*!
* \brief Initialize the data plane between workers.
* \param ccl The name of the communication backend, e.g., nccl, rccl, mpi.
* \param device_ids The device ids of the workers.
*/
virtual void InitCCL(String ccl, ShapeTuple device_ids) = 0;
/*!
* \brief Get the value of a register from a remote worker.
* \param reg_id The id of the register to be fetched.
* \param worker_id The id of the worker to be fetched from.
* \return The value of the register.
*/
virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;

static constexpr const char* _type_key = "runtime.disco.Session";
TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object);
/*!
* \brief Set the value of a register on a remote worker.
* \param reg_id The id of the register to be set.
* \param value The value to be set.
* \param worker_id The id of the worker to be set.
*/
virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0;

struct FFI;
friend struct SessionObj::FFI;
friend class DRefObj;
static constexpr const char* _type_key = "runtime.disco.Session";
TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object);

protected:
/*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */
Expand All @@ -239,8 +258,22 @@ class SessionObj : public Object {
*/
class Session : public ObjectRef {
public:
/*! \brief Create a session backed by a thread pool of workers */
static Session ThreadedSession(int num_workers);
/*!
* \brief Create a session backed by a thread pool of workers
* \param num_workers The number of workers.
*/
TVM_DLL static Session ThreadedSession(int num_workers);
/*!
* \brief Create a session backed by pipe-based multiprocessing
* \param num_workers The number of workers.
* \param process_pool_creator The name of a global function that takes `num_workers` as an input,
* and returns a PackedFunc, which takes an integer `worker_id` as the input and returns None.
* When `worker-id` is 0, it shuts down the process pool; Otherwise, it retursn a tuple
* (read_fd, writefd) used to communicate with the corresponding worker.
* \note Worker-0 is always co-located with the controler as a separate thread, and therefore
* worker-0 does not exist in the process pool.
*/
TVM_DLL static Session ProcessSession(int num_workers, String process_pool_creator);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj);
};

Expand All @@ -250,6 +283,7 @@ class Session : public ObjectRef {
*/
class DiscoChannel {
public:
virtual ~DiscoChannel() = default;
/*! \brief Send a packed sequence to the receiver */
virtual void Send(const TVMArgs& args) = 0;
/*! \brief Receive a packed sequence from worker */
Expand All @@ -272,6 +306,10 @@ TVMRetValue DRefObj::DebugGetFromRemote(int worker_id) {
return Downcast<Session>(this->session)->DebugGetFromRemote(this->reg_id, worker_id);
}

void DRefObj::DebugCopyFrom(int worker_id, TVMArgValue value) {
return Downcast<Session>(this->session)->DebugSetRegister(this->reg_id, value, worker_id);
}

template <typename... Args>
DRef SessionObj::CallPacked(const DRef& func, Args&&... args) {
constexpr int offset = 3;
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/exec/disco_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Internal DiscoWorker for Disco ProcessSession."""
import os
import sys

from tvm import runtime as _ # pylint: disable=unused-import
from tvm._ffi import get_global_func
from tvm.testing import disco as _ # pylint: disable=unused-import


def main():
"""Main worker function"""
if len(sys.argv) != 5:
print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
return
worker_id = int(sys.argv[1])
num_workers = int(sys.argv[2])
if sys.platform == "win32":
import msvcrt # pylint: disable=import-outside-toplevel,import-error

reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
else:
reader = int(sys.argv[3])
writer = int(sys.argv[4])

worker_func = get_global_func("runtime.disco.WorkerProcess")
worker_func(worker_id, num_workers, reader, writer)


if __name__ == "__main__":
try:
main()
except (KeyboardInterrupt, IOError):
pass
9 changes: 8 additions & 1 deletion python/tvm/runtime/disco/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""TVM distributed runtime API."""
from .session import DModule, DPackedFunc, DRef, Session, ThreadedSession
from .session import (
DModule,
DPackedFunc,
DRef,
ProcessSession,
Session,
ThreadedSession,
)
180 changes: 180 additions & 0 deletions python/tvm/runtime/disco/process_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Pipe worker for multi-processing."""
import os
import subprocess
import sys

import psutil

from tvm._ffi import register_func
from tvm.runtime import ShapeTuple


class DiscoPopenWorker:
"""A subprocess worker via Popen.
PopenWorker provides a low-level
API to interact with a separate process via Popen.
Parameters
----------
worker_id : int
The worker id of the current worker.
num_workers : int
The total number of workers.
stdout: Union[None, int, IO[Any]]
The standard output streams handler specified for the popen process.
stderr: Union[None, int, IO[Any]]
The standard error streams handler specified for the popen process.
"""

def __init__(self, worker_id: int, num_workers: int, stdout=None, stderr=None):
self.worker_id = worker_id
self.num_workers = num_workers
self._proc = None
self._stdout = stdout
self._stderr = stderr

def __del__(self):
try:
self.kill()
except ImportError:
pass

def kill(self):
"""Kill the current running process and cleanup.
Note
----
The worker can start a new process when send is called again.
"""
if self._proc is not None:
# kill all child processes recursively
try:
_kill_child_processes(self._proc.pid)
except TypeError:
pass
try:
self._proc.kill()
except OSError:
pass

# Join the child process to avoid zombie processes
self.join(timeout=1.0)
self._proc = None

def join(self, timeout=None):
"""Join the current process worker before it terminates.
Parameters
----------
timeout: Optional[number]
Timeout value, block at most timeout seconds if it
is a positive number.
"""
if self._proc:
try:
self._proc.wait(timeout)
except subprocess.TimeoutExpired:
pass

def start(self):
"""Start a new subprocess if nothing is available"""
if self._proc is not None:
return None, None

# connect subprocess with a pair of pipes
main_read, worker_write = os.pipe()
worker_read, main_write = os.pipe()

cmd = [
sys.executable,
"-m",
"tvm.exec.disco_worker",
str(self.worker_id),
str(self.num_workers),
]
if sys.platform == "win32":
import msvcrt # pylint: disable=import-error,import-outside-toplevel

worker_read_handle = msvcrt.get_osfhandle(worker_read)
worker_write_handle = msvcrt.get_osfhandle(worker_write)
os.set_handle_inheritable(worker_read_handle, True)
os.set_handle_inheritable(worker_write_handle, True)
cmd += [str(worker_read_handle), str(worker_write_handle)]
self._proc = subprocess.Popen(
cmd,
close_fds=False,
stdout=self._stdout,
stderr=self._stderr,
)
else:
cmd += [str(worker_read), str(worker_write)]
self._proc = subprocess.Popen( # pylint: disable=consider-using-with
cmd,
pass_fds=(worker_read, worker_write),
stdout=self._stdout,
stderr=self._stderr,
)

# close worker side of the pipe
os.close(worker_read)
os.close(worker_write)
return main_read, main_write


def _kill_child_processes(pid):
"""Kill all child processes recursively for a given pid.
Parameters
----------
pid : int
The given parameter id.
"""
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
except psutil.NoSuchProcess:
return

for process in children:
try:
process.kill()
except psutil.NoSuchProcess:
pass


@register_func("runtime.disco.create_process_pool")
def _create_process_pool(num_workers: int):
"""Create a process pool where the workers' are are [1, num_workers)."""
pool = [DiscoPopenWorker(i, num_workers) for i in range(1, num_workers)]

def result_func(worker_id: int):
nonlocal pool
if worker_id != 0:
read_fd, write_fd = pool[worker_id - 1].start()
return ShapeTuple([read_fd, write_fd])
print("Shutting down the process pool")
del pool
return None

return result_func
Loading

0 comments on commit 40b9a92

Please sign in to comment.