diff --git a/ariadne_django_ext/decorators.py b/ariadne_django_ext/decorators.py index 9f131ee..2272ec1 100644 --- a/ariadne_django_ext/decorators.py +++ b/ariadne_django_ext/decorators.py @@ -9,16 +9,14 @@ def allow_basic_auth(view_func): @wraps(view_func) def wrapper(request, *args, **kwargs): - if not is_authenticated(request): + if getattr(request, "user", None) is None: http_auth = request.META.get("HTTP_AUTHORIZATION") if http_auth and http_auth.startswith("Basic"): try: _, token = http_auth.split() username, password = b64decode(token).decode().split(":") - user = authenticate( - request=request, username=username, password=password - ) - if user and user.is_active: + user = authenticate(request, username=username, password=password) + if user: request.user = user except Exception: pass @@ -30,7 +28,7 @@ def wrapper(request, *args, **kwargs): def login_required(view_func): @wraps(view_func) def wrapper(request, *args, **kwargs): - if is_authenticated(request, is_active=True, raise_exception=True): + if is_authenticated(request): return view_func(request, *args, **kwargs) return wrapper diff --git a/ariadne_django_ext/utils.py b/ariadne_django_ext/utils.py index 62befd8..618748a 100644 --- a/ariadne_django_ext/utils.py +++ b/ariadne_django_ext/utils.py @@ -2,7 +2,7 @@ def is_authenticated(request, is_active=True, raise_exception=True): - user = getattr(request, "user") + user = getattr(request, "user", None) if user and user.is_authenticated and (not is_active or user.is_active): return user if raise_exception: diff --git a/tests/conftest.py b/tests/conftest.py index 36b35c7..9e2d318 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import pytest +password = "something" + @pytest.fixture( ids=["anonymous", "not-active", "authenticated"], @@ -7,9 +9,12 @@ ) def user(request, django_user_model): if request.param is not None: - return django_user_model.objects.create( - username="someone", password="something", is_active=request.param + user = django_user_model.objects.create( + username="someone", is_active=request.param ) + user.set_password(password) + user.save(update_fields=["password"]) + return user @pytest.fixture diff --git a/tests/settings.py b/tests/settings.py index 14c4ed1..9126471 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,4 +1,4 @@ -SECRET_KEY = "asda" +SECRET_KEY = "SECRET_KEY" INSTALLED_APPS = ["django.contrib.auth", "django.contrib.contenttypes"] DATABASES = { "default": { @@ -6,3 +6,4 @@ "NAME": "test", } } +USE_TZ = True diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 6634345..d64e632 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,14 +1,40 @@ +from base64 import b64encode + import pytest from django.core.exceptions import PermissionDenied -from ariadne_django_ext.decorators import login_required +from ariadne_django_ext import decorators + +from .conftest import password + + +def test_allow_basic_auth(user, rf, django_user_model): + return_value = "return_value" + + @decorators.allow_basic_auth + @decorators.login_required + def view(_): + return return_value + + request = rf.get("/") + if user: + request.META["HTTP_AUTHORIZATION"] = "Basic {}".format( + b64encode(f"{user.username}:{password}".encode()).decode() + ) + + assert getattr(request, "user", None) is None + if user and user.is_active: + assert view(request) == return_value + else: + with pytest.raises(PermissionDenied): + view(request) def test_login_required(user_request): return_value = "return_value" - @login_required - def view(request): + @decorators.login_required + def view(_): return return_value user = getattr(user_request, "user", None) @@ -17,3 +43,11 @@ def view(request): else: with pytest.raises(PermissionDenied): view(user_request) + + +def test_wrap_result(): + @decorators.wrap_result(key="key") + def resolver(): + return "result" + + assert resolver() == {"key": "result"} diff --git a/tests/test_wrap_result.py b/tests/test_wrap_result.py deleted file mode 100644 index d4efd13..0000000 --- a/tests/test_wrap_result.py +++ /dev/null @@ -1,9 +0,0 @@ -from ariadne_django_ext import wrap_result - - -def test_wrap_result(): - @wrap_result(key="key") - def resolver(): - return "result" - - assert resolver() == {"key": "result"}