diff --git a/chatApp/config/auth.py b/chatApp/config/auth.py index ccd6440..ce849d1 100644 --- a/chatApp/config/auth.py +++ b/chatApp/config/auth.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta from typing import Any from fastapi import Depends @@ -26,6 +26,7 @@ SECRET_KEY = settings.jwt_secret_key.get_secret_value() ALGORITHM = settings.jwt_algorithm ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes +REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days def verify_password(plain_password: str, hashed_password: str) -> bool: @@ -49,30 +50,39 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) -def create_access_token( - data: dict[str, Any], expires_delta: timedelta | None = None +def create_token( + data: dict[str, Any], + token_type: str, + expires_delta: timedelta | None = None, ) -> str: """ - Create a JWT access token with a specified expiration. + Create a JWT token with a specified expiration. :param data: The data to encode into the token. + :param token_type: The type of token to create :param expires_delta: Optional expiration time delta for the token. :return: The encoded JWT token as a string. """ to_encode = data.copy() if expires_delta: - expire = datetime.now(UTC) + expires_delta + expire = datetime.now() + expires_delta else: - expire = datetime.now(UTC) + timedelta( - minutes=ACCESS_TOKEN_EXPIRE_MINUTES - ) + match token_type: + case "access": + expire = datetime.now() + timedelta( + minutes=ACCESS_TOKEN_EXPIRE_MINUTES + ) + case "refresh": + expire = datetime.now() + timedelta( + days=REFRESH_TOKEN_EXPIRE_DAYS + ) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -def parse_access_token(token: str) -> dict[str, Any]: +def parse_token(token: str) -> dict[str, Any]: """ Parse and validate the given JWT token, returning its payload. @@ -81,13 +91,45 @@ def parse_access_token(token: str) -> dict[str, Any]: :raises credentials_exception: If the token is invalid or cannot be decoded. """ try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + payload = jwt.decode( + token, + SECRET_KEY, + algorithms=[ALGORITHM], + options={"verify_signature": False}, + ) return payload except JWTError as e: logger.error(f"JWT error: {e}") # Log the error for debugging purposes raise credentials_exception +def validate_token(token: str) -> bool: + """ + Validate the given JWT token by checking its expiration. + + :param token: The JWT token to validate. + :return: True if the token is valid and not expired, otherwise False. + """ + try: + # Decode the token without validating the signature to check expiration + payload = jwt.decode( + token, + SECRET_KEY, + algorithms=[ALGORITHM], + options={"verify_signature": False}, + ) + # Check if the token is expired + if ( + payload.get("exp") + and datetime.fromtimestamp(payload["exp"]) < datetime.now() + ): + return False + return True + except JWTError as e: + logger.error(f"JWT error: {e}") + return False + + async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB: """ Retrieve the current user from the database using the provided JWT token. @@ -97,8 +139,8 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB: :raises credentials_exception: If the user cannot be found or the token is invalid. """ # Parse the token to get the payload - payload = parse_access_token(token) - username: str | None = payload.get("sub") + payload = parse_token(token) + username: str | None = payload.get("username") if username is None: logger.error("Username is missing in the token payload.") diff --git a/chatApp/config/config.py b/chatApp/config/config.py index 069c30b..4891c7e 100644 --- a/chatApp/config/config.py +++ b/chatApp/config/config.py @@ -22,6 +22,7 @@ class Settings(BaseSettings): jwt_secret_key: SecretStr = Field(default="your-secret-key") jwt_algorithm: str = Field(default="HS256") access_token_expire_minutes: int = Field(default=1440) + refresh_token_expire_days: int = Field(default=14) # CORS settings cors_allow_origins: list[str] = Field(default=["*"]) diff --git a/chatApp/models/message.py b/chatApp/models/message.py index 88f7acc..6691fdc 100644 --- a/chatApp/models/message.py +++ b/chatApp/models/message.py @@ -7,7 +7,7 @@ class Message(BaseModel): user_id: PydanticObjectId - room_id: str | None = Field(default=None) + room_id: PydanticObjectId content: str = Field(default=None) media: str = Field(default=None) created_at: datetime = Field(default_factory=lambda: datetime.now()) diff --git a/chatApp/models/user.py b/chatApp/models/user.py index 01cda67..e26f046 100644 --- a/chatApp/models/user.py +++ b/chatApp/models/user.py @@ -1,7 +1,10 @@ +from collections.abc import Mapping from datetime import datetime +from typing import Any from pydantic import BaseModel, Field +from chatApp.config.database import get_users_collection from chatApp.utils.object_id import PydanticObjectId @@ -18,3 +21,21 @@ class User(BaseModel): class UserInDB(User): id: PydanticObjectId = Field(alias="_id", serialization_alias="id") + + +async def fetch_user_by_username(username: str) -> Mapping[str, Any] | None: + """Fetch a user from the database by username.""" + users_collection = get_users_collection() + return await users_collection.find_one({"username": username}) + + +async def fetch_user_by_id(user_id: str) -> Mapping[str, Any] | None: + """Fetch a user from the database by user ID.""" + users_collection = get_users_collection() + return await users_collection.find_one({"_id": PydanticObjectId(user_id)}) + + +async def fetch_user_by_email(email: str) -> Mapping[str, Any] | None: + """Fetch a user from the database by email.""" + users_collection = get_users_collection() + return await users_collection.find_one({"email": email}) diff --git a/chatApp/routes/auth.py b/chatApp/routes/auth.py index b9dde51..1d107b3 100644 --- a/chatApp/routes/auth.py +++ b/chatApp/routes/auth.py @@ -7,7 +7,13 @@ from chatApp.config import auth from chatApp.config.database import get_users_collection -from chatApp.models.user import User, UserInDB +from chatApp.models.user import ( + User, + UserInDB, + fetch_user_by_email, + fetch_user_by_id, + fetch_user_by_username, +) from chatApp.schemas.user import UserCreateSchema from chatApp.utils.exceptions import credentials_exception @@ -16,52 +22,114 @@ @router.post("/register", response_model=User) async def register_user(user: UserCreateSchema) -> UserInDB: - # Fetch the users_collection within the request scope users_collection = get_users_collection() - # Check if the user already exists - existing_user: Mapping[str, Any] | None = await users_collection.find_one( - {"username": user.username} - ) + existing_user = await fetch_user_by_username(user.username) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered", ) + existing_user = await fetch_user_by_email(user.email) + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered", + ) - # Hash the password before saving hashed_password = auth.get_password_hash(user.password) user_dict = user.model_dump(exclude={"password"}) user_dict["hashed_password"] = hashed_password - user_dict["created_at"] = datetime.now() - # Insert user into the database - await users_collection.insert_one(user_dict) + the_user = User(**user_dict) - # Construct and return a User instance from the inserted document - return UserInDB(**user_dict) + result = await users_collection.insert_one( + the_user.model_dump(by_alias=True) + ) + + return UserInDB( + **the_user.model_dump(by_alias=True), _id=result.inserted_id + ) @router.post("/token", response_model=dict) async def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), ) -> dict[str, str]: - # Attempt to authenticate the user using provided credentials user = await auth.authenticate_user(form_data.username, form_data.password) - - # Raise an exception if authentication fails if not user: raise credentials_exception - # Create an access token with a specific expiration time access_token_expires = timedelta(minutes=auth.ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = auth.create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires + refresh_token_expires = timedelta(days=auth.REFRESH_TOKEN_EXPIRE_DAYS) + data_to_encode = { + "username": user.username, + "email": user.email, + "id": str(user.id), + } + + access_token = auth.create_token( + data=data_to_encode, + token_type="access", + expires_delta=access_token_expires, + ) + + refresh_token = auth.create_token( + data=data_to_encode, + token_type="refresh", + expires_delta=refresh_token_expires, ) - # Return the generated token and its type - return {"access_token": access_token, "token_type": "bearer"} + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + } + + +@router.post("/token/refresh", response_model=dict) +async def refresh_token(token: str) -> dict[str, str]: + try: + payload: dict[str, Any] = auth.parse_token(token) + if not auth.validate_token(token): + raise credentials_exception + + user_id: str = payload["id"] + user: Mapping[str, Any] | None = await fetch_user_by_id(user_id) + if user is None: + raise credentials_exception + + access_token_expires = timedelta( + minutes=auth.ACCESS_TOKEN_EXPIRE_MINUTES + ) + refresh_token_expires = timedelta(days=auth.REFRESH_TOKEN_EXPIRE_DAYS) + data_to_encode = { + "username": user["username"], + "email": user["email"], + "id": str(user["id"]), + } + + new_access_token = auth.create_token( + data=data_to_encode, + token_type="access", + expires_delta=access_token_expires, + ) + + new_refresh_token = auth.create_token( + data=data_to_encode, + token_type="refresh", + expires_delta=refresh_token_expires, + ) + + return { + "access_token": new_access_token, + "refresh_token": new_refresh_token, + "token_type": "bearer", + } + + except HTTPException: + raise credentials_exception @router.get("/users/me/", response_model=User) diff --git a/chatApp/routes/chat.py b/chatApp/routes/chat.py index 44cc164..5831f4b 100644 --- a/chatApp/routes/chat.py +++ b/chatApp/routes/chat.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import Any -from bson import ObjectId from fastapi import APIRouter, Depends, HTTPException, Path from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorCursor @@ -16,7 +15,6 @@ from chatApp.models.private_room import PrivateRoom, PrivateRoomInDB from chatApp.models.public_room import PublicRoom, PublicRoomInDB from chatApp.models.user import UserInDB -from chatApp.schemas.message import MessageCreateSchema from chatApp.schemas.private_room import CreatePrivateRoom from chatApp.schemas.public_room import CreatePublicRoom from chatApp.utils.object_id import PydanticObjectId, is_valid_object_id @@ -24,65 +22,6 @@ router = APIRouter() -@router.get("/all-messages", response_model=list[MessageInDB]) -async def get_all_messages(): - messages_collection: AsyncIOMotorCollection = get_messages_collection() - - cursor: AsyncIOMotorCursor = messages_collection.find() - messages_dicts: list[Mapping[str, Any]] = await cursor.to_list(length=None) - messages: list[MessageInDB] = [ - MessageInDB(**message_dict) for message_dict in messages_dicts - ] - - return messages - - -@router.get("/messages", response_model=list[MessageInDB]) -async def get_messages(user: UserInDB = Depends(auth.get_current_user)): - messages_collection: AsyncIOMotorCollection = get_messages_collection() - - cursor: AsyncIOMotorCursor = messages_collection.find( - {"user_id": ObjectId(user.id)} - ) - messages_dicts: list[Mapping[str, Any]] = await cursor.to_list(length=None) - messages: list[MessageInDB] = [ - MessageInDB(**message_dict) for message_dict in messages_dicts - ] - - return messages - - -@router.get("/message/{message_id}", response_model=MessageInDB) -async def get_message( - message_id: str, user: UserInDB = Depends(auth.get_current_user) -): - messages_collection: AsyncIOMotorCollection = get_messages_collection() - - message = await messages_collection.find_one( - {"_id": ObjectId(message_id), "user_id": ObjectId(user.id)} - ) - - if message is None: - raise HTTPException(status_code=404, detail="Message not found") - - return MessageInDB(**message) - - -@router.post("/message", response_model=MessageInDB) -async def create_message( - message: MessageCreateSchema, - user: UserInDB = Depends(auth.get_current_user), -): - messages_collection = get_messages_collection() - - message_dict = message.model_dump() - message_dict["user_id"] = user.id - - await messages_collection.insert_one(message_dict) - - return MessageInDB(**message_dict) - - @router.post("/create-public-room", response_model=PublicRoomInDB) async def create_public_room( room_info: CreatePublicRoom, @@ -216,3 +155,74 @@ async def get_private_room( ) return PrivateRoomInDB(**room_data) + + +@router.get("/messages/public/{room_id}", response_model=list[MessageInDB]) +async def get_messages_of_public_room( + room_id: str = Path(..., description="id of the public room"), + user: UserInDB = Depends(auth.get_current_user), +): + messages_collection: AsyncIOMotorCollection = get_messages_collection() + public_room_collection: AsyncIOMotorCollection = ( + get_public_rooms_collection() + ) + + if not is_valid_object_id(room_id): + raise HTTPException(status_code=400, detail="Invalid room id") + + room = await public_room_collection.find_one( + {"_id": PydanticObjectId(room_id)} + ) + if room is None: + raise HTTPException(status_code=404, detail="Public room not found") + + if user.id not in room["members"]: + raise HTTPException( + status_code=403, detail="User not a member of the room" + ) + + cursor: AsyncIOMotorCursor = messages_collection.find( + {"room_id": PydanticObjectId(room_id)} + ) + messages_dicts: list[Mapping[str, Any]] = await cursor.to_list(length=None) + messages: list[MessageInDB] = [ + MessageInDB(**message_dict) for message_dict in messages_dicts + ] + + return messages + + +@router.get("/messages/private/{room_id}", response_model=MessageInDB) +async def get_message_of_private_room( + room_id: str = Path(..., description="id of the private room"), + user: UserInDB = Depends(auth.get_current_user), +): + messages_collection: AsyncIOMotorCollection = get_messages_collection() + private_rooms_collection: AsyncIOMotorCollection = ( + get_private_rooms_collection() + ) + + if not is_valid_object_id(room_id): + raise HTTPException(status_code=400, detail="Invalid room id") + + room = await private_rooms_collection.find_one( + {"_id": PydanticObjectId(room_id)} + ) + if room is None: + raise HTTPException(status_code=404, detail="Private room not found") + + if user.id not in [room["member1"], room["member2"]]: + raise HTTPException( + status_code=403, detail="User not a member of the room" + ) + + # Get the last message from the room + cursor: AsyncIOMotorCursor = messages_collection.find( + {"room_id": PydanticObjectId(room_id)} + ) + messages_dicts: list[Mapping[str, Any]] = await cursor.to_list(length=None) + messages: list[MessageInDB] = [ + MessageInDB(**message_dict) for message_dict in messages_dicts + ] + + return messages