Skip to content

Commit

Permalink
fix: extract column type (#11)
Browse files Browse the repository at this point in the history
* fix: extract column type

* version 043
  • Loading branch information
hantmac authored Dec 26, 2023
1 parent 6ee18da commit 20833a5
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 32 deletions.
2 changes: 1 addition & 1 deletion databend_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python


VERSION = (0, 4, 2)
VERSION = (0, 4, 3)
__version__ = ".".join(str(x) for x in VERSION)
33 changes: 19 additions & 14 deletions databend_sqlalchemy/databend_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def visit_notin_op_binary(self, binary, operator, **kw):
)

def visit_column(
self, column, add_to_result_map=None, include_table=True, **kwargs
self, column, add_to_result_map=None, include_table=True, **kwargs
):
# Columns prefixed with table name are not supported
return super(DatabendCompiler, self).visit_column(
Expand Down Expand Up @@ -233,7 +233,7 @@ class DatabendDialect(default.DefaultDialect):
_backslash_escapes = True

def __init__(
self, context: Optional[ExecutionContext] = None, *args: Any, **kwargs: Any
self, context: Optional[ExecutionContext] = None, *args: Any, **kwargs: Any
):
super(DatabendDialect, self).__init__(*args, **kwargs)
self.context: Union[ExecutionContext, Dict] = context or {}
Expand All @@ -257,7 +257,7 @@ def create_connect_args(self, url):
parameters = dict(url.query)
kwargs = {
"dsn": "databend://%s:%s@%s:%d/%s"
% (url.username, url.password, url.host, url.port or 8000, url.database),
% (url.username, url.password, url.host, url.port or 8000, url.database),
}
for k, v in parameters.items():
kwargs["dsn"] = kwargs["dsn"] + "?" + k + "=" + v
Expand Down Expand Up @@ -309,23 +309,13 @@ def get_columns(self, connection, table_name, schema=None, **kw):
return [
{
"name": row[0],
"type": ischema_names[self.extract_nullable_string(row[1]).lower()],
"type": ischema_names[extract_nullable_string(row[1]).lower()],
"nullable": get_is_nullable(row[2]),
"default": None,
}
for row in result
]

def extract_nullable_string(self, target):
if "Nullable" in target:
match = re.match(r"Nullable\(([^)]+)\)", target)
if match:
return match.group(1)
else:
return ""
else:
return target

@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# No support for foreign keys.
Expand Down Expand Up @@ -369,3 +359,18 @@ def _check_unicode_description(self, connection):

def get_is_nullable(column_is_nullable: str) -> bool:
return column_is_nullable == "YES"


def extract_nullable_string(target):
pattern = r'Nullable\((\w+)(?:\((.*?)\))?\)'
if "Nullable" in target:
match = re.match(pattern, target)
if match:
return match.group(1)
else:
return ""
else:
sl = target.split("(")
if len(sl) > 0:
return sl[0]
return target
46 changes: 29 additions & 17 deletions tests/unit/test_databend_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def test_create_dialect(self, dialect: DatabendDialect):
# assert result_dict["username"] == "user"

def test_do_execute(
self, dialect: DatabendDialect, cursor: mock.Mock(spec=MockCursor)
self, dialect: DatabendDialect, cursor: mock.Mock(spec=MockCursor)
):
dialect.do_execute(cursor, "SELECT *", None)
cursor.execute.assert_called_once_with("SELECT *", None)
cursor.execute.reset_mock()
dialect.do_execute(cursor, "SELECT *", (1, 22), None)

def test_table_names(
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
):
connection.execute.return_value = [
("table1",),
Expand All @@ -74,18 +74,18 @@ def test_table_names(
)

def test_view_names(
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
):
connection.execute.return_value = []
assert dialect.get_view_names(connection) == []

def test_indexes(
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
):
assert dialect.get_indexes(connection, "table") == []

def test_columns(
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
self, dialect: DatabendDialect, connection: mock.Mock(spec=MockDBApi)
):
def multi_column_row(columns):
def getitem(self, idx):
Expand All @@ -109,11 +109,11 @@ def getitem(self, idx):
expected_query_schema = expected_query + " and table_schema = 'schema'"

for call, expected_query in (
(lambda: dialect.get_columns(connection, "table"), expected_query),
(
lambda: dialect.get_columns(connection, "table", "schema"),
expected_query_schema,
),
(lambda: dialect.get_columns(connection, "table"), expected_query),
(
lambda: dialect.get_columns(connection, "table", "schema"),
expected_query_schema,
),
):
assert call() == [
{
Expand Down Expand Up @@ -145,24 +145,36 @@ def test_types():
assert databend_sqlalchemy.databend_dialect.CHAR is sqlalchemy.sql.sqltypes.CHAR
assert databend_sqlalchemy.databend_dialect.DATE is sqlalchemy.sql.sqltypes.DATE
assert (
databend_sqlalchemy.databend_dialect.DATETIME
is sqlalchemy.sql.sqltypes.DATETIME
databend_sqlalchemy.databend_dialect.DATETIME
is sqlalchemy.sql.sqltypes.DATETIME
)
assert (
databend_sqlalchemy.databend_dialect.INTEGER is sqlalchemy.sql.sqltypes.INTEGER
databend_sqlalchemy.databend_dialect.INTEGER is sqlalchemy.sql.sqltypes.INTEGER
)
assert databend_sqlalchemy.databend_dialect.BIGINT is sqlalchemy.sql.sqltypes.BIGINT
assert (
databend_sqlalchemy.databend_dialect.TIMESTAMP
is sqlalchemy.sql.sqltypes.TIMESTAMP
databend_sqlalchemy.databend_dialect.TIMESTAMP
is sqlalchemy.sql.sqltypes.TIMESTAMP
)
assert (
databend_sqlalchemy.databend_dialect.VARCHAR is sqlalchemy.sql.sqltypes.VARCHAR
databend_sqlalchemy.databend_dialect.VARCHAR is sqlalchemy.sql.sqltypes.VARCHAR
)
assert (
databend_sqlalchemy.databend_dialect.BOOLEAN is sqlalchemy.sql.sqltypes.BOOLEAN
databend_sqlalchemy.databend_dialect.BOOLEAN is sqlalchemy.sql.sqltypes.BOOLEAN
)
assert databend_sqlalchemy.databend_dialect.FLOAT is sqlalchemy.sql.sqltypes.FLOAT
assert issubclass(
databend_sqlalchemy.databend_dialect.ARRAY, sqlalchemy.types.TypeEngine
)


def test_extract_nullable_string():
types = ["INT", "FLOAT", "Nullable(INT)", "Nullable(Decimal(2,4))", "Nullable(Array(INT))",
"Nullable(Map(String, String))", "Decimal(1,2)"]
expected_types = ["int", "float", "int", "decimal", "array", "map", "decimal"]
i = 0
for t in types:
true_type = databend_sqlalchemy.databend_dialect.extract_nullable_string(t).lower()
assert expected_types[i] == true_type
i += 1
print(true_type)

0 comments on commit 20833a5

Please sign in to comment.