Skip to content

Commit

Permalink
feat(auth): authentication of user and user creation and access jwt a…
Browse files Browse the repository at this point in the history
…ccess token added to auth routes.
  • Loading branch information
sinasezza committed Jul 25, 2024
1 parent 1e3dec5 commit 4ef094e
Show file tree
Hide file tree
Showing 14 changed files with 444 additions and 91 deletions.
150 changes: 150 additions & 0 deletions chatApp/config/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# auth.py
from collections.abc import Mapping
from datetime import UTC, datetime, timedelta
from typing import Any, Optional

from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from motor.motor_asyncio import AsyncIOMotorCollection
from passlib.context import CryptContext

from chatApp.config.config import get_settings
from chatApp.config.database import mongo_db
from chatApp.config.logs import logger
from chatApp.models.user import User
from chatApp.utils.exceptions import credentials_exception

settings = get_settings()

# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")

# JWT settings
SECRET_KEY = settings.jwt_secret_key.get_secret_value()
ALGORITHM = settings.jwt_algorithm
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes


def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify if the provided password matches the stored hashed password.
:param plain_password: The plain text password.
:param hashed_password: The hashed password stored in the database.
:return: True if passwords match, otherwise False.
"""
return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
"""
Hash the given password using the password hashing context.
:param password: The plain text password to hash.
:return: The hashed password.
"""
return pwd_context.hash(password)


def create_access_token(
data: dict[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT access token with a specified expiration.
:param data: The data to encode into the token.
:param expires_delta: Optional expiration time delta for the token.
:return: The encoded JWT token as a string.
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})

encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def parse_access_token(token: str) -> dict[str, Any]:
"""
Parse and validate the given JWT token, returning its payload.
:param token: The JWT token to parse.
:return: The payload data from the token.
:raises credentials_exception: If the token is invalid or cannot be decoded.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError as e:
logger.error(f"JWT error: {e}") # Log the error for debugging purposes
raise credentials_exception


def get_users_collection() -> AsyncIOMotorCollection:
"""
Retrieve the users collection from the MongoDB database.
:return: The users collection instance.
:raises RuntimeError: If the users collection is not initialized.
"""
users_collection = mongo_db.users_collection
if users_collection is None:
raise RuntimeError("Users collection is not initialized.")
return users_collection


async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
"""
Retrieve the current user from the database using the provided JWT token.
:param token: The JWT token used for authentication.
:return: The User object representing the authenticated user.
:raises credentials_exception: If the user cannot be found or the token is invalid.
"""
# Parse the token to get the payload
payload = parse_access_token(token)
username: Optional[str] = payload.get("sub")

if username is None:
logger.error("Username is missing in the token payload.")
raise credentials_exception

# Fetch the users_collection within the request scope
users_collection = get_users_collection()

# Properly type the result of the find_one query
user: Optional[Mapping[str, Any]] = await users_collection.find_one(
{"username": username}
)

# Raise an exception if no user was found
if user is None:
logger.error(f"User with username {username} not found in database.")
raise credentials_exception

# Construct and return a User instance from the found document
return User(**user)


async def authenticate_user(username: str, password: str) -> Optional[User]:
# Fetch the users_collection within the request scope
users_collection = get_users_collection()

# Properly type the result of the find_one query
user: Optional[Mapping[str, Any]] = await users_collection.find_one(
{"username": username}
)

# Return None if no user was found or if password verification fails
if user is None or not verify_password(password, user["hashed_password"]):
return None

# Construct and return a User instance from the found document
return User(**user)
20 changes: 14 additions & 6 deletions chatApp/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

from dotenv import load_dotenv
from pydantic import Field
from pydantic import Field, SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict

BASE_DIR = Path(__file__).resolve().parent.parent.parent
Expand All @@ -13,12 +13,15 @@ class Settings(BaseSettings):
debug: bool = Field(default=False)

# database settings
database_url: str = "mongodb://localhost:27017"
database_name: str = "chat_app"
database_url: str = Field(default="mongodb://localhost:27017")
database_name: str = Field(default="chat_app")
max_pool_size: int = 10
min_pool_size: int = 1

# jwt settings
jwt_secret_key: str = "your-secret-key"
jwt_secret_key: SecretStr = Field(default="your-secret-key")
jwt_algorithm: str = Field(default="HS256")
access_token_expire_minutes: int = Field(default=1440)

# CORS settings
cors_allow_origins: list[str] = Field(default=["*"])
Expand All @@ -28,16 +31,21 @@ class Settings(BaseSettings):

# logs settings
log_level: str = Field(default="INFO")
log_file_path: str = Field(default=str(BASE_DIR / "logs/app.log"))
log_file_path: Path = Field(default=BASE_DIR / "logs/app.log")
log_max_bytes: int = Field(default=1048576) # 1 MB
log_backup_count: int = Field(default=3)

# upload settings
upload_dir: str = Field(default=str(BASE_DIR / "uploads"))
upload_dir: Path = Field(default=BASE_DIR / "uploads")
max_upload_size: int = Field(default=(5 * 1024 * 1024))

model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
self.upload_dir.mkdir(parents=True, exist_ok=True)


@lru_cache
def get_settings() -> Settings:
Expand Down
67 changes: 39 additions & 28 deletions chatApp/config/database.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,54 @@
import logging
from typing import Optional

