Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Local usage in asyncio Tasks. #477

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions asgiref/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def __setattr__(self, key: str, value: Any) -> None:
if key == "_data":
return super().__setattr__(key, value)

storage_object = self._data.get({})
storage_object[key] = value
self._data.set(storage_object)
# Update a copy of the existing storage.
self._data.set({**self._data.get({}), key: value})

def __delattr__(self, key: str) -> None:
storage_object = self._data.get({})
Expand Down
130 changes: 130 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
# type: ignore

import asyncio
import gc
import threading
import unittest
from collections.abc import Callable
from contextvars import ContextVar
from threading import Thread
from typing import Any
from unittest import TestCase

import pytest

import asgiref
from asgiref.local import Local
from asgiref.sync import async_to_sync, sync_to_async

Expand Down Expand Up @@ -338,3 +347,124 @@ async def async_function():
# inner value was set inside a new async context, meaning that
# we do not see it, as context vars don't propagate up the stack
assert not hasattr(test_local_not_tc, "test_value")


from packaging.version import Version

is_new_asgiref = Version(asgiref.__version__) >= Version("3.7")


def run_in_thread(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Run a callable in a thread."""
result: Any = None

def _thread_main() -> None:
nonlocal result
result = func(*args, **kwargs)

thread = Thread(target=_thread_main)
thread.start()
thread.join()

return result


def run_in_task(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Run a callable in an asyncio task."""

async def _task_main():
return func(*args, **kwargs)

async def _async_main():
return await asyncio.create_task(
_task_main(),
)

return asyncio.run(_async_main())


local = Local()
cvar: ContextVar[int] = ContextVar("cvar")


class Test(TestCase):
"""Test Local visibility behaviour."""

def setUp(self):
setattr(local, "value", 0)
cvar.set(0)

def test_visibility_thread_asgiref(self) -> None:
"""Check visibility with subthreads."""
self.assertEqual(local.value, 0)

def _test():
# Local() is cleared when changing thread
self.assertFalse(hasattr(local, "value"))
setattr(local, "value", 1)
self.assertEqual(local.value, 1)

run_in_thread(_test)

self.assertEqual(local.value, 0)

def test_visibility_thread_contextvar(self) -> None:
"""Check visibility with subthreads."""
self.assertEqual(cvar.get(), 0)

def _test():
# ContextVar is cleared when changing thread
with self.assertRaises(LookupError):
cvar.get()
cvar.set(1)
self.assertEqual(cvar.get(), 1)

run_in_thread(_test)

self.assertEqual(cvar.get(), 0)

@unittest.skipIf(is_new_asgiref, "test for old asgiref")
def test_visibility_task_asgiref_pre_37(self) -> None:
"""Check visibility with asyncio tasks."""
self.assertEqual(local.value, 0)

def _test():
# Local is cleared on pre-3.7 when changing task
self.assertFalse(hasattr(local, "value"))
setattr(local, "value", 1)
self.assertEqual(local.value, 1)

run_in_task(_test)

self.assertEqual(local.value, 0)

@unittest.skipIf(not is_new_asgiref, "test for new asgiref")
def test_visibility_task_asgiref_post_37(self) -> None:
"""Check visibility with asyncio tasks."""
self.assertEqual(local.value, 0)

def _test():
# Local is inherited on 3.7+ when changing task
self.assertEqual(local.value, 0)
local.value = 1
self.assertEqual(local.value, 1)

run_in_task(_test)

# Changes leak to the caller, and probably should not
self.assertEqual(local.value, 0)
self.assertEqual(getattr(local, "value"), 0)

def test_visibility_task_contextvar(self) -> None:
"""Check visibility with subthreads."""
self.assertEqual(cvar.get(), 0)

def _test():
# ContextVar is inherited when changing task
self.assertEqual(cvar.get(), 0)
cvar.set(1)
self.assertEqual(cvar.get(), 1)

run_in_task(_test)

self.assertEqual(cvar.get(), 0)
Loading