Skip to content

Commit

Permalink
feat: auth required decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
alexogeny committed Aug 18, 2024
1 parent d4286af commit f406c96
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 1 deletion.
54 changes: 54 additions & 0 deletions src/zara/login/auth_required.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from functools import wraps
from http import HTTPStatus
from typing import Any, Awaitable, Callable, Dict, List, Optional

from zara.login.jwt import verify_jwt
from zara.types.http import Http


def auth_required(
permissions: Optional[List[str]] = None,
roles: Optional[List[str]] = None,
):
def decorator(func: Callable[..., Awaitable[None]]):
@wraps(func)
async def wrapper(scope: Dict[str, Any], receive: Callable, send: Callable):
headers = dict(scope.get("headers", []))
authorization = headers.get(b"authorization", b"").decode()

if not authorization.startswith("Bearer "):
await send(Http.Response.Start(status=HTTPStatus.FORBIDDEN))
await send(
Http.Response.Detail(
message="Authorization header missing or malformed."
)
)
return

token = authorization[7:]
jwt_payload = verify_jwt(token)
if jwt_payload is None:
await send(Http.Response.Start(status=HTTPStatus.FORBIDDEN))
await send(Http.Response.Detail(message="Invalid token"))
return

user_permissions = jwt_payload.get("permissions", [])
user_roles = jwt_payload.get("roles", [])

if permissions:
if not all(p in user_permissions for p in permissions):
await send(Http.Response.Start(status=HTTPStatus.FORBIDDEN))
await send(Http.Response.Detail(message="Insufficient permissions"))
return

if roles:
if not any(r in user_roles for r in roles):
await send(Http.Response.Start(status=HTTPStatus.FORBIDDEN))
await send(Http.Response.Detail(message="Insufficient roles"))
return

await func(scope, receive, send)

return wrapper

return decorator
12 changes: 11 additions & 1 deletion src/zara/server/router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import http
from typing import Any, Awaitable, Callable, Dict, Self, Union

from zara.login.auth_required import auth_required
from zara.types.asgi import ASGI, CallableAwaitable
from zara.utils import camel_to_snake

Expand Down Expand Up @@ -54,8 +55,17 @@ def add_route(
route = Route(router, path, method, handler, **kwargs)
self.routes.append(route)

def route(self, path: str, method: str, **kwargs: dict[Any, Any]) -> WrappedRoute:
def route(
self,
path: str,
method: str,
permissions=None,
roles=None,
**kwargs: dict[Any, Any],
) -> WrappedRoute:
def decorator(func: CallableAwaitable) -> CallableAwaitable:
if permissions or roles:
func = auth_required(permissions=permissions, roles=roles)(func)
self.add_route(self, path, method, func, **kwargs)
return func

Expand Down
8 changes: 8 additions & 0 deletions src/zara/types/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

class Response:
def Start(status: int = HTTPStatus.OK, headers=[(b"content-type", b"text/plain")]):
if status in [HTTPStatus.FORBIDDEN]:
headers = [(b"content-type", b"application/json")]
return {
"type": "http.response.start",
"status": status,
Expand All @@ -20,6 +22,12 @@ def Body(content: bytes):
def Error(status: HTTPStatus):
return {"detail": status.phrase}

def Detail(message: str):
return {
"type": "http.response.body",
"body": b'{"detail": "' + message.encode("utf-8") + b'"}',
}


class Http:
Response = Response
Expand Down
Empty file.

0 comments on commit f406c96

Please sign in to comment.