Skip to content

Commit

Permalink
Ensure all asyncio tasks are canceled before closing event loop
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Nov 5, 2024
1 parent d719ce3 commit 1be1476
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
from ucxx._lib.libucxx import UCXWorker


def _cancel_task(event_loop, task):
if task is not None:
try:
task.cancel()
event_loop.run_until_complete(task)
except asyncio.exceptions.CancelledError:
pass


class ProgressTask(object):
def __init__(self, worker, event_loop):
"""Creates a task that keeps calling worker.progress()
Expand All @@ -28,20 +37,20 @@ def __init__(self, worker, event_loop):
"""
self.worker = worker
self.event_loop = event_loop
self.asyncio_task = None
self.asyncio_tasks = dict()

event_loop_close_original = self.event_loop.close

def _event_loop_close(event_loop_close_original, *args, **kwargs):
if not self.event_loop.is_closed() and self.asyncio_task is not None:
try:
self.asyncio_task.cancel()
self.event_loop.run_until_complete(self.asyncio_task)
except asyncio.exceptions.CancelledError:
pass
finally:
self.asyncio_task = None
event_loop_close_original(*args, **kwargs)
if self.event_loop.is_closed():
return

try:
for task in self.asyncio_tasks.values():
_cancel_task(event_loop, task)
finally:
event_loop_close_original(*args, **kwargs)
self.asyncio_tasks = None

self.event_loop.close = partial(_event_loop_close, event_loop_close_original)

Expand Down Expand Up @@ -72,7 +81,7 @@ def __del__(self):
class PollingMode(ProgressTask):
def __init__(self, worker, event_loop):
super().__init__(worker, event_loop)
self.asyncio_task = event_loop.create_task(self._progress_task())
self.asyncio_tasks["progress"] = event_loop.create_task(self._progress_task())
self.worker.init_blocking_progress_mode()

async def _progress_task(self):
Expand Down Expand Up @@ -135,9 +144,11 @@ def __init__(
weakref.finalize(self, self.rsock.close)

self.armed = False
self.blocking_asyncio_task = self.event_loop.create_task(self._arm_worker())
self.asyncio_tasks["arm"] = self.event_loop.create_task(self._arm_worker())
self.last_progress_time = time.monotonic() - self._progress_timeout
self.asyncio_task = event_loop.create_task(self._progress_with_timeout())
self.asyncio_tasks["progress"] = event_loop.create_task(
self._progress_with_timeout()
)

def _fd_reader_callback(self):
"""Schedule new progress task upon worker event.
Expand Down Expand Up @@ -198,10 +209,11 @@ async def _progress_with_timeout(self):
# seem to respect timeout with `asyncio.wait_for`, thus we cancel
# it here instead. It will get recreated after a new event on
# `worker.epoll_file_descriptor`.
if self.blocking_asyncio_task is not None:
self.blocking_asyncio_task.cancel()
arm_task = self.asyncio_tasks["arm"]
if arm_task is not None:
arm_task.cancel()
try:
await self.blocking_asyncio_task
await arm_task
except asyncio.exceptions.CancelledError:
pass

Expand Down

0 comments on commit 1be1476

Please sign in to comment.