From 0424038ea2fbfeff38ef44257e8cd85ce5de7101 Mon Sep 17 00:00:00 2001 From: sinasezza Date: Mon, 5 Aug 2024 19:46:42 +0330 Subject: [PATCH] ref(auth): users models and authentication codes refactored. --- chatApp/config/auth.py | 37 ++++++++++++------------------ chatApp/models/user.py | 33 ++++++++++++++++++++++++++ chatApp/routes/auth.py | 47 ++++++++++++-------------------------- chatApp/routes/user.py | 25 ++++++++------------ chatApp/schemas/user.py | 12 +++++++++- chatApp/utils/object_id.py | 3 --- 6 files changed, 82 insertions(+), 75 deletions(-) diff --git a/chatApp/config/auth.py b/chatApp/config/auth.py index ce849d1..8839eb4 100644 --- a/chatApp/config/auth.py +++ b/chatApp/config/auth.py @@ -1,17 +1,14 @@ -from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any from fastapi import Depends from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt -from motor.motor_asyncio import AsyncIOMotorCollection from passlib.context import CryptContext from chatApp.config.config import get_settings -from chatApp.config.database import get_users_collection from chatApp.config.logs import logger -from chatApp.models.user import UserInDB +from chatApp.models import user as user_model from chatApp.utils.exceptions import credentials_exception settings = get_settings() @@ -130,7 +127,9 @@ def validate_token(token: str) -> bool: return False -async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB: +async def get_current_user( + token: str = Depends(oauth2_scheme), +) -> user_model.UserInDB: """ Retrieve the current user from the database using the provided JWT token. @@ -146,12 +145,8 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB: logger.error("Username is missing in the token payload.") raise credentials_exception - # Fetch the users_collection within the request scope - users_collection: AsyncIOMotorCollection = get_users_collection() - - # Properly type the result of the find_one query - user: Mapping[str, Any] | None = await users_collection.find_one( - {"username": username} + user: user_model.UserInDB | None = await user_model.fetch_user_by_username( + username ) # Raise an exception if no user was found @@ -159,22 +154,18 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB: logger.error(f"User with username {username} not found in database.") raise credentials_exception - # Construct and return a User instance from the found document - return UserInDB(**user) - + return user -async def authenticate_user(username: str, password: str) -> UserInDB | None: - # Fetch the users_collection within the request scope - users_collection: AsyncIOMotorCollection = get_users_collection() - # Properly type the result of the find_one query - user: Mapping[str, Any] | None = await users_collection.find_one( - {"username": username} +async def authenticate_user( + username: str, password: str +) -> user_model.UserInDB | None: + user: user_model.UserInDB | None = await user_model.fetch_user_by_username( + username ) # Return None if no user was found or if password verification fails - if user is None or not verify_password(password, user["hashed_password"]): + if user is None or not verify_password(password, user.hashed_password): return None - # Construct and return a User instance from the found document - return UserInDB(**user) + return user diff --git a/chatApp/models/user.py b/chatApp/models/user.py index 0f6c826..4ec84ba 100644 --- a/chatApp/models/user.py +++ b/chatApp/models/user.py @@ -1,7 +1,11 @@ +from collections.abc import Mapping from datetime import datetime +from typing import Any +from motor.motor_asyncio import AsyncIOMotorCursor from pydantic import BaseModel, Field +from chatApp.config import auth from chatApp.config.database import get_users_collection from chatApp.utils.object_id import PydanticObjectId @@ -21,6 +25,18 @@ class UserInDB(User): id: PydanticObjectId = Field(alias="_id", serialization_alias="id") +async def get_all_users() -> list[Mapping[str, Any]]: + users_collection = get_users_collection() + + # Query to get users with necessary fields projected + cursor: AsyncIOMotorCursor = users_collection.find( + {}, {"_id": 1, "username": 1, "created_at": 1} + ) + + # Collect all users into a list of UserInDB objects + return await cursor.to_list(length=None) + + async def fetch_user_by_username(username: str) -> UserInDB | None: """Fetch a user from the database by username.""" users_collection = get_users_collection() @@ -40,3 +56,20 @@ async def fetch_user_by_email(email: str) -> UserInDB | None: users_collection = get_users_collection() user = await users_collection.find_one({"email": email}) return UserInDB(**user) if user else None + + +async def create_user(user_dict: dict[str, Any]) -> UserInDB: + """Create a new user in the database.""" + users_collection = get_users_collection() + + user_dict["created_at"] = datetime.now() + user_dict["updated_at"] = datetime.now() + user_dict["last_login"] = datetime.now() + user_dict["hashed_password"] = auth.get_password_hash( + user_dict["password"] + ) + + result = await users_collection.insert_one(user_dict) + user_dict["_id"] = str(result.inserted_id) + + return UserInDB(**user_dict) diff --git a/chatApp/routes/auth.py b/chatApp/routes/auth.py index 6017b3f..e00e937 100644 --- a/chatApp/routes/auth.py +++ b/chatApp/routes/auth.py @@ -1,55 +1,36 @@ -from datetime import datetime, timedelta +from datetime import timedelta from typing import Any from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from chatApp.config import auth -from chatApp.config.database import get_users_collection -from chatApp.models.user import ( - User, - UserInDB, - fetch_user_by_email, - fetch_user_by_id, - fetch_user_by_username, -) +from chatApp.models import user as user_model from chatApp.schemas.user import UserCreateSchema from chatApp.utils.exceptions import credentials_exception router = APIRouter() -@router.post("/register", response_model=User) -async def register_user(user: UserCreateSchema) -> UserInDB: - users_collection = get_users_collection() - - existing_user = await fetch_user_by_username(user.username) +@router.post("/register", response_model=user_model.UserInDB) +async def register_user(user_info: UserCreateSchema) -> user_model.UserInDB: + existing_user = await user_model.fetch_user_by_username(user_info.username) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered", ) - existing_user = await fetch_user_by_email(user.email) + existing_user = await user_model.fetch_user_by_email(user_info.email) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered", ) - hashed_password = auth.get_password_hash(user.password) - user_dict = user.model_dump(exclude={"password"}) - user_dict["hashed_password"] = hashed_password - user_dict["created_at"] = datetime.now() - - the_user = User(**user_dict) - - result = await users_collection.insert_one( - the_user.model_dump(by_alias=True) - ) + user_dict = user_info.model_dump() + user = await user_model.create_user(user_dict) - return UserInDB( - **the_user.model_dump(by_alias=True), _id=result.inserted_id - ) + return user @router.post("/token", response_model=dict) @@ -95,7 +76,9 @@ async def refresh_token(token: str) -> dict[str, str]: raise credentials_exception user_id: str = payload["id"] - user: UserInDB | None = await fetch_user_by_id(user_id) + user: user_model.UserInDB | None = await user_model.fetch_user_by_id( + user_id + ) if user is None: raise credentials_exception @@ -131,8 +114,8 @@ async def refresh_token(token: str) -> dict[str, str]: raise credentials_exception -@router.get("/users/me/", response_model=User) +@router.get("/users/me/", response_model=user_model.UserInDB) async def read_users_me( - current_user: UserInDB = Depends(auth.get_current_user), -) -> UserInDB: + current_user: user_model.UserInDB = Depends(auth.get_current_user), +) -> user_model.UserInDB: return current_user diff --git a/chatApp/routes/user.py b/chatApp/routes/user.py index a59fa86..ca78894 100644 --- a/chatApp/routes/user.py +++ b/chatApp/routes/user.py @@ -2,25 +2,18 @@ from typing import Any from fastapi import APIRouter -from motor.motor_asyncio import AsyncIOMotorCursor -from chatApp.config.database import get_users_collection -from chatApp.models.user import User +from chatApp.models import user as user_model +from chatApp.schemas.user import UserListSchema router = APIRouter() -@router.get("/", response_model=list[User]) -async def get_users() -> list[User]: - users_collection = get_users_collection() +@router.get("/", response_model=Mapping[str, Any]) +async def get_all_users(): + users: list[Mapping[str, Any]] = await user_model.get_all_users() - # Perform the query to get an async cursor - cursor: AsyncIOMotorCursor = users_collection.find() - - # Collect all users into a list of dictionaries - users_dicts: list[Mapping[str, Any]] = await cursor.to_list(length=None) - - # Convert each dictionary to a User object - users: list[User] = [User(**user_dict) for user_dict in users_dicts] - - return users + return { + "users": [UserListSchema(**user) for user in users], + "count": len(users), + } diff --git a/chatApp/schemas/user.py b/chatApp/schemas/user.py index 1a9c0c8..d55fc7c 100644 --- a/chatApp/schemas/user.py +++ b/chatApp/schemas/user.py @@ -1,7 +1,17 @@ -from pydantic import BaseModel +from datetime import datetime + +from pydantic import BaseModel, Field + +from chatApp.utils.object_id import PydanticObjectId class UserCreateSchema(BaseModel): username: str email: str password: str + + +class UserListSchema(BaseModel): + id: PydanticObjectId = Field(alias="_id", serialization_alias="id") + username: str + created_at: datetime = datetime.now() diff --git a/chatApp/utils/object_id.py b/chatApp/utils/object_id.py index 69a2037..b9831f4 100644 --- a/chatApp/utils/object_id.py +++ b/chatApp/utils/object_id.py @@ -6,8 +6,6 @@ class _ObjectIdPydanticAnnotation: - # Based on https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types. - @classmethod def __get_pydantic_core_schema__( cls, @@ -19,7 +17,6 @@ def validate_from_str(input_value: str) -> ObjectId: return core_schema.union_schema( [ - # check if it's an instance first before doing any further work core_schema.is_instance_schema(ObjectId), core_schema.no_info_plain_validator_function( validate_from_str