From 1be1476c5487f5bce39dc48f9f442993027c7a20 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 5 Nov 2024 07:07:56 -0800 Subject: [PATCH] Ensure all asyncio tasks are canceled before closing event loop --- .../_lib_async/continuous_ucx_progress.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py b/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py index 6507290d..02cde6fd 100644 --- a/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py +++ b/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py @@ -11,6 +11,15 @@ from ucxx._lib.libucxx import UCXWorker +def _cancel_task(event_loop, task): + if task is not None: + try: + task.cancel() + event_loop.run_until_complete(task) + except asyncio.exceptions.CancelledError: + pass + + class ProgressTask(object): def __init__(self, worker, event_loop): """Creates a task that keeps calling worker.progress() @@ -28,20 +37,20 @@ def __init__(self, worker, event_loop): """ self.worker = worker self.event_loop = event_loop - self.asyncio_task = None + self.asyncio_tasks = dict() event_loop_close_original = self.event_loop.close def _event_loop_close(event_loop_close_original, *args, **kwargs): - if not self.event_loop.is_closed() and self.asyncio_task is not None: - try: - self.asyncio_task.cancel() - self.event_loop.run_until_complete(self.asyncio_task) - except asyncio.exceptions.CancelledError: - pass - finally: - self.asyncio_task = None - event_loop_close_original(*args, **kwargs) + if self.event_loop.is_closed(): + return + + try: + for task in self.asyncio_tasks.values(): + _cancel_task(event_loop, task) + finally: + event_loop_close_original(*args, **kwargs) + self.asyncio_tasks = None self.event_loop.close = partial(_event_loop_close, event_loop_close_original) @@ -72,7 +81,7 @@ def __del__(self): class PollingMode(ProgressTask): def __init__(self, worker, event_loop): super().__init__(worker, event_loop) - self.asyncio_task = event_loop.create_task(self._progress_task()) + self.asyncio_tasks["progress"] = event_loop.create_task(self._progress_task()) self.worker.init_blocking_progress_mode() async def _progress_task(self): @@ -135,9 +144,11 @@ def __init__( weakref.finalize(self, self.rsock.close) self.armed = False - self.blocking_asyncio_task = self.event_loop.create_task(self._arm_worker()) + self.asyncio_tasks["arm"] = self.event_loop.create_task(self._arm_worker()) self.last_progress_time = time.monotonic() - self._progress_timeout - self.asyncio_task = event_loop.create_task(self._progress_with_timeout()) + self.asyncio_tasks["progress"] = event_loop.create_task( + self._progress_with_timeout() + ) def _fd_reader_callback(self): """Schedule new progress task upon worker event. @@ -198,10 +209,11 @@ async def _progress_with_timeout(self): # seem to respect timeout with `asyncio.wait_for`, thus we cancel # it here instead. It will get recreated after a new event on # `worker.epoll_file_descriptor`. - if self.blocking_asyncio_task is not None: - self.blocking_asyncio_task.cancel() + arm_task = self.asyncio_tasks["arm"] + if arm_task is not None: + arm_task.cancel() try: - await self.blocking_asyncio_task + await arm_task except asyncio.exceptions.CancelledError: pass