Skip to content

Commit

Permalink
Improve typehints for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Dec 20, 2024
1 parent 8372b84 commit 8618983
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions iceaxe/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from datetime import datetime
from enum import Enum
from typing import Any, Type, TypeVar, cast
from typing import Any, Literal, Type, TypeVar, cast

from iceaxe.base import (
DBFieldClassDefinition,
Expand All @@ -14,6 +14,45 @@

T = TypeVar("T")

DATE_PART_FIELD = Literal[
"century",
"day",
"decade",
"dow",
"doy",
"epoch",
"hour",
"isodow",
"isoyear",
"microseconds",
"millennium",
"milliseconds",
"minute",
"month",
"quarter",
"second",
"timezone",
"timezone_hour",
"timezone_minute",
"week",
"year",
]
DATE_PRECISION = Literal[
"microseconds",
"milliseconds",
"second",
"minute",
"hour",
"day",
"week",
"month",
"quarter",
"year",
"decade",
"century",
"millennium",
]


class FunctionMetadata(ComparisonBase):
"""
Expand Down Expand Up @@ -230,7 +269,7 @@ def abs(self, field: T) -> T:
metadata.literal = QueryLiteral(f"abs({metadata.literal})")
return cast(T, metadata)

def date_trunc(self, precision: str, field: T) -> T:
def date_trunc(self, precision: DATE_PRECISION, field: T) -> T:
"""
Truncates a timestamp or interval value to specified precision.
Expand All @@ -249,7 +288,7 @@ def date_trunc(self, precision: str, field: T) -> T:
)
return cast(T, metadata)

def date_part(self, field: str, source: Any) -> int:
def date_part(self, field: DATE_PART_FIELD, source: Any) -> float:
"""
Extracts a subfield from a date/time value.
Expand All @@ -264,9 +303,9 @@ def date_part(self, field: str, source: Any) -> int:
"""
metadata = self._column_to_metadata(source)
metadata.literal = QueryLiteral(f"date_part('{field}', {metadata.literal})")
return cast(int, metadata)
return cast(float, metadata)

def extract(self, field: str, source: Any) -> int:
def extract(self, field: DATE_PART_FIELD, source: Any) -> int:
"""
Extracts a subfield from a date/time value using SQL standard syntax.
Expand Down Expand Up @@ -539,7 +578,7 @@ def to_timestamp(self, field: Any, format: str) -> datetime:
"""
metadata = self._column_to_metadata(field)
metadata.literal = QueryLiteral(f"to_timestamp({metadata.literal}, '{format}')")
return cast(T, metadata)
return cast(datetime, metadata)

def _column_to_metadata(self, field: Any) -> FunctionMetadata:
"""
Expand Down

0 comments on commit 8618983

Please sign in to comment.