diff --git a/chatApp/config/auth.py b/chatApp/config/auth.py index 8839eb4..7c74608 100644 --- a/chatApp/config/auth.py +++ b/chatApp/config/auth.py @@ -2,22 +2,16 @@ from typing import Any from fastapi import Depends -from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt -from passlib.context import CryptContext from chatApp.config.config import get_settings from chatApp.config.logs import logger from chatApp.models import user as user_model +from chatApp.utils import hasher from chatApp.utils.exceptions import credentials_exception settings = get_settings() -# Password hashing context -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -# OAuth2 scheme -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") # JWT settings SECRET_KEY = settings.jwt_secret_key.get_secret_value() @@ -26,27 +20,6 @@ REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days -def verify_password(plain_password: str, hashed_password: str) -> bool: - """ - Verify if the provided password matches the stored hashed password. - - :param plain_password: The plain text password. - :param hashed_password: The hashed password stored in the database. - :return: True if passwords match, otherwise False. - """ - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password: str) -> str: - """ - Hash the given password using the password hashing context. - - :param password: The plain text password to hash. - :return: The hashed password. - """ - return pwd_context.hash(password) - - def create_token( data: dict[str, Any], token_type: str, @@ -128,7 +101,7 @@ def validate_token(token: str) -> bool: async def get_current_user( - token: str = Depends(oauth2_scheme), + token: str = Depends(hasher.oauth2_scheme), ) -> user_model.UserInDB: """ Retrieve the current user from the database using the provided JWT token. @@ -165,7 +138,9 @@ async def authenticate_user( ) # Return None if no user was found or if password verification fails - if user is None or not verify_password(password, user.hashed_password): + if user is None or not hasher.verify_password( + password, user.hashed_password + ): return None return user diff --git a/chatApp/config/config.py b/chatApp/config/config.py index c62a495..d5d2227 100644 --- a/chatApp/config/config.py +++ b/chatApp/config/config.py @@ -17,6 +17,9 @@ class Settings(BaseSettings): database_name: str = Field(default="chat_app") max_pool_size: int = 10 min_pool_size: int = 1 + test_database_url: str = Field(default="mongodb://localhost:27017") + test_database_name: str = Field(default="test_chat_app") + test_mode: bool = Field(default=False) # jwt settings jwt_secret_key: SecretStr = Field(default="your-secret-key") diff --git a/chatApp/config/database.py b/chatApp/config/database.py index 2729bd9..dbec7ba 100644 --- a/chatApp/config/database.py +++ b/chatApp/config/database.py @@ -16,24 +16,36 @@ class MongoDB: - def __init__(self) -> None: + def __init__(self, test_db: bool = False) -> None: self.db_client: AsyncIOMotorClient | None = None self.db: AsyncIOMotorDatabase | None = None self.users_collection: AsyncIOMotorCollection | None = None self.messages_collection: AsyncIOMotorCollection | None = None self.public_rooms_collection: AsyncIOMotorCollection | None = None self.private_rooms_collection: AsyncIOMotorCollection | None = None + self.test_db: bool = test_db async def connect_to_mongodb(self) -> None: try: + db_url = ( + settings.test_database_url + if self.test_db + else settings.database_url + ) + db_name = ( + settings.test_database_name + if self.test_db + else settings.database_name + ) + self.db_client = AsyncIOMotorClient( - settings.database_url, + db_url, maxPoolSize=settings.max_pool_size, minPoolSize=settings.min_pool_size, ) assert self.db_client is not None - self.db = self.db_client[settings.database_name] + self.db = self.db_client[db_name] assert self.db is not None # Define collections and schema validations @@ -41,7 +53,9 @@ async def connect_to_mongodb(self) -> None: # Ping the server to validate the connection await self.db_client.admin.command("ismaster") - logger.info("Connected to MongoDB") + logger.info( + f"Connected to MongoDB {'test' if self.test_db else ''} database" + ) except Exception as e: logger.error(f"Could not connect to MongoDB: {e}") raise @@ -196,21 +210,23 @@ async def close_mongodb_connection(self) -> None: self.db_client.close() logger.info("Closed MongoDB connection") + async def drop_database(self) -> None: + if self.db_client and self.db is not None: + await self.db_client.drop_database(self.db.name) + logger.info(f"Dropped database {self.db.name}") + mongo_db = None -async def init_mongo_db(): +async def init_mongo_db(test_db: bool = False) -> MongoDB: global mongo_db - mongo_db = MongoDB() + mongo_db = MongoDB(test_db=test_db) await mongo_db.connect_to_mongodb() return mongo_db -async def shutdown_mongo_db(): - """ - Close the MongoDB connection. - """ +async def shutdown_mongo_db() -> None: global mongo_db if mongo_db is not None: await mongo_db.close_mongodb_connection() diff --git a/chatApp/main.py b/chatApp/main.py index 94dbd62..88ef31b 100644 --- a/chatApp/main.py +++ b/chatApp/main.py @@ -1,3 +1,6 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -13,13 +16,22 @@ settings = get_settings() +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[None]: + # This function will be called on startup and shutdown + await init_mongo_db(test_db=settings.test_mode) + try: + yield + finally: + await shutdown_mongo_db() + + # Create a FastAPI app instance app = FastAPI( title="FastAPI Chat App", description="A chat application built with FastAPI and socket.io", version="1.0.0", - on_startup=[init_mongo_db], - on_shutdown=[shutdown_mongo_db], + lifespan=lifespan, ) ### Add middlewares ### diff --git a/chatApp/models/user.py b/chatApp/models/user.py index 4ec84ba..15b0ac0 100644 --- a/chatApp/models/user.py +++ b/chatApp/models/user.py @@ -5,8 +5,8 @@ from motor.motor_asyncio import AsyncIOMotorCursor from pydantic import BaseModel, Field -from chatApp.config import auth from chatApp.config.database import get_users_collection +from chatApp.utils import hasher from chatApp.utils.object_id import PydanticObjectId @@ -65,7 +65,7 @@ async def create_user(user_dict: dict[str, Any]) -> UserInDB: user_dict["created_at"] = datetime.now() user_dict["updated_at"] = datetime.now() user_dict["last_login"] = datetime.now() - user_dict["hashed_password"] = auth.get_password_hash( + user_dict["hashed_password"] = hasher.get_password_hash( user_dict["password"] ) diff --git a/chatApp/utils/hasher.py b/chatApp/utils/hasher.py new file mode 100644 index 0000000..0b7d018 --- /dev/null +++ b/chatApp/utils/hasher.py @@ -0,0 +1,33 @@ +from fastapi.security import OAuth2PasswordBearer +from passlib.context import CryptContext + +from chatApp.config.config import get_settings + +settings = get_settings() + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# OAuth2 scheme +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """ + Verify if the provided password matches the stored hashed password. + + :param plain_password: The plain text password. + :param hashed_password: The hashed password stored in the database. + :return: True if passwords match, otherwise False. + """ + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + """ + Hash the given password using the password hashing context. + + :param password: The plain text password to hash. + :return: The hashed password. + """ + return pwd_context.hash(password) diff --git a/pytest.ini b/pytest.ini index 2f4c80e..a28ad03 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,3 @@ [pytest] asyncio_mode = auto +addopts = -p no:warnings -vv diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..8c069c2 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,64 @@ +import pytest + +from chatApp.config import database +from chatApp.config.database import mongo_db + + +@pytest.fixture(scope="session") +async def db(): + # Initialize the test database + global mongo_db + print(f"mongodb is {mongo_db}") + mongo_db = await database.init_mongo_db(test_db=True) + yield mongo_db + # Clean up the test database + await database.shutdown_mongo_db() + + +@pytest.fixture +async def users_collection(db): + return db.users_collection + + +@pytest.fixture +async def messages_collection(db): + return db.messages_collection + + +@pytest.fixture +async def public_rooms_collection(db): + return db.public_rooms_collection + + +@pytest.fixture +async def private_rooms_collection(db): + return db.private_rooms_collection + + +@pytest.fixture +async def test_user(): + return { + "username": "test_user", + "email": "test@test.com", + "password": "test_password", + } + + +@pytest.fixture +async def test_room(): + return {"name": "test_room"} + + +@pytest.fixture +async def test_message(): + return {"sender": "test_user", "text": "test_message"} + + +@pytest.fixture +async def test_private_room(): + return {"name": "test_private_room", "users": ["test_user"]} + + +@pytest.fixture +async def test_public_room(): + return {"name": "test_public_room"} diff --git a/tests/unit/test_user.py b/tests/unit/test_user.py new file mode 100644 index 0000000..e69de29