Skip to content

Commit

Permalink
Add support for ROW and ARRAY in TrinoTypeCompiler
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco authored and hashhar committed Jun 21, 2024
1 parent bb35f1c commit 670d5f7
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 1 deletion.
106 changes: 105 additions & 1 deletion tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import pytest
import sqlalchemy as sqla
from sqlalchemy.sql import and_, not_, or_
from sqlalchemy.types import ARRAY

from tests.integration.conftest import trino_version
from tests.unit.conftest import sqlalchemy_version
from trino.sqlalchemy.datatype import JSON, MAP
from trino.sqlalchemy.datatype import JSON, MAP, ROW


@pytest.fixture
Expand Down Expand Up @@ -528,6 +529,109 @@ def test_map_column(trino_connection, map_object, sqla_type):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize(
'trino_connection,array_object,sqla_type',
[
('memory', None, ARRAY(sqla.sql.sqltypes.String)),
('memory', [], ARRAY(sqla.sql.sqltypes.String)),
('memory', [True, False, True], ARRAY(sqla.sql.sqltypes.Boolean)),
('memory', [1, 2, None], ARRAY(sqla.sql.sqltypes.Integer)),
('memory', [1.4, 2.3, math.inf], ARRAY(sqla.sql.sqltypes.Float)),
('memory', [Decimal("1.2"), Decimal("2.3")], ARRAY(sqla.sql.sqltypes.DECIMAL(2, 1))),
('memory', ["hello", "world"], ARRAY(sqla.sql.sqltypes.String)),
('memory', ["a ", "null"], ARRAY(sqla.sql.sqltypes.CHAR(4))),
('memory', [b'eh?', None, b'\x00'], ARRAY(sqla.sql.sqltypes.BINARY)),
],
indirect=['trino_connection']
)
def test_array_column(trino_connection, array_object, sqla_type):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
table_with_array = sqla.Table(
'table_with_array',
metadata,
sqla.Column('id', sqla.Integer),
sqla.Column('array_column', sqla_type),
schema="test"
)
metadata.create_all(engine)
ins = table_with_array.insert()
conn.execute(ins, {"id": 1, "array_column": array_object})
query = sqla.select(table_with_array)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0] == (1, array_object)
finally:
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize(
'trino_connection,row_object,sqla_type',
[
('memory', None, ROW([('field1', sqla.sql.sqltypes.String),
('field2', sqla.sql.sqltypes.String)])),
('memory', ('hello', 'world'), ROW([('field1', sqla.sql.sqltypes.String),
('field2', sqla.sql.sqltypes.String)])),
('memory', (True, False), ROW([('field1', sqla.sql.sqltypes.Boolean),
('field2', sqla.sql.sqltypes.Boolean)])),
('memory', (1, 2), ROW([('field1', sqla.sql.sqltypes.Integer),
('field2', sqla.sql.sqltypes.Integer)])),
('memory', (1.4, float('inf')), ROW([('field1', sqla.sql.sqltypes.Float),
('field2', sqla.sql.sqltypes.Float)])),
('memory', (Decimal("1.2"), Decimal("2.3")), ROW([('field1', sqla.sql.sqltypes.DECIMAL(2, 1)),
('field2', sqla.sql.sqltypes.DECIMAL(3, 1))])),
('memory', ("hello", "world"), ROW([('field1', sqla.sql.sqltypes.String),
('field2', sqla.sql.sqltypes.String)])),
('memory', ("a ", "null"), ROW([('field1', sqla.sql.sqltypes.CHAR(4)),
('field2', sqla.sql.sqltypes.CHAR(4))])),
('memory', (b'eh?', b'oh?'), ROW([('field1', sqla.sql.sqltypes.BINARY),
('field2', sqla.sql.sqltypes.BINARY)])),
],
indirect=['trino_connection']
)
def test_row_column(trino_connection, row_object, sqla_type):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
table_with_row = sqla.Table(
'table_with_row',
metadata,
sqla.Column('id', sqla.Integer),
sqla.Column('row_column', sqla_type),
schema="test"
)
metadata.create_all(engine)
ins = table_with_row.insert()
conn.execute(ins, {"id": 1, "row_column": row_object})
query = sqla.select(table_with_row)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0] == (1, row_object)
finally:
metadata.drop_all(engine)


@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_get_catalog_names(trino_connection):
engine, conn = trino_connection
Expand Down
6 changes: 6 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ def visit_MAP(self, type_, **kw):
value_type = self.process(type_.value_type, **kw)
return f'MAP({key_type}, {value_type})'

def visit_ARRAY(self, type_, **kw):
return f'ARRAY({self.process(type_.item_type, **kw)})'

def visit_ROW(self, type_, **kw):
return f'ROW({", ".join(f"{name} {self.process(attr_type, **kw)}" for name, attr_type in type_.attr_types)})'


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
Expand Down

0 comments on commit 670d5f7

Please sign in to comment.