diff --git a/icij-worker/icij_worker/backend/mp.py b/icij-worker/icij_worker/backend/mp.py index baea08e..424460e 100644 --- a/icij-worker/icij_worker/backend/mp.py +++ b/icij-worker/icij_worker/backend/mp.py @@ -11,9 +11,7 @@ as_completed, ) from contextlib import contextmanager -from typing import Callable, Dict, List, Optional, Set, Tuple - -from pydantic.class_validators import partial +from typing import Callable, Dict, List, Optional, Set import icij_worker from icij_common.logging_utils import setup_loggers @@ -93,7 +91,7 @@ def _get_mp_async_runner( worker_extras: Optional[Dict] = None, app_deps_extras: Optional[Dict] = None, worker_group: Optional[str], -) -> Tuple[Optional[TerminationCallback], List[Callable[[], Future]]]: +) -> List[Callable[[], Future]]: # This function is here to avoid code duplication, it will be removed # Here we set maxtasksperchild to 1. Each worker has a single never ending task @@ -112,7 +110,7 @@ def _get_mp_async_runner( futures = [] for _ in range(n_workers): futures.append(functools.partial(executor.submit, _mp_work_forever, **kwds)) - return partial(executor.shutdown, wait=False, cancel_futures=True), futures + return futures def _cancel_other_callback(errored: Future, others: List[Future]): @@ -186,7 +184,7 @@ def run_workers_with_multiprocessing_cm( ) return logger.info("Creating multiprocessing executor with %s workers", n_workers) - termination_cb, worker_runners = _get_mp_async_runner( + worker_runners = _get_mp_async_runner( app, config, n_workers, @@ -202,15 +200,14 @@ def run_workers_with_multiprocessing_cm( f.add_done_callback(functools.partial(_cancel_other_callback, others=futures)) logger.info("started %s workers for app %s", n_workers, app) original_error = None - with _handle_executor_termination(termination_cb, futures, True): - for f in as_completed(futures): - try: - f.result() - except CancelledError: - pass - except Exception as e: # pylint: disable=broad-exception-caught - original_error = e - del futures + for f in as_completed(futures): + try: + f.result() + except CancelledError: + pass + except Exception as e: # pylint: disable=broad-exception-caught + original_error = e + del futures if original_error: raise original_error @@ -237,7 +234,7 @@ def run_workers_with_multiprocessing( ) return logger.info("Creating multiprocessing executor with %s workers", n_workers) - termination_cb, worker_runners = _get_mp_async_runner( + worker_runners = _get_mp_async_runner( app, config, n_workers, @@ -254,14 +251,13 @@ def run_workers_with_multiprocessing( f.add_done_callback(functools.partial(_cancel_other_callback, others=futures)) logger.info("started %s workers for app %s", n_workers, app) original_error = None - with _handle_executor_termination(termination_cb, futures, True): - for f in as_completed(futures): - try: - f.result() - except CancelledError: - pass - except Exception as e: # pylint: disable=broad-exception-caught - original_error = e - del futures + for f in as_completed(futures): + try: + f.result() + except CancelledError: + pass + except Exception as e: # pylint: disable=broad-exception-caught + original_error = e + del futures if original_error: raise original_error