Skip to content

Commit

Permalink
refactor projects to use SQLmodel's select()
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Aug 9, 2024
1 parent 3247917 commit 3c3f1a7
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 43 deletions.
95 changes: 52 additions & 43 deletions offsets_db_api/routers/projects.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import datetime
import typing
from collections import defaultdict

from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from fastapi_cache.decorator import cache
from sqlalchemy import or_
from sqlalchemy.orm import contains_eager
from sqlmodel import Session
from sqlmodel import Session, col, select

from ..cache import CACHE_NAMESPACE
from ..database import get_session
from ..logging import get_logger
from ..models import Clip, ClipProject, PaginatedProjects, Project, ProjectWithClips
from ..query_helpers import apply_filters, apply_sorting, handle_pagination
from ..schemas import Pagination, Registries
from ..security import check_api_key
from offsets_db_api.cache import CACHE_NAMESPACE
from offsets_db_api.database import get_session
from offsets_db_api.logging import get_logger
from offsets_db_api.models import Clip, ClipProject, PaginatedProjects, Project, ProjectWithClips
from offsets_db_api.schemas import Pagination, Registries
from offsets_db_api.security import check_api_key
from offsets_db_api.sql_helpers import apply_filters, apply_sorting, handle_pagination

router = APIRouter()
logger = get_logger()
Expand Down Expand Up @@ -55,18 +56,11 @@ async def get_projects(

logger.info(f'Getting projects: {request.url}')

query = (
session.query(Project, Clip)
.join(Project.clip_relationships, isouter=True)
.join(ClipProject.clip, isouter=True)
.options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip))
)

filters = [
('registry', registry, 'ilike', Project),
('country', country, 'ilike', Project),
('protocol', protocol, 'ANY', Project),
('category', category, 'ANY', Project),
('protocol', protocol, 'ALL', Project),
('category', category, 'ALL', Project),
('is_compliance', is_compliance, '==', Project),
('listed_at', listed_at_from, '>=', Project),
('listed_at', listed_at_to, '<=', Project),
Expand All @@ -75,40 +69,56 @@ async def get_projects(
('retired', retired_min, '>=', Project),
('retired', retired_max, '<=', Project),
]
statement = (
select(Project, Clip)
.join(ClipProject, col(ClipProject.project_id) == col(Project.project_id))
.join(Clip, col(Clip.id) == col(ClipProject.clip_id))
.options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip))
)

