Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store settings in request #26

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/climsoft_api/api/statistics/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from climsoft_api.db import engine
from climsoft_api.utils.response import get_error_response
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from opencdms.models.climsoft.v4_1_1_core import TARGET_TABLES
from sqlalchemy import text
from sqlalchemy.orm import Session
from climsoft_api.api import deps


router = APIRouter()
Expand Down Expand Up @@ -47,16 +49,16 @@ class ClimsoftTables(str, enum.Enum):
"/statistics",
)
# async def get_statistics_for_all_climsoft_tables(table: ClimsoftTables):
async def get_statistics_for_all_climsoft_tables():
async def get_statistics_for_all_climsoft_tables(
db_session: Session = Depends(deps.get_session)
):
try:
stats = {}
for _table in TARGET_TABLES:
with engine.connect() as connection:
result = connection.execute(
text(f"SELECT COUNT(*) as count FROM {_table};")
)
count = [row['count'] for row in result][0]
stats[_table] = count
result = db_session.execute(
text(f"SELECT COUNT(*) as count FROM {_table};")
)
stats[_table] = [row['count'] for row in result][0]
return stats
except TypeError:
raise
Expand Down
13 changes: 9 additions & 4 deletions src/climsoft_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware
from climsoft_api.api import api_routers

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
# load controllers


def get_app():
def get_app(config=None):
app = FastAPI(docs_url="/")
app.add_middleware(BaseHTTPMiddleware, dispatch=LocalizationMiddleware())
if settings.MOUNT_STATIC:
Expand All @@ -28,6 +30,9 @@ def get_app():

@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
if config is not None and config.get("DATABASE_URI"):
engine: Engine = create_engine(config.get("DATABASE_URI"))
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
try:
request.state.get_session = SessionLocal
response = await call_next(request)
Expand All @@ -38,8 +43,8 @@ async def db_session_middleware(request: Request, call_next):
return app


def get_app_with_routers():
app = get_app()
def get_app_with_routers(config=None):
app = get_app(config)
for router in api_routers:
app.include_router(**router.dict())

Expand Down