Skip to content

Commit

Permalink
update credits endpoint to use select()
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Aug 16, 2024
1 parent 4693b62 commit af714bb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
29 changes: 18 additions & 11 deletions offsets_db_api/routers/credits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from fastapi import APIRouter, Depends, Query, Request
from fastapi_cache.decorator import cache
from sqlmodel import Session, or_
from sqlmodel import Session, col, or_, select

from offsets_db_api.cache import CACHE_NAMESPACE
from offsets_db_api.database import get_session
from offsets_db_api.log import get_logger
from offsets_db_api.models import Credit, PaginatedCredits, Project
from offsets_db_api.query_helpers import apply_filters, apply_sorting, handle_pagination
from offsets_db_api.schemas import Pagination, Registries
from offsets_db_api.security import check_api_key
from offsets_db_api.sql_helpers import apply_filters, apply_sorting, handle_pagination

router = APIRouter()
logger = get_logger()
Expand Down Expand Up @@ -49,8 +49,8 @@ async def get_credits(
logger.info(f'Getting credits: {request.url}')

# Outer join to get all credits, even if they don't have a project
query = session.query(Credit, Project.category).join(
Project, Credit.project_id == Project.project_id, isouter=True
statement = select(Credit, Project.category).join(
Project, col(Credit.project_id) == col(Project.project_id), isouter=True
)

filters = [
Expand All @@ -65,30 +65,37 @@ async def get_credits(

# Filter for project_id
if project_id:
# insert at the beginning of the list to ensure that it is applied first
filters.insert(0, ('project_id', project_id, '==', Project))

for attribute, values, operation, model in filters:
query = apply_filters(
query=query, model=model, attribute=attribute, values=values, operation=operation
statement = apply_filters(
statement=statement,
model=model,
attribute=attribute,
values=values,
operation=operation,
)

# Handle 'search' filter separately due to its unique logic
if search:
search_pattern = f'%{search}%'
query = query.filter(
or_(Project.project_id.ilike(search_pattern), Project.name.ilike(search_pattern))
statement = statement.where(
or_(
col(Project.project_id).ilike(search_pattern),
col(Project.name).ilike(search_pattern),
)
)

if sort:
query = apply_sorting(query=query, sort=sort, model=Credit, primary_key='id')
statement = apply_sorting(statement=statement, sort=sort, model=Credit, primary_key='id')

total_entries, current_page, total_pages, next_page, results = handle_pagination(
query=query,
statement=statement,
primary_key=Credit.id,
current_page=current_page,
per_page=per_page,
request=request,
session=session,
)

credits_with_category = [
Expand Down
12 changes: 6 additions & 6 deletions offsets_db_api/sql_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from fastapi import HTTPException, Request
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.sql.expression import ScalarSelect, Select as _Select
from sqlmodel import Session, and_, asc, desc, distinct, func, nullslast, or_, select
from sqlmodel.sql.expression import Select as _Select

from offsets_db_api.models import Clip, ClipProject, Credit, Project
from offsets_db_api.query_helpers import _generate_next_page_url
Expand All @@ -13,11 +13,11 @@

def apply_sorting(
*,
statement: _Select[typing.Any] | ScalarSelect[typing.Any],
statement: _Select[typing.Any],
sort: list[str],
model,
primary_key: str,
) -> _Select[typing.Any] | ScalarSelect[typing.Any]:
) -> _Select[typing.Any]:
# Define valid column names
columns = [c.name for c in model.__table__.columns]

Expand Down Expand Up @@ -54,12 +54,12 @@ def apply_sorting(

def apply_filters(
*,
statement: _Select[typing.Any] | ScalarSelect[typing.Any],
statement: _Select[typing.Any],
model: type[Credit | Project | Clip | ClipProject],
attribute: str,
values: list[str] | None | int | datetime.date | list[Registries],
operation: str,
) -> _Select[typing.Any] | ScalarSelect[typing.Any]:
) -> _Select[typing.Any]:
"""
Apply filters to the statement based on operation type.
Supports 'ilike', '==', '>=', and '<=' operations.
Expand Down Expand Up @@ -129,7 +129,7 @@ def apply_filters(

def handle_pagination(
*,
statement: _Select[typing.Any] | ScalarSelect[typing.Any],
statement: _Select[typing.Any],
primary_key,
current_page: int,
per_page: int,
Expand Down

0 comments on commit af714bb

Please sign in to comment.