Skip to content

Commit

Permalink
chore: simplify process pool shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Nov 21, 2024
1 parent 4abc6db commit 00d209d
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions icij-worker/icij_worker/backend/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
as_completed,
)
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Set, Tuple

from pydantic.class_validators import partial
from typing import Callable, Dict, List, Optional, Set

import icij_worker
from icij_common.logging_utils import setup_loggers
Expand Down Expand Up @@ -93,7 +91,7 @@ def _get_mp_async_runner(
worker_extras: Optional[Dict] = None,
app_deps_extras: Optional[Dict] = None,
worker_group: Optional[str],
) -> Tuple[Optional[TerminationCallback], List[Callable[[], Future]]]:
) -> List[Callable[[], Future]]:
# This function is here to avoid code duplication, it will be removed

# Here we set maxtasksperchild to 1. Each worker has a single never ending task
Expand All @@ -112,7 +110,7 @@ def _get_mp_async_runner(
futures = []
for _ in range(n_workers):
futures.append(functools.partial(executor.submit, _mp_work_forever, **kwds))
return partial(executor.shutdown, wait=False, cancel_futures=True), futures
return futures


def _cancel_other_callback(errored: Future, others: List[Future]):
Expand Down Expand Up @@ -186,7 +184,7 @@ def run_workers_with_multiprocessing_cm(
)
return
logger.info("Creating multiprocessing executor with %s workers", n_workers)
termination_cb, worker_runners = _get_mp_async_runner(
worker_runners = _get_mp_async_runner(
app,
config,
n_workers,
Expand All @@ -202,15 +200,14 @@ def run_workers_with_multiprocessing_cm(
f.add_done_callback(functools.partial(_cancel_other_callback, others=futures))
logger.info("started %s workers for app %s", n_workers, app)
original_error = None
with _handle_executor_termination(termination_cb, futures, True):
for f in as_completed(futures):
try:
f.result()
except CancelledError:
pass
except Exception as e: # pylint: disable=broad-exception-caught
original_error = e
del futures
for f in as_completed(futures):
try:
f.result()
except CancelledError:
pass
except Exception as e: # pylint: disable=broad-exception-caught
original_error = e
del futures
if original_error:
raise original_error

Expand All @@ -237,7 +234,7 @@ def run_workers_with_multiprocessing(
)
return
logger.info("Creating multiprocessing executor with %s workers", n_workers)
termination_cb, worker_runners = _get_mp_async_runner(
worker_runners = _get_mp_async_runner(
app,
config,
n_workers,
Expand All @@ -254,14 +251,13 @@ def run_workers_with_multiprocessing(
f.add_done_callback(functools.partial(_cancel_other_callback, others=futures))
logger.info("started %s workers for app %s", n_workers, app)
original_error = None
with _handle_executor_termination(termination_cb, futures, True):
for f in as_completed(futures):
try:
f.result()
except CancelledError:
pass
except Exception as e: # pylint: disable=broad-exception-caught
original_error = e
del futures
for f in as_completed(futures):
try:
f.result()
except CancelledError:
pass
except Exception as e: # pylint: disable=broad-exception-caught
original_error = e
del futures
if original_error:
raise original_error

0 comments on commit 00d209d

Please sign in to comment.