Skip to content

Commit

Permalink
Support branching of query builder
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Nov 15, 2024
1 parent d056add commit 3b27031
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
10 changes: 10 additions & 0 deletions iceaxe/__tests__/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,13 @@ def test_select_multiple_typehints():
query = select((UserDemo, UserDemo.id, UserDemo.name))
if TYPE_CHECKING:
_: QueryBuilder[tuple[UserDemo, int, str], Literal["SELECT"]] = query


def test_allow_branching():
base_query = select(UserDemo)

query_1 = base_query.limit(1)
query_2 = base_query.limit(2)

assert query_1.limit_value == 1
assert query_2.limit_value == 2
30 changes: 30 additions & 0 deletions iceaxe/queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from copy import copy
from functools import wraps
from typing import Any, Generic, Literal, Type, TypeVar, TypeVarTuple, cast, overload

from iceaxe.base import (
Expand Down Expand Up @@ -53,6 +55,22 @@
OrderDirection = Literal["ASC", "DESC"]


def allow_branching(fn):
"""
Allows query method modifiers to implement their logic as if `self` is being
modified, but in the background we'll actually return a new instance of the
query builder to allow for branching of the same underlying query.
"""

@wraps(fn)
def new_fn(self, *args, **kwargs):
self = copy(self)
return fn(self, *args, **kwargs)

return new_fn


class QueryBuilder(Generic[P, QueryType]):
"""
The QueryBuilder owns all construction of the SQL string given
Expand Down Expand Up @@ -118,6 +136,7 @@ def select(
self, fields: tuple[T | Type[T], T2 | Type[T2], T3 | Type[T3], *Ts]
) -> QueryBuilder[tuple[T, T2, T3, *Ts], Literal["SELECT"]]: ...

@allow_branching
def select(
self,
fields: (
Expand Down Expand Up @@ -212,6 +231,7 @@ def _select_inner(
self.select_raw.append(field)
self.select_aggregate_count += 1

@allow_branching
def update(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["UPDATE"]]:
"""
Creates a new update query for the given model. Returns the same
Expand All @@ -222,6 +242,7 @@ def update(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["UPDATE"]
self.main_model = model
return self # type: ignore

@allow_branching
def delete(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["DELETE"]]:
"""
Creates a new delete query for the given model. Returns the same
Expand All @@ -232,6 +253,7 @@ def delete(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["DELETE"]
self.main_model = model
return self # type: ignore

@allow_branching
def where(self, *conditions: bool):
"""
Adds a where condition to the query. The conditions are combined with
Expand All @@ -250,6 +272,7 @@ def where(self, *conditions: bool):
self.where_conditions += validated_comparisons
return self

@allow_branching
def order_by(self, field: Any, direction: OrderDirection = "ASC"):
"""
Adds an order by clause to the query. The field must be a column.
Expand All @@ -265,6 +288,7 @@ def order_by(self, field: Any, direction: OrderDirection = "ASC"):
self.order_by_clauses.append(f"{field_token} {direction}")
return self

@allow_branching
def join(self, table: Type[TableBase], on: bool, join_type: JoinType = "INNER"):
"""
Adds a join clause to the query. The `on` parameter should be a comparison
Expand All @@ -289,6 +313,7 @@ def join(self, table: Type[TableBase], on: bool, join_type: JoinType = "INNER"):
self.join_clauses.append(join_sql)
return self

@allow_branching
def set(self, column: T, value: T | None):
"""
Sets a column to a specific value in an update query.
Expand All @@ -300,6 +325,7 @@ def set(self, column: T, value: T | None):
self.update_values.append((column, value))
return self

@allow_branching
def limit(self, value: int):
"""
Limit the number of rows returned by the query. Useful in pagination
Expand All @@ -309,6 +335,7 @@ def limit(self, value: int):
self.limit_value = value
return self

@allow_branching
def offset(self, value: int):
"""
Offset the number of rows returned by the query.
Expand All @@ -317,6 +344,7 @@ def offset(self, value: int):
self.offset_value = value
return self

@allow_branching
def group_by(self, *fields: Any):
"""
Groups the results of the query by the given fields. This allows
Expand All @@ -334,6 +362,7 @@ def group_by(self, *fields: Any):
self.group_by_fields = valid_fields
return self

@allow_branching
def having(self, *conditions: bool):
"""
Require the result of an aggregation query like func.sum(MyTable.column)
Expand All @@ -351,6 +380,7 @@ def having(self, *conditions: bool):
self.having_conditions += valid_conditions
return self

@allow_branching
def text(self, query: str, *variables: Any):
"""
Override the ORM builder and use a raw SQL query instead.
Expand Down

0 comments on commit 3b27031

Please sign in to comment.