diff --git a/chia/util/task_referencer.py b/chia/util/task_referencer.py index 19e3a83159c6..ce95078bbc54 100644 --- a/chia/util/task_referencer.py +++ b/chia/util/task_referencer.py @@ -6,8 +6,6 @@ import asyncio import dataclasses import logging -import math -import time import typing T = typing.TypeVar("T") @@ -18,11 +16,11 @@ @dataclasses.dataclass(frozen=True) class _TaskInfo: task: asyncio.Task[object] - name: str + # retained for potential debugging use known_unreferenced: bool def __str__(self) -> str: - return self.name + return self.task.get_name() @dataclasses.dataclass @@ -32,12 +30,7 @@ class _TaskReferencer: task groups such as from anyio. """ - tasks: list[_TaskInfo] = dataclasses.field(default_factory=list) - clock: typing.Callable[[], float] = time.monotonic - last_cull_time: float = -math.inf - last_cull_length: int = 0 - cull_period: float = 30 - cull_count: int = 1000 + tasks: dict[asyncio.Task[object], _TaskInfo] = dataclasses.field(default_factory=dict) def create_task( self, @@ -46,30 +39,19 @@ def create_task( name: typing.Optional[str] = None, known_unreferenced: bool = False, ) -> asyncio.Task[T]: - self.maybe_cull() - task = asyncio.create_task(coro=coroutine, name=name) # noqa: TID251 - self.tasks.append( - _TaskInfo( - task=task, - name=task.get_name(), - known_unreferenced=known_unreferenced, - ) - ) - - return task + task.add_done_callback(self._task_done) - def maybe_cull(self) -> None: - now = self.clock() - since_last = now - self.last_cull_time + self.tasks[task] = _TaskInfo(task=task, known_unreferenced=known_unreferenced) - if len(self.tasks) <= self.last_cull_length + self.cull_count and since_last <= self.cull_period: - return + return task + def _task_done(self, task: asyncio.Task[object]) -> None: # TODO: consider collecting results and logging errors - self.tasks = [task_info for task_info in self.tasks if not task_info.task.done()] - self.last_cull_time = now - self.last_cull_length = len(self.tasks) + try: + del self.tasks[task] + except KeyError: + logger.warning("Task not found in task referencer: %s", task) _global_task_referencer = _TaskReferencer()