Skip to content

Commit

Permalink
Feature/refactor auth flow (#60)
Browse files Browse the repository at this point in the history
Redesign auth flow to use AbandonAuth UI and redirects to create a secure auth flow
  • Loading branch information
fisher60 authored Nov 5, 2023
1 parent cd358ff commit 29c260d
Show file tree
Hide file tree
Showing 20 changed files with 499 additions and 119 deletions.
9 changes: 3 additions & 6 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@ JWT_HASHING_ALGO=HS512
JWT_EXPIRES_IN_SECONDS_LONG_LIVED=2592000
JWT_EXPIRES_IN_SECONDS_SHORT_LIVED=120


ABANDON_AUTH_DISCORD_REDIRECT='https://discord.com/api/oauth2/authorize?client_id=<your_discord_client_id>&redirect_uri=http%3A%2F%2Flocalhost%3A8001/ui/discord-callback&response_type=code&scope=identify'
ABANDON_AUTH_DISCORD_CALLBACK='http://localhost:8001/ui/discord-callback'
DISCORD_CLIENT_ID=
ABANDON_AUTH_DISCORD_REDIRECT=
ABANDON_AUTH_DISCORD_CALLBACK='http://localhost:8000/ui/discord-callback'
ABANDON_AUTH_DEVELOPER_APP_ID=
DISCORD_CLIENT_SECRET=
DISCORD_CALLBACK=http://localhost:8000/discord


GITHUB_CLIENT_ID=
GITHUB_CLIENT_SECRET=
Expand Down
3 changes: 3 additions & 0 deletions abandonauth/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from prisma import Prisma

prisma_db = Prisma(auto_register=True)
131 changes: 95 additions & 36 deletions abandonauth/dependencies/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
from fastapi import HTTPException, Request
from fastapi.security import HTTPBearer
from jose import JWTError, jwt
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing import Any

from abandonauth.models.auth import JwtClaimsDataDto, ScopeEnum, LifespanEnum
from abandonauth.settings import settings

# Cache of all valid issued tokens. Tokens should be removed after their first use
valid_token_cache = set()

IGNORE_AUD_DECODE_OPTIONS = {"verify_aud": False}

def _generate_jwt(user_id: str, long_lived: bool = False) -> str:

def _generate_jwt(user_id: str, application_id_aud: str, long_lived: bool = False) -> str:
"""Generate an AbandonAuth long-lived or short-lived JWT for the given user.
Creates a JWT containing the user ID and expiration of the token.
application_id_aud is the verified developer application ID that will be consuming the token
long-lived = True should be used for user login sessions (i.e. website user or internal application login).
long-lived = False should be used for any token exchange (i.e. Discord OAuth login).
"""
Expand All @@ -26,12 +30,21 @@ def _generate_jwt(user_id: str, long_lived: bool = False) -> str:

expiration = datetime.now(timezone.utc) + timedelta(seconds=exp_seconds)

if long_lived and application_id_aud == settings.ABANDON_AUTH_DEVELOPER_APP_ID:
scope = " ".join((ScopeEnum.abandonauth, ScopeEnum.identify))
else:
scope = ScopeEnum.identify

claims = JwtClaimsDataDto(
user_id=user_id,
exp=expiration,
scope=scope,
aud=application_id_aud,
lifespan=LifespanEnum.long if long_lived else LifespanEnum.short
)

token = jwt.encode(
claims={
"user_id": user_id,
"exp": expiration,
"lifespan": "long" if long_lived else "short"
},
claims=dict(claims),
key=settings.JWT_SECRET.get_secret_value(),
algorithm=settings.JWT_HASHING_ALGO
)
Expand All @@ -41,32 +54,85 @@ def _generate_jwt(user_id: str, long_lived: bool = False) -> str:
return token


def generate_long_lived_jwt(user_id: str) -> str:
return _generate_jwt(user_id, long_lived=True)


def generate_short_lived_jwt(user_id: str) -> str:
def decode_jwt(
token: str,
aud: str | None = None,
required_scope: ScopeEnum = ScopeEnum.abandonauth
) -> JwtClaimsDataDto:
try:
if aud:
decode_kwargs = {
"audience": aud,
}
else:
decode_kwargs = {"options": IGNORE_AUD_DECODE_OPTIONS}

token_data = jwt.decode(
token,
settings.JWT_SECRET.get_secret_value(),
**decode_kwargs
)
except JWTError:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid token format"
)

if token_data["exp"] < datetime.utcnow().timestamp():
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Token has expired"
)

if required_scope != ScopeEnum.none and required_scope not in token_data["scope"]:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="JWT lacks the required scope to access this endpoint."
)

# If token is short-lived/exchange token check if it currently exists in the token cache
# This means short-lived tokens will only work with a single worker
# This is a hack and a future version will resolve this https://github.com/AbandonTech/abandonauth/issues/12
if token_data["lifespan"] == "short" and token not in valid_token_cache:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Token is not valid.")

return JwtClaimsDataDto(**dict(token_data))