for attribute, values, operation, model in filters:
query = apply_filters(
query=query, model=model, attribute=attribute, values=values, operation=operation
statement = apply_filters(
statement=statement,
model=model,
attribute=attribute,
values=values,
operation=operation,
)

# Handle 'search' filter separately due to its unique logic
if search:
search_pattern = f'%{search}%'
query = query.filter(
or_(Project.project_id.ilike(search_pattern), Project.name.ilike(search_pattern))
statement = statement.where(
or_(
col(Project.project_id).ilike(search_pattern),
col(Project.name).ilike(search_pattern),
)
)

if sort:
query = apply_sorting(query=query, sort=sort, model=Project, primary_key='project_id')
statement = apply_sorting(
statement=statement, sort=sort, model=Project, primary_key='project_id'
)

total_entries, current_page, total_pages, next_page, results = handle_pagination(
query=query,
statement=statement,
primary_key=Project.project_id,
current_page=current_page,
per_page=per_page,
request=request,
session=session,
)

# Execute the query
project_clip_pairs = results
project_clip_pairs: typing.Iterable[Project, Clip] = results

# Group clips by project using a dictionary and project_id as the key
project_to_clips = defaultdict(list)
projects = {}
for project, clip in project_clip_pairs:
if project.project_id not in projects:
projects[project.project_id] = project
project_to_clips[project.project_id].append(clip)
p_id = project.project_id
if p_id not in projects:
projects[p_id] = project
project_to_clips[p_id].append(clip)

# Transform the dictionary into a list of projects with clips
projects_with_clips = []
Expand Down Expand Up @@ -145,25 +155,24 @@ async def get_project(
logger.info(f'Getting project: {request.url}')

# Start the query to get the project and related clips
project_with_clips = (
session.query(Project)
statement = (
select(Project)
.join(Project.clip_relationships, isouter=True)
.join(ClipProject.clip, isouter=True)
.join(col(ClipProject.clip), isouter=True)
.options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip))
.filter(Project.project_id == project_id)
.one_or_none()
.where(col(Project.project_id) == project_id)
)

if project_with_clips:
# Extract the Project and related Clips from the query result
project_data = project_with_clips.model_dump()
project_data['clips'] = [
clip_project.clip.model_dump()
for clip_project in project_with_clips.clip_relationships
if clip_project.clip
]
return project_data
else:
if not (project_with_clips := session.exec(statement).unique().one_or_none()):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f'project {project_id} not found'
)
# Extract the Project and related Clips from the query result
project = project_with_clips
project_data = project.model_dump()
project_data['clips'] = [
clip_project.clip.model_dump()
for clip_project in project.clip_relationships
if clip_project.clip
]
return project_data
192 changes: 192 additions & 0 deletions offsets_db_api/sql_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import datetime
import typing

import sqlmodel
from fastapi import HTTPException, Request
from sqlalchemy.dialects.postgresql import ARRAY
from sqlmodel import Session, and_, asc, desc, distinct, func, nullslast, or_, select

from offsets_db_api.logging import get_logger
from offsets_db_api.models import Clip, ClipProject, Credit, Project
from offsets_db_api.query_helpers import _generate_next_page_url
from offsets_db_api.schemas import Registries

logger = get_logger()


def apply_sorting(*, statement, sort: list[str], model, primary_key: str):
# Define valid column names
columns = [c.name for c in model.__table__.columns]

# Ensure that the primary key field is always included in the sort parameters list to ensure consistent pagination
if primary_key not in sort and f'-{primary_key}' not in sort and f'+{primary_key}' not in sort:
sort.append(primary_key)

for sort_param in sort:
sort_param = sort_param.strip()
# Check if sort_param starts with '-' for descending order
if sort_param.startswith('-'):
order = desc
field = sort_param[1:] # Remove the '-' from sort_param

elif sort_param.startswith('+'):
order = asc
field = sort_param[1:] # Remove the '+' from sort_param
else:
order = asc
field = sort_param

# Check if field is a valid column name
if field not in columns:
raise HTTPException(
status_code=400,
detail=f'Invalid sort field: {field}. Must be one of {columns}',
)

# Apply sorting to the statement
statement = statement.order_by(nullslast(order(getattr(model, field))))

return statement


def apply_filters(
*,
statement,
model: type[Credit | Project | Clip | ClipProject],
attribute: str,
values: list[str] | None | int | datetime.date | list[Registries],
operation: str,
):
"""
Apply filters to the statement based on operation type.
Supports 'ilike', '==', '>=', and '<=' operations.
Parameters
----------
statement: Select
SQLAlchemy Select statement
model: Credit | Project | Clip | ClipProject
SQLAlchemy model class
attribute: str
model attribute to apply filter on
values: list
list of values to filter with
operation: str
operation type to apply to the filter ('ilike', '==', '>=', '<=')
Returns
-------
statement: Select
updated SQLAlchemy Select statement
"""

if values is not None:
attr_type = getattr(model, attribute).type
is_array = isinstance(attr_type, ARRAY)
is_list = isinstance(values, list | tuple | set)

if is_array and is_list:
if operation == 'ALL':
statement = statement.where(
and_(*[getattr(model, attribute).op('@>')(f'{{{v}}}') for v in values])
)
else:
statement = statement.where(
or_(*[getattr(model, attribute).op('@>')(f'{{{v}}}') for v in values])
)

if operation == 'ilike':
statement = (
statement.where(or_(*[getattr(model, attribute).ilike(v) for v in values]))
if is_list
else statement.where(getattr(model, attribute).ilike(values))
)
elif operation == '==':
statement = (
statement.where(or_(*[getattr(model, attribute) == v for v in values]))
if is_list
else statement.where(getattr(model, attribute) == values)
)
elif operation == '>=':
statement = (
statement.where(or_(*[getattr(model, attribute) >= v for v in values]))
if is_list
else statement.where(getattr(model, attribute) >= values)
)
elif operation == '<=':
statement = (
statement.where(or_(*[getattr(model, attribute) <= v for v in values]))
if is_list
else statement.where(getattr(model, attribute) <= values)
)

return statement


def handle_pagination(
*,
statement: sqlmodel.sql.expression.Select,
primary_key,
current_page: int,
per_page: int,
request: Request,
session: Session,
) -> tuple[
int,
int,
int,
str | None,
typing.Iterable[Project | Clip | ClipProject | Credit],
]:
"""
Calculate total records, pages and next page URL for a given query.
Parameters
----------
statement: Select
SQLAlchemy Select statement
primary_key
Primary key field for distinct count
current_page: int
Current page number
per_page: int
Number of records per page
request: Request
FastAPI request instance
session: Session
SQLAlchemy session instance
Returns
-------
total_entries: int
Total records in query
total_pages: int
Total pages in query
next_page: Optional[str]
URL of next page
results: List[SQLModel]
Results for the current page
"""

pk_column = primary_key if isinstance(primary_key, str) else primary_key.key
count_query = select(
func.count(distinct(getattr(statement.selected_columns, pk_column)))
).select_from(statement.subquery())
total_entries = session.exec(count_query).one()

total_pages = (total_entries + per_page - 1) // per_page # ceil(total / per_page)

# Calculate the next page URL
next_page = None

if current_page < total_pages:
next_page = _generate_next_page_url(
request=request, current_page=current_page, per_page=per_page
)

# Get the results for the current page
paginated_statement = statement.offset((current_page - 1) * per_page).limit(per_page)
results = session.exec(paginated_statement).unique().all()

return total_entries, current_page, total_pages, next_page, results

0 comments on commit 3c3f1a7

Please sign in to comment.