Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add func-based date modifiers #40

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions iceaxe/__tests__/conf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ class EmployeeMetadata(TableBase):
additional_info: dict[str, Any] = Field(is_json=True)


class FunctionTestModel(TableBase):
id: int = Field(primary_key=True, default=None)
balance: float
created_at: str
birth_date: str
start_date: str
end_date: str
year: int
month: int
day: int
hour: int
minute: int
second: int
years: int
months: int
days: int
weeks: int
hours: int
minutes: int
seconds: int
name: str
balance_str: str
timestamp_str: str


@contextmanager
def run_profile(request):
TESTS_ROOT = Path.cwd()
Expand Down
181 changes: 180 additions & 1 deletion iceaxe/__tests__/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from enum import Enum, IntEnum, StrEnum
from typing import TYPE_CHECKING, Literal

import pytest

from iceaxe.__tests__.conf_models import ArtifactDemo, UserDemo
from iceaxe.__tests__.conf_models import ArtifactDemo, FunctionTestModel, UserDemo
from iceaxe.functions import func
from iceaxe.queries import QueryBuilder, and_, or_, select


class UserStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
PENDING = "pending"


def test_select():
new_query = QueryBuilder().select(UserDemo)
assert new_query.build() == (
Expand Down Expand Up @@ -176,6 +183,146 @@ def test_function_distinct():
)


def test_function_abs():
new_query = QueryBuilder().select(func.abs(FunctionTestModel.balance))
assert new_query.build() == (
'SELECT abs("functiontestmodel"."balance") AS aggregate_0 FROM "functiontestmodel"',
[],
)


def test_function_date_trunc():
new_query = QueryBuilder().select(
func.date_trunc("month", FunctionTestModel.created_at)
)
assert new_query.build() == (
'SELECT date_trunc(\'month\', "functiontestmodel"."created_at") AS aggregate_0 FROM "functiontestmodel"',
[],
)


def test_function_date_part():
new_query = QueryBuilder().select(
func.date_part("year", FunctionTestModel.created_at)
)
assert new_query.build() == (
'SELECT date_part(\'year\', "functiontestmodel"."created_at") AS aggregate_0 FROM "functiontestmodel"',
[],
)


def test_function_extract():
new_query = QueryBuilder().select(
func.extract("month", FunctionTestModel.created_at)
)
assert new_query.build() == (
'SELECT extract(month from "functiontestmodel"."created_at") AS aggregate_0 FROM "functiontestmodel"',
[],
)


def test_function_age():
# Test age with single argument
new_query = QueryBuilder().select(func.age(FunctionTestModel.birth_date))
assert new_query.build() == (
'SELECT age("functiontestmodel"."birth_date") AS aggregate_0 FROM "functiontestmodel"',
[],
)

# Test age with two arguments
new_query = QueryBuilder().select(
func.age(FunctionTestModel.end_date, FunctionTestModel.start_date)
)
assert new_query.build() == (
'SELECT age("functiontestmodel"."end_date", "functiontestmodel"."start_date") AS aggregate_0 FROM "functiontestmodel"',
[],
)


def test_function_date():
new_query = QueryBuilder().select(func.date(FunctionTestModel.created_at))
assert new_query.build() == (
'SELECT date("functiontestmodel"."created_at") AS aggregate_0 FROM "functiontestmodel"',
[],
)


def test_function_transformations():
# Test string functions
new_query = QueryBuilder().select(
(
func.lower(FunctionTestModel.name),
func.upper(FunctionTestModel.name),
func.length(FunctionTestModel.name),
func.trim(FunctionTestModel.name),
func.substring(FunctionTestModel.name, 1, 3),
)
)
assert new_query.build() == (
'SELECT lower("functiontestmodel"."name") AS aggregate_0, '
'upper("functiontestmodel"."name") AS aggregate_1, '
'length("functiontestmodel"."name") AS aggregate_2, '
'trim("functiontestmodel"."name") AS aggregate_3, '
'substring("functiontestmodel"."name" from 1 for 3) AS aggregate_4 '
'FROM "functiontestmodel"',
[],
)

# Test mathematical functions
new_query = QueryBuilder().select(
(
func.round(FunctionTestModel.balance),
func.ceil(FunctionTestModel.balance),
func.floor(FunctionTestModel.balance),
func.power(FunctionTestModel.balance, 2),
func.sqrt(FunctionTestModel.balance),
)
)
assert new_query.build() == (
'SELECT round("functiontestmodel"."balance") AS aggregate_0, '
'ceil("functiontestmodel"."balance") AS aggregate_1, '
'floor("functiontestmodel"."balance") AS aggregate_2, '
'power("functiontestmodel"."balance", 2) AS aggregate_3, '
'sqrt("functiontestmodel"."balance") AS aggregate_4 '
'FROM "functiontestmodel"',
[],
)

# Test aggregate functions
new_query = QueryBuilder().select(
(
func.array_agg(FunctionTestModel.name),
func.string_agg(FunctionTestModel.name, ","),
)
)
assert new_query.build() == (
'SELECT array_agg("functiontestmodel"."name") AS aggregate_0, '
'string_agg("functiontestmodel"."name", \',\') AS aggregate_1 '
'FROM "functiontestmodel"',
[],
)

# Test type conversion functions
new_query = QueryBuilder().select(
(
func.cast(FunctionTestModel.balance, int),
func.cast(FunctionTestModel.name, UserStatus),
func.to_char(FunctionTestModel.created_at, "YYYY-MM-DD"),
func.to_number(FunctionTestModel.balance_str, "999999.99"),
func.to_timestamp(FunctionTestModel.timestamp_str, "YYYY-MM-DD HH24:MI:SS"),
)
)
assert new_query.build() == (
'SELECT cast("functiontestmodel"."balance" as integer) AS aggregate_0, '
'cast("functiontestmodel"."name" as userstatus) AS aggregate_1, '
'to_char("functiontestmodel"."created_at", \'YYYY-MM-DD\') AS aggregate_2, '
'to_number("functiontestmodel"."balance_str", \'999999.99\') AS aggregate_3, '
'to_timestamp("functiontestmodel"."timestamp_str", \'YYYY-MM-DD HH24:MI:SS\') AS aggregate_4 '
'FROM "functiontestmodel"',
[],
)


def test_invalid_where_condition():
with pytest.raises(ValueError):
QueryBuilder().select(UserDemo.id).where("invalid condition") # type: ignore
Expand Down Expand Up @@ -370,3 +517,35 @@ def test_for_update_multiple_of():
"FOR UPDATE OF artifactdemo, userdemo",
[],
)


def test_function_cast_enum():
"""
Test casting to enum types.
"""

class UserStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
PENDING = "pending"

class UserLevel(IntEnum):
BASIC = 1
PREMIUM = 2
VIP = 3

# Test casting to StrEnum
new_query = QueryBuilder().select(func.cast(FunctionTestModel.name, UserStatus))
assert new_query.build() == (
'SELECT cast("functiontestmodel"."name" as userstatus) AS aggregate_0 '
'FROM "functiontestmodel"',
[],
)

# Test casting to IntEnum
new_query = QueryBuilder().select(func.cast(FunctionTestModel.balance, UserLevel))
assert new_query.build() == (
'SELECT cast("functiontestmodel"."balance" as userlevel) AS aggregate_0 '
'FROM "functiontestmodel"',
[],
)
Loading
Loading