-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Disco] Pipe-based Multi-processing Session (#15727)
This PR introduces `ProcessSession`, a new session implementation based on multi-processing. `ProcessSession` shares exactly the same communication protocol with `ThreadedSession`, but all workers except for worker 0 are launched in a separate process than thread. Workers communicate with the controller via pipe provided by the OS, rather than SPSC message queue between threads. In our implementation, Python's `subproces.popen` is used to create subprocesses, and the Python executable, or more specifically, `sys.executable` calls into `tvm.exec.disco_worker` as the entrypoint. Besides the launching logic that is only executed once in the very beginning, the rest of the implementation resides in a C++-only environment, including reads/writes to pipe file descriptors, serialization and deserialization of messages, worker interpretation of each message, etc. Detailed engineering elements included in this PR: - Refactors the MinRPC-based communication protocol out to be shared by `ProcessSession` and `ThreadedSession` as `protocol.h`; - Refactors a controller-side worker thread into `DiscoWorkerThread`, which is shared by both session implementation to launch worker-0; - Added two instructions `kDebugGetFromRemote` and `kDebugSetRegister`, which are used to communicate with workers other than worker-0 in debug mode; - Introduces multi-processing infra including: `tvm.exec.disco_worker` serving as the entrypoint that launches workers, and `tvm/runtime/disco/process_pool.py` that exposes APIs to launch worker processes. `tvm.exec.disco_worker` calls into a global function `runtime.disco.WorkerProcess` that executes the worker main loop in pure C++; - Introduces `src/support/process_id.h` that provides cross-platform pid and tid printing utilities; - Refactors Disco's NCCL integration that get rids of initialized-once global NCCL context, and switches to broadcasting `ncclUniqueId` from controller to all workers, and then create NCCL communicators in each worker thread/process accordingly. This is a thread/process-agnostic way of using NCCL.
- Loading branch information
Showing
21 changed files
with
1,149 additions
and
290 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name | ||
"""Internal DiscoWorker for Disco ProcessSession.""" | ||
import os | ||
import sys | ||
|
||
from tvm import runtime as _ # pylint: disable=unused-import | ||
from tvm._ffi import get_global_func | ||
from tvm.testing import disco as _ # pylint: disable=unused-import | ||
|
||
|
||
def main(): | ||
"""Main worker function""" | ||
if len(sys.argv) != 5: | ||
print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>") | ||
return | ||
worker_id = int(sys.argv[1]) | ||
num_workers = int(sys.argv[2]) | ||
if sys.platform == "win32": | ||
import msvcrt # pylint: disable=import-outside-toplevel,import-error | ||
|
||
reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) | ||
writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) | ||
else: | ||
reader = int(sys.argv[3]) | ||
writer = int(sys.argv[4]) | ||
|
||
worker_func = get_global_func("runtime.disco.WorkerProcess") | ||
worker_func(worker_id, num_workers, reader, writer) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
main() | ||
except (KeyboardInterrupt, IOError): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name | ||
"""Pipe worker for multi-processing.""" | ||
import os | ||
import subprocess | ||
import sys | ||
|
||
import psutil | ||
|
||
from tvm._ffi import register_func | ||
from tvm.runtime import ShapeTuple | ||
|
||
|
||
class DiscoPopenWorker: | ||
"""A subprocess worker via Popen. | ||
PopenWorker provides a low-level | ||
API to interact with a separate process via Popen. | ||
Parameters | ||
---------- | ||
worker_id : int | ||
The worker id of the current worker. | ||
num_workers : int | ||
The total number of workers. | ||
stdout: Union[None, int, IO[Any]] | ||
The standard output streams handler specified for the popen process. | ||
stderr: Union[None, int, IO[Any]] | ||
The standard error streams handler specified for the popen process. | ||
""" | ||
|
||
def __init__(self, worker_id: int, num_workers: int, stdout=None, stderr=None): | ||
self.worker_id = worker_id | ||
self.num_workers = num_workers | ||
self._proc = None | ||
self._stdout = stdout | ||
self._stderr = stderr | ||
|
||
def __del__(self): | ||
try: | ||
self.kill() | ||
except ImportError: | ||
pass | ||
|
||
def kill(self): | ||
"""Kill the current running process and cleanup. | ||
Note | ||
---- | ||
The worker can start a new process when send is called again. | ||
""" | ||
if self._proc is not None: | ||
# kill all child processes recursively | ||
try: | ||
_kill_child_processes(self._proc.pid) | ||
except TypeError: | ||
pass | ||
try: | ||
self._proc.kill() | ||
except OSError: | ||
pass | ||
|
||
# Join the child process to avoid zombie processes | ||
self.join(timeout=1.0) | ||
self._proc = None | ||
|
||
def join(self, timeout=None): | ||
"""Join the current process worker before it terminates. | ||
Parameters | ||
---------- | ||
timeout: Optional[number] | ||
Timeout value, block at most timeout seconds if it | ||
is a positive number. | ||
""" | ||
if self._proc: | ||
try: | ||
self._proc.wait(timeout) | ||
except subprocess.TimeoutExpired: | ||
pass | ||
|
||
def start(self): | ||
"""Start a new subprocess if nothing is available""" | ||
if self._proc is not None: | ||
return None, None | ||
|
||
# connect subprocess with a pair of pipes | ||
main_read, worker_write = os.pipe() | ||
worker_read, main_write = os.pipe() | ||
|
||
cmd = [ | ||
sys.executable, | ||
"-m", | ||
"tvm.exec.disco_worker", | ||
str(self.worker_id), | ||
str(self.num_workers), | ||
] | ||
if sys.platform == "win32": | ||
import msvcrt # pylint: disable=import-error,import-outside-toplevel | ||
|
||
worker_read_handle = msvcrt.get_osfhandle(worker_read) | ||
worker_write_handle = msvcrt.get_osfhandle(worker_write) | ||
os.set_handle_inheritable(worker_read_handle, True) | ||
os.set_handle_inheritable(worker_write_handle, True) | ||
cmd += [str(worker_read_handle), str(worker_write_handle)] | ||
self._proc = subprocess.Popen( | ||
cmd, | ||
close_fds=False, | ||
stdout=self._stdout, | ||
stderr=self._stderr, | ||
) | ||
else: | ||
cmd += [str(worker_read), str(worker_write)] | ||
self._proc = subprocess.Popen( # pylint: disable=consider-using-with | ||
cmd, | ||
pass_fds=(worker_read, worker_write), | ||
stdout=self._stdout, | ||
stderr=self._stderr, | ||
) | ||
|
||
# close worker side of the pipe | ||
os.close(worker_read) | ||
os.close(worker_write) | ||
return main_read, main_write | ||
|
||
|
||
def _kill_child_processes(pid): | ||
"""Kill all child processes recursively for a given pid. | ||
Parameters | ||
---------- | ||
pid : int | ||
The given parameter id. | ||
""" | ||
try: | ||
parent = psutil.Process(pid) | ||
children = parent.children(recursive=True) | ||
except psutil.NoSuchProcess: | ||
return | ||
|
||
for process in children: | ||
try: | ||
process.kill() | ||
except psutil.NoSuchProcess: | ||
pass | ||
|
||
|
||
@register_func("runtime.disco.create_process_pool") | ||
def _create_process_pool(num_workers: int): | ||
"""Create a process pool where the workers' are are [1, num_workers).""" | ||
pool = [DiscoPopenWorker(i, num_workers) for i in range(1, num_workers)] | ||
|
||
def result_func(worker_id: int): | ||
nonlocal pool | ||
if worker_id != 0: | ||
read_fd, write_fd = pool[worker_id - 1].start() | ||
return ShapeTuple([read_fd, write_fd]) | ||
print("Shutting down the process pool") | ||
del pool | ||
return None | ||
|
||
return result_func |
Oops, something went wrong.