Skip to content

Commit

Permalink
Fixed cyclic garbage that keeps traceback frames alive in task group …
Browse files Browse the repository at this point in the history
…exceptions (#806)
  • Loading branch information
graingert authored Oct 13, 2024
1 parent 0614b4f commit 163f10c
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 78 deletions.
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Fixed an async fixture's ``self`` being different than the test's ``self`` in
class-based tests (`#633 <https://github.com/agronholm/anyio/issues/633>`_)
(PR by @agronholm and @graingert)
- Fixed TaskGroup and CancelScope producing cyclic references in tracebacks
when raising exceptions (`#806 <https://github.com/agronholm/anyio/pull/806>`_)
(PR by @graingert)

**4.6.0**

Expand Down
175 changes: 102 additions & 73 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
del exc_tb

if not self._active:
raise RuntimeError("This cancel scope is not active")
if current_task() is not self._host_task:
Expand All @@ -441,42 +443,46 @@ def __exit__(
"current cancel scope"
)

self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None

self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
break
try:
self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None

# We only swallow the exception iff it was an AnyIO CancelledError, either
# directly as exc_val or inside an exception group and there are no cancelled
# parent cancel scopes visible to us here
not_swallowed_exceptions = 0
swallow_exception = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if self._cancel_called and isinstance(exc, CancelledError):
if not (swallow_exception := self._uncancel(exc)):
self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)

host_task_state.cancel_scope = self._parent_scope

# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
break

# We only swallow the exception iff it was an AnyIO CancelledError, either
# directly as exc_val or inside an exception group and there are no cancelled
# parent cancel scopes visible to us here
not_swallowed_exceptions = 0
swallow_exception = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if self._cancel_called and isinstance(exc, CancelledError):
if not (swallow_exception := self._uncancel(exc)):
not_swallowed_exceptions += 1
else:
not_swallowed_exceptions += 1
else:
not_swallowed_exceptions += 1

# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()
return swallow_exception and not not_swallowed_exceptions
# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()
return swallow_exception and not not_swallowed_exceptions
finally:
self._host_task = None
del exc_val

@property
def _effectively_cancelled(self) -> bool:
Expand Down Expand Up @@ -683,6 +689,26 @@ def started(self, value: T_contra | None = None) -> None:
_task_states[task].parent_id = self._parent_id


async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None:
tasks = set(tasks)
waiter = get_running_loop().create_future()

def on_completion(task: asyncio.Task[object]) -> None:
tasks.discard(task)
if not tasks and not waiter.done():
waiter.set_result(None)

for task in tasks:
task.add_done_callback(on_completion)
del task

try:
await waiter
finally:
while tasks:
tasks.pop().remove_done_callback(on_completion)


class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
Expand All @@ -701,50 +727,53 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)

try:
if self._tasks:
with CancelScope() as wait_scope:
while self._tasks:
try:
await asyncio.wait(self._tasks)
except CancelledError as exc:
# Shield the scope against further cancellation attempts,
# as they're not productive (#695)
wait_scope.shield = True
self.cancel_scope.cancel()

# Set exc_val from the cancellation exception if it was
# previously unset. However, we should not replace a native
# cancellation exception with one raise by a cancel scope.
if exc_val is None or (
isinstance(exc_val, CancelledError)
and not is_anyio_cancellation(exc)
):
exc_val = exc
else:
# If there are no child tasks to wait on, run at least one checkpoint
# anyway
await AsyncIOBackend.cancel_shielded_checkpoint()
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)

self._active = False
if self._exceptions:
raise BaseExceptionGroup(
"unhandled errors in a TaskGroup", self._exceptions
)
elif exc_val:
raise exc_val
except BaseException as exc:
if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
return True
try:
if self._tasks:
with CancelScope() as wait_scope:
while self._tasks:
try:
await _wait(self._tasks)
except CancelledError as exc:
# Shield the scope against further cancellation attempts,
# as they're not productive (#695)
wait_scope.shield = True
self.cancel_scope.cancel()

