diff --git a/asgiref/local.py b/asgiref/local.py index a8b9459b..784c5709 100644 --- a/asgiref/local.py +++ b/asgiref/local.py @@ -24,9 +24,8 @@ def __setattr__(self, key: str, value: Any) -> None: if key == "_data": return super().__setattr__(key, value) - storage_object = self._data.get({}) - storage_object[key] = value - self._data.set(storage_object) + # Update a copy of the existing storage. + self._data.set({**self._data.get({}), key: value}) def __delattr__(self, key: str) -> None: storage_object = self._data.get({}) diff --git a/tests/test_local.py b/tests/test_local.py index d50cba21..32cd6cd3 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -1,9 +1,18 @@ +# type: ignore + import asyncio import gc import threading +import unittest +from collections.abc import Callable +from contextvars import ContextVar +from threading import Thread +from typing import Any +from unittest import TestCase import pytest +import asgiref from asgiref.local import Local from asgiref.sync import async_to_sync, sync_to_async @@ -338,3 +347,124 @@ async def async_function(): # inner value was set inside a new async context, meaning that # we do not see it, as context vars don't propagate up the stack assert not hasattr(test_local_not_tc, "test_value") + + +from packaging.version import Version + +is_new_asgiref = Version(asgiref.__version__) >= Version("3.7") + + +def run_in_thread(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Run a callable in a thread.""" + result: Any = None + + def _thread_main() -> None: + nonlocal result + result = func(*args, **kwargs) + + thread = Thread(target=_thread_main) + thread.start() + thread.join() + + return result + + +def run_in_task(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Run a callable in an asyncio task.""" + + async def _task_main(): + return func(*args, **kwargs) + + async def _async_main(): + return await asyncio.create_task( + _task_main(), + ) + + return asyncio.run(_async_main()) + + +local = Local() +cvar: ContextVar[int] = ContextVar("cvar") + + +class Test(TestCase): + """Test Local visibility behaviour.""" + + def setUp(self): + setattr(local, "value", 0) + cvar.set(0) + + def test_visibility_thread_asgiref(self) -> None: + """Check visibility with subthreads.""" + self.assertEqual(local.value, 0) + + def _test(): + # Local() is cleared when changing thread + self.assertFalse(hasattr(local, "value")) + setattr(local, "value", 1) + self.assertEqual(local.value, 1) + + run_in_thread(_test) + + self.assertEqual(local.value, 0) + + def test_visibility_thread_contextvar(self) -> None: + """Check visibility with subthreads.""" + self.assertEqual(cvar.get(), 0) + + def _test(): + # ContextVar is cleared when changing thread + with self.assertRaises(LookupError): + cvar.get() + cvar.set(1) + self.assertEqual(cvar.get(), 1) + + run_in_thread(_test) + + self.assertEqual(cvar.get(), 0) + + @unittest.skipIf(is_new_asgiref, "test for old asgiref") + def test_visibility_task_asgiref_pre_37(self) -> None: + """Check visibility with asyncio tasks.""" + self.assertEqual(local.value, 0) + + def _test(): + # Local is cleared on pre-3.7 when changing task + self.assertFalse(hasattr(local, "value")) + setattr(local, "value", 1) + self.assertEqual(local.value, 1) + + run_in_task(_test) + + self.assertEqual(local.value, 0) + + @unittest.skipIf(not is_new_asgiref, "test for new asgiref") + def test_visibility_task_asgiref_post_37(self) -> None: + """Check visibility with asyncio tasks.""" + self.assertEqual(local.value, 0) + + def _test(): + # Local is inherited on 3.7+ when changing task + self.assertEqual(local.value, 0) + local.value = 1 + self.assertEqual(local.value, 1) + + run_in_task(_test) + + # Changes leak to the caller, and probably should not + self.assertEqual(local.value, 0) + self.assertEqual(getattr(local, "value"), 0) + + def test_visibility_task_contextvar(self) -> None: + """Check visibility with subthreads.""" + self.assertEqual(cvar.get(), 0) + + def _test(): + # ContextVar is inherited when changing task + self.assertEqual(cvar.get(), 0) + cvar.set(1) + self.assertEqual(cvar.get(), 1) + + run_in_task(_test) + + self.assertEqual(cvar.get(), 0)