def generate_long_lived_jwt(user_id: str, application_id_aud: str) -> str:
return _generate_jwt(user_id, application_id_aud, long_lived=True)


def generate_short_lived_jwt(user_id: str, application_id_aud: str) -> str:
"""Create a JWT token using the given user ID."""
return _generate_jwt(user_id, long_lived=False)
return _generate_jwt(user_id, application_id_aud, long_lived=False)


class JWTBearer(HTTPBearer):
"""Dependency for routes to enforce JWT auth."""

def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
scope: ScopeEnum = ScopeEnum.abandonauth,
aud: str | None = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)

self.token_data: dict[str, Any] | None = None
self.token_data: dict[str, Any]
self.aud = aud
self.required_scope = scope

async def __call__(self, request: Request) -> str:
async def __call__(self, request: Request) -> JwtClaimsDataDto:
"""
Retrieve user from a jwt token provided in headers.
If no token is present, a 403 will be raised
If the token is invalid, a 403 will be raised
If the token has expired, a 400 will be raised
If the token has expired, a 403 will be raised
"""
credentials = await super().__call__(request)

if credentials is None:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
Expand All @@ -75,24 +141,17 @@ async def __call__(self, request: Request) -> str:

credentials_string = credentials.credentials

try:
self.token_data = jwt.decode(credentials_string, settings.JWT_SECRET.get_secret_value())
except JWTError:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid token format"
)

if self.token_data["exp"] < datetime.utcnow().timestamp():
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Token has expired"
)
return decode_jwt(credentials_string, self.aud, self.required_scope)

# If token is short-lived/exchange token check if it currently exists in the token cache
# This means short-lived tokens will only work with a single worker
# This is a hack and a future version will resolve this https://github.com/AbandonTech/abandonauth/issues/12
if self.token_data["lifespan"] == "short" and credentials_string not in valid_token_cache:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Token is not valid.")

return self.token_data["user_id"]
class DeveloperAppJwtBearer(JWTBearer):
"""JWTBearer class for authorizing developer application tokens"""
def __init__(
self,
**kwargs: Any
) -> None:
super().__init__(
scope=ScopeEnum.abandonauth,
aud=settings.ABANDON_AUTH_DEVELOPER_APP_ID,
**kwargs
)
14 changes: 7 additions & 7 deletions abandonauth/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

from abandonauth.routers import routers
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from prisma import Prisma

from abandonauth.database import prisma_db
from abandonauth.routers import routers


app = FastAPI(
title="AbandonAuth",
Expand All @@ -15,17 +17,15 @@
for router in routers:
app.include_router(router)

prisma = Prisma(auto_register=True)


@app.on_event("startup")
async def startup() -> None:
"""On startup connect prisma to the database."""
await prisma.connect()
await prisma_db.connect()


@app.on_event("shutdown")
async def shutdown() -> None:
"""On shutdown disconnect prisma from the database."""
if prisma.is_connected():
await prisma.disconnect()
if prisma_db.is_connected():
await prisma_db.disconnect()
9 changes: 8 additions & 1 deletion abandonauth/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .developer_application import CreateDeveloperApplicationDto, DeveloperApplicationDto, LoginDeveloperApplicationDto
from .developer_application import (
CallbackUriDto,
CreateCallbackUriDto,
CreateDeveloperApplicationDto,
DeveloperApplicationDto,
DeveloperApplicationWithCallbackUriDto,
LoginDeveloperApplicationDto
)
from .discord import DiscordLoginDto
from .auth import JwtDto
from .user import UserDto
24 changes: 24 additions & 0 deletions abandonauth/models/auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
from datetime import datetime
from enum import Enum

from pydantic import BaseModel


class JwtDto(BaseModel):
"""Contains jwt token data to be sent to a client."""

token: str


class ScopeEnum(str, Enum):
identify = "identify"
abandonauth = "abandonauth"
none = None


class LifespanEnum(str, Enum):
long = "long"
short = "short"


class JwtClaimsDataDto(BaseModel):
"""All claim data for an Abandon Auth JWT"""

user_id: str
exp: datetime
scope: str
aud: str
lifespan: LifespanEnum
17 changes: 17 additions & 0 deletions abandonauth/models/developer_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ class DeveloperApplicationDto(BaseModel):
owner_id: str


class DeveloperApplicationWithCallbackUriDto(DeveloperApplicationDto):
"""Basic data for developer applications as well as the Callback URIs for the app"""

callback_uris: list[str]


class CreateDeveloperApplicationDto(DeveloperApplicationDto):
"""Basic info for a developer application as well as the refresh token."""

Expand All @@ -19,3 +25,14 @@ class LoginDeveloperApplicationDto(BaseModel):

id: str
refresh_token: str


class CreateCallbackUriDto(BaseModel):
"""Data for creating a callback URI"""
developer_application_id: str
uri: str


class CallbackUriDto(CreateCallbackUriDto):
"""All data that should be displayed to a user for a callback URI"""
id: int
Loading

0 comments on commit 29c260d

Please sign in to comment.