Skip to content

Commit

Permalink
ref(auth): users models and authentication codes refactored.
Browse files Browse the repository at this point in the history
  • Loading branch information
sinasezza committed Aug 5, 2024
1 parent 6bc448c commit 0424038
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 75 deletions.
37 changes: 14 additions & 23 deletions chatApp/config/auth.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from collections.abc import Mapping
from datetime import datetime, timedelta
from typing import Any

from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from motor.motor_asyncio import AsyncIOMotorCollection
from passlib.context import CryptContext

from chatApp.config.config import get_settings
from chatApp.config.database import get_users_collection
from chatApp.config.logs import logger
from chatApp.models.user import UserInDB
from chatApp.models import user as user_model
from chatApp.utils.exceptions import credentials_exception

settings = get_settings()
Expand Down Expand Up @@ -130,7 +127,9 @@ def validate_token(token: str) -> bool:
return False


async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB:
async def get_current_user(
token: str = Depends(oauth2_scheme),
) -> user_model.UserInDB:
"""
Retrieve the current user from the database using the provided JWT token.
Expand All @@ -146,35 +145,27 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB:
logger.error("Username is missing in the token payload.")
raise credentials_exception

# Fetch the users_collection within the request scope
users_collection: AsyncIOMotorCollection = get_users_collection()

# Properly type the result of the find_one query
user: Mapping[str, Any] | None = await users_collection.find_one(
{"username": username}
user: user_model.UserInDB | None = await user_model.fetch_user_by_username(
username
)

# Raise an exception if no user was found
if user is None:
logger.error(f"User with username {username} not found in database.")
raise credentials_exception

# Construct and return a User instance from the found document
return UserInDB(**user)

return user

async def authenticate_user(username: str, password: str) -> UserInDB | None:
# Fetch the users_collection within the request scope
users_collection: AsyncIOMotorCollection = get_users_collection()

# Properly type the result of the find_one query
user: Mapping[str, Any] | None = await users_collection.find_one(
{"username": username}
async def authenticate_user(
username: str, password: str
) -> user_model.UserInDB | None:
user: user_model.UserInDB | None = await user_model.fetch_user_by_username(
username
)

# Return None if no user was found or if password verification fails
if user is None or not verify_password(password, user["hashed_password"]):
if user is None or not verify_password(password, user.hashed_password):
return None

# Construct and return a User instance from the found document
return UserInDB(**user)
return user
33 changes: 33 additions & 0 deletions chatApp/models/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any

from motor.motor_asyncio import AsyncIOMotorCursor
from pydantic import BaseModel, Field

from chatApp.config import auth
from chatApp.config.database import get_users_collection
from chatApp.utils.object_id import PydanticObjectId

Expand All @@ -21,6 +25,18 @@ class UserInDB(User):
id: PydanticObjectId = Field(alias="_id", serialization_alias="id")


async def get_all_users() -> list[Mapping[str, Any]]:
users_collection = get_users_collection()

# Query to get users with necessary fields projected
cursor: AsyncIOMotorCursor = users_collection.find(
{}, {"_id": 1, "username": 1, "created_at": 1}
)

# Collect all users into a list of UserInDB objects
return await cursor.to_list(length=None)


async def fetch_user_by_username(username: str) -> UserInDB | None:
"""Fetch a user from the database by username."""
users_collection = get_users_collection()
Expand All @@ -40,3 +56,20 @@ async def fetch_user_by_email(email: str) -> UserInDB | None:
users_collection = get_users_collection()
user = await users_collection.find_one({"email": email})
return UserInDB(**user) if user else None


async def create_user(user_dict: dict[str, Any]) -> UserInDB:
"""Create a new user in the database."""
users_collection = get_users_collection()

user_dict["created_at"] = datetime.now()
user_dict["updated_at"] = datetime.now()
user_dict["last_login"] = datetime.now()
user_dict["hashed_password"] = auth.get_password_hash(
user_dict["password"]
)

result = await users_collection.insert_one(user_dict)
user_dict["_id"] = str(result.inserted_id)

return UserInDB(**user_dict)
47 changes: 15 additions & 32 deletions chatApp/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,36 @@
from datetime import datetime, timedelta
from datetime import timedelta
from typing import Any

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm

from chatApp.config import auth
from chatApp.config.database import get_users_collection
from chatApp.models.user import (
User,
UserInDB,
fetch_user_by_email,
fetch_user_by_id,
fetch_user_by_username,
)
from chatApp.models import user as user_model
from chatApp.schemas.user import UserCreateSchema
from chatApp.utils.exceptions import credentials_exception

router = APIRouter()


@router.post("/register", response_model=User)
async def register_user(user: UserCreateSchema) -> UserInDB:
users_collection = get_users_collection()

existing_user = await fetch_user_by_username(user.username)
@router.post("/register", response_model=user_model.UserInDB)
async def register_user(user_info: UserCreateSchema) -> user_model.UserInDB:
existing_user = await user_model.fetch_user_by_username(user_info.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered",
)
existing_user = await fetch_user_by_email(user.email)
existing_user = await user_model.fetch_user_by_email(user_info.email)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)

hashed_password = auth.get_password_hash(user.password)
user_dict = user.model_dump(exclude={"password"})
user_dict["hashed_password"] = hashed_password
user_dict["created_at"] = datetime.now()

the_user = User(**user_dict)

result = await users_collection.insert_one(
the_user.model_dump(by_alias=True)
)
user_dict = user_info.model_dump()
user = await user_model.create_user(user_dict)

return UserInDB(
**the_user.model_dump(by_alias=True), _id=result.inserted_id
)
return user


@router.post("/token", response_model=dict)
Expand Down Expand Up @@ -95,7 +76,9 @@ async def refresh_token(token: str) -> dict[str, str]:
raise credentials_exception

user_id: str = payload["id"]
user: UserInDB | None = await fetch_user_by_id(user_id)
user: user_model.UserInDB | None = await user_model.fetch_user_by_id(
user_id
)
if user is None:
raise credentials_exception

Expand Down Expand Up @@ -131,8 +114,8 @@ async def refresh_token(token: str) -> dict[str, str]:
raise credentials_exception


@router.get("/users/me/", response_model=User)
@router.get("/users/me/", response_model=user_model.UserInDB)
async def read_users_me(
current_user: UserInDB = Depends(auth.get_current_user),
) -> UserInDB:
current_user: user_model.UserInDB = Depends(auth.get_current_user),
) -> user_model.UserInDB:
return current_user
25 changes: 9 additions & 16 deletions chatApp/routes/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,18 @@
from typing import Any

from fastapi import APIRouter
from motor.motor_asyncio import AsyncIOMotorCursor

from chatApp.config.database import get_users_collection
from chatApp.models.user import User
from chatApp.models import user as user_model
from chatApp.schemas.user import UserListSchema

router = APIRouter()


@router.get("/", response_model=list[User])
async def get_users() -> list[User]:
users_collection = get_users_collection()
@router.get("/", response_model=Mapping[str, Any])
async def get_all_users():
users: list[Mapping[str, Any]] = await user_model.get_all_users()

# Perform the query to get an async cursor
cursor: AsyncIOMotorCursor = users_collection.find()

# Collect all users into a list of dictionaries
users_dicts: list[Mapping[str, Any]] = await cursor.to_list(length=None)

# Convert each dictionary to a User object
users: list[User] = [User(**user_dict) for user_dict in users_dicts]

return users
return {
"users": [UserListSchema(**user) for user in users],
"count": len(users),
}
12 changes: 11 additions & 1 deletion chatApp/schemas/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from pydantic import BaseModel
from datetime import datetime

from pydantic import BaseModel, Field

from chatApp.utils.object_id import PydanticObjectId


class UserCreateSchema(BaseModel):
username: str
email: str
password: str


class UserListSchema(BaseModel):
id: PydanticObjectId = Field(alias="_id", serialization_alias="id")
username: str
created_at: datetime = datetime.now()
3 changes: 0 additions & 3 deletions chatApp/utils/object_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@


class _ObjectIdPydanticAnnotation:
# Based on https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types.

@classmethod
def __get_pydantic_core_schema__(
cls,
Expand All @@ -19,7 +17,6 @@ def validate_from_str(input_value: str) -> ObjectId:

return core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(ObjectId),
core_schema.no_info_plain_validator_function(
validate_from_str
Expand Down

0 comments on commit 0424038

Please sign in to comment.