Skip to content

Commit

Permalink
Add func-based date modifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Dec 20, 2024
1 parent c28eb5a commit 7931aad
Show file tree
Hide file tree
Showing 2 changed files with 778 additions and 0 deletions.
233 changes: 233 additions & 0 deletions iceaxe/__tests__/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,239 @@ def test_function_distinct():
)


def test_function_abs():
new_query = QueryBuilder().select(func.abs(UserDemo.balance))
assert new_query.build() == (
'SELECT abs("userdemo"."balance") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_date_trunc():
new_query = QueryBuilder().select(func.date_trunc('month', UserDemo.created_at))
assert new_query.build() == (
'SELECT date_trunc(\'month\', "userdemo"."created_at") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_date_part():
new_query = QueryBuilder().select(func.date_part('year', UserDemo.created_at))
assert new_query.build() == (
'SELECT date_part(\'year\', "userdemo"."created_at") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_extract():
new_query = QueryBuilder().select(func.extract('month', UserDemo.created_at))
assert new_query.build() == (
'SELECT extract(month from "userdemo"."created_at") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_age():
# Test age with single argument
new_query = QueryBuilder().select(func.age(UserDemo.birth_date))
assert new_query.build() == (
'SELECT age("userdemo"."birth_date") AS aggregate_0 FROM "userdemo"',
[],
)

# Test age with two arguments
new_query = QueryBuilder().select(func.age(UserDemo.end_date, UserDemo.start_date))
assert new_query.build() == (
'SELECT age("userdemo"."end_date", "userdemo"."start_date") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_current_date():
new_query = QueryBuilder().select(func.current_date())
assert new_query.build() == (
'SELECT current_date AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_current_time():
new_query = QueryBuilder().select(func.current_time())
assert new_query.build() == (
'SELECT current_time AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_current_timestamp():
new_query = QueryBuilder().select(func.current_timestamp())
assert new_query.build() == (
'SELECT current_timestamp AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_date():
new_query = QueryBuilder().select(func.date(UserDemo.created_at))
assert new_query.build() == (
'SELECT date("userdemo"."created_at") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_make_date():
new_query = QueryBuilder().select(func.make_date(UserDemo.year, UserDemo.month, UserDemo.day))
assert new_query.build() == (
'SELECT make_date("userdemo"."year", "userdemo"."month", "userdemo"."day") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_make_time():
new_query = QueryBuilder().select(func.make_time(UserDemo.hour, UserDemo.minute, UserDemo.second))
assert new_query.build() == (
'SELECT make_time("userdemo"."hour", "userdemo"."minute", "userdemo"."second") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_make_timestamp():
new_query = QueryBuilder().select(
func.make_timestamp(
UserDemo.year,
UserDemo.month,
UserDemo.day,
UserDemo.hour,
UserDemo.minute,
UserDemo.second
)
)
assert new_query.build() == (
'SELECT make_timestamp("userdemo"."year", "userdemo"."month", "userdemo"."day", '
'"userdemo"."hour", "userdemo"."minute", "userdemo"."second") AS aggregate_0 FROM "userdemo"',
[],
)


def test_function_make_interval():
# Test with some components
new_query = QueryBuilder().select(
func.make_interval(years=UserDemo.years, months=UserDemo.months, days=UserDemo.days)
)
assert new_query.build() == (
'SELECT make_interval(years => "userdemo"."years", months => "userdemo"."months", '
'days => "userdemo"."days") AS aggregate_0 FROM "userdemo"',
[],
)

# Test with all components
new_query = QueryBuilder().select(
func.make_interval(
years=UserDemo.years,
months=UserDemo.months,
weeks=UserDemo.weeks,
days=UserDemo.days,
hours=UserDemo.hours,
mins=UserDemo.minutes,
secs=UserDemo.seconds
)
)
assert new_query.build() == (
'SELECT make_interval(years => "userdemo"."years", months => "userdemo"."months", '
'weeks => "userdemo"."weeks", days => "userdemo"."days", hours => "userdemo"."hours", '
'mins => "userdemo"."minutes", secs => "userdemo"."seconds") AS aggregate_0 FROM "userdemo"',
[],
)

# Test with no components should raise ValueError
with pytest.raises(ValueError):
QueryBuilder().select(func.make_interval())


def test_function_transformations():
# Test string functions
new_query = QueryBuilder().select((
func.lower(UserDemo.name),
func.upper(UserDemo.name),
func.length(UserDemo.name),
func.trim(UserDemo.name),
func.substring(UserDemo.name, 1, 3),
))
assert new_query.build() == (
'SELECT lower("userdemo"."name") AS aggregate_0, '
'upper("userdemo"."name") AS aggregate_1, '
'length("userdemo"."name") AS aggregate_2, '
'trim("userdemo"."name") AS aggregate_3, '
'substring("userdemo"."name" from 1 for 3) AS aggregate_4 '
'FROM "userdemo"',
[],
)

# Test mathematical functions
new_query = QueryBuilder().select((
func.round(UserDemo.balance),
func.ceil(UserDemo.balance),
func.floor(UserDemo.balance),
func.power(UserDemo.balance, 2),
func.sqrt(UserDemo.balance),
))
assert new_query.build() == (
'SELECT round("userdemo"."balance") AS aggregate_0, '
'ceil("userdemo"."balance") AS aggregate_1, '
'floor("userdemo"."balance") AS aggregate_2, '
'power("userdemo"."balance", 2) AS aggregate_3, '
'sqrt("userdemo"."balance") AS aggregate_4 '
'FROM "userdemo"',
[],
)

# Test aggregate functions
new_query = QueryBuilder().select((
func.array_agg(UserDemo.name),
func.string_agg(UserDemo.name, ','),
))
assert new_query.build() == (
'SELECT array_agg("userdemo"."name") AS aggregate_0, '
'string_agg("userdemo"."name", \',\') AS aggregate_1 '
'FROM "userdemo"',
[],
)

# Test window functions
new_query = QueryBuilder().select((
func.row_number().over(),
func.rank().over(),
func.dense_rank().over(),
func.lag(UserDemo.balance).over(),
func.lead(UserDemo.balance).over(),
))
assert new_query.build() == (
'SELECT row_number() OVER () AS aggregate_0, '
'rank() OVER () AS aggregate_1, '
'dense_rank() OVER () AS aggregate_2, '
'lag("userdemo"."balance") OVER () AS aggregate_3, '
'lead("userdemo"."balance") OVER () AS aggregate_4 '
'FROM "userdemo"',
[],
)

# Test type conversion functions
new_query = QueryBuilder().select((
func.cast(UserDemo.balance, 'integer'),
func.to_char(UserDemo.created_at, 'YYYY-MM-DD'),
func.to_number(UserDemo.balance_str, '999999.99'),
func.to_timestamp(UserDemo.timestamp_str, 'YYYY-MM-DD HH24:MI:SS'),
))
assert new_query.build() == (
'SELECT cast("userdemo"."balance" as integer) AS aggregate_0, '
'to_char("userdemo"."created_at", \'YYYY-MM-DD\') AS aggregate_1, '
'to_number("userdemo"."balance_str", \'999999.99\') AS aggregate_2, '
'to_timestamp("userdemo"."timestamp_str", \'YYYY-MM-DD HH24:MI:SS\') AS aggregate_3 '
'FROM "userdemo"',
[],
)


def test_invalid_where_condition():
with pytest.raises(ValueError):
QueryBuilder().select(UserDemo.id).where("invalid condition") # type: ignore
Expand Down
Loading

0 comments on commit 7931aad

Please sign in to comment.