diff --git a/setup.py b/setup.py index d1035797..3b9fee29 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def run(self) -> None: install_requires=[ "sqlparse==0.4.3", "networkx>=2.4", - "sqlfluff==2.0.2", + "sqlfluff==2.1.2", ], entry_points={"console_scripts": ["sqllineage = sqllineage.cli:main"]}, extras_require={ diff --git a/sqllineage/core/parser/sqlfluff/utils.py b/sqllineage/core/parser/sqlfluff/utils.py index 10268e35..a665416a 100644 --- a/sqllineage/core/parser/sqlfluff/utils.py +++ b/sqllineage/core/parser/sqlfluff/utils.py @@ -32,7 +32,12 @@ def extract_as_and_target_segment( """ as_segment = segment.get_child("alias_expression") sublist = retrieve_segments(segment, False) - target = sublist[0] if is_subquery(sublist[0]) else sublist[0].segments[0] + if is_subquery(sublist[0]): + target = sublist[0] + elif sublist[0].type == "bracketed": + target = get_innermost_bracketed(sublist[0]) + else: + target = sublist[0].segments[0] return as_segment, target @@ -168,14 +173,20 @@ def get_inner_from_expression(segment: BaseSegment) -> BaseSegment: :param segment: segment to be processed :return: a list of segments from a 'from_expression' or 'from_expression_element' segment """ - if segment.get_child("from_expression") and segment.get_child( - "from_expression" - ).get_child("from_expression_element"): - return segment.get_child("from_expression").get_child("from_expression_element") + if segment.get_child("from_expression"): + if segment.get_child("from_expression").get_child("from_expression_element"): + return segment.get_child("from_expression").get_child( + "from_expression_element" + ) + if segment.get_child("from_expression").get_child("bracketed"): + innermost_bracketed = get_innermost_bracketed( + segment.get_child("from_expression").get_child("bracketed") + ) + if innermost_bracketed.get_child("from_expression_element"): + return innermost_bracketed.get_child("from_expression_element") elif segment.get_child("from_expression_element"): return segment.get_child("from_expression_element") - else: - return segment + return segment def filter_segments_by_keyword( diff --git a/tests/test_others_dialect_specific.py b/tests/test_others_dialect_specific.py index 0aac6d03..5727d832 100644 --- a/tests/test_others_dialect_specific.py +++ b/tests/test_others_dialect_specific.py @@ -148,8 +148,10 @@ def test_uncache_table_if_exists(dialect: str): assert_table_lineage_equal("uncache table if exists tab1", None, None, dialect) -@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +@pytest.mark.parametrize("dialect", ["hive"]) def test_lateral_view_using_json_tuple(dialect: str): + # disabling this method for dialect "databricks", "sparksql" + # as sqlfluff produces incorrect tree for those cases sql = """INSERT OVERWRITE TABLE foo SELECT sc.id, q.item0, q.item1 FROM bar sc