diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c78cdc5..dafbf63 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: strategy: fail-fast: true matrix: - python-version: ["3.6", "3.7", "3.8", "3.9"] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] steps: - name: Check out repository uses: actions/checkout@v2 diff --git a/README.md b/README.md index 6a058e7..1ac913a 100644 --- a/README.md +++ b/README.md @@ -39,9 +39,11 @@ pip install ariadne-django-ext ### cache -**cache** decorator will cache a result returned from resolver using Django cache framework. You can it accepts **timeout** and **version** parameters and passed down. +**cache** decorator will cache a result returned from resolver using Django cache framework. You can it accepts same keyword arguments and passed down to Django cache. -Cache key must be either str or callable. Callable will receive same argument as resolver then return cache key. +Cache key must be either value or callable. Callable will receive same arguments as resolver then return cache key. Callable may return **None** to bypass the cache. + +It uses **typename** and **key** from **info.path** as cache key prefix. ```python from ariadne_django_ext import cache diff --git a/ariadne_django_ext/cache.py b/ariadne_django_ext/cache.py index 451d7f4..285fad6 100644 --- a/ariadne_django_ext/cache.py +++ b/ariadne_django_ext/cache.py @@ -1,22 +1,39 @@ -from functools import wraps +from functools import partial, wraps from typing import Callable, Union from django.core.cache import cache as dj_cache +default = "ariadne-django-ext" +delimiter = "::" -def cache(key: Union[str, Callable], **djkwargs): + +def cache_keygen(info, key): + return "{}{delimiter}{}{delimiter}{}".format( + info.path.typename, + info.path.key, + delimiter.join(str(k) for k in key) + if isinstance(key, tuple) or isinstance(key, list) + else key, + delimiter=delimiter, + ) + + +def cache(key: Union[str, Callable], **cache_kwargs): def wrap_resolver(resolver: Callable): @wraps(resolver) - def wrapper(*args, **kwargs): - cache_key = key(*args, **kwargs) if callable(key) else key - cached = dj_cache.get(cache_key) - if cached is not None: - return cached - - result = resolver(*args, **kwargs) - if result is not None: - dj_cache.add(cache_key, result, **djkwargs) - return result + def wrapper(parent, info, **kwargs): + resolve = partial(resolver, parent, info, **kwargs) + cache_key = key(parent, info, **kwargs) if callable(key) else key + if cache_key is not None: + cache_key = cache_keygen(info, cache_key) + cached = dj_cache.get(cache_key, default, **cache_kwargs) + if cached != default: + return cached + + result = resolve() + dj_cache.add(cache_key, result, **cache_kwargs) + return result + return resolve() return wrapper diff --git a/ariadne_django_ext/directives.py b/ariadne_django_ext/directives.py index fac38e2..1650ba5 100644 --- a/ariadne_django_ext/directives.py +++ b/ariadne_django_ext/directives.py @@ -1,6 +1,6 @@ from ariadne import SchemaDirectiveVisitor -from graphql import default_field_resolver from django.core.exceptions import PermissionDenied +from graphql import default_field_resolver class IsAuthenticatedDirective(SchemaDirectiveVisitor): diff --git a/noxfile.py b/noxfile.py index 3a39e1b..543ae2a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -7,7 +7,7 @@ def lint(session): session.run("black", "--check", ".") session.run("flake8", ".") - session.run("isort", "-q", ".") + session.run("isort", "-c", "-q", ".") @nox.session diff --git a/tests/test_cache.py b/tests/test_cache.py index 8e985c2..ff8f433 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,58 +1,85 @@ from unittest.mock import Mock from django.core.cache import cache as dj_cache -from django.db import models -from ariadne_django_ext import cache +from ariadne_django_ext.cache import cache, cache_keygen, delimiter -class MyModel(models.Model): - name = models.CharField(max_length=10) +class GraphQLInfo: + def __init__(self, typename, key): + self.path = GraphQLInfoPath(typename, key) - class Meta: - app_label = "tests" +class GraphQLInfoPath: + def __init__(self, typename, key): + self.key = key + self.typename = typename -def test_cache_key_callable(): - return_value = "result" - resolver = Mock(return_value=return_value) - cached_resolver = cache(key=lambda parent, _: parent)(resolver) - for i in range(5): - assert cached_resolver(i, None) == return_value - assert cached_resolver(i, False) == return_value - assert cached_resolver(i, True) == return_value - assert dj_cache.get(i) == return_value +def test_cache_keygen(): + typename, key = "typename", "key" + prefix = "{}{delimiter}{}".format(typename, key, delimiter=delimiter) + info = GraphQLInfo(typename, key) - resolver.assert_called_once() - resolver.assert_called_with(i, None) - resolver.reset_mock() + assert cache_keygen(info, 1) == "{}{delimiter}{}".format( + prefix, str(1), delimiter=delimiter + ) + assert cache_keygen(info, "1") == "{}{delimiter}{}".format( + prefix, "1", delimiter=delimiter + ) + assert cache_keygen(info, True) == "{}{delimiter}{}".format( + prefix, str(True), delimiter=delimiter + ) + assert cache_keygen(info, (1, 2)) == "{}{delimiter}{}{delimiter}{}".format( + prefix, str(1), str(2), delimiter=delimiter + ) + assert cache_keygen(info, [1, 2]) == "{}{delimiter}{}{delimiter}{}".format( + prefix, str(1), str(2), delimiter=delimiter + ) -def test_cache_key_str(): - key = "key_str" +def test_cache_key_callable(): + dj_cache.clear() return_value = "result" - resolver = Mock(return_value=return_value) - cached_resolver = cache(key=key)(resolver) + info = GraphQLInfo("typename", "key") + for key_callable in ( + lambda parent, _: parent, + lambda *_: "key", + lambda *_: True, + lambda *_: (1, 2, 3), + lambda *_: [1, 2, 4], + ): + resolver = Mock(return_value=return_value) + cached_resolver = cache(key=key_callable)(resolver) + assert cached_resolver(0, info) == return_value + assert cached_resolver(0, info) == return_value + resolver.assert_called_once() - assert cached_resolver(None, None) == return_value - assert cached_resolver(None, None) == cached_resolver(True, False) - assert cached_resolver(True, False) == cached_resolver(False, True) - assert dj_cache.get(key) == return_value - resolver.assert_called_once() +def test_cache_key_none(): + dj_cache.clear() + return_value = "result" + info = GraphQLInfo("typename", "key") + resolver = Mock(return_value=return_value) + for cache_key in (None, lambda *_: None): + cached_resolver = cache(key=cache_key)(resolver) + for parent in (None, False, True): + assert cached_resolver(parent, info) == return_value + resolver.assert_called_with(parent, info) -def test_cache_result_none(): - # None value won't be cached - key = "result_none" - resolver = Mock(return_value=None) - cached_resolver = cache(key=key)(resolver) +def test_cache_key_value(): + dj_cache.clear() + info = GraphQLInfo("typename", "key") + return_value = "result" - assert cached_resolver(None, None) is None - assert cached_resolver(None, None) is None - assert dj_cache.get(key, True) is True + for cache_key in ("key", 1, True, False, (1, 2, 3), [1, 2, 4]): + resolver = Mock(return_value=return_value) + cached_resolver = cache(key=cache_key)(resolver) + assert cached_resolver(0, info) == return_value + assert cached_resolver(1, info) == return_value - resolver.assert_called() + resolver.assert_called_once() + resolver.assert_called_with(0, info) diff --git a/tests/test_directives.py b/tests/test_directives.py new file mode 100644 index 0000000..e69de29