Skip to content

Commit

Permalink
Ensure all Python multiprocessing tests have timeouts (#300)
Browse files Browse the repository at this point in the history
Ensure all Python multiprocessing tests timeout during `join` and get terminated properly, appropriately raising an error if the subprocess failed to terminate cleanly.

To simplify joining of multiple processes, a new testing function `join_processes` was added, where all processes in the list will be joined with a common timeout, if the total time elapsed is longer than the timeout the process will still be joined but wait will be non-positive, meaning `join` returns immediately and the process is later terminated with `terminate_process`.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #300
  • Loading branch information
pentschev authored Oct 23, 2024
1 parent 122d2f4 commit 90c05dc
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 50 deletions.
5 changes: 2 additions & 3 deletions python/ucxx/ucxx/_lib/tests/test_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ucxx._lib.libucxx as ucx_api
from ucxx._lib.arr import Array
from ucxx.testing import terminate_process
from ucxx.testing import join_processes, terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -80,7 +80,6 @@ def test_message_probe():
args=(queue,),
)
client.start()
client.join(timeout=10)
server.join(timeout=10)
join_processes([client, server], timeout=10)
terminate_process(client)
terminate_process(server)
5 changes: 2 additions & 3 deletions python/ucxx/ucxx/_lib/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ucxx._lib.libucxx as ucx_api
from ucxx._lib.arr import Array
from ucxx.testing import terminate_process, wait_requests
from ucxx.testing import join_processes, terminate_process, wait_requests

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -108,7 +108,6 @@ def test_close_callback(server_close_callback):
args=(port, server_close_callback),
)
client.start()
client.join(timeout=10)
server.join(timeout=10)
join_processes([client, server], timeout=10)
terminate_process(client)
terminate_process(server)
5 changes: 2 additions & 3 deletions python/ucxx/ucxx/_lib/tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ucxx._lib import libucxx as ucx_api
from ucxx._lib.arr import Array
from ucxx.testing import terminate_process, wait_requests
from ucxx.testing import join_processes, terminate_process, wait_requests

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -128,7 +128,6 @@ def test_message_probe(transfer_api):
server.start()
client = mp.Process(target=_client_probe, args=(queue, transfer_api))
client.start()
client.join(timeout=10)
server.join(timeout=10)
join_processes([client, server], timeout=10)
terminate_process(client)
terminate_process(server)
35 changes: 34 additions & 1 deletion python/ucxx/ucxx/_lib/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import multiprocessing
import re
import time
from multiprocessing.queues import Empty

import pytest

from ucxx.testing import terminate_process
from ucxx.testing import join_processes, terminate_process


def _test_process(queue):
Expand Down Expand Up @@ -84,3 +85,35 @@ def test_terminate_process_kill_timeout(mp_context):
ValueError, match="Cannot close a process while it is still running.*"
):
terminate_process(proc, kill_wait=0.0)


@pytest.mark.parametrize("mp_context", ["default", "fork", "forkserver", "spawn"])
@pytest.mark.parametrize("num_processes", [1, 2, 4])
def test_join_processes(mp_context, num_processes):
mp = (
multiprocessing
if mp_context == "default"
else multiprocessing.get_context(mp_context)
)

queue = mp.Queue()
processes = []
for _ in range(num_processes):
proc = mp.Process(
target=_test_process,
args=(queue,),
)
proc.start()
processes.append(proc)

start = time.monotonic()
join_processes(processes, timeout=1.25)
total_time = time.monotonic() - start
assert total_time >= 1.25 and total_time < 2.5

for proc in processes:
try:
terminate_process(proc)
except RuntimeError:
# The process has to be killed and that will raise a `RuntimeError`
pass
9 changes: 4 additions & 5 deletions python/ucxx/ucxx/_lib_async/tests/test_benchmark_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from ucxx.benchmarks.utils import _run_cluster_server, _run_cluster_workers
from ucxx.testing import join_processes, terminate_process


async def _worker(rank, eps, args):
Expand Down Expand Up @@ -46,9 +47,7 @@ async def test_benchmark_cluster(n_chunks=1, n_nodes=2, n_workers=2):
)
)

