Skip to content

Commit

Permalink
ENH: stop using daemon=True for subpool (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Mar 14, 2024
1 parent 2a041ae commit 297f8b8
Show file tree
Hide file tree
Showing 11 changed files with 358 additions and 207 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.2.0
hooks:
- id: black
files: python/xoscar
Expand All @@ -24,7 +24,7 @@ repos:
args: [--sp, python/setup.cfg]
files: python/xoscar
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies: [tokenize-rt==3.2.0]
Expand Down
6 changes: 3 additions & 3 deletions python/xoscar/backends/communication/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ class DummyServer(Server):
else tuple()
)

_address_to_instances: weakref.WeakValueDictionary[
str, "DummyServer"
] = weakref.WeakValueDictionary()
_address_to_instances: weakref.WeakValueDictionary[str, "DummyServer"] = (
weakref.WeakValueDictionary()
)
_channels: list[ChannelType]
_tasks: list[asyncio.Task]
scheme: str | None = "dummy"
Expand Down
40 changes: 30 additions & 10 deletions python/xoscar/backends/indigen/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,19 @@ async def start_sub_pool(
def start_pool_in_process():
ctx = multiprocessing.get_context(method=start_method)
status_queue = ctx.Queue()
main_pool_pid = os.getpid()

with _suspend_init_main():
process = ctx.Process(
target=cls._start_sub_pool,
args=(actor_pool_config, process_index, status_queue),
args=(
actor_pool_config,
process_index,
status_queue,
main_pool_pid,
),
name=f"IndigenActorPool{process_index}",
)
process.daemon = True
process.start()

# wait for sub actor pool to finish starting
Expand All @@ -209,15 +214,22 @@ def start_pool_in_process():

@classmethod
async def wait_sub_pools_ready(cls, create_pool_tasks: List[asyncio.Task]):
processes = []
processes: list[multiprocessing.Process] = []
ext_addresses = []
error = None
for task in create_pool_tasks:
process, status = await task
processes.append(process)
if status.status == 1:
# start sub pool failed
raise status.error.with_traceback(status.traceback)
processes.append(process)
ext_addresses.append(status.external_addresses)
error = status.error.with_traceback(status.traceback)
else:
ext_addresses.append(status.external_addresses)
if error:
for p in processes:
# error happens, kill all subprocesses
p.kill()
raise error
return processes, ext_addresses

@classmethod
Expand All @@ -226,6 +238,7 @@ def _start_sub_pool(
actor_config: ActorPoolConfig,
process_index: int,
status_queue: multiprocessing.Queue,
main_pool_pid: int,
):
ensure_coverage()

Expand Down Expand Up @@ -259,7 +272,9 @@ def _start_sub_pool(
else:
asyncio.set_event_loop(asyncio.new_event_loop())

coro = cls._create_sub_pool(actor_config, process_index, status_queue)
coro = cls._create_sub_pool(
actor_config, process_index, status_queue, main_pool_pid
)
asyncio.run(coro)

@classmethod
Expand All @@ -268,6 +283,7 @@ async def _create_sub_pool(
actor_config: ActorPoolConfig,
process_index: int,
status_queue: multiprocessing.Queue,
main_pool_pid: int,
):
process_status = None
try:
Expand All @@ -276,7 +292,11 @@ async def _create_sub_pool(
if env:
os.environ.update(env)
pool = await SubActorPool.create(
{"actor_pool_config": actor_config, "process_index": process_index}
{
"actor_pool_config": actor_config,
"process_index": process_index,
"main_pool_pid": main_pool_pid,
}
)
external_addresses = cur_pool_config["external_address"]
process_status = SubpoolStatus(
Expand Down Expand Up @@ -342,14 +362,14 @@ async def append_sub_pool(
def start_pool_in_process():
ctx = multiprocessing.get_context(method=start_method)
status_queue = ctx.Queue()
main_pool_pid = os.getpid()

with _suspend_init_main():
process = ctx.Process(
target=self._start_sub_pool,
args=(self._config, process_index, status_queue),
args=(self._config, process_index, status_queue, main_pool_pid),
name=f"IndigenActorPool{process_index}",
)
process.daemon = True
process.start()

# wait for sub actor pool to finish starting
Expand Down
79 changes: 79 additions & 0 deletions python/xoscar/backends/indigen/tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

import asyncio
import logging
import multiprocessing
import os
import re
import sys
import time
from unittest import mock

import psutil
import pytest

from .... import Actor, create_actor, create_actor_ref, get_pool_config, kill_actor
Expand Down Expand Up @@ -1099,3 +1101,80 @@ def test():
assert process_index not in config.get_process_indexes()
with pytest.raises(KeyError):
config.get_pool_config(process_index)


async def _run(started: multiprocessing.Event): # type: ignore
pool = await create_actor_pool( # type: ignore
"127.0.0.1", pool_cls=MainActorPool, n_process=1
)

class DummyActor(Actor):
@staticmethod
def test():
return "this is dummy!"

ref = await create_actor(
DummyActor, address=pool.external_address, allocate_strategy=RandomSubPool()
)
assert ref is not None

started.set() # type: ignore
await pool.join()


def _run_in_process(started: multiprocessing.Event): # type: ignore
asyncio.run(_run(started))


@pytest.mark.asyncio
async def test_sub_pool_quit_with_main_pool():
s = multiprocessing.Event()
p = multiprocessing.Process(target=_run_in_process, args=(s,))
p.start()
s.wait()

processes = psutil.Process(p.pid).children()
assert len(processes) == 1

# kill main process
p.kill()
p.join()
await asyncio.sleep(1)

# subprocess should have died
assert not psutil.pid_exists(processes[0].pid)


def _add(x: int) -> int:
return x + 1


class _ProcessActor(Actor):
def run(self, x: int):
p = multiprocessing.Process(target=_add, args=(x,))
p.start()
p.join()
return x + 1


@pytest.mark.asyncio
async def test_process_in_actor():
start_method = (
os.environ.get("POOL_START_METHOD", "forkserver")
if sys.platform != "win32"
else None
)
pool = await create_actor_pool( # type: ignore
"127.0.0.1",
pool_cls=MainActorPool,
n_process=1,
subprocess_start_method=start_method,
)

async with pool:
ref = await create_actor(
_ProcessActor,
address=pool.external_address,
allocate_strategy=RandomSubPool(),
)
assert 2 == await ref.run(1)
34 changes: 33 additions & 1 deletion python/xoscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Coroutine, Optional, Type, TypeVar

import psutil

from .._utils import TypeDispatcher, create_actor_ref, to_binary
from ..api import Actor
from ..core import ActorRef, BufferRef, FileObjectRef, register_local_pool
Expand Down Expand Up @@ -821,7 +823,8 @@ def handle_channel(channel):


class SubActorPoolBase(ActorPoolBase):
__slots__ = ("_main_address",)
__slots__ = ("_main_address", "_watch_main_pool_task")
_watch_main_pool_task: Optional[asyncio.Task]

def __init__(
self,
Expand All @@ -834,6 +837,7 @@ def __init__(
config: ActorPoolConfig,
servers: list[Server],
main_address: str,
main_pool_pid: Optional[int],
):
super().__init__(
process_index,
Expand All @@ -846,6 +850,26 @@ def __init__(
servers,
)
self._main_address = main_address
if main_pool_pid:
self._watch_main_pool_task = asyncio.create_task(
self._watch_main_pool(main_pool_pid)
)
else:
self._watch_main_pool_task = None

async def _watch_main_pool(self, main_pool_pid: int):
main_process = psutil.Process(main_pool_pid)
while not self.stopped:
try:
await asyncio.to_thread(main_process.status)
await asyncio.sleep(0.1)
continue
except (psutil.NoSuchProcess, ProcessLookupError, asyncio.CancelledError):
# main pool died
break

if not self.stopped:
await self.stop()

async def notify_main_pool_to_destroy(
self, message: DestroyActorMessage
Expand Down Expand Up @@ -900,14 +924,22 @@ async def handle_control_command(

@staticmethod
def _parse_config(config: dict, kw: dict) -> dict:
main_pool_pid = config.pop("main_pool_pid", None)
kw = AbstractActorPool._parse_config(config, kw)
pool_config: ActorPoolConfig = kw["config"]
main_process_index = pool_config.get_process_indexes()[0]
kw["main_address"] = pool_config.get_pool_config(main_process_index)[
"external_address"
][0]
kw["main_pool_pid"] = main_pool_pid
return kw

async def stop(self):
await super().stop()
if self._watch_main_pool_task:
self._watch_main_pool_task.cancel()
await self._watch_main_pool_task


class MainActorPoolBase(ActorPoolBase):
__slots__ = (
Expand Down
9 changes: 7 additions & 2 deletions python/xoscar/backends/test/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def start_sub_pool(
status_queue: multiprocessing.Queue = multiprocessing.Queue()
return (
asyncio.create_task(
cls._create_sub_pool(actor_pool_config, process_index, status_queue)
cls._create_sub_pool(actor_pool_config, process_index, status_queue, 0)
),
status_queue,
)
Expand All @@ -77,9 +77,14 @@ async def _create_sub_pool(
actor_config: ActorPoolConfig,
process_index: int,
status_queue: multiprocessing.Queue,
main_pool_pid: int,
):
pool: TestSubActorPool = await TestSubActorPool.create(
{"actor_pool_config": actor_config, "process_index": process_index}
{
"actor_pool_config": actor_config,
"process_index": process_index,
"main_pool_pid": main_pool_pid,
}
)
await pool.start()
status_queue.put(
Expand Down
2 changes: 1 addition & 1 deletion python/xoscar/backends/test/tests/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def _copy_test(scheme1: Optional[str], scheme2: Optional[str], cpu: bool):
external_address_schemes=[None, scheme1, scheme2],
)

async with pool:
async with pool, pool2:
ctx = get_context()

# actor on main pool
Expand Down
36 changes: 18 additions & 18 deletions python/xoscar/collective/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,30 @@ class AllReduceAlgorithm(IntEnum):


TypeMappingGloo: Dict[Type[np.dtype], "xp.GlooDataType_t"] = {
np.int8: xp.GlooDataType_t.glooInt8,
np.uint8: xp.GlooDataType_t.glooUint8,
np.int32: xp.GlooDataType_t.glooInt32,
np.uint32: xp.GlooDataType_t.glooUint32,
np.int64: xp.GlooDataType_t.glooInt64,
np.uint64: xp.GlooDataType_t.glooUint64,
np.float16: xp.GlooDataType_t.glooFloat16,
np.float32: xp.GlooDataType_t.glooFloat32,
np.float64: xp.GlooDataType_t.glooFloat64,
np.int8: xp.GlooDataType_t.glooInt8, # type: ignore
np.uint8: xp.GlooDataType_t.glooUint8, # type: ignore
np.int32: xp.GlooDataType_t.glooInt32, # type: ignore
np.uint32: xp.GlooDataType_t.glooUint32, # type: ignore
np.int64: xp.GlooDataType_t.glooInt64, # type: ignore
np.uint64: xp.GlooDataType_t.glooUint64, # type: ignore
np.float16: xp.GlooDataType_t.glooFloat16, # type: ignore
np.float32: xp.GlooDataType_t.glooFloat32, # type: ignore
np.float64: xp.GlooDataType_t.glooFloat64, # type: ignore
}
cupy = lazy_import("cupy")
if cupy is not None:
from cupy.cuda import nccl

TypeMappingNCCL: Dict[Type[np.dtype], int] = {
np.int8: nccl.NCCL_INT8,
np.uint8: nccl.NCCL_UINT8,
np.int32: nccl.NCCL_INT32,
np.uint32: nccl.NCCL_UINT32,
np.int64: nccl.NCCL_INT64,
np.uint64: nccl.NCCL_UINT64,
np.float16: nccl.NCCL_FLOAT16,
np.float32: nccl.NCCL_FLOAT32,
np.float64: nccl.NCCL_FLOAT64,
np.int8: nccl.NCCL_INT8, # type: ignore
np.uint8: nccl.NCCL_UINT8, # type: ignore
np.int32: nccl.NCCL_INT32, # type: ignore
np.uint32: nccl.NCCL_UINT32, # type: ignore
np.int64: nccl.NCCL_INT64, # type: ignore
np.uint64: nccl.NCCL_UINT64, # type: ignore
np.float16: nccl.NCCL_FLOAT16, # type: ignore
np.float32: nccl.NCCL_FLOAT32, # type: ignore
np.float64: nccl.NCCL_FLOAT64, # type: ignore
}

ReduceOpMappingNCCL: Dict[CollectiveReduceOp, int] = {
Expand Down
Loading

0 comments on commit 297f8b8

Please sign in to comment.