diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index b29ea919..7398fdac 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -194,6 +194,8 @@ class BigQueryCompiler(_struct.SQLCompiler, vendored_postgresql.PGCompiler): compound_keywords = SQLCompiler.compound_keywords.copy() compound_keywords[selectable.CompoundSelect.UNION] = "UNION DISTINCT" compound_keywords[selectable.CompoundSelect.UNION_ALL] = "UNION ALL" + compound_keywords[selectable.CompoundSelect.EXCEPT] = "EXCEPT DISTINCT" + compound_keywords[selectable.CompoundSelect.INTERSECT] = "INTERSECT DISTINCT" def __init__(self, dialect, statement, *args, **kwargs): if isinstance(statement, Column): diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index 07f21443..9d2ba21e 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -168,6 +168,94 @@ def test_typed_parameters(faux_conn, type_, val, btype, vrep): ) +def test_except(faux_conn): + table = setup_table( + faux_conn, + "table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("foo", sqlalchemy.Integer), + ) + + s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2) + s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4) + + s3 = s1.except_(s2) + + result = s3.compile(faux_conn).string + + expected = ( + "SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_1:INT64)s EXCEPT DISTINCT SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_2:INT64)s" + ) + assert result == expected + + +def test_intersect(faux_conn): + table = setup_table( + faux_conn, + "table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("foo", sqlalchemy.Integer), + ) + + s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2) + s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4) + + s3 = s1.intersect(s2) + + result = s3.compile(faux_conn).string + + expected = ( + "SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_1:INT64)s INTERSECT DISTINCT SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_2:INT64)s" + ) + assert result == expected + + +def test_union(faux_conn): + table = setup_table( + faux_conn, + "table", + sqlalchemy.Column("id", sqlalchemy.Integer), + sqlalchemy.Column("foo", sqlalchemy.Integer), + ) + + s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2) + s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4) + + s3 = s1.union(s2) + + result = s3.compile(faux_conn).string + + expected = ( + "SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_1:INT64)s UNION DISTINCT SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_2:INT64)s" + ) + assert result == expected + + s4 = s1.union_all(s2) + + result = s4.compile(faux_conn).string + + expected = ( + "SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_1:INT64)s UNION ALL SELECT `table`.`foo` \n" + "FROM `table` \n" + "WHERE `table`.`id` >= %(id_2:INT64)s" + ) + assert result == expected + + def test_select_struct(faux_conn, metadata): from sqlalchemy_bigquery import STRUCT