join_processes(workers + [server], timeout=30)
for worker in workers:
worker.join()
assert not worker.exitcode

server.join()
assert not server.exitcode
terminate_process(worker)
terminate_process(server)
9 changes: 5 additions & 4 deletions python/ucxx/ucxx/_lib_async/tests/test_disconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ucxx
from ucxx._lib_async.utils import get_event_loop
from ucxx._lib_async.utils_test import wait_listener_client_handlers
from ucxx.testing import terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -127,9 +128,9 @@ def test_shutdown_unexpected_closed_peer(caplog, endpoint_error_handling):
args=(client_queue, server_queue, endpoint_error_handling),
)
p2.start()
p2.join()
p2.join(timeout=30)
server_queue.put("client is down")
p1.join()
p1.join(timeout=30)

assert not p1.exitcode
assert not p2.exitcode
terminate_process(p2)
terminate_process(p1)
18 changes: 7 additions & 11 deletions python/ucxx/ucxx/_lib_async/tests/test_from_worker_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import ucxx
from ucxx._lib_async.utils import get_event_loop, hash64bits
from ucxx.testing import join_processes, terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -90,11 +91,9 @@ def test_from_worker_address():
)
client.start()

client.join()
server.join()

assert not server.exitcode
assert not client.exitcode
join_processes([client, server], timeout=30)
terminate_process(client)
terminate_process(server)


def _get_address_info(address=None):
Expand Down Expand Up @@ -259,10 +258,7 @@ def test_from_worker_address_multinode(num_nodes):
client.start()
clients.append(client)

join_processes(clients + [server], timeout=30)
for client in clients:
client.join()

server.join()

assert not server.exitcode
assert not client.exitcode
terminate_process(client)
terminate_process(server)
32 changes: 17 additions & 15 deletions python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ucxx
from ucxx._lib_async.utils import get_event_loop
from ucxx.testing import join_processes, terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -179,18 +180,19 @@ def test_from_worker_address_error(error_type):
)
client.start()

server.join()
client.join()

assert not server.exitcode

if ucxx.get_ucx_version() < (1, 12, 0) and client.exitcode == 1:
if all(t in error_type for t in ["timeout", "send"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7527 with rc/ud."
)
elif all(t in error_type for t in ["timeout", "recv"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7531 with rc/ud."
)
assert not client.exitcode
join_processes([client, server], timeout=30)
terminate_process(server)
try:
terminate_process(client)
except RuntimeError as e:
if ucxx.get_ucx_version() < (1, 12, 0):
if all(t in error_type for t in ["timeout", "send"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7527 with rc/ud."
)
elif all(t in error_type for t in ["timeout", "recv"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7531 with rc/ud."
)
else:
raise e
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
send,
wait_listener_client_handlers,
)
from ucxx.testing import join_processes, terminate_process

cupy = pytest.importorskip("cupy")
rmm = pytest.importorskip("rmm")
Expand Down Expand Up @@ -240,8 +241,6 @@ def test_send_recv_cu(cuda_obj_generator, comm_api):
os.environ.update(env_client)
client_process.start()

server_process.join()
client_process.join()

assert server_process.exitcode == 0
assert client_process.exitcode == 0
join_processes([client, server], timeout=30)
terminate_process(client)
terminate_process(server)
24 changes: 24 additions & 0 deletions python/ucxx/ucxx/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,30 @@
from typing import Type, Union


def join_processes(
processes: list[Type[BaseProcess]],
timeout: Union[float, int],
) -> None:
"""
Join a list of processes with a combined timeout.
Join a list of processes with a combined timeout, for each process `join()`
is called with a timeout equal to the difference of `timeout` and the time
elapsed since this function was called.
Parameters
----------
processes:
The list of processes to be joined.
timeout: float or integer
Maximum time to wait for all the processes to be joined.
"""
start = time.monotonic()
for p in processes:
t = timeout - (time.monotonic() - start)
p.join(timeout=t)


def terminate_process(
process: Type[BaseProcess], kill_wait: Union[float, int] = 3.0
) -> None:
Expand Down

0 comments on commit 90c05dc

Please sign in to comment.