# Set exc_val from the cancellation exception if it was
# previously unset. However, we should not replace a native
# cancellation exception with one raise by a cancel scope.
if exc_val is None or (
isinstance(exc_val, CancelledError)
and not is_anyio_cancellation(exc)
):
exc_val = exc
else:
# If there are no child tasks to wait on, run at least one checkpoint
# anyway
await AsyncIOBackend.cancel_shielded_checkpoint()

raise
self._active = False
if self._exceptions:
raise BaseExceptionGroup(
"unhandled errors in a TaskGroup", self._exceptions
)
elif exc_val:
raise exc_val
except BaseException as exc:
if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
return True

raise

return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
finally:
del exc_val, exc_tb, self._exceptions

def _spawn(
self,
Expand Down
7 changes: 3 additions & 4 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,12 @@ async def __aexit__(
try:
return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb)
except BaseExceptionGroup as exc:
_, rest = exc.split(trio.Cancelled)
if not rest:
cancelled_exc = trio.Cancelled._create()
raise cancelled_exc from exc
if not exc.split(trio.Cancelled)[1]:
raise trio.Cancelled._create() from exc

raise
finally:
del exc_val, exc_tb
self._active = False

def start_soon(
Expand Down
122 changes: 121 additions & 1 deletion tests/test_taskgroups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import math
import sys
import time
Expand All @@ -9,7 +10,7 @@
from typing import Any, NoReturn, cast

import pytest
from exceptiongroup import catch
from exceptiongroup import ExceptionGroup, catch
from pytest_mock import MockerFixture

import anyio
Expand Down Expand Up @@ -1548,6 +1549,125 @@ async def in_task_group(task_status: TaskStatus[None]) -> None:
assert not tg.cancel_scope.cancel_called


if sys.version_info <= (3, 11):

def no_other_refs() -> list[object]:
return [sys._getframe(1)]
else:

def no_other_refs() -> list[object]:
return []


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy see "
"https://github.com/pypy/pypy/issues/5075"
),
)
class TestRefcycles:
async def test_exception_refcycles_direct(self) -> None:
"""
Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup
Note: This test never failed on anyio, but keeping this test to align
with the tests from cpython.
"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

try:
async with tg:
raise _Done
except ExceptionGroup as e:
exc = e

assert exc is not None
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_errors(self) -> None:
"""Test that TaskGroup deletes self._exceptions, and __aexit__ args"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

try:
async with tg:
raise _Done
except ExceptionGroup as excs:
exc = excs.exceptions[0]

assert isinstance(exc, _Done)
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_parent_task(self) -> None:
"""Test that TaskGroup's cancel_scope deletes self._host_task"""
tg = create_task_group()
exc = None

class _Done(Exception):
pass

async def coro_fn() -> None:
async with tg:
raise _Done

try:
async with anyio.create_task_group() as tg2:
tg2.start_soon(coro_fn)
except ExceptionGroup as excs:
exc = excs.exceptions[0].exceptions[0]

assert isinstance(exc, _Done)
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_propagate_cancellation_error(self) -> None:
"""Test that TaskGroup deletes cancelled_exc"""
tg = anyio.create_task_group()
exc = None

with CancelScope() as cs:
cs.cancel()
try:
async with tg:
await checkpoint()
except get_cancelled_exc_class() as e:
exc = e
raise

assert isinstance(exc, get_cancelled_exc_class())
assert gc.get_referrers(exc) == no_other_refs()

async def test_exception_refcycles_base_error(self) -> None:
"""
Test for BaseExceptions.
anyio doesn't treat these differently so this test is redundant
but copied from CPython's asyncio.TaskGroup tests for completion.
"""

class MyKeyboardInterrupt(KeyboardInterrupt):
pass

tg = create_task_group()
exc = None

try:
async with tg:
raise MyKeyboardInterrupt
except BaseExceptionGroup as excs:
exc = excs.exceptions[0]

assert isinstance(exc, MyKeyboardInterrupt)
assert gc.get_referrers(exc) == no_other_refs()


class TestTaskStatusTyping:
"""
These tests do not do anything at run time, but since the test suite is also checked
Expand Down

0 comments on commit 163f10c

Please sign in to comment.