From 4fdbd580c5c290bd53a085ace572501c40183c10 Mon Sep 17 00:00:00 2001 From: ttys0dev <126845556+ttys0dev@users.noreply.github.com> Date: Fri, 12 Jan 2024 17:17:43 -0700 Subject: [PATCH] Test asyncio.shield with sync and async middlware --- tests/test_sync.py | 101 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/tests/test_sync.py b/tests/test_sync.py index 7391c08b..5e58ae25 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -849,3 +849,104 @@ def sync_task(): await sync_to_async(sync_middleware)() assert task_executed + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="deadlocks") +async def test_inner_shield_sync_middleware(): + """ + Test that asyncio.shield works when using sync middleware. + """ + + def sync_middleware(): + async_to_sync(async_view)() + + task_complete = False + task_cancel_caught = False + task_blocker = asyncio.Future() + + async def async_view(): + nonlocal task_complete, task_cancel_caught, task_blocker + task = asyncio.create_task(async_task()) + try: + await asyncio.shield(task) + except asyncio.CancelledError: + task_cancel_caught = True + task_blocker.set_result(True) + await task + task_complete = True + + task_executed = False + task_started_future = asyncio.Future() + + async def async_task(): + nonlocal task_started_future, task_executed, task_blocker + task_started_future.set_result(True) + await task_blocker + task_executed = True + + task_cancel_propagated = False + + async with ThreadSensitiveContext(): + task = asyncio.create_task(sync_to_async(sync_middleware)()) + await task_started_future + task.cancel() + try: + await task + except asyncio.CancelledError: + task_cancel_propagated = True + assert not task_cancel_propagated + assert task_cancel_caught + assert task_complete + + assert task_executed + + +@pytest.mark.asyncio +async def test_inner_shield_async_middleware(): + """ + Test that asyncio.shield works when using async middleware. + """ + + async def async_middleware(): + await async_view() + + task_complete = False + task_cancel_caught = False + task_blocker = asyncio.Future() + + async def async_view(): + nonlocal task_complete, task_cancel_caught, task_blocker + task = asyncio.create_task(async_task()) + try: + await asyncio.shield(task) + except asyncio.CancelledError: + task_cancel_caught = True + task_blocker.set_result(True) + await task + task_complete = True + + task_executed = False + task_started_future = asyncio.Future() + + async def async_task(): + nonlocal task_started_future, task_executed, task_blocker + task_started_future.set_result(True) + await task_blocker + task_executed = True + + task_cancel_propagated = False + + async with ThreadSensitiveContext(): + task = asyncio.create_task(async_middleware()) + await task_started_future + task.cancel() + try: + await task + except asyncio.CancelledError: + task_cancel_propagated = True + assert not task_cancel_propagated + assert task_cancel_caught + assert task_complete + + assert task_executed