Skip to content

Commit

Permalink
Revert "refactor routers.projects to use SQLmodel's select() (#115)"
Browse files Browse the repository at this point in the history
This reverts commit 88bf820.
  • Loading branch information
andersy005 committed Aug 9, 2024
1 parent 88bf820 commit 294322e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 242 deletions.
92 changes: 42 additions & 50 deletions offsets_db_api/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from fastapi_cache.decorator import cache
from sqlalchemy import or_
from sqlalchemy.orm import contains_eager
from sqlmodel import Session, col, select
from sqlmodel import Session

from offsets_db_api.cache import CACHE_NAMESPACE
from offsets_db_api.database import get_session
from offsets_db_api.logging import get_logger
from offsets_db_api.models import Clip, ClipProject, PaginatedProjects, Project, ProjectWithClips
from offsets_db_api.schemas import Pagination, Registries
from offsets_db_api.security import check_api_key
from offsets_db_api.sql_helpers import apply_filters, apply_sorting, handle_pagination
from ..cache import CACHE_NAMESPACE
from ..database import get_session
from ..logging import get_logger
from ..models import Clip, ClipProject, PaginatedProjects, Project, ProjectWithClips
from ..query_helpers import apply_filters, apply_sorting, handle_pagination
from ..schemas import Pagination, Registries
from ..security import check_api_key

router = APIRouter()
logger = get_logger()
Expand Down Expand Up @@ -55,11 +55,18 @@ async def get_projects(

logger.info(f'Getting projects: {request.url}')

query = (
session.query(Project, Clip)
.join(Project.clip_relationships, isouter=True)
.join(ClipProject.clip, isouter=True)
.options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip))
)

filters = [
('registry', registry, 'ilike', Project),
('country', country, 'ilike', Project),
('protocol', protocol, 'ALL', Project),
('category', category, 'ALL', Project),
('protocol', protocol, 'ANY', Project),
('category', category, 'ANY', Project),
('is_compliance', is_compliance, '==', Project),
('listed_at', listed_at_from, '>=', Project),
('listed_at', listed_at_to, '<=', Project),
Expand All @@ -68,43 +75,28 @@ async def get_projects(
('retired', retired_min, '>=', Project),
('retired', retired_max, '<=', Project),
]
statement = (
select(Project, Clip)
.join(ClipProject, col(ClipProject.project_id) == col(Project.project_id))
.join(Clip, col(Clip.id) == col(ClipProject.clip_id))
.options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip))
)

for attribute, values, operation, model in filters:
statement = apply_filters(
statement=statement,
model=model,
attribute=attribute,
values=values,
operation=operation,
query = apply_filters(
query=query, model=model, attribute=attribute, values=values, operation=operation
)

# Handle 'search' filter separately due to its unique logic
if search:
search_pattern = f'%{search}%'
statement = statement.where(
or_(
col(Project.project_id).ilike(search_pattern),
col(Project.name).ilike(search_pattern),
)
query = query.filter(
or_(Project.project_id.ilike(search_pattern), Project.name.ilike(search_pattern))
)

if sort:
statement = apply_sorting(
statement=statement, sort=sort, model=Project, primary_key='project_id'
)
query = apply_sorting(query=query, sort=sort, model=Project, primary_key='project_id')

total_entries, current_page, total_pages, next_page, results = handle_pagination(
statement=statement,
query=query,
primary_key=Project.project_id,
current_page=current_page,
per_page=per_page,
request=request,
session=session,
)

# Execute the query
Expand All @@ -114,10 +106,9 @@ async def get_projects(
project_to_clips = defaultdict(list)
projects = {}
for project, clip in project_clip_pairs:
p_id = project.project_id
if p_id not in projects:
projects[p_id] = project
project_to_clips[p_id].append(clip)
if project.project_id not in projects:
projects[project.project_id] = project
project_to_clips[project.project_id].append(clip)

# Transform the dictionary into a list of projects with clips
projects_with_clips = []
Expand Down Expand Up @@ -154,24 +145,25 @@ async def get_project(
logger.info(f'Getting project: {request.url}')

# Start the query to get the project and related clips
statement = (
select(Project)
project_with_clips = (
session.query(Project)
.join(Project.clip_relationships, isouter=True)
.join(col(ClipProject.clip), isouter=True)
.join(ClipProject.clip, isouter=True)
.options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip))
.where(col(Project.project_id) == project_id)
.filter(Project.project_id == project_id)
.one_or_none()
)

if not (project_with_clips := session.exec(statement).unique().one_or_none()):
if project_with_clips:
# Extract the Project and related Clips from the query result
project_data = project_with_clips.model_dump()
project_data['clips'] = [
clip_project.clip.model_dump()
for clip_project in project_with_clips.clip_relationships
if clip_project.clip
]
return project_data
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f'project {project_id} not found'
)
# Extract the Project and related Clips from the query result
project = project_with_clips
project_data = project.model_dump()
project_data['clips'] = [
clip_project.clip.model_dump()
for clip_project in project.clip_relationships
if clip_project.clip
]
return project_data
192 changes: 0 additions & 192 deletions offsets_db_api/sql_helpers.py

This file was deleted.

0 comments on commit 294322e

Please sign in to comment.