Skip to content

Commit

Permalink
Add type hints for sqlalchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
damian3031 committed Jan 13, 2023
1 parent ea55fc0 commit 87e70bb
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 71 deletions.
40 changes: 21 additions & 19 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

from sqlalchemy.sql import compiler
from sqlalchemy.sql.base import DialectKWArgs

Expand Down Expand Up @@ -92,7 +94,7 @@


class TrinoSQLCompiler(compiler.SQLCompiler):
def limit_clause(self, select, **kw):
def limit_clause(self, select: Any, **kw: Any) -> str:
"""
Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
"""
Expand All @@ -103,15 +105,15 @@ def limit_clause(self, select, **kw):
text += "\nLIMIT " + self.process(select._limit_clause, **kw)
return text

def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
def visit_table(self, table: Any, asfrom: bool = False, iscrud: bool = False, ashint: bool = False,
fromhints: Optional[Any] = None, use_schema: bool = True, **kwargs: Any) -> str:
sql = super(TrinoSQLCompiler, self).visit_table(
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
)
return self.add_catalog(sql, table)

@staticmethod
def add_catalog(sql, table):
def add_catalog(sql: str, table: Any) -> str:
if table is None or not isinstance(table, DialectKWArgs):
return sql

Expand All @@ -131,7 +133,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):


class TrinoTypeCompiler(compiler.GenericTypeCompiler):
def visit_FLOAT(self, type_, **kw):
def visit_FLOAT(self, type_: Any, **kw: Any) -> str:
precision = type_.precision or 32
if 0 <= precision <= 32:
return self.visit_REAL(type_, **kw)
Expand All @@ -140,37 +142,37 @@ def visit_FLOAT(self, type_, **kw):
else:
raise ValueError(f"type.precision must be in range [0, 64], got {type_.precision}")

def visit_DOUBLE(self, type_, **kw):
def visit_DOUBLE(self, type_: Any, **kw: Any) -> str:
return "DOUBLE"

def visit_NUMERIC(self, type_, **kw):
def visit_NUMERIC(self, type_: Any, **kw: Any) -> str:
return self.visit_DECIMAL(type_, **kw)

def visit_NCHAR(self, type_, **kw):
def visit_NCHAR(self, type_: Any, **kw: Any) -> str:
return self.visit_CHAR(type_, **kw)

def visit_NVARCHAR(self, type_, **kw):
def visit_NVARCHAR(self, type_: Any, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_TEXT(self, type_, **kw):
def visit_TEXT(self, type_: Any, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_BINARY(self, type_, **kw):
def visit_BINARY(self, type_: Any, **kw: Any) -> str:
return self.visit_VARBINARY(type_, **kw)

def visit_CLOB(self, type_, **kw):
def visit_CLOB(self, type_: Any, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_NCLOB(self, type_, **kw):
def visit_NCLOB(self, type_: Any, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)

def visit_BLOB(self, type_, **kw):
def visit_BLOB(self, type_: Any, **kw: Any) -> str:
return self.visit_VARBINARY(type_, **kw)

def visit_DATETIME(self, type_, **kw):
def visit_DATETIME(self, type_: Any, **kw: Any) -> str:
return self.visit_TIMESTAMP(type_, **kw)

def visit_TIMESTAMP(self, type_, **kw):
def visit_TIMESTAMP(self, type_: Any, **kw: Any) -> str:
datatype = "TIMESTAMP"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
Expand All @@ -182,7 +184,7 @@ def visit_TIMESTAMP(self, type_, **kw):

return datatype

def visit_TIME(self, type_, **kw):
def visit_TIME(self, type_: Any, **kw: Any) -> str:
datatype = "TIME"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
Expand All @@ -193,13 +195,13 @@ def visit_TIME(self, type_, **kw):
datatype += " WITH TIME ZONE"
return datatype

def visit_JSON(self, type_, **kw):
def visit_JSON(self, type_: Any, **kw: Any) -> str:
return 'JSON'


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS

def format_table(self, table, use_schema=True, name=None):
def format_table(self, table: Any, use_schema: bool = True, name: Optional[str] = None) -> str:
result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name)
return TrinoSQLCompiler.add_catalog(result, table)
20 changes: 12 additions & 8 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
# limitations under the License.
import json
import re
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional
from typing import Text as typing_Text
from typing import Tuple, Type, TypeVar, Union

from sqlalchemy import util
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
from sqlalchemy.types import String

SQLType = Union[TypeEngine, Type[TypeEngine]]
_T = TypeVar('_T')


class DOUBLE(sqltypes.Float):
Expand All @@ -38,7 +42,7 @@ def __init__(self, key_type: SQLType, value_type: SQLType):
self.value_type: TypeEngine = value_type

@property
def python_type(self):
def python_type(self) -> type:
return dict


Expand All @@ -53,36 +57,36 @@ def __init__(self, attr_types: List[Tuple[Optional[str], SQLType]]):
self.attr_types.append((attr_name, attr_type))

@property
def python_type(self):
def python_type(self) -> type:
return list


class TIME(sqltypes.TIME):
__visit_name__ = "TIME"

def __init__(self, precision=None, timezone=False):
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision


class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"

def __init__(self, precision=None, timezone=False):
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision


class JSON(TypeDecorator):
impl = String

def process_bind_param(self, value, dialect):
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]:
return json.dumps(value)

def process_result_value(self, value, dialect):
def process_result_value(self, value: Union[str, bytes], dialect: Dialect) -> Optional[_T]:
return json.loads(value)

def get_col_spec(self, **kw):
def get_col_spec(self, **kw: Any) -> str:
return 'JSON'


Expand Down
Loading

0 comments on commit 87e70bb

Please sign in to comment.