Skip to content

Commit

Permalink
🔧 Update (endpoint): Check health status of database
Browse files Browse the repository at this point in the history
  • Loading branch information
jpcadena committed Oct 6, 2024
1 parent 399b5db commit 3151351
Show file tree
Hide file tree
Showing 5 changed files with 1,929 additions and 195 deletions.
25 changes: 13 additions & 12 deletions app/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

from fastapi import FastAPI

from app.api.deps import RedisConnectionManager
# from app.api.deps import RedisConnectionManager
from app.config.config import get_auth_settings, get_init_settings, get_settings
from app.crud.user import get_user_repository
from app.db.init_db import init_db
from app.services.infrastructure.ip_blacklist import get_ip_blacklist_service

# from app.services.infrastructure.ip_blacklist import get_ip_blacklist_service

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,16 +44,16 @@ async def lifespan(application: FastAPI) -> AsyncGenerator[Any, None]:
)
logger.info("Database initialized.")

redis_manager: RedisConnectionManager = RedisConnectionManager(
application.state.auth_settings
)
async with redis_manager.connection() as connection:
application.state.redis_connection = connection
application.state.ip_blacklist_service = get_ip_blacklist_service(
connection, application.state.auth_settings
)
logger.info("Redis connection established.")
yield
# redis_manager: RedisConnectionManager = RedisConnectionManager(
# application.state.auth_settings
# )
# async with redis_manager.connection() as connection:
# application.state.redis_connection = connection
# application.state.ip_blacklist_service = get_ip_blacklist_service(
# connection, application.state.auth_settings
# )
# logger.info("Redis connection established.")
yield
except Exception as exc:
logger.error(f"Error during application startup: {exc}")
raise
Expand Down
58 changes: 57 additions & 1 deletion app/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
"""

import logging
from typing import AsyncGenerator

from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
async_sessionmaker, create_async_engine,
)

from app.config.config import sql_database_setting
Expand All @@ -18,6 +21,12 @@
async_engine: AsyncEngine = create_async_engine(
url, pool_pre_ping=True, future=True, echo=True
)
AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
autoflush=False,
expire_on_commit=False,
class_=AsyncSession,
)


@with_logging
Expand All @@ -32,3 +41,50 @@ async def get_session() -> AsyncSession:
bind=async_engine, expire_on_commit=False
) as session:
return session


async def get_session_generator() -> AsyncGenerator[AsyncSession, None]:
"""
Get an asynchronous session to the database as a generator
:yield: Async session for database connection
:rtype: AsyncGenerator[AsyncSession, None]
"""
async with AsyncSessionLocal() as session:
try:
yield session
except Exception as e:
logger.error("Session rollback because of exception: %s", e)
await session.rollback()
raise
finally:
await session.close()


async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Get the database session as a context manager from generator
:return: The session generated
:rtype: AsyncGenerator[AsyncSession, None]
"""
async for session in get_session_generator():
yield session


async def check_db_health(session: AsyncSession) -> bool:
"""
Check the health of the database connection.
:param session: The SQLAlchemy asynchronous session object used to
interact with the database.
:type session: AsyncSession
:returns: True if the database connection is healthy, False otherwise.
:rtype: bool
"""
try:
await session.execute(text("SELECT 1"))
return True
except SQLAlchemyError as e:
logger.error(f"Database connection error: {e}")
return False
27 changes: 27 additions & 0 deletions app/schemas/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, cast
from uuid import UUID, uuid4

from pydantic import PositiveInt
from pydantic.config import JsonDict
from pydantic_extra_types.phone_numbers import PhoneNumber

Expand Down Expand Up @@ -243,3 +244,29 @@ def custom_serializer(my_dict: dict[str, Any]) -> dict[str, Any]:
elif isinstance(value, dict):
custom_serializer(value)
return my_dict


health_example: dict[PositiveInt | str, dict[str, Any]] | None = {
200: {
"content": {
"application/json": {
"example": [
{
"status": "healthy",
},
],
},
},
},
503: {
"content": {
"application/json": {
"example": [
{
"status": "unhealthy",
},
],
},
},
},
}
42 changes: 31 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@

import logging
from functools import partial
from typing import Annotated

import uvicorn
from fastapi import FastAPI, status
from fastapi import Depends, FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.responses import ORJSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from pydantic import PositiveInt
from sqlalchemy.ext.asyncio import AsyncSession

from app.api.api_v1.api import api_router
from app.config.config import auth_setting, init_setting, setting
from app.core import logging_config
from app.core.lifecycle import lifespan
from app.db.session import check_db_health, get_db_session
from app.middlewares.blacklist_token import blacklist_middleware
from app.middlewares.ip_blacklist import IPBlacklistMiddleware
from app.middlewares.rate_limiter import RateLimiterMiddleware

# from app.middlewares.ip_blacklist import IPBlacklistMiddleware
# from app.middlewares.rate_limiter import RateLimiterMiddleware
from app.middlewares.security_headers import SecurityHeadersMiddleware
from app.schemas.schemas import health_example
from app.utils.files_utils.openapi_utils import (
custom_generate_unique_id,
custom_openapi,
Expand All @@ -39,8 +45,8 @@
)
app.openapi = partial(custom_openapi, app) # type: ignore
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RateLimiterMiddleware) # type: ignore
app.add_middleware(IPBlacklistMiddleware) # type: ignore
# app.add_middleware(RateLimiterMiddleware) # type: ignore
# app.add_middleware(IPBlacklistMiddleware) # type: ignore
app.add_middleware(
CORSMiddleware,
allow_origins=setting.BACKEND_CORS_ORIGINS,
Expand Down Expand Up @@ -73,15 +79,29 @@ async def redirect_to_docs() -> RedirectResponse:
return RedirectResponse("/docs")


@app.get("/health", response_class=JSONResponse)
async def check_health() -> JSONResponse:
@app.get(
"/health",
responses=health_example,
)
async def check_health(
session: Annotated[AsyncSession, Depends(get_db_session)],
) -> ORJSONResponse:
"""
Check the health of the application backend.
## Response:
- `return:` **The JSON response**
- `rtype:` **JSONResponse**
- `return:` **The ORJSON response**
- `rtype:` **ORJSONResponse**
\f
"""
return JSONResponse({"status": "healthy"})
health_status: dict[str, str] = {
"status": "healthy",
}
status_code: PositiveInt = status.HTTP_200_OK
if not await check_db_health(session):
health_status["status"] = "unhealthy"
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return ORJSONResponse(health_status, status_code=status_code)


if __name__ == "__main__":
Expand Down
1,972 changes: 1,801 additions & 171 deletions openapi.json

Large diffs are not rendered by default.

0 comments on commit 3151351

Please sign in to comment.