Skip to content

Commit

Permalink
Merge pull request #7 from sinasezza/sinasezza
Browse files Browse the repository at this point in the history
feat(route/chat, route/auth): token refresh added to auth and fetchin…
  • Loading branch information
sinasezza authored Jul 29, 2024
2 parents e6f6850 + 8bd0371 commit d4a3e96
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 94 deletions.
66 changes: 54 additions & 12 deletions chatApp/config/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping
from datetime import UTC, datetime, timedelta
from datetime import datetime, timedelta
from typing import Any

from fastapi import Depends
Expand All @@ -26,6 +26,7 @@
SECRET_KEY = settings.jwt_secret_key.get_secret_value()
ALGORITHM = settings.jwt_algorithm
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days


def verify_password(plain_password: str, hashed_password: str) -> bool:
Expand All @@ -49,30 +50,39 @@ def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


def create_access_token(
data: dict[str, Any], expires_delta: timedelta | None = None
def create_token(
data: dict[str, Any],
token_type: str,
expires_delta: timedelta | None = None,
) -> str:
"""
Create a JWT access token with a specified expiration.
Create a JWT token with a specified expiration.
:param data: The data to encode into the token.
:param token_type: The type of token to create
:param expires_delta: Optional expiration time delta for the token.
:return: The encoded JWT token as a string.
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
expire = datetime.now() + expires_delta
else:
expire = datetime.now(UTC) + timedelta(
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
)
match token_type:
case "access":
expire = datetime.now() + timedelta(
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
)
case "refresh":
expire = datetime.now() + timedelta(
days=REFRESH_TOKEN_EXPIRE_DAYS
)
to_encode.update({"exp": expire})

encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def parse_access_token(token: str) -> dict[str, Any]:
def parse_token(token: str) -> dict[str, Any]:
"""
Parse and validate the given JWT token, returning its payload.
Expand All @@ -81,13 +91,45 @@ def parse_access_token(token: str) -> dict[str, Any]:
:raises credentials_exception: If the token is invalid or cannot be decoded.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={"verify_signature": False},
)
return payload
except JWTError as e:
logger.error(f"JWT error: {e}") # Log the error for debugging purposes
raise credentials_exception


def validate_token(token: str) -> bool:
"""
Validate the given JWT token by checking its expiration.
:param token: The JWT token to validate.
:return: True if the token is valid and not expired, otherwise False.
"""
try:
# Decode the token without validating the signature to check expiration
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={"verify_signature": False},
)
# Check if the token is expired
if (
payload.get("exp")
and datetime.fromtimestamp(payload["exp"]) < datetime.now()
):
return False
return True
except JWTError as e:
logger.error(f"JWT error: {e}")
return False


async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB:
"""
Retrieve the current user from the database using the provided JWT token.
Expand All @@ -97,8 +139,8 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB:
:raises credentials_exception: If the user cannot be found or the token is invalid.
"""
# Parse the token to get the payload
payload = parse_access_token(token)
username: str | None = payload.get("sub")
payload = parse_token(token)
username: str | None = payload.get("username")

if username is None:
logger.error("Username is missing in the token payload.")
Expand Down
1 change: 1 addition & 0 deletions chatApp/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Settings(BaseSettings):
jwt_secret_key: SecretStr = Field(default="your-secret-key")
jwt_algorithm: str = Field(default="HS256")
access_token_expire_minutes: int = Field(default=1440)
refresh_token_expire_days: int = Field(default=14)

# CORS settings
cors_allow_origins: list[str] = Field(default=["*"])
Expand Down
2 changes: 1 addition & 1 deletion chatApp/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class Message(BaseModel):
user_id: PydanticObjectId
room_id: str | None = Field(default=None)
room_id: PydanticObjectId
content: str = Field(default=None)
media: str = Field(default=None)
created_at: datetime = Field(default_factory=lambda: datetime.now())
Expand Down
21 changes: 21 additions & 0 deletions chatApp/models/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any

from pydantic import BaseModel, Field

from chatApp.config.database import get_users_collection
from chatApp.utils.object_id import PydanticObjectId


Expand All @@ -18,3 +21,21 @@ class User(BaseModel):

class UserInDB(User):
id: PydanticObjectId = Field(alias="_id", serialization_alias="id")


