Skip to content

Commit

Permalink
added docstrings and cleaned up task handling method
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Sep 17, 2024
1 parent f8fba54 commit 6966334
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,39 +60,60 @@ def configure_epics_environment():
os.environ["EPICS_PVA_AUTO_ADDR_LIST"] = "NO"


_ALLOWED_CORO_TASKS = {"async_finalizer", "async_setup", "async_teardown"}
_ALLOWED_PYTEST_TASKS = {"async_finalizer", "async_setup", "async_teardown"}
def _error_and_kill_pending_tasks(
loop: asyncio.AbstractEventLoop, request: FixtureRequest, fail_count: int
):
loop: asyncio.AbstractEventLoop, test_name: str, test_failed: bool
) -> set[asyncio.Task]:
"""Cancels pending tasks in the event loop for a test. Raises an exception if
the test hasn't already.
Args:
loop: The event loop to check for pending tasks.
test_name: The name of the test.
test_failed: Indicates whether the test has failed.
Returns:
set[asyncio.Task]: The set of unfinished tasks that were cancelled.
Raises:
RuntimeError: If there are unfinished tasks and the test didn't fail.
"""
unfinished_tasks = {

task
for task in asyncio.all_tasks(loop)
if task.get_coro().__name__ not in _ALLOWED_CORO_TASKS and not task.done()
if task.get_coro().__name__ not in _ALLOWED_PYTEST_TASKS and not task.done()
}
for task in unfinished_tasks:
task.cancel()

if unfinished_tasks:
# We only raise an exception here if the test failed.
if request.session.testsfailed == fail_count:
raise RuntimeError(
f"Not all tasks closed during test {request.node.name}:\n"
f"{pprint.pformat(unfinished_tasks)}"
)
# We only raise an exception here if the test didn't fail.
# If it did then it makes sense that there's some tasks we need to cancel,
# but an exception will already have been raised.
if unfinished_tasks and not test_failed:
raise RuntimeError(
f"Not all tasks closed during test {test_name}:\n"
f"{pprint.pformat(unfinished_tasks)}"
)

return unfinished_tasks



@pytest.fixture(autouse=True, scope="function")
def fail_test_on_unclosed_tasks(request: FixtureRequest):
"""
Used on every test to ensure failure if there are pending tasks
by the end of the test.
"""

fail_count = request.session.testsfailed
loop = asyncio.get_event_loop()
loop.set_debug(True)

request.addfinalizer(
lambda: _error_and_kill_pending_tasks(loop, request, fail_count)
lambda: _error_and_kill_pending_tasks(
loop, request.node.name, request.session.testsfailed != fail_count
)
)


Expand All @@ -115,7 +136,11 @@ def clean_event_loop():
RE._th.join()

try:
_error_and_kill_pending_tasks(loop, request, fail_count)
_error_and_kill_pending_tasks(
loop,
request.node.name,
request.session.testsfailed != fail_count
)
finally:
loop.close()

Expand Down

0 comments on commit 6966334

Please sign in to comment.