diff --git a/tests/conftest.py b/tests/conftest.py index a6fbc0876..6f042ac19 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -61,8 +61,10 @@ def configure_epics_environment(): _ALLOWED_PYTEST_TASKS = {"async_finalizer", "async_setup", "async_teardown"} + + def _error_and_kill_pending_tasks( - loop: asyncio.AbstractEventLoop, test_name: str, test_failed: bool + loop: asyncio.AbstractEventLoop, test_name: str, test_passed: bool ) -> set[asyncio.Task]: """Cancels pending tasks in the event loop for a test. Raises an exception if the test hasn't already. @@ -70,7 +72,7 @@ def _error_and_kill_pending_tasks( Args: loop: The event loop to check for pending tasks. test_name: The name of the test. - test_failed: Indicates whether the test has failed. + test_passed: Indicates whether the test passed. Returns: set[asyncio.Task]: The set of unfinished tasks that were cancelled. @@ -79,7 +81,6 @@ def _error_and_kill_pending_tasks( RuntimeError: If there are unfinished tasks and the test didn't fail. """ unfinished_tasks = { - task for task in asyncio.all_tasks(loop) if task.get_coro().__name__ not in _ALLOWED_PYTEST_TASKS and not task.done() @@ -87,13 +88,13 @@ def _error_and_kill_pending_tasks( for task in unfinished_tasks: task.cancel() - # We only raise an exception here if the test didn't fail. + # We only raise an exception here if the test didn't fail anyway. # If it did then it makes sense that there's some tasks we need to cancel, # but an exception will already have been raised. - if unfinished_tasks and not test_failed: + if unfinished_tasks and test_passed: raise RuntimeError( f"Not all tasks closed during test {test_name}:\n" - f"{pprint.pformat(unfinished_tasks)}" + f"{pprint.pformat(unfinished_tasks, width=88)}" ) return unfinished_tasks @@ -112,7 +113,7 @@ def fail_test_on_unclosed_tasks(request: FixtureRequest): request.addfinalizer( lambda: _error_and_kill_pending_tasks( - loop, request.node.name, request.session.testsfailed != fail_count + loop, request.node.name, request.session.testsfailed == fail_count ) ) @@ -131,20 +132,16 @@ def clean_event_loop(): except TransitionError: pass - loop.call_soon_threadsafe(loop.stop) RE._th.join() try: _error_and_kill_pending_tasks( - loop, - request.node.name, - request.session.testsfailed != fail_count + loop, request.node.name, request.session.testsfailed == fail_count ) finally: loop.close() - request.addfinalizer(clean_event_loop) return RE