async def fetch_user_by_username(username: str) -> Mapping[str, Any] | None:
"""Fetch a user from the database by username."""
users_collection = get_users_collection()
return await users_collection.find_one({"username": username})


async def fetch_user_by_id(user_id: str) -> Mapping[str, Any] | None:
"""Fetch a user from the database by user ID."""
users_collection = get_users_collection()
return await users_collection.find_one({"_id": PydanticObjectId(user_id)})


async def fetch_user_by_email(email: str) -> Mapping[str, Any] | None:
"""Fetch a user from the database by email."""
users_collection = get_users_collection()
return await users_collection.find_one({"email": email})
108 changes: 88 additions & 20 deletions chatApp/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from chatApp.config import auth
from chatApp.config.database import get_users_collection
from chatApp.models.user import User, UserInDB
from chatApp.models.user import (
User,
UserInDB,
fetch_user_by_email,
fetch_user_by_id,
fetch_user_by_username,
)
from chatApp.schemas.user import UserCreateSchema
from chatApp.utils.exceptions import credentials_exception

Expand All @@ -16,52 +22,114 @@

@router.post("/register", response_model=User)
async def register_user(user: UserCreateSchema) -> UserInDB:
# Fetch the users_collection within the request scope
users_collection = get_users_collection()

# Check if the user already exists
existing_user: Mapping[str, Any] | None = await users_collection.find_one(
{"username": user.username}
)
existing_user = await fetch_user_by_username(user.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered",
)
existing_user = await fetch_user_by_email(user.email)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)

# Hash the password before saving
hashed_password = auth.get_password_hash(user.password)
user_dict = user.model_dump(exclude={"password"})
user_dict["hashed_password"] = hashed_password

user_dict["created_at"] = datetime.now()

# Insert user into the database
await users_collection.insert_one(user_dict)
the_user = User(**user_dict)

# Construct and return a User instance from the inserted document
return UserInDB(**user_dict)
result = await users_collection.insert_one(
the_user.model_dump(by_alias=True)
)

return UserInDB(
**the_user.model_dump(by_alias=True), _id=result.inserted_id
)


@router.post("/token", response_model=dict)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
) -> dict[str, str]:
# Attempt to authenticate the user using provided credentials
user = await auth.authenticate_user(form_data.username, form_data.password)

# Raise an exception if authentication fails
if not user:
raise credentials_exception

# Create an access token with a specific expiration time
access_token_expires = timedelta(minutes=auth.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = auth.create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
refresh_token_expires = timedelta(days=auth.REFRESH_TOKEN_EXPIRE_DAYS)
data_to_encode = {
"username": user.username,
"email": user.email,
"id": str(user.id),
}

access_token = auth.create_token(
data=data_to_encode,
token_type="access",
expires_delta=access_token_expires,
)

refresh_token = auth.create_token(
data=data_to_encode,
token_type="refresh",
expires_delta=refresh_token_expires,
)

# Return the generated token and its type
return {"access_token": access_token, "token_type": "bearer"}
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}


@router.post("/token/refresh", response_model=dict)
async def refresh_token(token: str) -> dict[str, str]:
try:
payload: dict[str, Any] = auth.parse_token(token)
if not auth.validate_token(token):
raise credentials_exception

user_id: str = payload["id"]
user: Mapping[str, Any] | None = await fetch_user_by_id(user_id)
if user is None:
raise credentials_exception

access_token_expires = timedelta(
minutes=auth.ACCESS_TOKEN_EXPIRE_MINUTES
)
refresh_token_expires = timedelta(days=auth.REFRESH_TOKEN_EXPIRE_DAYS)
data_to_encode = {
"username": user["username"],
"email": user["email"],
"id": str(user["id"]),
}

new_access_token = auth.create_token(
data=data_to_encode,
token_type="access",
expires_delta=access_token_expires,
)

new_refresh_token = auth.create_token(
data=data_to_encode,
token_type="refresh",
expires_delta=refresh_token_expires,
)

return {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
"token_type": "bearer",
}

except HTTPException:
raise credentials_exception


@router.get("/users/me/", response_model=User)
Expand Down
Loading

0 comments on commit d4a3e96

Please sign in to comment.