From 20833a5fd15c08b0544fd4dd444b000dc8dde32d Mon Sep 17 00:00:00 2001 From: Jeremy Date: Tue, 26 Dec 2023 17:29:31 +0800 Subject: [PATCH] fix: extract column type (#11) * fix: extract column type * version 043 --- databend_sqlalchemy/__init__.py | 2 +- databend_sqlalchemy/databend_dialect.py | 33 ++++++++++-------- tests/unit/test_databend_dialect.py | 46 ++++++++++++++++--------- 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/databend_sqlalchemy/__init__.py b/databend_sqlalchemy/__init__.py index 17767b1..f5f89f8 100644 --- a/databend_sqlalchemy/__init__.py +++ b/databend_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -VERSION = (0, 4, 2) +VERSION = (0, 4, 3) __version__ = ".".join(str(x) for x in VERSION) diff --git a/databend_sqlalchemy/databend_dialect.py b/databend_sqlalchemy/databend_dialect.py index 592876c..ecb0047 100644 --- a/databend_sqlalchemy/databend_dialect.py +++ b/databend_sqlalchemy/databend_dialect.py @@ -158,7 +158,7 @@ def visit_notin_op_binary(self, binary, operator, **kw): ) def visit_column( - self, column, add_to_result_map=None, include_table=True, **kwargs + self, column, add_to_result_map=None, include_table=True, **kwargs ): # Columns prefixed with table name are not supported return super(DatabendCompiler, self).visit_column( @@ -233,7 +233,7 @@ class DatabendDialect(default.DefaultDialect): _backslash_escapes = True def __init__( - self, context: Optional[ExecutionContext] = None, *args: Any, **kwargs: Any + self, context: Optional[ExecutionContext] = None, *args: Any, **kwargs: Any ): super(DatabendDialect, self).__init__(*args, **kwargs) self.context: Union[ExecutionContext, Dict] = context or {} @@ -257,7 +257,7 @@ def create_connect_args(self, url): parameters = dict(url.query) kwargs = { "dsn": "databend://%s:%s@%s:%d/%s" - % (url.username, url.password, url.host, url.port or 8000, url.database), + % (url.username, url.password, url.host, url.port or 8000, url.database), } for k, v in parameters.items(): kwargs["dsn"] = kwargs["dsn"] + "?" + k + "=" + v @@ -309,23 +309,13 @@ def get_columns(self, connection, table_name, schema=None, **kw): return [ { "name": row[0], - "type": ischema_names[self.extract_nullable_string(row[1]).lower()], + "type": ischema_names[extract_nullable_string(row[1]).lower()], "nullable": get_is_nullable(row[2]), "default": None, } for row in result ] - def extract_nullable_string(self, target): - if "Nullable" in target: - match = re.match(r"Nullable\(([^)]+)\)", target) - if match: - return match.group(1) - else: - return "" - else: - return target - @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): # No support for foreign keys. @@ -369,3 +359,18 @@ def _check_unicode_description(self, connection): def get_is_nullable(column_is_nullable: str) -> bool: return column_is_nullable == "YES" + + +def extract_nullable_string(target): + pattern = r'Nullable\((\w+)(?:\((.*?)\))?\)' + if "Nullable" in target: + match = re.match(pattern, target) + if match: + return match.group(1) + else: + return "" + else: + sl = target.split("(") + if len(sl) > 0: + return sl[0] + return target diff --git a/tests/unit/test_databend_dialect.py b/tests/unit/test_databend_dialect.py index a7800a8..7f0800d 100644 --- a/tests/unit/test_databend_dialect.py +++ b/tests/unit/test_databend_dialect.py @@ -41,7 +41,7 @@ def test_create_dialect(self, dialect: DatabendDialect): # assert result_dict["username"] == "user" def test_do_execute( - self, dialect: DatabendDialect, cursor: mock.Mock(spec=MockCursor) + self, dialect: DatabendDialect, cursor: mock.Mock(spec=MockCursor) ): dialect.do_execute(cursor, "SELECT *", None) cursor.execute.assert_called_once_with("SELECT *", None) @@ -49,7 +49,7 @@ def test_do_execute( dialect.do_execute(cursor, "SELECT *", (1, 22), None) def test_table_names( - self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) + self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) ): connection.execute.return_value = [ ("table1",), @@ -74,18 +74,18 @@ def test_table_names( ) def test_view_names( - self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) + self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) ): connection.execute.return_value = [] assert dialect.get_view_names(connection) == [] def test_indexes( - self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) + self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) ): assert dialect.get_indexes(connection, "table") == [] def test_columns( - self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) + self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi) ): def multi_column_row(columns): def getitem(self, idx): @@ -109,11 +109,11 @@ def getitem(self, idx): expected_query_schema = expected_query + " and table_schema = 'schema'" for call, expected_query in ( - (lambda: dialect.get_columns(connection, "table"), expected_query), - ( - lambda: dialect.get_columns(connection, "table", "schema"), - expected_query_schema, - ), + (lambda: dialect.get_columns(connection, "table"), expected_query), + ( + lambda: dialect.get_columns(connection, "table", "schema"), + expected_query_schema, + ), ): assert call() == [ { @@ -145,24 +145,36 @@ def test_types(): assert databend_sqlalchemy.databend_dialect.CHAR is sqlalchemy.sql.sqltypes.CHAR assert databend_sqlalchemy.databend_dialect.DATE is sqlalchemy.sql.sqltypes.DATE assert ( - databend_sqlalchemy.databend_dialect.DATETIME - is sqlalchemy.sql.sqltypes.DATETIME + databend_sqlalchemy.databend_dialect.DATETIME + is sqlalchemy.sql.sqltypes.DATETIME ) assert ( - databend_sqlalchemy.databend_dialect.INTEGER is sqlalchemy.sql.sqltypes.INTEGER + databend_sqlalchemy.databend_dialect.INTEGER is sqlalchemy.sql.sqltypes.INTEGER ) assert databend_sqlalchemy.databend_dialect.BIGINT is sqlalchemy.sql.sqltypes.BIGINT assert ( - databend_sqlalchemy.databend_dialect.TIMESTAMP - is sqlalchemy.sql.sqltypes.TIMESTAMP + databend_sqlalchemy.databend_dialect.TIMESTAMP + is sqlalchemy.sql.sqltypes.TIMESTAMP ) assert ( - databend_sqlalchemy.databend_dialect.VARCHAR is sqlalchemy.sql.sqltypes.VARCHAR + databend_sqlalchemy.databend_dialect.VARCHAR is sqlalchemy.sql.sqltypes.VARCHAR ) assert ( - databend_sqlalchemy.databend_dialect.BOOLEAN is sqlalchemy.sql.sqltypes.BOOLEAN + databend_sqlalchemy.databend_dialect.BOOLEAN is sqlalchemy.sql.sqltypes.BOOLEAN ) assert databend_sqlalchemy.databend_dialect.FLOAT is sqlalchemy.sql.sqltypes.FLOAT assert issubclass( databend_sqlalchemy.databend_dialect.ARRAY, sqlalchemy.types.TypeEngine ) + + +def test_extract_nullable_string(): + types = ["INT", "FLOAT", "Nullable(INT)", "Nullable(Decimal(2,4))", "Nullable(Array(INT))", + "Nullable(Map(String, String))", "Decimal(1,2)"] + expected_types = ["int", "float", "int", "decimal", "array", "map", "decimal"] + i = 0 + for t in types: + true_type = databend_sqlalchemy.databend_dialect.extract_nullable_string(t).lower() + assert expected_types[i] == true_type + i += 1 + print(true_type)