from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.collection import Collection
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorCollection,
AsyncIOMotorDatabase,
)

from .config import get_settings

logger = logging.getLogger(__name__)
settings = get_settings()


# Singleton Pattern for connecting to the database
class MongoDB:
def __init__(self):
def __init__(self) -> None:
self.db_client: Optional[AsyncIOMotorClient] = None
self.db = None
self.users_collection: Optional[Collection] = None
self.messages_collection: Optional[Collection] = None
self.rooms_collection: Optional[Collection] = None

async def connect_to_mongodb(self):
self.db_client = AsyncIOMotorClient(
settings.database_url, maxPoolSize=10, minPoolSize=1
)
self.db = self.db_client[settings.database_name]
print("Connected to MongoDB")

# Initialize collections and create indexes
self.users_collection = self.db.get_collection("users")
self.messages_collection = self.db.get_collection("messages")
self.rooms_collection = self.db.get_collection("rooms")

# # Create indexes
# self.users_collection.create_index([("_id", 1)], unique=True)
# self.messages_collection.create_index([("_id", 1)], unique=True)
# self.rooms_collection.create_index([("_id", 1)], unique=True)

async def close_mongodb_connection(self):
self.db: Optional[AsyncIOMotorDatabase] = None
self.users_collection: Optional[AsyncIOMotorCollection] = None
self.messages_collection: Optional[AsyncIOMotorCollection] = None
self.rooms_collection: Optional[AsyncIOMotorCollection] = None

async def connect_to_mongodb(self) -> None:
try:
self.db_client = AsyncIOMotorClient(
settings.database_url,
maxPoolSize=settings.max_pool_size,
minPoolSize=settings.min_pool_size,
)

assert self.db_client is not None

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
self.db = self.db_client[settings.database_name]
assert self.db is not None

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

# Initialize collections
self.users_collection = self.db.get_collection("users")
self.messages_collection = self.db.get_collection("messages")
self.rooms_collection = self.db.get_collection("rooms")

# Ping the server to validate the connection
await self.db_client.admin.command("ismaster")
logger.info("Connected to MongoDB")
except Exception as e:
logger.error(f"Could not connect to MongoDB: {e}")
raise

async def close_mongodb_connection(self) -> None:
if self.db_client:
self.db_client.close()
print("Closed MongoDB connection")
logger.info("Closed MongoDB connection")


# Create a global instance of MongoDB
Expand Down
39 changes: 25 additions & 14 deletions chatApp/config/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,41 @@

settings = get_settings()


# Ensure the log directory exists
log_path = Path(BASE_DIR / settings.log_file_path)
log_path.parent.mkdir(parents=True, exist_ok=True)
try:
log_path.parent.mkdir(parents=True, exist_ok=True)
except Exception as e:
print(f"Error creating log directory: {e}")
raise

# Define logging format
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

# Configure the root logger
logging.basicConfig(
level=settings.log_level, # Set the logging level from env
format=LOG_FORMAT,
handlers=[
logging.StreamHandler(), # Log to console
RotatingFileHandler(
Path(BASE_DIR / settings.log_file_path),
maxBytes=settings.log_max_bytes,
backupCount=settings.log_backup_count,
), # Log to file with rotation
],
)
try:
logging.basicConfig(
level=settings.log_level.upper(), # Ensure it's uppercase
format=LOG_FORMAT,
handlers=[
logging.StreamHandler(), # Log to console
RotatingFileHandler(
log_path,
maxBytes=settings.log_max_bytes,
backupCount=settings.log_backup_count,
), # Log to file with rotation
],
)
except Exception as e:
print(f"Error setting up logging: {e}")
raise

# Get a logger instance
logger: Logger = logging.getLogger(__name__)

# Example log message to test configuration
logger.info("Logging configuration is set up.")


def get_logger(name: str) -> Logger:
return logging.getLogger(name)
20 changes: 9 additions & 11 deletions chatApp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,19 @@
settings = get_settings()


# Define lifespan event handlers
async def lifespan(app: FastAPI):
# On startup
# Define startup and shutdown event handlers
async def startup_event():
await mongo_db.connect_to_mongodb() # Use mongo_db instance

yield

# On shutdown
async def shutdown_event():
await mongo_db.close_mongodb_connection() # Use mongo_db instance


# Create a FastAPI app instance with lifespan events
app = FastAPI(lifespan=lifespan) # Pass lifespan as a parameter
# Create a FastAPI app instance
app = FastAPI(on_startup=[startup_event], on_shutdown=[shutdown_event])

# Configure CORS using settings
# Configure CORS using settings with explicit type annotations
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_allow_origins,
Expand All @@ -42,7 +40,7 @@ async def lifespan(app: FastAPI):


@app.get("/")
async def root():
async def root() -> dict[str, str]:
return {"message": "Welcome to the FastAPI Chat App"}


Expand All @@ -53,12 +51,12 @@ async def root():


@sio.event
async def connect(sid, environ):
async def connect(sid: str, environ: dict) -> None:
print(f"Client connected: {sid}")


@sio.event
async def disconnect(sid):
async def disconnect(sid: str) -> None:
print(f"Client disconnected: {sid}")


Expand Down
Loading

0 comments on commit 4ef094e

Please sign in to comment.