-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from sinasezza/sinasezza
Sinasezza
- Loading branch information
Showing
11 changed files
with
885 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,66 @@ | ||
import socketio | ||
import uvicorn | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from fastapi.middleware.trustedhost import TrustedHostMiddleware | ||
|
||
from chatApp.config.config import get_settings | ||
from chatApp.config.database import mongo_db | ||
from chatApp.middlewares.request_limit import RequestLimitMiddleware | ||
from chatApp.routes import auth, chat, user | ||
from chatApp.sockets import sio_app | ||
|
||
# Fetch settings | ||
settings = get_settings() | ||
|
||
|
||
# Define startup and shutdown event handlers | ||
async def startup_event(): | ||
await mongo_db.connect_to_mongodb() # Use mongo_db instance | ||
await mongo_db.connect_to_mongodb() | ||
|
||
|
||
async def shutdown_event(): | ||
await mongo_db.close_mongodb_connection() # Use mongo_db instance | ||
await mongo_db.close_mongodb_connection() | ||
|
||
|
||
# Create a FastAPI app instance | ||
app = FastAPI(on_startup=[startup_event], on_shutdown=[shutdown_event]) | ||
app = FastAPI( | ||
title="FastAPI Chat App", | ||
description="A chat application built with FastAPI and socket.io", | ||
version="1.0.0", | ||
on_startup=[startup_event], | ||
on_shutdown=[shutdown_event], | ||
) | ||
|
||
### Add middlewares ### | ||
|
||
# Configure CORS using settings with explicit type annotations | ||
# Configure CORS using settings | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=settings.cors_allow_origins, | ||
allow_credentials=settings.cors_allow_credentials, | ||
allow_methods=settings.cors_allow_methods, | ||
allow_headers=settings.cors_allow_headers, | ||
) | ||
app.add_middleware(RequestLimitMiddleware, max_requests=10, window_seconds=1) | ||
app.add_middleware( | ||
TrustedHostMiddleware, | ||
allowed_hosts=settings.trusted_hosts, | ||
) | ||
|
||
# Create a Socket.IO server | ||
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") | ||
|
||
# Wrap with ASGI application | ||
socket_app = socketio.ASGIApp(sio, app) | ||
|
||
|
||
@app.get("/") | ||
async def root() -> dict[str, str]: | ||
return {"message": "Welcome to the FastAPI Chat App"} | ||
|
||
|
||
# Include routers | ||
# Include your routers for API endpoints | ||
app.include_router(auth.router, prefix="/auth") | ||
app.include_router(chat.router, prefix="/chat") | ||
app.include_router(user.router, prefix="/user") | ||
|
||
|
||
@sio.event | ||
async def connect(sid: str, environ: dict) -> None: | ||
print(f"Client connected: {sid}") | ||
@app.get("/") | ||
async def root() -> dict[str, str]: | ||
return {"message": "Welcome to the FastAPI Chat App"} | ||
|
||
|
||
@sio.event | ||
async def disconnect(sid: str) -> None: | ||
print(f"Client disconnected: {sid}") | ||
# Mount socket.io app | ||
app.mount("/", app=sio_app) | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run("main:socket_app", host="0.0.0.0", port=8000, reload=True) | ||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import time | ||
from collections import defaultdict | ||
|
||
from fastapi import Request, Response | ||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | ||
from starlette.types import ASGIApp | ||
|
||
from chatApp.config.logs import logger # Import your custom logger | ||
|
||
|
||
class RequestLimitMiddleware(BaseHTTPMiddleware): | ||
def __init__(self, app: ASGIApp, max_requests: int = 4, window_seconds: int = 1): | ||
super().__init__(app) | ||
self.max_requests = max_requests | ||
self.window_seconds = window_seconds | ||
self.request_history: dict[str, tuple[int, float]] = defaultdict( | ||
lambda: (0, 0.0) | ||
) | ||
|
||
async def dispatch( | ||
self, request: Request, call_next: RequestResponseEndpoint | ||
) -> Response: | ||
client_ip = request.client.host if request.client else "unknown" | ||
current_time = time.time() | ||
|
||
# Log request start time | ||
logger.info(f"Received request from {client_ip} at {current_time}") | ||
|
||
# Get the request count and last request time for this IP | ||
count, last_request_time = self.request_history[client_ip] | ||
|
||
# If it's been longer than the window, reset the count | ||
if current_time - last_request_time > self.window_seconds: | ||
count = 0 | ||
|
||
# Increment the count | ||
count += 1 | ||
|
||
# Update the request history | ||
self.request_history[client_ip] = (count, current_time) | ||
|
||
# If the count exceeds the limit, return a 429 Too Many Requests response | ||
if count > self.max_requests: | ||
logger.warning(f"Too many requests from {client_ip} - Count: {count}") | ||
return Response("Too many requests", status_code=429) | ||
|
||
# Measure start time of request processing | ||
start_time = time.time() | ||
|
||
# Process the request | ||
response = await call_next(request) | ||
|
||
# Calculate process time | ||
process_time = time.time() - start_time | ||
|
||
# Log the request processing time | ||
logger.info(f"Processed request from {client_ip} in {process_time:.4f} seconds") | ||
|
||
# Add X-Process-Time header to the response | ||
response.headers["X-Process-Time"] = str(process_time) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import socketio | ||
|
||
from chatApp.config.config import get_settings | ||
|
||
settings = get_settings() | ||
|
||
# Define the Socket.IO server | ||
sio_server = socketio.AsyncServer( | ||
async_mode="asgi", | ||
cors_allowed_origins=settings.cors_allow_origins, | ||
) | ||
|
||
# Create the ASGI app using the defined server | ||
sio_app = socketio.ASGIApp( | ||
socketio_server=sio_server, socketio_path="/", other_asgi_app="main:app" | ||
) | ||
|
||
|
||
# Event handlers | ||
@sio_server.event | ||
async def connect(sid: str, environ: dict, auth: dict) -> None: | ||
print(f"Client connected: {sid}") | ||
|
||
|
||
@sio_server.event | ||
async def disconnect(sid: str) -> None: | ||
print(f"Client disconnected: {sid}") |
Oops, something went wrong.