Skip to content

Commit

Permalink
Sync missing merge query changes from sqllineage (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulixius9 authored Jun 30, 2023
1 parent 2c683c2 commit 488fd9b
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 27 deletions.
2 changes: 1 addition & 1 deletion sqllineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


NAME = "openmetadata-sqllineage"
VERSION = "1.1.0.0"
VERSION = "1.1.0.1"
DEFAULT_LOGGING = {
"version": 1,
"disable_existing_loggers": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def extract(
for current_handler in handlers:
current_handler.handle(segment, holder)

if segment.type == "select_statement":
if segment.type in {"select_statement", "set_expression"}:
holder |= DmlSelectExtractor(self.dialect).extract(
segment,
AnalyzerContext(prev_cte=holder.cte, prev_write=holder.write),
Expand Down
14 changes: 14 additions & 0 deletions sqllineage/core/parser/sqlfluff/extractors/dml_insert_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ def extract(
segment,
AnalyzerContext(prev_cte=holder.cte, prev_write=holder.write),
)
elif segment.type == "bracketed" and any(
s.type == "with_compound_statement" for s in segment.segments
):
for sgmt in segment.segments:
if sgmt.type == "with_compound_statement":
from .cte_extractor import DmlCteExtractor

holder |= DmlCteExtractor(self.dialect).extract(
sgmt,
AnalyzerContext(
prev_cte=holder.cte, prev_write=holder.write
),
)

elif segment.type == "bracketed" and (
self.parse_subquery(segment) or is_union(segment)
):
Expand Down
50 changes: 32 additions & 18 deletions sqllineage/core/parser/sqlfluff/extractors/dml_merge_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sqllineage.core.holders import StatementLineageHolder, SubQueryLineageHolder
from sqllineage.core.models import AnalyzerContext, Column, SubQuery, Table
from sqllineage.core.parser.sqlfluff.extractors.cte_extractor import DmlCteExtractor
from sqllineage.core.parser.sqlfluff.extractors.dml_select_extractor import (
DmlSelectExtractor,
)
Expand Down Expand Up @@ -40,7 +41,7 @@ def extract(
direct_source: Optional[Union[Table, SubQuery]] = None
segments = retrieve_segments(statement)
for i, segment in enumerate(segments):
if segment.type == "keyword" and segment.raw_upper == "INTO":
if segment.type == "keyword" and segment.raw_upper in {"INTO", "MERGE"}:
tgt_flag = True
continue
if segment.type == "keyword" and segment.raw_upper == "USING":
Expand All @@ -64,9 +65,18 @@ def extract(
else None,
)
holder.add_read(direct_source)
holder |= DmlSelectExtractor(self.dialect).extract(
direct_source.query, AnalyzerContext(direct_source, holder.cte)
)
if direct_source.query.get_child("with_compound_statement"):
# in case the subquery is a CTE query
holder |= DmlCteExtractor(self.dialect).extract(
direct_source.query,
AnalyzerContext(direct_source, prev_cte=holder.cte),
)
else:
# in case the subquery is a select query
holder |= DmlSelectExtractor(self.dialect).extract(
direct_source.query,
AnalyzerContext(direct_source, holder.cte),
)
src_flag = False

for match in get_grandchildren(
Expand All @@ -89,20 +99,24 @@ def extract(
):
merge_insert = not_match.get_child("merge_insert_clause")
insert_columns = []
for c in merge_insert.get_child("bracketed").get_children(
merge_insert_bracketed = merge_insert.get_child("bracketed")
if merge_insert_bracketed and merge_insert_bracketed.get_children(
"column_reference"
):
tgt_col = Column(get_identifier(c))
tgt_col.parent = list(holder.write)[0]
insert_columns.append(tgt_col)
for j, e in enumerate(
merge_insert.get_child("values_clause")
.get_child("bracketed")
.get_children("expression")
):
col_ref = e.get_child("column_reference")
if col_ref:
src_col = Column(get_identifier(col_ref))
src_col.parent = direct_source
holder.add_column_lineage(src_col, insert_columns[j])
for c in merge_insert.get_child("bracketed").get_children(
"column_reference"
):
tgt_col = Column(get_identifier(c))
tgt_col.parent = list(holder.write)[0]
insert_columns.append(tgt_col)
for j, e in enumerate(
merge_insert.get_child("values_clause")
.get_child("bracketed")
.get_children("expression")
):
col_ref = e.get_child("column_reference")
if col_ref:
src_col = Column(get_identifier(col_ref))
src_col.parent = direct_source
holder.add_column_lineage(src_col, insert_columns[j])
return holder
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_multiple_identifiers,
get_subqueries,
is_subquery,
is_union,
)
from sqllineage.utils.entities import SubQueryTuple

Expand Down Expand Up @@ -74,7 +73,7 @@ def parse_subquery(cls, segment: BaseSegment) -> List[SubQuery]:
)
if segment.type in ["select_clause", "from_clause", "where_clause"]:
result = cls._parse_subquery(get_subqueries(segment))
elif is_subquery(segment) and not is_union(segment):
elif is_subquery(segment):
# Parenthesis for SubQuery without alias, this is valid syntax for certain SQL dialect
result = [SqlFluffSubQuery.of(segment, None)]
return result
Expand Down
8 changes: 5 additions & 3 deletions sqllineage/core/parser/sqlfluff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def get_subqueries(segment: BaseSegment) -> List[SubQueryTuple]:
return [SubQueryTuple(get_innermost_bracketed(bracketed_segments[0]), None)]
return []
elif is_union(segment):
if segment.type != "set_expression":
segment = segment.get_child("set_expression")
for s in retrieve_segments(segment, check_bracketed=True):
if s.type == "bracketed" or s.type == "select_statement":
subquery.append(SubQueryTuple(s, None))
Expand Down Expand Up @@ -262,7 +260,11 @@ def retrieve_segments(
:return: a list of segments
"""
if segment.type == "bracketed" and is_union(segment):
return [segment]
result = []
for sgmt in segment.segments:
if sgmt.type == "set_expression":
result = [sgmt]
return result
elif segment.type == "bracketed" and check_bracketed:
segments = [
sg
Expand Down
4 changes: 2 additions & 2 deletions sqllineage/core/parser/sqlparse/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _extract_from_dml_merge(cls, stmt: Statement) -> StatementLineageHolder:
if is_token_negligible(token):
continue
if token.is_keyword:
if token.normalized == "INTO":
if token.normalized in {"INTO", "MERGE"}:
tgt_flag = True
elif token.normalized == "USING":
src_flag = True
Expand Down Expand Up @@ -162,7 +162,7 @@ def _extract_from_dml_merge(cls, stmt: Statement) -> StatementLineageHolder:
tgt_col = Column(identifier.get_real_name())
tgt_col.parent = list(holder.write)[0]
insert_columns.append(tgt_col)
elif isinstance(token, Values):
elif insert_columns and isinstance(token, Values):
for sub_token in token.tokens:
if isinstance(sub_token, Parenthesis):
t = sub_token.tokens[1]
Expand Down
18 changes: 18 additions & 0 deletions tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,3 +1082,21 @@ def test_merge_into_using_subquery():
(ColumnQualifierTuple("k", "src"), ColumnQualifierTuple("k", "target")),
],
)


def test_union_inside_cte():
sql = """INSERT INTO dataset.target WITH temp_cte AS (SELECT col1 FROM dataset.tab1 UNION ALL
SELECT col1 FROM dataset.tab2) SELECT col1 FROM temp_cte"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("col1", "dataset.tab1"),
ColumnQualifierTuple("col1", "dataset.target"),
),
(
ColumnQualifierTuple("col1", "dataset.tab2"),
ColumnQualifierTuple("col1", "dataset.target"),
),
],
)
12 changes: 12 additions & 0 deletions tests/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,15 @@ def test_with_insert_in_query():
{"tab2"},
{"tab3"},
)


def test_union_at_last_cte():
# issue #398
sql = """WITH cte_1 AS (select col1 from tab1)
SELECT col2 from tab2
UNION
SELECT col3 from tab3"""
assert_table_lineage_equal(
sql,
{"tab1", "tab2", "tab3", "tab3"},
)
30 changes: 30 additions & 0 deletions tests/test_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,33 @@ def test_split_statements_with_desc():
DESC tab1;"""
assert len(split(sql)) == 2


def test_merge_using_cte_subquery():
sql = """MERGE INTO tgt t
USING (
WITH base AS (
SELECT
id, max(value) AS value
FROM src
GROUP BY id
)
SELECT
id, value
FROM base
) s
ON t.id = s.id
WHEN MATCHED THEN
UPDATE SET t.value = s.value"""
assert_table_lineage_equal(
sql,
{"src"},
{"tgt"},
)


def test_merge_into_insert_one_column():
sql = """MERGE INTO target
USING src ON target.k = src.k
WHEN NOT MATCHED THEN INSERT VALUES (src.k)"""
assert_table_lineage_equal(sql, {"src"}, {"target"})
48 changes: 48 additions & 0 deletions tests/test_sqlfluff.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from sqllineage.utils.entities import ColumnQualifierTuple
from .helpers import assert_column_lineage_equal, assert_table_lineage_equal

Expand Down Expand Up @@ -104,6 +106,38 @@ def test_non_reserved_keyword_as_column_name():
)


@pytest.mark.parametrize("dialect", ["snowflake", "bigquery"])
def test_create_clone(dialect: str):
"""
Language manual:
https://cloud.google.com/bigquery/docs/table-clones-create
https://docs.snowflake.com/en/sql-reference/sql/create-clone
Note clone is not a keyword in sqlparse, we'll skip testing for it.
"""
assert_table_lineage_equal(
"create table tab2 CLONE tab1;",
{"tab1"},
{"tab2"},
dialect=dialect,
test_sqlparse=False,
)


@pytest.mark.parametrize("dialect", ["snowflake"])
def test_alter_table_swap_partition(dialect: str):
"""
See https://docs.snowflake.com/en/sql-reference/sql/alter-table for language manual
Note swap is not a keyword in sqlparse, we'll skip testing for it.
"""
assert_table_lineage_equal(
"alter table tab1 swap with tab2",
{"tab2"},
{"tab1"},
dialect=dialect,
test_sqlparse=False,
)


# For top-level query parenthesis in DML, we don't treat it as subquery.
# sqlparse has some problems identifying these subqueries.
# note the table-level lineage works, only column-level lineage breaks for sqlparse
Expand All @@ -117,6 +151,20 @@ def test_create_as_with_parenthesis_around_both():
assert_table_lineage_equal(sql, {"tab2"}, {"tab1"}, test_sqlparse=False)


def test_cte_inside_bracket_of_insert():
sql = """INSERT INTO tab3 (WITH tab1 AS (SELECT * FROM tab2) SELECT * FROM tab1)"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("*", "tab2"),
ColumnQualifierTuple("*", "tab3"),
),
],
test_sqlparse=False,
)


# specify columns in CREATE statement, sqlparse would parse my_view as function call
def test_view_with_subquery_custom_columns():
# select as subquery
Expand Down

0 comments on commit 488fd9b

Please sign in to comment.