Skip to content

Commit

Permalink
Fix Local behaviour with asyncio Task (#478)
Browse files Browse the repository at this point in the history
Redid CVar storage as a dict of contextvars
  • Loading branch information
spanezz authored Oct 11, 2024
1 parent 8e39bcc commit 85d2445
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 14 deletions.
29 changes: 15 additions & 14 deletions asgiref/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,38 @@
import contextlib
import contextvars
import threading
from typing import Any, Dict, Union
from typing import Any, Union


class _CVar:
"""Storage utility for Local."""

def __init__(self) -> None:
self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar(
"asgiref.local"
)
self._data: dict[str, contextvars.ContextVar[Any]] = {}

def __getattr__(self, key):
storage_object = self._data.get({})
def __getattr__(self, key: str) -> Any:
try:
return storage_object[key]
var = self._data[key]
except KeyError:
raise AttributeError(f"{self!r} object has no attribute {key!r}")

try:
return var.get()
except LookupError:
raise AttributeError(f"{self!r} object has no attribute {key!r}")

def __setattr__(self, key: str, value: Any) -> None:
if key == "_data":
return super().__setattr__(key, value)

storage_object = self._data.get({})
storage_object[key] = value
self._data.set(storage_object)
var = self._data.get(key)
if var is None:
self._data[key] = var = contextvars.ContextVar(key)
var.set(value)

def __delattr__(self, key: str) -> None:
storage_object = self._data.get({})
if key in storage_object:
del storage_object[key]
self._data.set(storage_object)
if key in self._data:
del self._data[key]
else:
raise AttributeError(f"{self!r} object has no attribute {key!r}")

Expand Down
37 changes: 37 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import gc
import threading
from threading import Thread

import pytest

Expand Down Expand Up @@ -338,3 +339,39 @@ async def async_function():
# inner value was set inside a new async context, meaning that
# we do not see it, as context vars don't propagate up the stack
assert not hasattr(test_local_not_tc, "test_value")


def test_visibility_thread_asgiref() -> None:
"""Check visibility with subthreads."""
test_local = Local()
test_local.value = 0

def _test() -> None:
# Local() is cleared when changing thread
assert not hasattr(test_local, "value")
setattr(test_local, "value", 1)
assert test_local.value == 1

thread = Thread(target=_test)
thread.start()
thread.join()

assert test_local.value == 0


@pytest.mark.asyncio
async def test_visibility_task() -> None:
"""Check visibility with asyncio tasks."""
test_local = Local()
test_local.value = 0

async def _test() -> None:
# Local is inherited when changing task
assert test_local.value == 0
test_local.value = 1
assert test_local.value == 1

await asyncio.create_task(_test())

# Changes should not leak to the caller
assert test_local.value == 0

0 comments on commit 85d2445

Please sign in to comment.