From c94fbf7703e5a53114e4eea26b890424033424bc Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Wed, 11 Jan 2023 17:40:09 +0100 Subject: [PATCH] Add type hints for sqlalchemy --- trino/sqlalchemy/compiler.py | 41 ++++---- trino/sqlalchemy/datatype.py | 20 ++-- trino/sqlalchemy/dialect.py | 196 +++++++++++++++++++++++++++-------- trino/sqlalchemy/util.py | 6 +- 4 files changed, 192 insertions(+), 71 deletions(-) diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 8747d190..41aa4a19 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -9,8 +9,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Optional + from sqlalchemy.sql import compiler from sqlalchemy.sql.base import DialectKWArgs +from sqlalchemy.sql.schema import Table # https://trino.io/docs/current/language/reserved.html RESERVED_WORDS = { @@ -92,7 +95,7 @@ class TrinoSQLCompiler(compiler.SQLCompiler): - def limit_clause(self, select, **kw): + def limit_clause(self, select: Any, **kw: dict[str, Any]) -> str: """ Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax. """ @@ -103,15 +106,15 @@ def limit_clause(self, select, **kw): text += "\nLIMIT " + self.process(select._limit_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): + def visit_table(self, table: Table, asfrom: bool = False, iscrud: bool = False, ashint: bool = False, + fromhints: Optional[Any] = None, use_schema: bool = True, **kwargs: Any) -> str: sql = super(TrinoSQLCompiler, self).visit_table( table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs ) return self.add_catalog(sql, table) @staticmethod - def add_catalog(sql, table): + def add_catalog(sql: str, table: Table) -> str: if table is None or not isinstance(table, DialectKWArgs): return sql @@ -131,7 +134,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler): class TrinoTypeCompiler(compiler.GenericTypeCompiler): - def visit_FLOAT(self, type_, **kw): + def visit_FLOAT(self, type_: Any, **kw: dict[str, Any]) -> str: precision = type_.precision or 32 if 0 <= precision <= 32: return self.visit_REAL(type_, **kw) @@ -140,37 +143,37 @@ def visit_FLOAT(self, type_, **kw): else: raise ValueError(f"type.precision must be in range [0, 64], got {type_.precision}") - def visit_DOUBLE(self, type_, **kw): + def visit_DOUBLE(self, type_: Any, **kw: dict[str, Any]) -> str: return "DOUBLE" - def visit_NUMERIC(self, type_, **kw): + def visit_NUMERIC(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_DECIMAL(type_, **kw) - def visit_NCHAR(self, type_, **kw): + def visit_NCHAR(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_CHAR(type_, **kw) - def visit_NVARCHAR(self, type_, **kw): + def visit_NVARCHAR(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_TEXT(self, type_, **kw): + def visit_TEXT(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_BINARY(self, type_, **kw): + def visit_BINARY(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_VARBINARY(type_, **kw) - def visit_CLOB(self, type_, **kw): + def visit_CLOB(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_NCLOB(self, type_, **kw): + def visit_NCLOB(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_BLOB(self, type_, **kw): + def visit_BLOB(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_VARBINARY(type_, **kw) - def visit_DATETIME(self, type_, **kw): + def visit_DATETIME(self, type_: Any, **kw: dict[str, Any]) -> str: return self.visit_TIMESTAMP(type_, **kw) - def visit_TIMESTAMP(self, type_, **kw): + def visit_TIMESTAMP(self, type_: Any, **kw: dict[str, Any]) -> str: datatype = "TIMESTAMP" precision = getattr(type_, "precision", None) if precision not in range(0, 13) and precision is not None: @@ -182,7 +185,7 @@ def visit_TIMESTAMP(self, type_, **kw): return datatype - def visit_TIME(self, type_, **kw): + def visit_TIME(self, type_: Any, **kw: dict[str, Any]) -> str: datatype = "TIME" precision = getattr(type_, "precision", None) if precision not in range(0, 13) and precision is not None: @@ -193,13 +196,13 @@ def visit_TIME(self, type_, **kw): datatype += " WITH TIME ZONE" return datatype - def visit_JSON(self, type_, **kw): + def visit_JSON(self, type_: Any, **kw: dict[str, Any]) -> str: return 'JSON' class TrinoIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS - def format_table(self, table, use_schema=True, name=None): + def format_table(self, table: Table, use_schema: bool = True, name: Optional[str] = None) -> str: result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name) return TrinoSQLCompiler.add_catalog(result, table) diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 224996e9..05bbd3ad 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -11,14 +11,18 @@ # limitations under the License. import json import re -from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterator, List, Optional +from typing import Text as typing_Text +from typing import Tuple, Type, TypeVar, Union from sqlalchemy import util +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import sqltypes from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine from sqlalchemy.types import String SQLType = Union[TypeEngine, Type[TypeEngine]] +_T = TypeVar('_T') class DOUBLE(sqltypes.Float): @@ -38,7 +42,7 @@ def __init__(self, key_type: SQLType, value_type: SQLType): self.value_type: TypeEngine = value_type @property - def python_type(self): + def python_type(self) -> type: return dict @@ -53,14 +57,14 @@ def __init__(self, attr_types: List[Tuple[Optional[str], SQLType]]): self.attr_types.append((attr_name, attr_type)) @property - def python_type(self): + def python_type(self) -> type: return list class TIME(sqltypes.TIME): __visit_name__ = "TIME" - def __init__(self, precision=None, timezone=False): + def __init__(self, precision: Optional[int] = None, timezone: bool = False): super(TIME, self).__init__(timezone=timezone) self.precision = precision @@ -68,7 +72,7 @@ def __init__(self, precision=None, timezone=False): class TIMESTAMP(sqltypes.TIMESTAMP): __visit_name__ = "TIMESTAMP" - def __init__(self, precision=None, timezone=False): + def __init__(self, precision: Optional[int] = None, timezone: bool = False): super(TIMESTAMP, self).__init__(timezone=timezone) self.precision = precision @@ -76,13 +80,13 @@ def __init__(self, precision=None, timezone=False): class JSON(TypeDecorator): impl = String - def process_bind_param(self, value, dialect): + def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]: return json.dumps(value) - def process_result_value(self, value, dialect): + def process_result_value(self, value: Union[str, bytes], dialect: Dialect) -> Optional[_T]: return json.loads(value) - def get_col_spec(self, **kw): + def get_col_spec(self, **kw: dict[str, Any]) -> str: return 'JSON' diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 7bc4603b..67a9b415 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -11,6 +11,7 @@ # limitations under the License. import json from textwrap import dedent +from types import ModuleType from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from urllib.parse import unquote_plus @@ -67,7 +68,7 @@ class TrinoDialect(DefaultDialect): cte_follows_insert = True @classmethod - def dbapi(cls): + def dbapi(cls) -> ModuleType: """ ref: https://www.python.org/dev/peps/pep-0249/#module-interface """ @@ -94,9 +95,13 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if url.password: if not url.username: - raise ValueError("Username is required when specify password in connection URL") + raise ValueError( + "Username is required when specify password in connection URL" + ) kwargs["http_scheme"] = "https" - kwargs["auth"] = BasicAuthentication(unquote_plus(url.username), unquote_plus(url.password)) + kwargs["auth"] = BasicAuthentication( + unquote_plus(url.username), unquote_plus(url.password) + ) if "access_token" in url.query: kwargs["http_scheme"] = "https" @@ -104,7 +109,9 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "cert" and "key" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key'])) + kwargs["auth"] = CertificateAuthentication( + unquote_plus(url.query["cert"]), unquote_plus(url.query["key"]) + ) if "source" in url.query: kwargs["source"] = unquote_plus(url.query["source"]) @@ -112,21 +119,28 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any kwargs["source"] = "trino-sqlalchemy" if "session_properties" in url.query: - kwargs["session_properties"] = json.loads(unquote_plus(url.query["session_properties"])) + kwargs["session_properties"] = json.loads( + unquote_plus(url.query["session_properties"]) + ) if "http_headers" in url.query: kwargs["http_headers"] = json.loads(unquote_plus(url.query["http_headers"])) if "extra_credential" in url.query: kwargs["extra_credential"] = [ - tuple(extra_credential) for extra_credential in json.loads(unquote_plus(url.query["extra_credential"])) + tuple(extra_credential) + for extra_credential in json.loads( + unquote_plus(url.query["extra_credential"]) + ) ] if "client_tags" in url.query: kwargs["client_tags"] = json.loads(unquote_plus(url.query["client_tags"])) if "legacy_primitive_types" in url.query: - kwargs["legacy_primitive_types"] = json.loads(unquote_plus(url.query["legacy_primitive_types"])) + kwargs["legacy_primitive_types"] = json.loads( + unquote_plus(url.query["legacy_primitive_types"]) + ) if "verify" in url.query: kwargs["verify"] = json.loads(unquote_plus(url.query["verify"])) @@ -136,12 +150,24 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any return args, kwargs - def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + def get_columns( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> List[Dict[str, Any]]: if not self.has_table(connection, table_name, schema): raise exc.NoSuchTableError(f"schema={schema}, table={table_name}") return self._get_columns(connection, table_name, schema, **kw) - def _get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + def _get_columns( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> List[Dict[str, Any]]: schema = schema or self._get_default_schema_name(connection) query = dedent( """ @@ -156,7 +182,9 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No ORDER BY "ordinal_position" ASC """ ).strip() - res = connection.execute(sql.text(query), {"schema": schema, "table": table_name}) + res = connection.execute( + sql.text(query), {"schema": schema, "table": table_name} + ) columns = [] for record in res: column = dict( @@ -168,21 +196,37 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No columns.append(column) return columns - def get_pk_constraint(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + def get_pk_constraint( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> Dict[str, Any]: """Trino has no support for primary keys. Returns a dummy""" return dict(name=None, constrained_columns=[]) - def get_primary_keys(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[str]: + def get_primary_keys( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> List[str]: pk = self.get_pk_constraint(connection, table_name, schema) return pk.get("constrained_columns") # type: ignore def get_foreign_keys( - self, connection: Connection, table_name: str, schema: str = None, **kw + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], ) -> List[Dict[str, Any]]: """Trino has no support for foreign keys. Returns an empty list.""" return [] - def get_schema_names(self, connection: Connection, **kw) -> List[str]: + def get_schema_names(self, connection: Connection, **kw: dict[str, Any]) -> List[str]: query = dedent( """ SELECT "schema_name" @@ -192,7 +236,9 @@ def get_schema_names(self, connection: Connection, **kw) -> List[str]: res = connection.execute(sql.text(query)) return [row.schema_name for row in res] - def get_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_table_names( + self, connection: Connection, schema: Optional[str] = None, **kw: dict[str, Any] + ) -> List[str]: schema = schema or self._get_default_schema_name(connection) if schema is None: raise exc.NoSuchTableError("schema is required") @@ -207,11 +253,15 @@ def get_table_names(self, connection: Connection, schema: str = None, **kw) -> L res = connection.execute(sql.text(query), {"schema": schema}) return [row.table_name for row in res] - def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_temp_table_names( + self, connection: Connection, schema: Optional[str] = None, **kw: dict[str, Any] + ) -> List[str]: """Trino has no support for temporary tables. Returns an empty list.""" return [] - def get_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: dict[str, Any] + ) -> List[str]: schema = schema or self._get_default_schema_name(connection) if schema is None: raise exc.NoSuchTableError("schema is required") @@ -228,11 +278,19 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li res = connection.execute(sql.text(query), {"schema": schema}) return [row.table_name for row in res] - def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_temp_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: dict[str, Any] + ) -> List[str]: """Trino has no support for temporary views. Returns an empty list.""" return [] - def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str: + def get_view_definition( + self, + connection: Connection, + view_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> str: schema = schema or self._get_default_schema_name(connection) if schema is None: raise exc.NoSuchTableError("schema is required") @@ -247,37 +305,61 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st res = connection.execute(sql.text(query), {"schema": schema, "view": view_name}) return res.scalar() - def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + def get_indexes( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> List[Dict[str, Any]]: if not self.has_table(connection, table_name, schema): raise exc.NoSuchTableError(f"schema={schema}, table={table_name}") - partitioned_columns = self._get_columns(connection, f"{table_name}$partitions", schema, **kw) + partitioned_columns = self._get_columns( + connection, f"{table_name}$partitions", schema, **kw + ) if not partitioned_columns: return [] partition_index = dict( name="partition", column_names=[col["name"] for col in partitioned_columns], - unique=False + unique=False, ) return [partition_index] - def get_sequence_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_sequence_names( + self, connection: Connection, schema: Optional[str] = None, **kw: dict[str, Any] + ) -> List[str]: """Trino has no support for sequences. Returns an empty list.""" return [] def get_unique_constraints( - self, connection: Connection, table_name: str, schema: str = None, **kw + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], ) -> List[Dict[str, Any]]: """Trino has no support for unique constraints. Returns an empty list.""" return [] def get_check_constraints( - self, connection: Connection, table_name: str, schema: str = None, **kw + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], ) -> List[Dict[str, Any]]: """Trino has no support for check constraints. Returns an empty list.""" return [] - def get_table_comment(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + def get_table_comment( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> Dict[str, Any]: catalog_name = self._get_default_catalog_name(connection) if catalog_name is None: raise exc.NoSuchTableError("catalog is required in connection") @@ -296,13 +378,15 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str try: res = connection.execute( sql.text(query), - {"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name} + { + "catalog_name": catalog_name, + "schema_name": schema_name, + "table_name": table_name, + }, ) return dict(text=res.scalar()) except error.TrinoQueryError as e: - if e.error_name in ( - error.PERMISSION_DENIED, - ): + if e.error_name in (error.PERMISSION_DENIED,): return dict(text=None) raise @@ -317,7 +401,13 @@ def has_schema(self, connection: Connection, schema: str) -> bool: res = connection.execute(sql.text(query), {"schema": schema}) return res.first() is not None - def has_table(self, connection: Connection, table_name: str, schema: str = None, **kw) -> bool: + def has_table( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> bool: schema = schema or self._get_default_schema_name(connection) if schema is None: return False @@ -329,10 +419,18 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None, AND "table_name" = :table """ ).strip() - res = connection.execute(sql.text(query), {"schema": schema, "table": table_name}) + res = connection.execute( + sql.text(query), {"schema": schema, "table": table_name} + ) return res.first() is not None - def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool: + def has_sequence( + self, + connection: Connection, + sequence_name: str, + schema: Optional[str] = None, + **kw: dict[str, Any], + ) -> bool: """Trino has no support for sequence. Returns False indicate that given sequence does not exists.""" return False @@ -346,7 +444,9 @@ def _get_server_version_info(self, connection: Connection) -> Any: logger.debug(f"Failed to get server version: {e.orig.message}") return None - def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection: + def _raw_connection( + self, connection: Union[Engine, Connection] + ) -> trino_dbapi.Connection: if isinstance(connection, Engine): return connection.raw_connection() return connection.connection @@ -360,15 +460,21 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]: return dbapi_connection.schema def do_execute( - self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None - ): + self, + cursor: Cursor, + statement: str, + parameters: Tuple[Any, ...], + context: DefaultExecutionContext = None, + ) -> None: cursor.execute(statement, parameters) - def do_rollback(self, dbapi_connection: trino_dbapi.Connection): + def do_rollback(self, dbapi_connection: trino_dbapi.Connection) -> None: if dbapi_connection.transaction is not None: dbapi_connection.rollback() - def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level: str) -> None: + def set_isolation_level( + self, dbapi_conn: trino_dbapi.Connection, level: str + ) -> None: dbapi_conn._isolation_level = trino_dbapi.IsolationLevel[level] def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: @@ -377,10 +483,18 @@ def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: def get_default_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: return trino_dbapi.IsolationLevel.AUTOCOMMIT.name - def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str: - table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name + def _get_full_table( + self, table_name: str, schema: Optional[str] = None, quote: bool = True + ) -> str: + table_part = ( + self.identifier_preparer.quote_identifier(table_name) + if quote + else table_name + ) if schema: - schema_part = self.identifier_preparer.quote_identifier(schema) if quote else schema + schema_part = ( + self.identifier_preparer.quote_identifier(schema) if quote else schema + ) return f"{schema_part}.{table_part}" return table_part diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py index cfa9c6b1..eeac94be 100644 --- a/trino/sqlalchemy/util.py +++ b/trino/sqlalchemy/util.py @@ -6,7 +6,7 @@ from sqlalchemy import exc -def _rfc_1738_quote(text): +def _rfc_1738_quote(text: str) -> str: return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) @@ -18,8 +18,8 @@ def _url( catalog: Optional[str] = None, schema: Optional[str] = None, source: Optional[str] = "trino-sqlalchemy", - session_properties: Dict[str, str] = None, - http_headers: Dict[str, Union[str, int]] = None, + session_properties: Optional[Dict[str, str]] = None, + http_headers: Optional[Dict[str, Union[str, int]]] = None, extra_credential: Optional[List[Tuple[str, str]]] = None, client_tags: Optional[List[str]] = None, legacy_primitive_types: Optional[bool] = None,