diff --git a/docs/basic_concepts/lineage_analyzer.rst b/docs/basic_concepts/lineage_analyzer.rst index 8febd9ba..98fed3d4 100644 --- a/docs/basic_concepts/lineage_analyzer.rst +++ b/docs/basic_concepts/lineage_analyzer.rst @@ -2,31 +2,13 @@ LineageAnalyzer *************** -LineageAnalyzer contains the core processing logic for one-statement SQL analysis. +LineageAnalyzer is an abstract class, supposed to include the core processing logic for one-statement SQL analysis. -At the core of analyzer is all kinds of ``sqllineage.core.handlers`` to handle the interested tokens and store the -result in ``sqllineage.core.holders``. +Each parser implementation will inherit LineageAnalyzer and do parser specific analysis based on the AST they generates +and store the result in ``sqllineage.core.holders``. LineageAnalyzer ======================================== .. autoclass:: sqllineage.core.analyzer.LineageAnalyzer :members: - - -SourceHandler -======================================== - -.. autoclass:: sqllineage.core.handlers.source.SourceHandler - - -TargetHandler -======================================== - -.. autoclass:: sqllineage.core.handlers.target.TargetHandler - - -CTEHandler -======================================== - -.. autoclass:: sqllineage.core.handlers.cte.CTEHandler diff --git a/docs/basic_concepts/lineage_holder.rst b/docs/basic_concepts/lineage_holder.rst index ad0a433e..80965abb 100644 --- a/docs/basic_concepts/lineage_holder.rst +++ b/docs/basic_concepts/lineage_holder.rst @@ -5,7 +5,7 @@ LineageHolder LineageHolder is an abstraction to hold the lineage result analyzed by LineageAnalyzer at different level. At the bottom, we have :class:`sqllineage.core.holder.SubQueryLineageHolder` to hold lineage at subquery level. -This is used internally for :class:`sqllineage.core.analyzer.LineageAnalyzer`, which generate +This is used internally for :class:`sqllineage.core.analyzer.Analyzer`, which generate :class:`sqllineage.core.holder.StatementLineageHolder` as the result of lineage at SQL statement level. And to assemble multiple :class:`sqllineage.core.holder.StatementLineageHolder` into a DAG based data structure serving for the final output, we have :class:`sqllineage.core.holders.SQLLineageHolder` diff --git a/docs/basic_concepts/lineage_runner.rst b/docs/basic_concepts/lineage_runner.rst index 7426dc7a..9ec5622b 100644 --- a/docs/basic_concepts/lineage_runner.rst +++ b/docs/basic_concepts/lineage_runner.rst @@ -6,9 +6,9 @@ LineageRunner is the entry point for SQLLineage core processing logic. After par representation of SQL statements will be fed to LineageRunner for processing. From a bird's-eye view, it contains three steps: -1. Calling ``sqlparse.parse`` function to parse string-base SQL statements into a list of ``sqlparse.sql.Statement`` +1. Calling ``sqllineage.utils.helpers.split`` function to split string-base SQL statements into a list of ``str`` statement. -2. Calling :class:`sqllineage.core.analyzer.LineageAnalyzer` to analyze each ``sqlparse.sql.Statement`` and return a list of +2. Calling :class:`sqllineage.core.analyzer.LineageAnalyzer` to analyze each one statement sql string and return a list of :class:`sqllineage.core.holders.StatementLineageHolder` . 3. Calling :class:`sqllineage.core.holders.SQLLineageHolder.of` function to assemble the list of diff --git a/docs/conf.py b/docs/conf.py index 3dce1bd0..9be25627 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,9 +22,9 @@ sys.path.insert(0, os.path.abspath("..")) -from sqllineage import VERSION # noqa +from sqllineage import NAME, VERSION # noqa -project = "sqllineage" +project = NAME copyright = f"2019-{datetime.now().year}, Reata" # noqa author = "Reata" diff --git a/docs/index.rst b/docs/index.rst index 967eeeda..8d5265ed 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -74,7 +74,7 @@ Basic concepts LineageAnalyzer: The core functionality of analyze one SQL statement :doc:`basic_concepts/lineage_holder` - LineageCombiner: To hold lineage result at different level + LineageHolder: To hold lineage result at different level :doc:`basic_concepts/lineage_model` The data classes for SQLLineage diff --git a/mypy.ini b/mypy.ini index ee764b05..a1b91c7e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -8,6 +8,6 @@ warn_no_return=True warn_redundant_casts=True warn_unused_ignores=True disallow_any_generics=True -[mypy-sqllineage.sqlfluff_core.utils.sqlfluff] +[mypy-sqllineage.core.parser.sqlfluff.utils.sqlfluff] disallow_untyped_calls=False -warn_return_any = False +warn_return_any = False \ No newline at end of file diff --git a/setup.py b/setup.py index 129c083f..722b20e5 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ def run(self) -> None: install_requires=[ "sqlparse>=0.3.1", "networkx>=2.4", - "sqlfluff>=1.4.5", + "sqlfluff==2.0.0a6", ], entry_points={"console_scripts": ["sqllineage = sqllineage.cli:main"]}, extras_require={ diff --git a/sqllineage/__init__.py b/sqllineage/__init__.py index 94877981..a2c4af10 100644 --- a/sqllineage/__init__.py +++ b/sqllineage/__init__.py @@ -1,47 +1,8 @@ import os -def _patch_adding_window_function_token() -> None: - from sqlparse.engine import grouping - from sqllineage.utils.sqlparse import group_function_with_window - - grouping.group_functions = group_function_with_window - - -def _patch_adding_builtin_type() -> None: - from sqlparse import tokens - from sqlparse.keywords import KEYWORDS - - KEYWORDS["STRING"] = tokens.Name.Builtin - KEYWORDS["DATETIME"] = tokens.Name.Builtin - - -def _patch_updating_lateral_view_lexeme() -> None: - import re - from sqlparse.keywords import SQL_REGEX - - for i, (regex, lexeme) in enumerate(SQL_REGEX): - if regex("LATERAL VIEW EXPLODE(col)"): - new_regex = r"(LATERAL\s+VIEW\s+)(OUTER\s+)?(EXPLODE|INLINE|PARSE_URL_TUPLE|POSEXPLODE|STACK|JSON_TUPLE)\b" - new_compile = re.compile(new_regex, re.IGNORECASE | re.UNICODE).match - SQL_REGEX[i] = (new_compile, lexeme) - break - - -def _monkey_patch() -> None: - try: - _patch_adding_window_function_token() - _patch_adding_builtin_type() - _patch_updating_lateral_view_lexeme() - except ImportError: - # when imported by setup.py for constant variables, dependency is not ready yet - pass - - -_monkey_patch() - -NAME = "openmetadata-sqllineage" -VERSION = "1.0.1" +NAME = "sqllineage" +VERSION = "1.3.7" DEFAULT_LOGGING = { "version": 1, "disable_existing_loggers": False, @@ -74,6 +35,6 @@ def _monkey_patch() -> None: "SQLLINEAGE_DIRECTORY", os.path.join(os.path.dirname(__file__), "data") ) DEFAULT_HOST = "localhost" -DEFAULT_PORT = 5001 -DEFAULT_DIALECT = "ansi" -DEFAULT_USE_SQLFLUFF = False +DEFAULT_PORT = 5000 +SQLPARSE_DIALECT = "non-validating" +DEFAULT_DIALECT = SQLPARSE_DIALECT diff --git a/sqllineage/cli.py b/sqllineage/cli.py index 5e274643..54b56d7c 100644 --- a/sqllineage/cli.py +++ b/sqllineage/cli.py @@ -68,13 +68,7 @@ def main(args=None) -> None: help="the dialect used to compute the lineage", type=str, default=DEFAULT_DIALECT, - metavar="ansi, mysql, snowflake, redshift, hive, etc. Chec supported dialects by sqlfluff.", - ) - parser.add_argument( - "-s", - "--sqlfluff", - help="use sqlfluff as the parser", - action="store_false", + metavar="ansi, mysql, snowflake, redshift, hive, etc. Check supported dialects by sqlfluff.", ) args = parser.parse_args(args) if args.e and args.f: @@ -92,10 +86,9 @@ def main(args=None) -> None: "f": args.f if args.f else None, }, dialect=args.dialect, - use_sqlparse=args.sqlfluff, ) if args.graph_visualization: - runner.draw(args.dialect, args.sqlfluff) + runner.draw(args.dialect) elif args.level == LineageLevel.COLUMN: runner.print_column_lineage() else: diff --git a/sqllineage/core/__init__.py b/sqllineage/core/__init__.py index 30989338..e69de29b 100644 --- a/sqllineage/core/__init__.py +++ b/sqllineage/core/__init__.py @@ -1,2 +0,0 @@ -# For backward compatibility, so people can still do `from sqllineage.core import LineageAnalyzer` -from sqllineage.core.analyzer import LineageAnalyzer # noqa diff --git a/sqllineage/core/analyzer.py b/sqllineage/core/analyzer.py index e42cadeb..49d097ae 100644 --- a/sqllineage/core/analyzer.py +++ b/sqllineage/core/analyzer.py @@ -1,177 +1,16 @@ -from functools import reduce -from operator import add -from typing import List, NamedTuple, Optional, Set, Union +from abc import abstractmethod -from sqlparse.sql import ( - Function, - Identifier, - IdentifierList, - Statement, - TokenList, - Where, -) - -from sqllineage.core.handlers.base import ( - CurrentTokenBaseHandler, - NextTokenBaseHandler, -) -from sqllineage.core.holders import StatementLineageHolder, SubQueryLineageHolder -from sqllineage.core.models import SubQuery, Table -from sqllineage.utils.sqlparse import ( - get_subquery_parentheses, - is_subquery, - is_token_negligible, -) - - -class AnalyzerContext(NamedTuple): - subquery: Optional[SubQuery] = None - prev_cte: Optional[Set[SubQuery]] = None +from sqllineage.core.holders import StatementLineageHolder class LineageAnalyzer: - """SQL Statement Level Lineage Analyzer.""" - - def analyze(self, stmt: Statement) -> StatementLineageHolder: - """ - to analyze the Statement and store the result into :class:`sqllineage.holders.StatementLineageHolder`. - - :param stmt: a SQL statement parsed by `sqlparse` - """ - if ( - stmt.get_type() == "DELETE" - or stmt.token_first(skip_cm=True).normalized == "TRUNCATE" - or stmt.token_first(skip_cm=True).normalized.upper() == "REFRESH" - or stmt.token_first(skip_cm=True).normalized == "CACHE" - or stmt.token_first(skip_cm=True).normalized.upper() == "UNCACHE" - or stmt.token_first(skip_cm=True).normalized == "SHOW" - ): - holder = StatementLineageHolder() - elif stmt.get_type() == "DROP": - holder = self._extract_from_ddl_drop(stmt) - elif ( - stmt.get_type() == "ALTER" - or stmt.token_first(skip_cm=True).normalized == "RENAME" - ): - holder = self._extract_from_ddl_alter(stmt) - else: - # DML parsing logic also applies to CREATE DDL - holder = StatementLineageHolder.of( - self._extract_from_dml(stmt, AnalyzerContext()) - ) - return holder - - @classmethod - def _extract_from_ddl_drop(cls, stmt: Statement) -> StatementLineageHolder: - holder = StatementLineageHolder() - for table in {Table.of(t) for t in stmt.tokens if isinstance(t, Identifier)}: - holder.add_drop(table) - return holder - - @classmethod - def _extract_from_ddl_alter(cls, stmt: Statement) -> StatementLineageHolder: - holder = StatementLineageHolder() - tables = [] - for t in stmt.tokens: - if isinstance(t, Identifier): - tables.append(Table.of(t)) - elif isinstance(t, IdentifierList): - for identifier in t.get_identifiers(): - tables.append(Table.of(identifier)) - keywords = [t for t in stmt.tokens if t.is_keyword] - if any(k.normalized == "RENAME" for k in keywords): - if stmt.get_type() == "ALTER" and len(tables) == 2: - holder.add_rename(tables[0], tables[1]) - elif ( - stmt.token_first(skip_cm=True).normalized == "RENAME" - and len(tables) % 2 == 0 - ): - for i in range(0, len(tables), 2): - holder.add_rename(tables[i], tables[i + 1]) - if any(k.normalized == "EXCHANGE" for k in keywords) and len(tables) == 2: - holder.add_write(tables[0]) - holder.add_read(tables[1]) - return holder - - @classmethod - def _extract_from_dml( - cls, token: TokenList, context: AnalyzerContext - ) -> SubQueryLineageHolder: - holder = SubQueryLineageHolder() - if context.prev_cte is not None: - # CTE can be referenced by subsequent CTEs - for cte in context.prev_cte: - holder.add_cte(cte) - if context.subquery is not None: - # If within subquery, then manually add subquery as target table - holder.add_write(context.subquery) - current_handlers = [ - handler_cls() for handler_cls in CurrentTokenBaseHandler.__subclasses__() - ] - next_handlers = [ - handler_cls() for handler_cls in NextTokenBaseHandler.__subclasses__() - ] - - subqueries = [] - for sub_token in token.tokens: - if is_token_negligible(sub_token): - continue - - for sq in cls.parse_subquery(sub_token): - # Collecting subquery on the way, hold on parsing until last - # so that each handler don't have to worry about what's inside subquery - subqueries.append(sq) - - for current_handler in current_handlers: - current_handler.handle(sub_token, holder) - - if sub_token.is_keyword: - for next_handler in next_handlers: - next_handler.indicate(sub_token) - continue - - for next_handler in next_handlers: - if next_handler.indicator: - next_handler.handle(sub_token, holder) - else: - # call end of query hook here as loop is over - for next_handler in next_handlers: - next_handler.end_of_query_cleanup(holder) - # By recursively extracting each subquery of the parent and merge, we're doing Depth-first search - for sq in subqueries: - holder |= cls._extract_from_dml(sq.token, AnalyzerContext(sq, holder.cte)) - return holder - - @classmethod - def parse_subquery(cls, token: TokenList) -> List[SubQuery]: - result = [] - if isinstance(token, (Identifier, Function, Where)): - # usually SubQuery is an Identifier, but not all Identifiers are SubQuery - # Function for CTE without AS keyword - result = cls._parse_subquery(token) - elif isinstance(token, IdentifierList): - # IdentifierList for SQL89 style of JOIN or multiple CTEs, this is actually SubQueries - result = reduce( - add, - [ - cls._parse_subquery(identifier) - for identifier in token.get_sublists() - ], - [], - ) - elif is_subquery(token): - # Parenthesis for SubQuery without alias, this is valid syntax for certain SQL dialect - result = [SubQuery.of(token, None)] - return result + """SQL Statement Level Lineage Analyzer + Parser specific implementation should inherit this class and implement analyze method + """ - @classmethod - def _parse_subquery( - cls, token: Union[Identifier, Function, Where] - ) -> List[SubQuery]: + @abstractmethod + def analyze(self, sql: str) -> StatementLineageHolder: """ - convert SubQueryTuple to sqllineage.core.models.SubQuery + to analyze single statement sql and store the result into + :class:`sqllineage.core.holders.StatementLineageHolder`. """ - return [ - SubQuery.of(parenthesis, alias) - for parenthesis, alias in get_subquery_parentheses(token) - ] diff --git a/sqllineage/core/holders.py b/sqllineage/core/holders.py index b264249e..7434522c 100644 --- a/sqllineage/core/holders.py +++ b/sqllineage/core/holders.py @@ -28,7 +28,7 @@ def get_column_lineage(self, exclude_subquery=True) -> Set[Tuple[Column, ...]]: node for node in target_columns if isinstance(node.parent, Table) } columns = set() - for (source, target) in itertools.product(source_columns, target_columns): + for source, target in itertools.product(source_columns, target_columns): simple_paths = list(nx.all_simple_paths(self.graph, source, target)) for path in simple_paths: columns.add(tuple(path)) @@ -48,6 +48,7 @@ class SubQueryLineageHolder(ColumnLineageMixin): def __init__(self) -> None: self.graph = nx.DiGraph() + self.extra_subqueries: Set[SubQuery] = set() def __or__(self, other): self.graph = nx.compose(self.graph, other.graph) @@ -139,7 +140,7 @@ def add_rename(self, src: Table, tgt: Table) -> None: self.graph.add_edge(src, tgt, type=EdgeType.RENAME) @staticmethod - def of(holder: SubQueryLineageHolder): + def of(holder: SubQueryLineageHolder) -> "StatementLineageHolder": stmt_holder = StatementLineageHolder() stmt_holder.graph = holder.graph return stmt_holder @@ -235,7 +236,7 @@ def _build_digraph(*args: StatementLineageHolder) -> DiGraph: if g.has_node(table) and g.degree[table] == 0: g.remove_node(table) elif holder.rename: - for (table_old, table_new) in holder.rename: + for table_old, table_new in holder.rename: g = nx.relabel_nodes(g, {table_old: table_new}) g.remove_edge(table_new, table_new) if g.degree[table_new] == 0: @@ -284,7 +285,7 @@ def _build_digraph(*args: StatementLineageHolder) -> DiGraph: return g @staticmethod - def of(*args: StatementLineageHolder): + def of(*args: StatementLineageHolder) -> "SQLLineageHolder": """ To assemble multiple :class:`sqllineage.holders.StatementLineageHolder` into :class:`sqllineage.holders.SQLLineageHolder` diff --git a/sqllineage/core/models.py b/sqllineage/core/models.py index 9f310af2..4df575c8 100644 --- a/sqllineage/core/models.py +++ b/sqllineage/core/models.py @@ -1,35 +1,19 @@ import warnings -from typing import Dict, List, Optional, Set, Union - -from sqlparse import tokens as T -from sqlparse.engine import grouping -from sqlparse.keywords import is_keyword -from sqlparse.sql import ( - Case, - Comparison, - Function, - Identifier, - IdentifierList, - Operation, - Parenthesis, - Token, - TokenList, -) -from sqlparse.utils import imt +from typing import Any, Dict, List, Optional, Set, Union from sqllineage.exceptions import SQLLineageException -from sqllineage.utils.entities import ColumnExpression, ColumnQualifierTuple from sqllineage.utils.helpers import escape_identifier_name -from sqllineage.utils.sqlparse import get_parameters, is_subquery class Schema: + """ + Data Class for Schema + """ + unknown = "" def __init__(self, name: str = unknown): """ - Data Class for Schema - :param name: schema name """ self.raw_name = escape_identifier_name(name) @@ -41,7 +25,7 @@ def __repr__(self): return "Schema: " + str(self) def __eq__(self, other): - return type(self) is type(other) and str(self) == str(other) + return isinstance(other, Schema) and str(self) == str(other) def __hash__(self): return hash(str(self)) @@ -51,10 +35,12 @@ def __bool__(self): class Table: + """ + Data Class for Table + """ + def __init__(self, name: str, schema: Schema = Schema(), **kwargs): """ - Data Class for Table - :param name: table name :param schema: schema as defined by :class:`Schema` """ @@ -79,40 +65,25 @@ def __repr__(self): return "Table: " + str(self) def __eq__(self, other): - return type(self) is type(other) and str(self) == str(other) + return isinstance(other, Table) and str(self) == str(other) def __hash__(self): return hash(str(self)) @staticmethod - def of(identifier: Identifier) -> "Table": - # rewrite identifier's get_real_name method, by matching the last dot instead of the first dot, so that the - # real name for a.b.c will be c instead of b - dot_idx, _ = identifier._token_matching( - lambda token: imt(token, m=(T.Punctuation, ".")), - start=len(identifier.tokens), - reverse=True, - ) - real_name = identifier._get_first_name(dot_idx, real_name=True) - # rewrite identifier's get_parent_name accordingly - parent_name = ( - "".join( - [ - escape_identifier_name(token.value) - for token in identifier.tokens[:dot_idx] - ] - ) - if dot_idx - else None - ) - schema = Schema(parent_name) if parent_name is not None else Schema() - alias = identifier.get_alias() - kwargs = {"alias": alias} if alias else {} - return Table(real_name, schema, **kwargs) + def of(table: Any) -> "Table": + raise NotImplementedError class Path: + """ + Data Class for Path + """ + def __init__(self, uri: str): + """ + :param uri: uri of the path + """ self.uri = escape_identifier_name(uri) def __str__(self): @@ -122,26 +93,24 @@ def __repr__(self): return "Path: " + str(self) def __eq__(self, other): - return type(self) is type(other) and self.uri == other.uri + return isinstance(other, Path) and self.uri == other.uri def __hash__(self): return hash(self.uri) -class Partition: - pass - - class SubQuery: - def __init__(self, token: Parenthesis, alias: Optional[str]): - """ - Data Class for SubQuery + """ + Data Class for SubQuery + """ - :param token: subquery token - :param alias: subquery name + def __init__(self, subquery: Any, subquery_raw: str, alias: Optional[str]): """ - self.token = token - self._query = token.value + :param subquery: subquery + :param alias: subquery alias name + """ + self.query = subquery + self.query_raw = subquery_raw self.alias = alias if alias is not None else f"subquery_{hash(self)}" def __str__(self): @@ -151,29 +120,30 @@ def __repr__(self): return "SubQuery: " + str(self) def __eq__(self, other): - return type(self) is type(other) and self._query == other._query + return isinstance(other, SubQuery) and self.query_raw == other.query_raw def __hash__(self): - return hash(self._query) + return hash(self.query_raw) @staticmethod - def of(parenthesis: Parenthesis, alias: Optional[str]) -> "SubQuery": - return SubQuery(parenthesis, alias) + def of(subquery: Any, alias: Optional[str]) -> "SubQuery": + raise NotImplementedError class Column: + """ + Data Class for Column + """ + def __init__(self, name: str, **kwargs): """ - Data Class for Column - :param name: column name :param parent: :class:`Table` or :class:`SubQuery` :param kwargs: """ - self._parent: Set[Union[Table, SubQuery]] = set() + self._parent: Set[Union[Path, Table, SubQuery]] = set() self.raw_name = escape_identifier_name(name) self.source_columns = kwargs.pop("source_columns", ((self.raw_name, None),)) - self.expression = kwargs.pop("expression", ColumnExpression(True, None)) def __str__(self): return ( @@ -186,165 +156,47 @@ def __repr__(self): return "Column: " + str(self) def __eq__(self, other): - return type(self) is type(other) and str(self) == str(other) + return isinstance(other, Column) and str(self) == str(other) def __hash__(self): return hash(str(self)) @property - def parent(self) -> Optional[Union[Table, SubQuery]]: + def parent(self) -> Optional[Union[Path, Table, SubQuery]]: return list(self._parent)[0] if len(self._parent) == 1 else None @parent.setter - def parent(self, value: Union[Table, SubQuery]): + def parent(self, value: Union[Path, Table, SubQuery]): self._parent.add(value) @property - def parent_candidates(self) -> List[Union[Table, SubQuery]]: + def parent_candidates(self) -> List[Union[Path, Table, SubQuery]]: return sorted(self._parent, key=lambda p: str(p)) @staticmethod - def of(token: Token): - if isinstance(token, Identifier): - alias = token.get_alias() - if alias: - # handle column alias, including alias for column name or Case, Function - kw_idx, kw = token.token_next_by(m=(T.Keyword, "AS")) - if kw_idx is None: - # alias without AS - kw_idx, _ = token.token_next_by(i=Identifier) - if kw_idx is None: - # invalid syntax: col AS, without alias - return Column(alias) - else: - idx, _ = token.token_prev(kw_idx, skip_cm=True) - expr = grouping.group(TokenList(token.tokens[: idx + 1]))[0] - source_columns = Column._extract_source_columns(expr) - return Column( - alias, - source_columns=source_columns, - expression=ColumnExpression(False, expr), - ) - else: - # select column name directly without alias - return Column( - token.get_real_name(), - source_columns=((token.get_real_name(), token.get_parent_name()),), - expression=ColumnExpression(True, None), - ) - else: - # Wildcard, Case, Function without alias (thus not recognized as an Identifier) - source_columns = Column._extract_source_columns(token) - return Column( - token.value, - source_columns=source_columns, - expression=ColumnExpression(False, token), - ) - - @staticmethod - def _extract_source_columns(token: Token) -> List[ColumnQualifierTuple]: - if isinstance(token, Function): - # max(col1) AS col2 - source_columns = [ - cqt - for tk in get_parameters(token) - for cqt in Column._extract_source_columns(tk) - ] - elif isinstance(token, Parenthesis): - if is_subquery(token): - # This is to avoid circular import - from sqllineage.runner import LineageRunner - - # (SELECT avg(col1) AS col1 FROM tab3), used after WHEN or THEN in CASE clause - src_cols = [ - lineage[0] - for lineage in LineageRunner(token.value).get_column_lineage( - exclude_subquery=False - ) - ] - source_columns = [ - ColumnQualifierTuple(src_col.raw_name, src_col.parent.raw_name) - for src_col in src_cols - ] - else: - # (col1 + col2) AS col3 - source_columns = [ - cqt - for tk in token.tokens[1:-1] - for cqt in Column._extract_source_columns(tk) - ] - elif isinstance(token, Operation): - # col1 + col2 AS col3 - source_columns = [ - cqt - for tk in token.get_sublists() - for cqt in Column._extract_source_columns(tk) - ] - elif isinstance(token, Case): - # CASE WHEN col1 = 2 THEN "V1" WHEN col1 = "2" THEN "V2" END AS col2 - source_columns = [ - cqt - for tk in token.get_sublists() - for cqt in Column._extract_source_columns(tk) - ] - elif isinstance(token, Comparison): - source_columns = Column._extract_source_columns( - token.left - ) + Column._extract_source_columns(token.right) - elif isinstance(token, IdentifierList): - source_columns = [ - cqt - for tk in token.get_sublists() - for cqt in Column._extract_source_columns(tk) - ] - elif isinstance(token, Identifier): - real_name = token.get_real_name() - # ignore function dtypes that don't need to check for extract column - FUNC_DTYPE = ["decimal", "numeric"] - has_function = any( - isinstance(t, Function) and t.get_real_name() not in FUNC_DTYPE - for t in token.tokens - ) - is_kw = is_keyword(real_name) if real_name is not None else False - if ( - # real name is None: col1=1 AS int - real_name is None - # real_name is decimal: case when col1 > 0 then col2 else col3 end as decimal(18, 0) - or (real_name in FUNC_DTYPE and isinstance(token.tokens[-1], Function)) - or (is_kw and has_function) - ): - source_columns = [ - cqt - for tk in token.get_sublists() - for cqt in Column._extract_source_columns(tk) - ] - else: - # col1 AS col2 - source_columns = [ - ColumnQualifierTuple(token.get_real_name(), token.get_parent_name()) - ] - else: - if token.ttype == T.Wildcard: - # select * - source_columns = [ColumnQualifierTuple(token.value, None)] - else: - # typically, T.Literal here - source_columns = [] - return source_columns + def of(column: Any, **kwargs) -> "Column": + """ + Build a 'Column' object + :param column: column segment or token + :return: + """ + raise NotImplementedError - def to_source_columns(self, alias_mapping: Dict[str, Union[Table, SubQuery]]): + def to_source_columns(self, alias_mapping: Dict[str, Union[Path, Table, SubQuery]]): """ Best guess for source table given all the possible table/subquery and their alias. """ - def _to_src_col(name: str, parent: Optional[Union[Table, SubQuery]] = None): + def _to_src_col( + name: str, parent: Optional[Union[Path, Table, SubQuery]] = None + ): col = Column(name) if parent: col.parent = parent return col source_columns = set() - for (src_col, qualifier) in self.source_columns: + for src_col, qualifier in self.source_columns: if qualifier is None: if src_col == "*": # select * @@ -366,3 +218,28 @@ def _to_src_col(name: str, parent: Optional[Union[Table, SubQuery]] = None): else: source_columns.add(_to_src_col(src_col, Table(qualifier))) return source_columns + + +class AnalyzerContext: + """ + Data class to hold the analyzer context + """ + + subquery: Optional[SubQuery] + prev_cte: Optional[Set[SubQuery]] + prev_write: Optional[Set[Union[SubQuery, Table]]] + + def __init__( + self, + subquery: Optional[SubQuery] = None, + prev_cte: Optional[Set[SubQuery]] = None, + prev_write: Optional[Set[Union[SubQuery, Table]]] = None, + ): + """ + :param subquery: subquery + :param prev_cte: previous CTE queries + :param prev_write: previous written tables + """ + self.subquery = subquery + self.prev_cte = prev_cte + self.prev_write = prev_write diff --git a/sqllineage/core/parser/__init__.py b/sqllineage/core/parser/__init__.py new file mode 100644 index 00000000..3b2a1422 --- /dev/null +++ b/sqllineage/core/parser/__init__.py @@ -0,0 +1,59 @@ +from typing import Dict, List, Tuple, Union + +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import Column, Path, SubQuery, Table +from sqllineage.exceptions import SQLLineageException +from sqllineage.utils.constant import EdgeType + + +class SourceHandlerMixin: + tables: List[Union[Path, SubQuery, Table]] + columns: List[Column] + union_barriers: List[Tuple[int, int]] + + def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None: + for i, tbl in enumerate(self.tables): + holder.add_read(tbl) + self.union_barriers.append((len(self.columns), len(self.tables))) + for i, (col_barrier, tbl_barrier) in enumerate(self.union_barriers): + prev_col_barrier, prev_tbl_barrier = ( + (0, 0) if i == 0 else self.union_barriers[i - 1] + ) + col_grp = self.columns[prev_col_barrier:col_barrier] + tbl_grp = self.tables[prev_tbl_barrier:tbl_barrier] + tgt_tbl = None + if holder.write: + if len(holder.write) > 1: + raise SQLLineageException + tgt_tbl = list(holder.write)[0] + if tgt_tbl: + for tgt_col in col_grp: + tgt_col.parent = tgt_tbl + for src_col in tgt_col.to_source_columns( + self.get_alias_mapping_from_table_group(tbl_grp, holder) + ): + holder.add_column_lineage(src_col, tgt_col) + + @classmethod + def get_alias_mapping_from_table_group( + cls, + table_group: List[Union[Path, Table, SubQuery]], + holder: SubQueryLineageHolder, + ) -> Dict[str, Union[Path, Table, SubQuery]]: + """ + A table can be referred to as alias, table name, or database_name.table_name, create the mapping here. + For SubQuery, it's only alias then. + """ + return { + **{ + tgt: src + for src, tgt, attr in holder.graph.edges(data=True) + if attr.get("type") == EdgeType.HAS_ALIAS and src in table_group + }, + **{ + table.raw_name: table + for table in table_group + if isinstance(table, Table) + }, + **{str(table): table for table in table_group if isinstance(table, Table)}, + } diff --git a/sqllineage/sqlfluff_core/__init__.py b/sqllineage/core/parser/sqlfluff/__init__.py similarity index 100% rename from sqllineage/sqlfluff_core/__init__.py rename to sqllineage/core/parser/sqlfluff/__init__.py diff --git a/sqllineage/core/parser/sqlfluff/analyzer.py b/sqllineage/core/parser/sqlfluff/analyzer.py new file mode 100644 index 00000000..44eba384 --- /dev/null +++ b/sqllineage/core/parser/sqlfluff/analyzer.py @@ -0,0 +1,63 @@ +from sqlfluff.core import Linter + +from sqllineage.core.analyzer import LineageAnalyzer +from sqllineage.core.holders import ( + StatementLineageHolder, + SubQueryLineageHolder, +) +from sqllineage.core.models import AnalyzerContext +from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import ( + LineageHolderExtractor, +) +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import ( + clean_parentheses, + get_statement_segment, + is_subquery_statement, + remove_statement_parentheses, +) +from sqllineage.exceptions import ( + InvalidSyntaxException, + UnsupportedStatementException, +) + + +class SqlFluffLineageAnalyzer(LineageAnalyzer): + """SQL Statement Level Lineage Analyzer for `sqlfluff`""" + + def __init__(self, dialect: str): + self._dialect = dialect + + def analyze(self, sql: str) -> StatementLineageHolder: + # remove nested parentheses that sqlfluff cannot parse + sql = clean_parentheses(sql) + is_sub_query = is_subquery_statement(sql) + if is_sub_query: + sql = remove_statement_parentheses(sql) + linter = Linter(dialect=self._dialect) + parsed_string = linter.parse_string(sql) + statement_segment = get_statement_segment(parsed_string) + extractors = [ + extractor_cls(self._dialect) + for extractor_cls in LineageHolderExtractor.__subclasses__() + ] + if statement_segment and any( + extractor.can_extract(statement_segment.type) for extractor in extractors + ): + if "unparsable" in statement_segment.descendant_type_set: + raise InvalidSyntaxException( + f"SQLLineage cannot parse the statement properly, please check potential syntax error for SQL:" + f"{sql}" + ) + else: + raise UnsupportedStatementException( + f"SQLLineage doesn't support analyzing statement type [{statement_segment.type}] for SQL:" + f"{sql}" + ) + lineage_holder = SubQueryLineageHolder() + for extractor in extractors: + if extractor.can_extract(statement_segment.type): + lineage_holder = extractor.extract( + statement_segment, AnalyzerContext(), is_sub_query + ) + break + return StatementLineageHolder.of(lineage_holder) diff --git a/sqllineage/core/parser/sqlfluff/extractors/__init__.py b/sqllineage/core/parser/sqlfluff/extractors/__init__.py new file mode 100644 index 00000000..c1ed1013 --- /dev/null +++ b/sqllineage/core/parser/sqlfluff/extractors/__init__.py @@ -0,0 +1,7 @@ +import importlib +import os +import pkgutil + +# import each module so that LineageHolderExtractor's __subclasses__ will work +for module in pkgutil.iter_modules([os.path.dirname(__file__)]): + importlib.import_module(__name__ + "." + module.name) diff --git a/sqllineage/sqlfluff_core/subquery/cte_extractor.py b/sqllineage/core/parser/sqlfluff/extractors/cte_extractor.py similarity index 52% rename from sqllineage/sqlfluff_core/subquery/cte_extractor.py rename to sqllineage/core/parser/sqlfluff/extractors/cte_extractor.py index a4d20620..38751045 100644 --- a/sqllineage/sqlfluff_core/subquery/cte_extractor.py +++ b/sqllineage/core/parser/sqlfluff/extractors/cte_extractor.py @@ -1,14 +1,18 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext -from sqllineage.sqlfluff_core.models import SqlFluffSubQuery -from sqllineage.sqlfluff_core.subquery.dml_insert_extractor import DmlInsertExtractor -from sqllineage.sqlfluff_core.subquery.dml_select_extractor import DmlSelectExtractor -from sqllineage.sqlfluff_core.subquery.lineage_holder_extractor import ( +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext +from sqllineage.core.parser.sqlfluff.extractors.dml_insert_extractor import ( + DmlInsertExtractor, +) +from sqllineage.core.parser.sqlfluff.extractors.dml_select_extractor import ( + DmlSelectExtractor, +) +from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import ( LineageHolderExtractor, ) -from sqllineage.sqlfluff_core.utils.sqlfluff import has_alias, retrieve_segments +from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import has_alias, retrieve_segments class DmlCteExtractor(LineageHolderExtractor): @@ -16,32 +20,25 @@ class DmlCteExtractor(LineageHolderExtractor): DML CTE queries lineage extractor """ - CTE_STMT_TYPES = ["with_compound_statement"] + SUPPORTED_STMT_TYPES = ["with_compound_statement"] def __init__(self, dialect: str): super().__init__(dialect) - def can_extract(self, statement_type: str) -> bool: - """ - Determine if the current lineage holder extractor can process the statement - :param statement_type: a sqlfluff segment type - """ - return statement_type in self.CTE_STMT_TYPES - def extract( self, statement: BaseSegment, - context: SqlFluffAnalyzerContext, + context: AnalyzerContext, is_sub_query: bool = False, - ) -> SqlFluffSubQueryLineageHolder: + ) -> SubQueryLineageHolder: """ Extract lineage for a given statement. :param statement: a sqlfluff segment with a statement - :param context: 'SqlFluffAnalyzerContext' + :param context: 'AnalyzerContext' :param is_sub_query: determine if the statement is bracketed or not - :return 'SqlFluffSubQueryLineageHolder' object + :return 'SubQueryLineageHolder' object """ - handlers, conditional_handlers = self._init_handlers() + handlers, _ = self._init_handlers() holder = self._init_holder(context) @@ -49,26 +46,19 @@ def extract( segments = retrieve_segments(statement) for segment in segments: - for sq in self.parse_subquery(segment): - # Collecting subquery on the way, hold on parsing until last - # so that each handler don't have to worry about what's inside subquery - subqueries.append(sq) - for current_handler in handlers: current_handler.handle(segment, holder) if segment.type == "select_statement": holder |= DmlSelectExtractor(self.dialect).extract( segment, - SqlFluffAnalyzerContext( - prev_cte=holder.cte, prev_write=holder.write - ), + AnalyzerContext(prev_cte=holder.cte, prev_write=holder.write), ) if segment.type == "insert_statement": holder |= DmlInsertExtractor(self.dialect).extract( segment, - SqlFluffAnalyzerContext(prev_cte=holder.cte), + AnalyzerContext(prev_cte=holder.cte), ) identifier = None @@ -88,15 +78,11 @@ def extract( if segment_has_alias: holder.add_cte(SqlFluffSubQuery.of(sub_segment, identifier)) - for conditional_handler in conditional_handlers: - if conditional_handler.indicate(segment): - conditional_handler.handle(segment, holder) - - # By recursively extracting each subquery of the parent and merge, we're doing Depth-first search + # By recursively extracting each extractor of the parent and merge, we're doing Depth-first search for sq in subqueries: holder |= DmlSelectExtractor(self.dialect).extract( - sq.segment, - SqlFluffAnalyzerContext(sq, prev_cte=holder.cte), + sq.query, + AnalyzerContext(sq, prev_cte=holder.cte), ) return holder diff --git a/sqllineage/sqlfluff_core/subquery/ddl_alter_extractor.py b/sqllineage/core/parser/sqlfluff/extractors/ddl_alter_extractor.py similarity index 59% rename from sqllineage/sqlfluff_core/subquery/ddl_alter_extractor.py rename to sqllineage/core/parser/sqlfluff/extractors/ddl_alter_extractor.py index f8664500..83a8a217 100644 --- a/sqllineage/sqlfluff_core/subquery/ddl_alter_extractor.py +++ b/sqllineage/core/parser/sqlfluff/extractors/ddl_alter_extractor.py @@ -1,14 +1,11 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.holders import ( - SqlFluffStatementLineageHolder, - SqlFluffSubQueryLineageHolder, -) -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext -from sqllineage.sqlfluff_core.models import SqlFluffTable -from sqllineage.sqlfluff_core.subquery.lineage_holder_extractor import ( +from sqllineage.core.holders import StatementLineageHolder, SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext +from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import ( LineageHolderExtractor, ) +from sqllineage.core.parser.sqlfluff.models import SqlFluffTable class DdlAlterExtractor(LineageHolderExtractor): @@ -16,7 +13,7 @@ class DdlAlterExtractor(LineageHolderExtractor): DDL Alter queries lineage extractor """ - DDL_ALTER_STMT_TYPES = [ + SUPPORTED_STMT_TYPES = [ "alter_table_statement", "rename_statement", "rename_table_statement", @@ -25,27 +22,20 @@ class DdlAlterExtractor(LineageHolderExtractor): def __init__(self, dialect: str): super().__init__(dialect) - def can_extract(self, statement_type: str) -> bool: - """ - Determine if the current lineage holder extractor can process the statement - :param statement_type: a sqlfluff segment type - """ - return statement_type in self.DDL_ALTER_STMT_TYPES - def extract( self, statement: BaseSegment, - context: SqlFluffAnalyzerContext, + context: AnalyzerContext, is_sub_query: bool = False, - ) -> SqlFluffSubQueryLineageHolder: + ) -> SubQueryLineageHolder: """ Extract lineage for a given statement. :param statement: a sqlfluff segment with a statement - :param context: 'SqlFluffAnalyzerContext' + :param context: 'AnalyzerContext' :param is_sub_query: determine if the statement is bracketed or not - :return 'SqlFluffSubQueryLineageHolder' object + :return 'SubQueryLineageHolder' object """ - holder = SqlFluffStatementLineageHolder() + holder = StatementLineageHolder() tables = [] for t in statement.segments: if t.type == "table_reference": diff --git a/sqllineage/core/parser/sqlfluff/extractors/ddl_drop_extractor.py b/sqllineage/core/parser/sqlfluff/extractors/ddl_drop_extractor.py new file mode 100644 index 00000000..8cb4e6db --- /dev/null +++ b/sqllineage/core/parser/sqlfluff/extractors/ddl_drop_extractor.py @@ -0,0 +1,41 @@ +from sqlfluff.core.parser import BaseSegment + +from sqllineage.core.holders import StatementLineageHolder, SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext +from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import ( + LineageHolderExtractor, +) +from sqllineage.core.parser.sqlfluff.models import SqlFluffTable + + +class DdlDropExtractor(LineageHolderExtractor): + """ + DDL Drop queries lineage extractor + """ + + SUPPORTED_STMT_TYPES = ["drop_table_statement"] + + def __init__(self, dialect: str): + super().__init__(dialect) + + def extract( + self, + statement: BaseSegment, + context: AnalyzerContext, + is_sub_query: bool = False, + ) -> SubQueryLineageHolder: + """ + Extract lineage for a given statement. + :param statement: a sqlfluff segment with a statement + :param context: 'AnalyzerContext' + :param is_sub_query: determine if the statement is bracketed or not + :return 'SubQueryLineageHolder' object + """ + holder = StatementLineageHolder() + for table in { + SqlFluffTable.of(t) + for t in statement.segments + if t.type == "table_reference" + }: + holder.add_drop(table) + return holder diff --git a/sqllineage/sqlfluff_core/subquery/dml_insert_extractor.py b/sqllineage/core/parser/sqlfluff/extractors/dml_insert_extractor.py similarity index 69% rename from sqllineage/sqlfluff_core/subquery/dml_insert_extractor.py rename to sqllineage/core/parser/sqlfluff/extractors/dml_insert_extractor.py index 03ea558a..6bea880e 100644 --- a/sqllineage/sqlfluff_core/subquery/dml_insert_extractor.py +++ b/sqllineage/core/parser/sqlfluff/extractors/dml_insert_extractor.py @@ -1,12 +1,15 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext, SqlFluffSubQuery -from sqllineage.sqlfluff_core.subquery.dml_select_extractor import DmlSelectExtractor -from sqllineage.sqlfluff_core.subquery.lineage_holder_extractor import ( +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext +from sqllineage.core.parser.sqlfluff.extractors.dml_select_extractor import ( + DmlSelectExtractor, +) +from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import ( LineageHolderExtractor, ) -from sqllineage.sqlfluff_core.utils.sqlfluff import retrieve_segments +from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import retrieve_segments class DmlInsertExtractor(LineageHolderExtractor): @@ -14,7 +17,7 @@ class DmlInsertExtractor(LineageHolderExtractor): DML Insert queries lineage extractor """ - DML_INSERT_STMT_TYPES = [ + SUPPORTED_STMT_TYPES = [ "insert_statement", "create_table_statement", "create_view_statement", @@ -28,25 +31,18 @@ class DmlInsertExtractor(LineageHolderExtractor): def __init__(self, dialect: str): super().__init__(dialect) - def can_extract(self, statement_type: str) -> bool: - """ - Determine if the current lineage holder extractor can process the statement - :param statement_type: a sqlfluff segment type - """ - return statement_type in self.DML_INSERT_STMT_TYPES - def extract( self, statement: BaseSegment, - context: SqlFluffAnalyzerContext, + context: AnalyzerContext, is_sub_query: bool = False, - ) -> SqlFluffSubQueryLineageHolder: + ) -> SubQueryLineageHolder: """ Extract lineage for a given statement. :param statement: a sqlfluff segment with a statement - :param context: 'SqlFluffAnalyzerContext' + :param context: 'AnalyzerContext' :param is_sub_query: determine if the statement is bracketed or not - :return 'SqlFluffSubQueryLineageHolder' object + :return 'SubQueryLineageHolder' object """ handlers, conditional_handlers = self._init_handlers() @@ -66,8 +62,8 @@ def extract( if segment.type == "select_statement": holder |= DmlSelectExtractor(self.dialect).extract( segment, - SqlFluffAnalyzerContext( - SqlFluffSubQuery(segment, None), + AnalyzerContext( + SqlFluffSubQuery.of(segment, None), prev_cte=holder.cte, prev_write=holder.write, ), @@ -80,8 +76,8 @@ def extract( if sub_segment.type == "select_statement": holder |= DmlSelectExtractor(self.dialect).extract( sub_segment, - SqlFluffAnalyzerContext( - SqlFluffSubQuery(segment, None), + AnalyzerContext( + SqlFluffSubQuery.of(segment, None), prev_cte=holder.cte, prev_write=holder.write, ), @@ -95,7 +91,7 @@ def extract( # By recursively extracting each subquery of the parent and merge, we're doing Depth-first search for sq in subqueries: holder |= DmlSelectExtractor(self.dialect).extract( - sq.segment, SqlFluffAnalyzerContext(sq, holder.cte) + sq.query, AnalyzerContext(sq, holder.cte) ) return holder diff --git a/sqllineage/sqlfluff_core/subquery/dml_select_extractor.py b/sqllineage/core/parser/sqlfluff/extractors/dml_select_extractor.py similarity index 62% rename from sqllineage/sqlfluff_core/subquery/dml_select_extractor.py rename to sqllineage/core/parser/sqlfluff/extractors/dml_select_extractor.py index 0ef197d5..0b8bd760 100644 --- a/sqllineage/sqlfluff_core/subquery/dml_select_extractor.py +++ b/sqllineage/core/parser/sqlfluff/extractors/dml_select_extractor.py @@ -1,11 +1,14 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext, SqlFluffSubQuery -from sqllineage.sqlfluff_core.subquery.lineage_holder_extractor import ( +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext +from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import ( LineageHolderExtractor, ) -from sqllineage.sqlfluff_core.utils.sqlfluff import retrieve_segments +from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import ( + retrieve_segments, +) class DmlSelectExtractor(LineageHolderExtractor): @@ -13,35 +16,32 @@ class DmlSelectExtractor(LineageHolderExtractor): DML Select queries lineage extractor """ - DML_SELECT_STMT_TYPES = ["select_statement"] + SUPPORTED_STMT_TYPES = ["select_statement", "set_expression"] def __init__(self, dialect: str): super().__init__(dialect) - def can_extract(self, statement_type: str) -> bool: - """ - Determine if the current lineage holder extractor can process the statement - :param statement_type: a sqlfluff segment type - """ - return statement_type in self.DML_SELECT_STMT_TYPES - def extract( self, statement: BaseSegment, - context: SqlFluffAnalyzerContext, + context: AnalyzerContext, is_sub_query: bool = False, - ) -> SqlFluffSubQueryLineageHolder: + ) -> SubQueryLineageHolder: """ Extract lineage for a given statement. :param statement: a sqlfluff segment with a statement - :param context: 'SqlFluffAnalyzerContext' + :param context: 'AnalyzerContext' :param is_sub_query: determine if the statement is bracketed or not - :return 'SqlFluffSubQueryLineageHolder' object + :return 'SubQueryLineageHolder' object """ handlers, conditional_handlers = self._init_handlers() holder = self._init_holder(context) subqueries = [SqlFluffSubQuery.of(statement, None)] if is_sub_query else [] - segments = retrieve_segments(statement) + segments = ( + [statement] + if statement.type == "set_expression" + else retrieve_segments(statement) + ) for segment in segments: for sq in self.parse_subquery(segment): # Collecting subquery on the way, hold on parsing until last @@ -61,9 +61,9 @@ def extract( # By recursively extracting each subquery of the parent and merge, we're doing Depth-first search for sq in subqueries: - holder |= self.extract(sq.segment, SqlFluffAnalyzerContext(sq, holder.cte)) + holder |= self.extract(sq.query, AnalyzerContext(sq, holder.cte)) for sq in holder.extra_subqueries: - holder |= self.extract(sq.segment, SqlFluffAnalyzerContext(sq, holder.cte)) + holder |= self.extract(sq.query, AnalyzerContext(sq, holder.cte)) return holder diff --git a/sqllineage/sqlfluff_core/subquery/lineage_holder_extractor.py b/sqllineage/core/parser/sqlfluff/extractors/lineage_holder_extractor.py similarity index 71% rename from sqllineage/sqlfluff_core/subquery/lineage_holder_extractor.py rename to sqllineage/core/parser/sqlfluff/extractors/lineage_holder_extractor.py index c4f53ab2..a9a5d78d 100644 --- a/sqllineage/sqlfluff_core/subquery/lineage_holder_extractor.py +++ b/sqllineage/core/parser/sqlfluff/extractors/lineage_holder_extractor.py @@ -1,66 +1,66 @@ -from abc import ABC, abstractmethod from functools import reduce from operator import add from typing import List, Tuple from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.handlers.base import ( +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext, SubQuery +from sqllineage.core.parser.sqlfluff.handlers.base import ( ConditionalSegmentBaseHandler, SegmentBaseHandler, ) -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext -from sqllineage.sqlfluff_core.models import SqlFluffSubQuery -from sqllineage.sqlfluff_core.utils.entities import SubSqlFluffQueryTuple -from sqllineage.sqlfluff_core.utils.sqlfluff import ( +from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import ( get_multiple_identifiers, get_subqueries, is_subquery, + is_union, ) +from sqllineage.utils.entities import SubQueryTuple -class LineageHolderExtractor(ABC): +class LineageHolderExtractor: """ - Abstract class implementation for extract 'SqlFluffSubQueryLineageHolder' from different statement types + Abstract class implementation for extract 'SubQueryLineageHolder' from different statement types """ + SUPPORTED_STMT_TYPES: List[str] = [] + def __init__(self, dialect: str): self.dialect = dialect - @abstractmethod def can_extract(self, statement_type: str) -> bool: """ Determine if the current lineage holder extractor can process the statement :param statement_type: a sqlfluff segment type """ - pass + return statement_type in self.SUPPORTED_STMT_TYPES - @abstractmethod def extract( self, statement: BaseSegment, - context: SqlFluffAnalyzerContext, + context: AnalyzerContext, is_sub_query: bool = False, - ) -> SqlFluffSubQueryLineageHolder: + ) -> SubQueryLineageHolder: """ Extract lineage for a given statement. :param statement: a sqlfluff segment with a statement - :param context: 'SqlFluffAnalyzerContext' + :param context: 'AnalyzerContext' :param is_sub_query: determine if the statement is bracketed or not - :return 'SqlFluffSubQueryLineageHolder' object + :return 'SubQueryLineageHolder' object """ - pass + raise NotImplementedError @classmethod - def parse_subquery(cls, segment: BaseSegment) -> List[SqlFluffSubQuery]: + def parse_subquery(cls, segment: BaseSegment) -> List[SubQuery]: """ The parse_subquery function takes a segment as an argument. :param segment: segment to determine if it is a subquery :return: A list of `SqlFluffSubQuery` objects, otherwise, if the segment is not matching any of the expected types it returns an empty list. """ - result: List[SqlFluffSubQuery] = [] + result: List[SubQuery] = [] identifiers = get_multiple_identifiers(segment) if identifiers and len(identifiers) > 1: # for SQL89 style of JOIN or multiple CTEs, this is actually SubQueries @@ -74,15 +74,13 @@ def parse_subquery(cls, segment: BaseSegment) -> List[SqlFluffSubQuery]: ) if segment.type in ["select_clause", "from_clause", "where_clause"]: result = cls._parse_subquery(get_subqueries(segment)) - elif is_subquery(segment): + elif is_subquery(segment) and not is_union(segment): # Parenthesis for SubQuery without alias, this is valid syntax for certain SQL dialect result = [SqlFluffSubQuery.of(segment, None)] return result @classmethod - def _parse_subquery( - cls, subqueries: List[SubSqlFluffQueryTuple] - ) -> List[SqlFluffSubQuery]: + def _parse_subquery(cls, subqueries: List[SubQueryTuple]) -> List[SubQuery]: """ Convert a list of 'SqlFluffSubQueryTuple' to 'SqlFluffSubQuery' :param subqueries: a list of 'SqlFluffSubQueryTuple' @@ -100,23 +98,23 @@ def _init_handlers( Initialize handlers used during the extraction of lineage :return: A tuple with a list of SegmentBaseHandler and ConditionalSegmentBaseHandler """ - handlers = [ + handlers: List[SegmentBaseHandler] = [ handler_cls() for handler_cls in SegmentBaseHandler.__subclasses__() ] conditional_handlers = [ - handler_cls(self.dialect) + handler_cls() for handler_cls in ConditionalSegmentBaseHandler.__subclasses__() ] return handlers, conditional_handlers @staticmethod - def _init_holder(context: SqlFluffAnalyzerContext) -> SqlFluffSubQueryLineageHolder: + def _init_holder(context: AnalyzerContext) -> SubQueryLineageHolder: """ - Initialize lineage holder for a given 'SqlFluffAnalyzerContext' + Initialize lineage holder for a given 'AnalyzerContext' :param context: a previous context that the lineage extractor must consider - :return: an initialized SqlFluffSubQueryLineageHolder + :return: an initialized SubQueryLineageHolder """ - holder = SqlFluffSubQueryLineageHolder() + holder = SubQueryLineageHolder() if context.prev_cte is not None: # CTE can be referenced by subsequent CTEs diff --git a/sqllineage/sqlfluff_core/subquery/noop_extractor.py b/sqllineage/core/parser/sqlfluff/extractors/noop_extractor.py similarity index 50% rename from sqllineage/sqlfluff_core/subquery/noop_extractor.py rename to sqllineage/core/parser/sqlfluff/extractors/noop_extractor.py index aa8d5e80..3d850248 100644 --- a/sqllineage/sqlfluff_core/subquery/noop_extractor.py +++ b/sqllineage/core/parser/sqlfluff/extractors/noop_extractor.py @@ -1,8 +1,8 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext -from sqllineage.sqlfluff_core.subquery.lineage_holder_extractor import ( +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext +from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import ( LineageHolderExtractor, ) @@ -12,7 +12,7 @@ class NoopExtractor(LineageHolderExtractor): Extractor for queries which do not provide any lineage """ - NOOP_STMT_TYPES = [ + SUPPORTED_STMT_TYPES = [ "delete_statement", "truncate_table", "refresh_statement", @@ -25,24 +25,17 @@ class NoopExtractor(LineageHolderExtractor): def __init__(self, dialect: str): super().__init__(dialect) - def can_extract(self, statement_type: str) -> bool: - """ - Determine if the current lineage holder extractor can process the statement - :param statement_type: a sqlfluff segment type - """ - return statement_type in self.NOOP_STMT_TYPES - def extract( self, statement: BaseSegment, - context: SqlFluffAnalyzerContext, + context: AnalyzerContext, is_sub_query: bool = False, - ) -> SqlFluffSubQueryLineageHolder: + ) -> SubQueryLineageHolder: """ Extract lineage for a given statement. :param statement: a sqlfluff segment with a statement - :param context: 'SqlFluffAnalyzerContext' + :param context: 'AnalyzerContext' :param is_sub_query: determine if the statement is bracketed or not - :return 'SqlFluffSubQueryLineageHolder' object + :return 'SubQueryLineageHolder' object """ - return SqlFluffSubQueryLineageHolder() + return SubQueryLineageHolder() diff --git a/sqllineage/core/handlers/__init__.py b/sqllineage/core/parser/sqlfluff/handlers/__init__.py similarity index 100% rename from sqllineage/core/handlers/__init__.py rename to sqllineage/core/parser/sqlfluff/handlers/__init__.py diff --git a/sqllineage/sqlfluff_core/handlers/base.py b/sqllineage/core/parser/sqlfluff/handlers/base.py similarity index 55% rename from sqllineage/sqlfluff_core/handlers/base.py rename to sqllineage/core/parser/sqlfluff/handlers/base.py index 9d503444..2f172633 100644 --- a/sqllineage/sqlfluff_core/handlers/base.py +++ b/sqllineage/core/parser/sqlfluff/handlers/base.py @@ -1,6 +1,6 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder +from sqllineage.core.holders import SubQueryLineageHolder class ConditionalSegmentBaseHandler: @@ -8,20 +8,11 @@ class ConditionalSegmentBaseHandler: Extract lineage from a segment when the segment match the condition """ - def __init__(self, dialect: str) -> None: - """ - :param dialect: dialect of the handler - """ - self.indicator = False - self.dialect = dialect - - def handle( - self, segment: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> None: + def handle(self, segment: BaseSegment, holder: SubQueryLineageHolder) -> None: """ Handle the segment, and update the lineage result accordingly in the holder :param segment: segment to be handled - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage + :param holder: 'SubQueryLineageHolder' to hold lineage """ raise NotImplementedError @@ -31,12 +22,12 @@ def indicate(self, segment: BaseSegment) -> bool: :param segment: segment to be handled :return: True if it can be handled, by default return False """ - return False + raise NotImplementedError - def end_of_query_cleanup(self, holder: SqlFluffSubQueryLineageHolder) -> None: + def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None: """ Optional method to be called at the end of statement or subquery - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage + :param holder: 'SubQueryLineageHolder' to hold lineage """ pass @@ -46,12 +37,9 @@ class SegmentBaseHandler: Extract lineage from a specific segment """ - def handle( - self, segment: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> None: + def handle(self, segment: BaseSegment, holder: SubQueryLineageHolder) -> None: """ - :param segment: segment to be handled - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage + :param holder: 'SubQueryLineageHolder' to hold lineage """ raise NotImplementedError diff --git a/sqllineage/core/parser/sqlfluff/handlers/source.py b/sqllineage/core/parser/sqlfluff/handlers/source.py new file mode 100644 index 00000000..da4e6549 --- /dev/null +++ b/sqllineage/core/parser/sqlfluff/handlers/source.py @@ -0,0 +1,170 @@ +from typing import Union + +from sqlfluff.core.parser import BaseSegment + +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import Path, SubQuery, Table +from sqllineage.core.parser import SourceHandlerMixin +from sqllineage.core.parser.sqlfluff.handlers.base import ConditionalSegmentBaseHandler +from sqllineage.core.parser.sqlfluff.models import ( + SqlFluffColumn, + SqlFluffSubQuery, +) +from sqllineage.core.parser.sqlfluff.models import ( + SqlFluffTable, +) +from sqllineage.core.parser.sqlfluff.utils.holder import retrieve_holder_data_from +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import ( + find_table_identifier, + get_grandchild, + get_inner_from_expression, + get_multiple_identifiers, + get_subqueries, + is_union, + is_values_clause, + retrieve_extra_segment, + retrieve_segments, +) + + +class SourceHandler(SourceHandlerMixin, ConditionalSegmentBaseHandler): + """ + Source table and column handler + """ + + def __init__(self): + self.columns = [] + self.tables = [] + self.union_barriers = [] + + def indicate(self, segment: BaseSegment) -> bool: + """ + Indicates if the handler can handle the segment + :param segment: segment to be processed + :return: True if it can be handled + """ + return ( + self._indicate_column(segment) + or self._indicate_table(segment) + or is_union(segment) + ) + + def handle(self, segment: BaseSegment, holder: SubQueryLineageHolder) -> None: + """ + Handle the segment, and update the lineage result accordingly in the holder + :param segment: segment to be handled + :param holder: 'SubQueryLineageHolder' to hold lineage + """ + if self._indicate_table(segment): + self._handle_table(segment, holder) + elif is_union(segment): + self._handle_union(segment) + if self._indicate_column(segment): + self._handle_column(segment) + + def _handle_table( + self, segment: BaseSegment, holder: SubQueryLineageHolder + ) -> None: + """ + Table handler method + :param segment: segment to be handled + :param holder: 'SubQueryLineageHolder' to hold lineage + """ + identifiers = get_multiple_identifiers(segment) + if identifiers and len(identifiers) > 1: + for identifier in identifiers: + self._add_dataset_from_expression_element(identifier, holder) + from_segment = get_inner_from_expression(segment) + if from_segment.type == "from_expression_element": + self._add_dataset_from_expression_element(from_segment, holder) + for extra_segment in retrieve_extra_segment(segment): + self._handle_table(extra_segment, holder) + + def _handle_column(self, segment: BaseSegment) -> None: + """ + Column handler method + :param segment: segment to be handled + """ + sub_segments = retrieve_segments(segment) + for sub_segment in sub_segments: + if sub_segment.type == "select_clause_element": + self.columns.append(SqlFluffColumn.of(sub_segment)) + + def _handle_union(self, segment: BaseSegment) -> None: + """ + Union handler method + :param segment: segment to be handled + """ + subqueries = get_subqueries(segment) + if subqueries: + for idx, sq in enumerate(subqueries): + if idx != 0: + self.union_barriers.append((len(self.columns), len(self.tables))) + subquery, alias = sq + table_identifier = find_table_identifier(subquery) + if table_identifier: + read_sq = SqlFluffTable.of(table_identifier, alias) + segments = retrieve_segments(subquery) + for seg in segments: + if seg.type == "select_clause": + self._handle_column(seg) + self.tables.append(read_sq) + + def _add_dataset_from_expression_element( + self, segment: BaseSegment, holder: SubQueryLineageHolder + ) -> None: + """ + Append tables and subqueries identified in the 'from_expression_element' type segment to the table and + holder extra subqueries sets + :param segment: 'from_expression_element' type segment + :param holder: 'SubQueryLineageHolder' to hold lineage + """ + dataset: Union[Path, Table, SubQuery] + all_segments = [ + seg for seg in retrieve_segments(segment) if seg.type != "keyword" + ] + first_segment = all_segments[0] + function_as_table = get_grandchild(segment, "table_expression", "function") + if first_segment.type == "function" or function_as_table: + # function() as alias, no dataset involved + return + elif first_segment.type == "bracketed" and is_values_clause(first_segment): + # (VALUES ...) AS alias, no dataset involved + return + elif is_union(segment): + subqueries = get_subqueries(segment) + subquery, alias = subqueries[0] + self.tables.append(SqlFluffSubQuery.of(subquery, alias)) + else: + subqueries = get_subqueries(segment) + if subqueries: + for sq in subqueries: + bracketed, alias = sq + read_sq = SqlFluffSubQuery.of(bracketed, alias) + holder.extra_subqueries.add(read_sq) + self.tables.append(read_sq) + else: + table_identifier = find_table_identifier(segment) + if table_identifier: + dataset = retrieve_holder_data_from( + all_segments, holder, table_identifier + ) + self.tables.append(dataset) + + @staticmethod + def _indicate_column(segment: BaseSegment) -> bool: + """ + Check if it is a column + :param segment: segment to be checked + :return: True if type is 'select_clause' + """ + return bool(segment.type == "select_clause") + + @staticmethod + def _indicate_table(segment: BaseSegment) -> bool: + """ + Check if it is a table + :param segment: segment to be checked + :return: True if type is 'from_clause' + """ + return bool(segment.type == "from_clause") diff --git a/sqllineage/sqlfluff_core/handlers/swap_partition.py b/sqllineage/core/parser/sqlfluff/handlers/swap_partition.py similarity index 74% rename from sqllineage/sqlfluff_core/handlers/swap_partition.py rename to sqllineage/core/parser/sqlfluff/handlers/swap_partition.py index ddd9ba53..24d9423c 100644 --- a/sqllineage/sqlfluff_core/handlers/swap_partition.py +++ b/sqllineage/core/parser/sqlfluff/handlers/swap_partition.py @@ -1,9 +1,12 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.handlers.base import SegmentBaseHandler -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import SqlFluffTable -from sqllineage.sqlfluff_core.utils.sqlfluff import get_grandchild, get_grandchildren +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.parser.sqlfluff.handlers.base import SegmentBaseHandler +from sqllineage.core.parser.sqlfluff.models import SqlFluffTable +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import ( + get_grandchild, + get_grandchildren, +) from sqllineage.utils.helpers import escape_identifier_name @@ -12,9 +15,7 @@ class SwapPartitionHandler(SegmentBaseHandler): A handler for swap_partitions_between_tables function """ - def handle( - self, segment: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> None: + def handle(self, segment: BaseSegment, holder: SubQueryLineageHolder) -> None: """ Handle the segment, and update the lineage result accordingly in the holder :param segment: segment to be handled diff --git a/sqllineage/sqlfluff_core/handlers/target.py b/sqllineage/core/parser/sqlfluff/handlers/target.py similarity index 64% rename from sqllineage/sqlfluff_core/handlers/target.py rename to sqllineage/core/parser/sqlfluff/handlers/target.py index b8297705..19737c7b 100644 --- a/sqllineage/sqlfluff_core/handlers/target.py +++ b/sqllineage/core/parser/sqlfluff/handlers/target.py @@ -1,16 +1,13 @@ -from typing import Optional, Union - from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.handlers.base import ConditionalSegmentBaseHandler -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import ( - SqlFluffPath, - SqlFluffSubQuery, +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import Path +from sqllineage.core.parser.sqlfluff.handlers.base import ConditionalSegmentBaseHandler +from sqllineage.core.parser.sqlfluff.models import ( SqlFluffTable, ) -from sqllineage.sqlfluff_core.utils.holder import retrieve_holder_data_from -from sqllineage.sqlfluff_core.utils.sqlfluff import ( +from sqllineage.core.parser.sqlfluff.utils.holder import retrieve_holder_data_from +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import ( find_table_identifier, get_child, retrieve_segments, @@ -23,8 +20,8 @@ class TargetHandler(ConditionalSegmentBaseHandler): Target table handler """ - def __init__(self, dialect: str) -> None: - super().__init__(dialect) + def __init__(self) -> None: + self.indicator = False self.prev_token_like = False self.prev_token_from = False @@ -66,27 +63,22 @@ def indicate(self, segment: BaseSegment) -> bool: :param segment: segment to be processed :return: True if it can be handled """ - if ( - self.indicator is True - or segment.type == "keyword" - and segment.raw_upper in self.TARGET_KEYWORDS + if self.indicator is True or ( + segment.type == "keyword" and segment.raw_upper in self.TARGET_KEYWORDS ): self.indicator = True self._init_tokens(segment) return self.indicator return False - def handle( - self, segment: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> None: + def handle(self, segment: BaseSegment, holder: SubQueryLineageHolder) -> None: """ Handle the segment, and update the lineage result accordingly in the holder :param segment: segment to be handled :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage """ - object_segment = self._extract_table_reference(segment, holder) - if segment.type == "table_reference" or object_segment: - write_obj = object_segment if object_segment else SqlFluffTable.of(segment) + if segment.type == "table_reference": + write_obj = SqlFluffTable.of(segment) if self.prev_token_like: holder.add_read(write_obj) else: @@ -95,9 +87,9 @@ def handle( elif segment.type in {"literal", "storage_location"}: if self.prev_token_from: - holder.add_read(SqlFluffPath(escape_identifier_name(segment.raw))) + holder.add_read(Path(escape_identifier_name(segment.raw))) else: - holder.add_write(SqlFluffPath(escape_identifier_name(segment.raw))) + holder.add_write(Path(escape_identifier_name(segment.raw))) self._reset_tokens() elif segment.type == "from_expression": @@ -129,20 +121,3 @@ def handle( ) if read: holder.add_read(read) - - @staticmethod - def _extract_table_reference( - object_reference: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> Optional[Union[SqlFluffTable, SqlFluffSubQuery]]: - """ - :param object_reference: object reference segment - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage - :return: a 'SqlFluffTable' or 'SqlFluffSubQuery' from the object_reference - """ - if object_reference and object_reference.type == "object_reference": - return retrieve_holder_data_from( - object_reference.segments, - holder, - get_child(object_reference, "identifier"), - ) - return None diff --git a/sqllineage/core/parser/sqlfluff/models.py b/sqllineage/core/parser/sqlfluff/models.py new file mode 100644 index 00000000..5853fe7b --- /dev/null +++ b/sqllineage/core/parser/sqlfluff/models.py @@ -0,0 +1,225 @@ +from typing import List +from typing import Optional, Tuple + +from sqlfluff.core.parser import BaseSegment + +from sqllineage import SQLPARSE_DIALECT +from sqllineage.core.models import Column, Schema, SubQuery, Table +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import ( + get_identifier, + is_subquery, + is_wildcard, + retrieve_segments, + token_matching, +) +from sqllineage.utils.entities import ColumnQualifierTuple +from sqllineage.utils.helpers import escape_identifier_name + +NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE = [ + "function", + "over_clause", + "partitionby_clause", + "orderby_clause", + "expression", + "case_expression", + "when_clause", + "else_clause", + "select_clause_element", +] + +SOURCE_COLUMN_SEGMENT_TYPE = NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + [ + "identifier", + "column_reference", +] + + +class SqlFluffTable(Table): + """ + Data Class for SqlFluffTable + """ + + @staticmethod + def of(table: BaseSegment, alias: Optional[str] = None) -> Table: + """ + Build an object of type 'Table' + :param table: table segment to be processed + :param alias: alias of the table segment + :return: 'Table' object + """ + # rewrite identifier's get_real_name method, by matching the last dot instead of the first dot, so that the + # real name for a.b.c will be c instead of b + dot_idx, _ = token_matching( + table, + (lambda s: bool(s.type == "symbol"),), + start=len(table.segments), + reverse=True, + ) + real_name = ( + table.segments[dot_idx + 1].raw + if dot_idx + else (table.raw if table.type == "identifier" else table.segments[0].raw) + ) + # rewrite identifier's get_parent_name accordingly + parent_name = ( + "".join( + [ + escape_identifier_name(segment.raw) + for segment in table.segments[:dot_idx] + ] + ) + if dot_idx + else None + ) + schema = Schema(parent_name) if parent_name is not None else Schema() + kwargs = {"alias": alias} if alias else {} + return Table(real_name, schema, **kwargs) + + +class SqlFluffSubQuery(SubQuery): + """ + Data Class for SqlFluffSubQuery + """ + + @staticmethod + def of(subquery: BaseSegment, alias: Optional[str]) -> SubQuery: + """ + Build a 'SubQuery' object + :param subquery: subquery segment + :param alias: subquery alias + :return: 'SubQuery' object + """ + return SubQuery(subquery, subquery.raw, alias) + + +class SqlFluffColumn(Column): + """ + Data Class for SqlFluffColumn + """ + + @staticmethod + def of(column: BaseSegment, **kwargs) -> Column: + """ + Build a 'SqlFluffSubQuery' object + :param column: column segment + :return: + """ + if column.type == "select_clause_element": + source_columns, alias = SqlFluffColumn._get_column_and_alias(column) + if alias: + return Column( + alias, + source_columns=source_columns, + ) + if source_columns: + sub_segments = retrieve_segments(column) + column_name = None + for sub_segment in sub_segments: + if sub_segment.type == "column_reference": + column_name = get_identifier(sub_segment) + + return Column( + column.raw if column_name is None else column_name, + source_columns=source_columns, + ) + + # Wildcard, Case, Function without alias (thus not recognized as an Identifier) + source_columns = SqlFluffColumn._extract_source_columns(column) + return Column( + column.raw, + source_columns=source_columns, + ) + + @staticmethod + def _extract_source_columns(segment: BaseSegment) -> List[ColumnQualifierTuple]: + """ + :param segment: segment to be processed + :return: list of extracted source columns + """ + if segment.type == "identifier" or is_wildcard(segment): + return [ColumnQualifierTuple(segment.raw, None)] + if segment.type == "column_reference": + parent, column = SqlFluffColumn._get_column_and_parent(segment) + return [ColumnQualifierTuple(column, parent)] + if segment.type in NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE: + sub_segments = retrieve_segments(segment) + col_list = [] + for sub_segment in sub_segments: + if sub_segment.type == "bracketed": + if is_subquery(sub_segment): + col_list += SqlFluffColumn._get_column_from_subquery( + sub_segment + ) + else: + col_list += SqlFluffColumn._get_column_from_parenthesis( + sub_segment + ) + elif sub_segment.type in SOURCE_COLUMN_SEGMENT_TYPE or is_wildcard( + sub_segment + ): + res = SqlFluffColumn._extract_source_columns(sub_segment) + col_list.extend(res) + return col_list + return [] + + @staticmethod + def _get_column_from_subquery( + sub_segment: BaseSegment, + ) -> List[ColumnQualifierTuple]: + """ + :param sub_segment: segment to be processed + :return: A list of source columns from a segment + """ + # This is to avoid circular import + from sqllineage.runner import LineageRunner + + src_cols = [ + lineage[0] + for lineage in LineageRunner( + sub_segment.raw, + dialect=SQLPARSE_DIALECT, + ).get_column_lineage(exclude_subquery=False) + ] + source_columns = [ + ColumnQualifierTuple(src_col.raw_name, src_col.parent.raw_name) + for src_col in src_cols + ] + return source_columns + + @staticmethod + def _get_column_from_parenthesis( + sub_segment: BaseSegment, + ) -> List[ColumnQualifierTuple]: + """ + :param sub_segment: segment to be processed + :return: list of columns and alias from the segment + """ + col, _ = SqlFluffColumn._get_column_and_alias(sub_segment) + if col: + return col + col, _ = SqlFluffColumn._get_column_and_alias(sub_segment, False) + return col if col else [] + + @staticmethod + def _get_column_and_alias( + segment: BaseSegment, check_bracketed: bool = True + ) -> Tuple[List[ColumnQualifierTuple], Optional[str]]: + alias = None + columns = [] + sub_segments = retrieve_segments(segment, check_bracketed) + for sub_segment in sub_segments: + if sub_segment.type == "alias_expression": + alias = get_identifier(sub_segment) + elif sub_segment.type in SOURCE_COLUMN_SEGMENT_TYPE or is_wildcard( + sub_segment + ): + res = SqlFluffColumn._extract_source_columns(sub_segment) + columns += res if res else [] + + return columns, alias + + @staticmethod + def _get_column_and_parent(col_segment: BaseSegment) -> Tuple[Optional[str], str]: + identifiers = retrieve_segments(col_segment) + if len(identifiers) > 1: + return identifiers[-2].raw, identifiers[-1].raw + return None, identifiers[-1].raw diff --git a/sqllineage/sqlfluff_core/subquery/__init__.py b/sqllineage/core/parser/sqlfluff/utils/__init__.py similarity index 100% rename from sqllineage/sqlfluff_core/subquery/__init__.py rename to sqllineage/core/parser/sqlfluff/utils/__init__.py diff --git a/sqllineage/sqlfluff_core/utils/holder.py b/sqllineage/core/parser/sqlfluff/utils/holder.py similarity index 59% rename from sqllineage/sqlfluff_core/utils/holder.py rename to sqllineage/core/parser/sqlfluff/utils/holder.py index a555dcd2..b232c103 100644 --- a/sqllineage/sqlfluff_core/utils/holder.py +++ b/sqllineage/core/parser/sqlfluff/utils/holder.py @@ -2,23 +2,25 @@ from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import SqlFluffSubQuery, SqlFluffTable -from sqllineage.sqlfluff_core.utils.sqlfluff import get_table_alias +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.models import Path, SubQuery, Table +from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery, SqlFluffTable +from sqllineage.core.parser.sqlfluff.utils.sqlfluff import get_table_alias +from sqllineage.utils.helpers import escape_identifier_name def retrieve_holder_data_from( segments: List[BaseSegment], - holder: SqlFluffSubQueryLineageHolder, + holder: SubQueryLineageHolder, table_identifier: BaseSegment, -) -> Union[SqlFluffTable, SqlFluffSubQuery]: +) -> Union[Path, SubQuery, Table]: """ Build a 'SqlFluffSubquery' or 'SqlFluffTable' for a given list of segments and a table identifier segment. It will use the list of segments to find an alias and the holder CTE set of 'SqlFluffSubQuery'. :param segments: list of segments to search for an alias :param holder: 'SqlFluffSubQueryLineageHolder' to use the CTE set of 'SqlFluffSubQuery' :param table_identifier: a table identifier segment - :return: 'SqlFluffSubQuery' or 'SqlFluffTable' object + :return: 'Path' or 'SqlFluffSubQuery' or 'SqlFluffTable' object """ data = None alias = get_table_alias(segments) @@ -28,9 +30,12 @@ def retrieve_holder_data_from( if cte is not None: # could reference CTE with or without alias data = SqlFluffSubQuery.of( - cte.segment, + cte.query, alias or table_identifier.raw, ) if data is None: - return SqlFluffTable.of(table_identifier, alias=alias) + if table_identifier.type == "file_reference": + return Path(escape_identifier_name(table_identifier.segments[-1].raw)) + else: + return SqlFluffTable.of(table_identifier, alias=alias) return data diff --git a/sqllineage/sqlfluff_core/utils/sqlfluff.py b/sqllineage/core/parser/sqlfluff/utils/sqlfluff.py similarity index 77% rename from sqllineage/sqlfluff_core/utils/sqlfluff.py rename to sqllineage/core/parser/sqlfluff/utils/sqlfluff.py index 0d710b0a..e1658e01 100644 --- a/sqllineage/sqlfluff_core/utils/sqlfluff.py +++ b/sqllineage/core/parser/sqlfluff/utils/sqlfluff.py @@ -1,12 +1,13 @@ """ Utils class to deal with the sqlfluff segments manipulations """ -from typing import Callable, Iterable, List, Optional, Tuple, Union +import re +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union from sqlfluff.core.linter import ParsedString from sqlfluff.core.parser import BaseSegment -from sqllineage.sqlfluff_core.utils.entities import SubSqlFluffQueryTuple +from sqllineage.utils.entities import SubQueryTuple def is_segment_negligible(segment: BaseSegment) -> bool: @@ -24,106 +25,92 @@ def is_segment_negligible(segment: BaseSegment) -> bool: def get_bracketed_subqueries_select( segment: BaseSegment, -) -> List[SubSqlFluffQueryTuple]: +) -> List[SubQueryTuple]: """ - Retrieve a list of 'SubSqlFluffQueryTuple' for a given segment of type "select_clause" + Retrieve a list of 'SubQueryTuple' for a given segment of type "select_clause" :param segment: segment to be processed segment to be processed - :return: list is either empty when no subquery parsed or list of 'SubSqlFluffQueryTuple' + :return: list is either empty when no subquery parsed or list of 'SubQueryTuple' """ subquery = [] as_segment = segment.get_child("select_clause_element").get_child( "alias_expression" ) select_clause = segment.get_child("select_clause_element") - sublist = list( - [ - seg - for seg in select_clause.segments - if not is_segment_negligible(seg) and seg.type != "table_expression" - ] - ) - if as_segment is not None and len(sublist) == 1: - # CTE: tbl AS (SELECT 1) - target = sublist[0] - else: - case_expression = select_clause.get_child( - "expression" - ) and select_clause.get_child("expression").get_child("case_expression") - target = case_expression or select_clause.get_child("column_reference") + case_expression = select_clause.get_child("expression") and select_clause.get_child( + "expression" + ).get_child("case_expression") + target = case_expression or select_clause.get_child("column_reference") if target and target.type == "case_expression": for when_clause in target.get_children("when_clause"): for bracketed_segment in get_bracketed_from( when_clause, to_keyword="THEN", children_segments="expression" ): - subquery.append(SubSqlFluffQueryTuple(bracketed_segment, None)) + subquery.append(SubQueryTuple(bracketed_segment, None)) for bracketed_segment in get_bracketed_from( when_clause, from_keyword="THEN", children_segments="expression" ): subquery.append( - SubSqlFluffQueryTuple(bracketed_segment, get_identifier(as_segment)) - ) - for else_clause in target.get_children("else_clause"): - for bracketed_segment in get_bracketed_from( - else_clause, children_segments="expression" - ): - subquery.append( - SubSqlFluffQueryTuple(bracketed_segment, get_identifier(as_segment)) + SubQueryTuple(bracketed_segment, get_identifier(as_segment)) ) - if target and is_subquery(target): - subquery = [ - SubSqlFluffQueryTuple(get_innermost_bracketed(target), as_segment.raw) - ] return subquery -def get_bracketed_subqueries_from( - segment: BaseSegment, skip_union: bool = True -) -> List[SubSqlFluffQueryTuple]: +def get_bracketed_subqueries_from(segment: BaseSegment) -> List[SubQueryTuple]: """ - Retrieve a list of 'SubSqlFluffQueryTuple' for a given segment of type "from_" + Retrieve a list of 'SubQueryTuple' for a given segment of type "from_" :param segment: segment to be processed - :param skip_union: do not search for subqueries if the segment is part or contains a UNION query - :return: a list of 'SubSqlFluffQueryTuple' + :return: a list of 'SubQueryTuple' """ subquery = [] as_segment, target = extract_as_and_target_segment( get_inner_from_expression(segment) ) - if not skip_union and is_union(target): - for sq in get_union_subqueries(target): - subquery.append( - SubSqlFluffQueryTuple( - sq, - get_identifier(as_segment) if as_segment else None, - ) - ) - elif is_subquery(target): + if is_subquery(target): as_segment, target = extract_as_and_target_segment( get_inner_from_expression(segment) ) subquery = [ - SubSqlFluffQueryTuple( - get_innermost_bracketed(target), + SubQueryTuple( + get_innermost_bracketed(target) if not is_union(target) else target, get_identifier(as_segment) if as_segment else None, ) ] return subquery +def get_subqueries_union(segment: BaseSegment) -> List[SubQueryTuple]: + """ + Retrieve a list of 'SubQueryTuple' for a given segment of type "union" + :param segment: segment to be processed + :return: a list of 'SubQueryTuple' + """ + subquery = [] + for sq in get_union_subqueries( + segment + if segment.type == "set_expression" + else segment.get_child("set_expression") + ): + subquery.append( + SubQueryTuple( + sq, + None, + ) + ) + return subquery + + def get_bracketed_subqueries_where( segment: BaseSegment, -) -> List[SubSqlFluffQueryTuple]: +) -> List[SubQueryTuple]: """ - Retrieve a list of 'SubSqlFluffQueryTuple' for a given segment of type "where_clause" + Retrieve a list of 'SubQueryTuple' for a given segment of type "where_clause" :param segment: segment to be processed - :return: a list of 'SubSqlFluffQueryTuple' + :return: a list of 'SubQueryTuple' """ expression_segments = segment.get_child("expression").segments or [] bracketed_segments = [seg for seg in expression_segments if seg.type == "bracketed"] if bracketed_segments and is_subquery(bracketed_segments[0]): - return [ - SubSqlFluffQueryTuple(get_innermost_bracketed(bracketed_segments[0]), None) - ] + return [SubQueryTuple(get_innermost_bracketed(bracketed_segments[0]), None)] return [] @@ -136,30 +123,24 @@ def extract_as_and_target_segment( """ as_segment = segment.get_child("alias_expression") sublist = list([seg for seg in segment.segments if not is_segment_negligible(seg)]) - if as_segment is not None and len(sublist) == 1: - target = sublist[0] - else: - target = sublist[0] if is_subquery(sublist[0]) else sublist[0].segments[0] + target = sublist[0] if is_subquery(sublist[0]) else sublist[0].segments[0] return as_segment, target -def get_subqueries( - segment: BaseSegment, skip_union: bool = True -) -> List[SubSqlFluffQueryTuple]: +def get_subqueries(segment: BaseSegment) -> List[SubQueryTuple]: """ - Retrieve a list of 'SubSqlFluffQueryTuple' based on the type of the segment. + Retrieve a list of 'SubQueryTuple' based on the type of the segment. :param segment: segment to be processed - :param skip_union: do not search for subqueries if the segment is part or contains a UNION query - :return: a list of 'SubSqlFluffQueryTuple' + :return: a list of 'SubQueryTuple' """ if segment.type in ["select_clause"]: return get_bracketed_subqueries_select(segment) elif segment.type in ["from_clause", "from_expression", "from_expression_element"]: - return get_bracketed_subqueries_from(segment, skip_union) - elif segment.type in ["join_clause"]: - return [] + return get_bracketed_subqueries_from(segment) elif segment.type in ["where_clause"]: return get_bracketed_subqueries_where(segment) + elif is_union(segment): + return get_subqueries_union(segment) else: raise NotImplementedError() @@ -219,11 +200,6 @@ def is_values_clause(segment: BaseSegment) -> bool: "table_expression" ).get_child("values_clause"): return True - for s in segment.segments: - if is_segment_negligible(s): - continue - if s.type == "values_clause": - return True return False @@ -310,12 +286,12 @@ def get_bracketed_from( def find_table_identifier(identifier: BaseSegment) -> Optional[BaseSegment]: """ :param identifier: segment to be processed - :return: a "table_reference" type segment if it exists in the identifier's children list, otherwise the identifier + :return: a table_reference or file_reference type segment if it exists in children list, otherwise the identifier """ table_identifier = None if identifier.segments: for segment in identifier.segments: - if segment.type == "table_reference": + if segment.type in ("table_reference", "file_reference"): return segment else: table_identifier = find_table_identifier(segment) @@ -333,7 +309,9 @@ def retrieve_segments( :param check_bracketed: process segment if it is of type "bracketed" :return: a list of segments """ - if segment.type == "bracketed" and check_bracketed: + if segment.type == "bracketed" and is_union(segment): + return [segment] + elif segment.type == "bracketed" and check_bracketed: segments = [ sg for sg in segment.iter_segments(expanding=["expression"], pass_through=True) @@ -400,36 +378,6 @@ def has_alias(segment: BaseSegment) -> bool: return len([s for s in segment.get_children("keyword") if s.raw_upper == "AS"]) > 0 -def is_union(segment: BaseSegment) -> bool: - """ - :param segment: segment to be processed - :return: True if the segment contains 'UNION' or 'UNION ALL' keyword - """ - sub_segments = retrieve_segments(segment, check_bracketed=True) - return ( - len( - [ - s - for s in sub_segments - if s.type == "set_operator" - and (s.raw_upper == "UNION" or s.raw_upper == "UNION ALL") - ] - ) - > 0 - ) - - -def get_union_subqueries(segment: BaseSegment) -> List[BaseSegment]: - """ - :param segment: segment to be processed - :return: a list of subqueries or select statements from a UNION segment - """ - sub_segments = retrieve_segments(segment, check_bracketed=True) - return [ - s for s in sub_segments if s.type == "bracketed" or s.type == "select_statement" - ] - - def token_matching( segment: BaseSegment, funcs: Tuple[Callable[[BaseSegment], bool]], @@ -511,9 +459,7 @@ def get_child(segment: BaseSegment, child: str) -> BaseSegment: return segment.get_child(child) -def get_grandchildren( - segment: BaseSegment, child: str, grandchildren: str -) -> List[BaseSegment]: +def get_grandchildren(segment: BaseSegment, child: str, grandchildren: str) -> Any: """ :param segment: segment to be processed :param child: child segment @@ -527,23 +473,71 @@ def get_grandchildren( ) -def get_statement_segment(parsed_string: ParsedString) -> Optional[BaseSegment]: +def get_statement_segment(parsed_string: ParsedString) -> BaseSegment: """ :param parsed_string: parsed string :return: first segment from the statement segment of the segments of parsed_string """ - try: - if parsed_string.tree: - return next( - ( - x.segments[0] - if x.type == "statement" - else x.get_child("statement").segments[0] - for x in parsed_string.tree.segments - if x.type == "statement" or x.type == "batch" - ), - None, - ) - except AttributeError: - return None - return None + return next( + ( + x.segments[0] + if x.type == "statement" + else x.get_child("statement").segments[0] + for x in getattr(parsed_string.tree, "segments") + if x.type == "statement" or x.type == "batch" + ) + ) + + +def is_union(segment: BaseSegment) -> bool: + """ + :param segment: segment to be processed + :return: True if the segment contains 'UNION' or 'UNION ALL' keyword + """ + return ( + len( + [ + s + for s in segment.raw_segments + if (s.raw_upper == "UNION" or s.raw_upper == "UNION ALL") + ] + ) + > 0 + ) + + +def get_union_subqueries(segment: BaseSegment) -> List[BaseSegment]: + """ + :param segment: segment to be processed + :return: a list of subqueries or select statements from a UNION segment + """ + sub_segments = retrieve_segments(segment, check_bracketed=True) + return [ + s for s in sub_segments if s.type == "bracketed" or s.type == "select_statement" + ] + + +def is_subquery_statement(stmt: str) -> bool: + parentheses_regex = r"^\(.*\)" + return bool(re.match(parentheses_regex, stmt)) + + +def remove_statement_parentheses(stmt: str) -> str: + parentheses_regex = r"^\((.*)\)" + return re.sub(parentheses_regex, r"\1", stmt) + + +def clean_parentheses(stmt: str) -> str: + """ + Clean redundant parentheses from a SQL statement e.g: + `SELECT col1 FROM (((((((SELECT col1 FROM tab1))))))) dt` + will be: + `SELECT col1 FROM (SELECT col1 FROM tab1) dt` + + :param stmt: a SQL str to be cleaned + """ + redundant_parentheses = r"\(\(([^()]+)\)\)" + if re.findall(redundant_parentheses, stmt): + stmt = re.sub(redundant_parentheses, r"(\1)", stmt) + stmt = clean_parentheses(stmt) + return stmt diff --git a/sqllineage/core/parser/sqlparse/__init__.py b/sqllineage/core/parser/sqlparse/__init__.py new file mode 100644 index 00000000..a988eb73 --- /dev/null +++ b/sqllineage/core/parser/sqlparse/__init__.py @@ -0,0 +1,34 @@ +import re + +from sqlparse import tokens +from sqlparse.engine import grouping +from sqlparse.keywords import KEYWORDS, SQL_REGEX + +from sqllineage.core.parser.sqlparse.utils.sqlparse import group_function_with_window + + +def _patch_adding_window_function_token() -> None: + grouping.group_functions = group_function_with_window + + +def _patch_adding_builtin_type() -> None: + KEYWORDS["STRING"] = tokens.Name.Builtin + KEYWORDS["DATETIME"] = tokens.Name.Builtin + + +def _patch_updating_lateral_view_lexeme() -> None: + for i, (regex, lexeme) in enumerate(SQL_REGEX): + if regex("LATERAL VIEW EXPLODE(col)"): + new_regex = r"(LATERAL\s+VIEW\s+)(OUTER\s+)?(EXPLODE|INLINE|PARSE_URL_TUPLE|POSEXPLODE|STACK|JSON_TUPLE)\b" + new_compile = re.compile(new_regex, re.IGNORECASE | re.UNICODE).match + SQL_REGEX[i] = (new_compile, lexeme) + break + + +def _monkey_patch() -> None: + _patch_adding_window_function_token() + _patch_adding_builtin_type() + _patch_updating_lateral_view_lexeme() + + +_monkey_patch() diff --git a/sqllineage/core/parser/sqlparse/analyzer.py b/sqllineage/core/parser/sqlparse/analyzer.py new file mode 100644 index 00000000..9f054101 --- /dev/null +++ b/sqllineage/core/parser/sqlparse/analyzer.py @@ -0,0 +1,175 @@ +from functools import reduce +from operator import add +from typing import List, Union + +import sqlparse +from sqlparse.sql import ( + Function, + Identifier, + IdentifierList, + Statement, + TokenList, + Where, +) + +from sqllineage.core.analyzer import LineageAnalyzer +from sqllineage.core.holders import StatementLineageHolder, SubQueryLineageHolder +from sqllineage.core.models import AnalyzerContext, SubQuery +from sqllineage.core.parser.sqlparse.handlers.base import ( + CurrentTokenBaseHandler, + NextTokenBaseHandler, +) +from sqllineage.core.parser.sqlparse.models import SqlParseSubQuery, SqlParseTable +from sqllineage.core.parser.sqlparse.utils.sqlparse import ( + get_subquery_parentheses, + is_subquery, + is_token_negligible, +) +from sqllineage.utils.helpers import trim_comment + + +class SqlParseLineageAnalyzer(LineageAnalyzer): + """SQL Statement Level Lineage Analyzer.""" + + def analyze(self, sql: str) -> StatementLineageHolder: + # get rid of comments, which cause inconsistencies in sqlparse output + stmt = sqlparse.parse(trim_comment(sql))[0] + if ( + stmt.get_type() == "DELETE" + or stmt.token_first(skip_cm=True).normalized == "TRUNCATE" + or stmt.token_first(skip_cm=True).normalized.upper() == "REFRESH" + or stmt.token_first(skip_cm=True).normalized == "CACHE" + or stmt.token_first(skip_cm=True).normalized.upper() == "UNCACHE" + or stmt.token_first(skip_cm=True).normalized == "SHOW" + ): + holder = StatementLineageHolder() + elif stmt.get_type() == "DROP": + holder = self._extract_from_ddl_drop(stmt) + elif ( + stmt.get_type() == "ALTER" + or stmt.token_first(skip_cm=True).normalized == "RENAME" + ): + holder = self._extract_from_ddl_alter(stmt) + else: + # DML parsing logic also applies to CREATE DDL + holder = StatementLineageHolder.of( + self._extract_from_dml(stmt, AnalyzerContext()) + ) + return holder + + @classmethod + def _extract_from_ddl_drop(cls, stmt: Statement) -> StatementLineageHolder: + holder = StatementLineageHolder() + for table in { + SqlParseTable.of(t) for t in stmt.tokens if isinstance(t, Identifier) + }: + holder.add_drop(table) + return holder + + @classmethod + def _extract_from_ddl_alter(cls, stmt: Statement) -> StatementLineageHolder: + holder = StatementLineageHolder() + tables = [] + for t in stmt.tokens: + if isinstance(t, Identifier): + tables.append(SqlParseTable.of(t)) + elif isinstance(t, IdentifierList): + for identifier in t.get_identifiers(): + tables.append(SqlParseTable.of(identifier)) + keywords = [t for t in stmt.tokens if t.is_keyword] + if any(k.normalized == "RENAME" for k in keywords): + if stmt.get_type() == "ALTER" and len(tables) == 2: + holder.add_rename(tables[0], tables[1]) + elif ( + stmt.token_first(skip_cm=True).normalized == "RENAME" + and len(tables) % 2 == 0 + ): + for i in range(0, len(tables), 2): + holder.add_rename(tables[i], tables[i + 1]) + if any(k.normalized == "EXCHANGE" for k in keywords) and len(tables) == 2: + holder.add_write(tables[0]) + holder.add_read(tables[1]) + return holder + + @classmethod + def _extract_from_dml( + cls, token: TokenList, context: AnalyzerContext + ) -> SubQueryLineageHolder: + holder = SubQueryLineageHolder() + if context.prev_cte is not None: + # CTE can be referenced by subsequent CTEs + for cte in context.prev_cte: + holder.add_cte(cte) + if context.subquery is not None: + # If within subquery, then manually add subquery as target table + holder.add_write(context.subquery) + current_handlers = [ + handler_cls() for handler_cls in CurrentTokenBaseHandler.__subclasses__() + ] + next_handlers = [ + handler_cls() for handler_cls in NextTokenBaseHandler.__subclasses__() + ] + + subqueries = [] + for sub_token in token.tokens: + if is_token_negligible(sub_token): + continue + + for sq in cls.parse_subquery(sub_token): + # Collecting subquery on the way, hold on parsing until last + # so that each handler don't have to worry about what's inside subquery + subqueries.append(sq) + + for current_handler in current_handlers: + current_handler.handle(sub_token, holder) + + if sub_token.is_keyword: + for next_handler in next_handlers: + next_handler.indicate(sub_token) + continue + + for next_handler in next_handlers: + if next_handler.indicator: + next_handler.handle(sub_token, holder) + else: + # call end of query hook here as loop is over + for next_handler in next_handlers: + next_handler.end_of_query_cleanup(holder) + # By recursively extracting each subquery of the parent and merge, we're doing Depth-first search + for sq in subqueries: + holder |= cls._extract_from_dml(sq.query, AnalyzerContext(sq, holder.cte)) + return holder + + @classmethod + def parse_subquery(cls, token: TokenList) -> List[SubQuery]: + result = [] + if isinstance(token, (Identifier, Function, Where)): + # usually SubQuery is an Identifier, but not all Identifiers are SubQuery + # Function for CTE without AS keyword + result = cls._parse_subquery(token) + elif isinstance(token, IdentifierList): + # IdentifierList for SQL89 style of JOIN or multiple CTEs, this is actually SubQueries + result = reduce( + add, + [ + cls._parse_subquery(identifier) + for identifier in token.get_sublists() + ], + [], + ) + elif is_subquery(token): + # Parenthesis for SubQuery without alias, this is valid syntax for certain SQL dialect + result = [SqlParseSubQuery.of(token, None)] + return result + + @classmethod + def _parse_subquery( + cls, token: Union[Identifier, Function, Where] + ) -> List[SubQuery]: + """ + convert SubQueryTuple to sqllineage.core.models.SubQuery + """ + return [ + SqlParseSubQuery.of(parenthesis, alias) + for parenthesis, alias in get_subquery_parentheses(token) + ] diff --git a/sqllineage/sqlfluff_core/handlers/__init__.py b/sqllineage/core/parser/sqlparse/handlers/__init__.py similarity index 100% rename from sqllineage/sqlfluff_core/handlers/__init__.py rename to sqllineage/core/parser/sqlparse/handlers/__init__.py diff --git a/sqllineage/core/handlers/base.py b/sqllineage/core/parser/sqlparse/handlers/base.py similarity index 100% rename from sqllineage/core/handlers/base.py rename to sqllineage/core/parser/sqlparse/handlers/base.py diff --git a/sqllineage/core/handlers/cte.py b/sqllineage/core/parser/sqlparse/handlers/cte.py similarity index 84% rename from sqllineage/core/handlers/cte.py rename to sqllineage/core/parser/sqlparse/handlers/cte.py index 9aa90478..22785cc3 100644 --- a/sqllineage/core/handlers/cte.py +++ b/sqllineage/core/parser/sqlparse/handlers/cte.py @@ -1,8 +1,8 @@ from sqlparse.sql import Function, Identifier, IdentifierList, Token -from sqllineage.core.handlers.base import NextTokenBaseHandler from sqllineage.core.holders import SubQueryLineageHolder -from sqllineage.core.models import SubQuery +from sqllineage.core.parser.sqlparse.handlers.base import NextTokenBaseHandler +from sqllineage.core.parser.sqlparse.models import SqlParseSubQuery class CTEHandler(NextTokenBaseHandler): @@ -30,4 +30,4 @@ def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None: sublist = list(token.get_sublists()) if sublist: # CTE: tbl AS (SELECT 1), tbl is alias and (SELECT 1) is subquery Parenthesis - holder.add_cte(SubQuery.of(sublist[0], token.get_real_name())) + holder.add_cte(SqlParseSubQuery.of(sublist[0], token.get_real_name())) diff --git a/sqllineage/core/handlers/source.py b/sqllineage/core/parser/sqlparse/handlers/source.py similarity index 66% rename from sqllineage/core/handlers/source.py rename to sqllineage/core/parser/sqlparse/handlers/source.py index 25cb7860..1d4f54d6 100644 --- a/sqllineage/core/handlers/source.py +++ b/sqllineage/core/parser/sqlparse/handlers/source.py @@ -1,5 +1,5 @@ import re -from typing import Dict, List, Union +from typing import Union from sqlparse.sql import ( Case, @@ -12,19 +12,24 @@ ) from sqlparse.tokens import Literal, Wildcard -from sqllineage.core.handlers.base import NextTokenBaseHandler from sqllineage.core.holders import SubQueryLineageHolder -from sqllineage.core.models import Column, Path, SubQuery, Table -from sqllineage.exceptions import SQLLineageException -from sqllineage.utils.constant import EdgeType -from sqllineage.utils.sqlparse import ( +from sqllineage.core.models import Path, SubQuery, Table +from sqllineage.core.parser import SourceHandlerMixin +from sqllineage.core.parser.sqlparse.handlers.base import NextTokenBaseHandler +from sqllineage.core.parser.sqlparse.models import ( + SqlParseColumn, + SqlParseSubQuery, + SqlParseTable, +) +from sqllineage.core.parser.sqlparse.utils.sqlparse import ( get_subquery_parentheses, is_subquery, is_values_clause, ) +from sqllineage.exceptions import SQLLineageException -class SourceHandler(NextTokenBaseHandler): +class SourceHandler(SourceHandlerMixin, NextTokenBaseHandler): """Source Table & Column Handler.""" SOURCE_TABLE_TOKENS = ( @@ -70,7 +75,7 @@ def _handle_table(self, token: Token, holder: SubQueryLineageHolder) -> None: if is_subquery(token): # SELECT col1 FROM (SELECT col2 FROM tab1), the subquery will be parsed as Parenthesis # This syntax without alias for subquery is invalid in MySQL, while valid for SparkSQL - self.tables.append(SubQuery.of(token, None)) + self.tables.append(SqlParseSubQuery.of(token, None)) elif is_values_clause(token): # SELECT * FROM (VALUES ...), no operation needed pass @@ -102,30 +107,7 @@ def _handle_column(self, token: Token) -> None: # SELECT constant value will end up here column_tokens = [] for token in column_tokens: - self.columns.append(Column.of(token)) - - def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None: - for i, tbl in enumerate(self.tables): - holder.add_read(tbl) - self.union_barriers.append((len(self.columns), len(self.tables))) - for i, (col_barrier, tbl_barrier) in enumerate(self.union_barriers): - prev_col_barrier, prev_tbl_barrier = ( - (0, 0) if i == 0 else self.union_barriers[i - 1] - ) - col_grp = self.columns[prev_col_barrier:col_barrier] - tbl_grp = self.tables[prev_tbl_barrier:tbl_barrier] - tgt_tbl = None - if holder.write: - if len(holder.write) > 1: - raise SQLLineageException - tgt_tbl = list(holder.write)[0] - if tgt_tbl: - for tgt_col in col_grp: - tgt_col.parent = tgt_tbl - for src_col in tgt_col.to_source_columns( - self._get_alias_mapping_from_table_group(tbl_grp, holder) - ): - holder.add_column_lineage(src_col, tgt_col) + self.columns.append(SqlParseColumn.of(token)) def _add_dataset_from_identifier( self, identifier: Identifier, holder: SubQueryLineageHolder @@ -148,42 +130,18 @@ def _add_dataset_from_identifier( # SELECT col1 FROM (SELECT col2 FROM tab1) dt, the subquery will be parsed as Identifier # referring https://github.com/andialbrecht/sqlparse/issues/218 for further information parenthesis, alias = subqueries[0] - read = SubQuery.of(parenthesis, alias) + read = SqlParseSubQuery.of(parenthesis, alias) else: cte_dict = {s.alias: s for s in holder.cte} if "." not in identifier.value: cte = cte_dict.get(identifier.get_real_name()) if cte is not None: # could reference CTE with or without alias - read = SubQuery.of( - cte.token, + read = SqlParseSubQuery.of( + cte.query, identifier.get_alias() or identifier.get_real_name(), ) if read is None: - read = Table.of(identifier) + read = SqlParseTable.of(identifier) dataset = read self.tables.append(dataset) - - @classmethod - def _get_alias_mapping_from_table_group( - cls, - table_group: List[Union[Path, Table, SubQuery]], - holder: SubQueryLineageHolder, - ) -> Dict[str, Union[Path, Table, SubQuery]]: - """ - A table can be referred to as alias, table name, or database_name.table_name, create the mapping here. - For SubQuery, it's only alias then. - """ - return { - **{ - tgt: src - for src, tgt, attr in holder.graph.edges(data=True) - if attr.get("type") == EdgeType.HAS_ALIAS and src in table_group - }, - **{ - table.raw_name: table - for table in table_group - if isinstance(table, Table) - }, - **{str(table): table for table in table_group if isinstance(table, Table)}, - } diff --git a/sqllineage/core/handlers/swap_partition.py b/sqllineage/core/parser/sqlparse/handlers/swap_partition.py similarity index 63% rename from sqllineage/core/handlers/swap_partition.py rename to sqllineage/core/parser/sqlparse/handlers/swap_partition.py index 10d8da9e..93b9ce3a 100644 --- a/sqllineage/core/handlers/swap_partition.py +++ b/sqllineage/core/parser/sqlparse/handlers/swap_partition.py @@ -1,8 +1,8 @@ from sqlparse.sql import Function, Token -from sqllineage.core.handlers.base import CurrentTokenBaseHandler from sqllineage.core.holders import SubQueryLineageHolder -from sqllineage.core.models import Table +from sqllineage.core.parser.sqlparse.handlers.base import CurrentTokenBaseHandler +from sqllineage.core.parser.sqlparse.models import SqlParseTable from sqllineage.utils.helpers import escape_identifier_name @@ -19,5 +19,9 @@ def handle(self, token: Token, holder: SubQueryLineageHolder) -> None: _, parenthesis = token.tokens _, identifier_list, _ = parenthesis.tokens identifiers = list(identifier_list.get_identifiers()) - holder.add_read(Table(escape_identifier_name(identifiers[0].normalized))) - holder.add_write(Table(escape_identifier_name(identifiers[3].normalized))) + holder.add_read( + SqlParseTable(escape_identifier_name(identifiers[0].normalized)) + ) + holder.add_write( + SqlParseTable(escape_identifier_name(identifiers[3].normalized)) + ) diff --git a/sqllineage/core/handlers/target.py b/sqllineage/core/parser/sqlparse/handlers/target.py similarity index 83% rename from sqllineage/core/handlers/target.py rename to sqllineage/core/parser/sqlparse/handlers/target.py index 412bcd4f..e8597a1c 100644 --- a/sqllineage/core/handlers/target.py +++ b/sqllineage/core/parser/sqlparse/handlers/target.py @@ -1,9 +1,10 @@ from sqlparse.sql import Comparison, Function, Identifier, Token from sqlparse.tokens import Literal, Number -from sqllineage.core.handlers.base import NextTokenBaseHandler from sqllineage.core.holders import SubQueryLineageHolder -from sqllineage.core.models import Path, Table +from sqllineage.core.models import Path +from sqllineage.core.parser.sqlparse.handlers.base import NextTokenBaseHandler +from sqllineage.core.parser.sqlparse.models import SqlParseTable from sqllineage.exceptions import SQLLineageException @@ -40,7 +41,7 @@ def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None: "An Identifier is expected, got %s[value: %s] instead." % (type(token).__name__, token) ) - holder.add_write(Table.of(token.token_first(skip_cm=True))) + holder.add_write(SqlParseTable.of(token.token_first(skip_cm=True))) elif isinstance(token, Comparison): # create table tab1 like tab2, tab1 like tab2 will be parsed as Comparison # referring https://github.com/andialbrecht/sqlparse/issues/543 for further information @@ -52,8 +53,8 @@ def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None: "An Identifier is expected, got %s[value: %s] instead." % (type(token).__name__, token) ) - holder.add_write(Table.of(token.left)) - holder.add_read(Table.of(token.right)) + holder.add_write(SqlParseTable.of(token.left)) + holder.add_read(SqlParseTable.of(token.right)) elif token.ttype == Literal.String.Single: holder.add_write(Path(token.value)) elif isinstance(token, Identifier): @@ -61,4 +62,4 @@ def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None: # Special Handling for Spark Bucket Table DDL pass else: - holder.add_write(Table.of(token)) + holder.add_write(SqlParseTable.of(token)) diff --git a/sqllineage/core/parser/sqlparse/models.py b/sqllineage/core/parser/sqlparse/models.py new file mode 100644 index 00000000..ed4b0db6 --- /dev/null +++ b/sqllineage/core/parser/sqlparse/models.py @@ -0,0 +1,186 @@ +from typing import List, Optional + +from sqlparse import tokens as T +from sqlparse.engine import grouping +from sqlparse.keywords import is_keyword +from sqlparse.sql import ( + Case, + Comparison, + Function, + Identifier, + IdentifierList, + Operation, + Parenthesis, + Token, + TokenList, +) +from sqlparse.utils import imt + +from sqllineage.core.models import Column, Schema, SubQuery, Table +from sqllineage.core.parser.sqlparse.utils.sqlparse import get_parameters, is_subquery +from sqllineage.utils.entities import ColumnQualifierTuple +from sqllineage.utils.helpers import escape_identifier_name + + +class SqlParseTable(Table): + @staticmethod + def of(table: Identifier) -> Table: + # rewrite identifier's get_real_name method, by matching the last dot instead of the first dot, so that the + # real name for a.b.c will be c instead of b + dot_idx, _ = table._token_matching( + lambda token: imt(token, m=(T.Punctuation, ".")), + start=len(table.tokens), + reverse=True, + ) + real_name = table._get_first_name(dot_idx, real_name=True) + # rewrite identifier's get_parent_name accordingly + parent_name = ( + "".join( + [ + escape_identifier_name(token.value) + for token in table.tokens[:dot_idx] + ] + ) + if dot_idx + else None + ) + schema = Schema(parent_name) if parent_name is not None else Schema() + alias = table.get_alias() + kwargs = {"alias": alias} if alias else {} + return Table(real_name, schema, **kwargs) + + +class SqlParseSubQuery(SubQuery): + @staticmethod + def of(subquery: Parenthesis, alias: Optional[str]) -> SubQuery: + return SubQuery(subquery, subquery.value, alias) + + +class SqlParseColumn(Column): + @staticmethod + def of(column: Token, **kwargs) -> Column: + if isinstance(column, Identifier): + alias = column.get_alias() + if alias: + # handle column alias, including alias for column name or Case, Function + kw_idx, kw = column.token_next_by(m=(T.Keyword, "AS")) + if kw_idx is None: + # alias without AS + kw_idx, _ = column.token_next_by(i=Identifier) + if kw_idx is None: + # invalid syntax: col AS, without alias + return Column(alias) + else: + idx, _ = column.token_prev(kw_idx, skip_cm=True) + expr = grouping.group(TokenList(column.tokens[: idx + 1]))[0] + source_columns = SqlParseColumn._extract_source_columns(expr) + return Column( + alias, + source_columns=source_columns, + ) + else: + # select column name directly without alias + return Column( + column.get_real_name(), + source_columns=( + (column.get_real_name(), column.get_parent_name()), + ), + ) + else: + # Wildcard, Case, Function without alias (thus not recognized as an Identifier) + source_columns = SqlParseColumn._extract_source_columns(column) + return Column( + column.value, + source_columns=source_columns, + ) + + @staticmethod + def _extract_source_columns(token: Token) -> List[ColumnQualifierTuple]: + if isinstance(token, Function): + # max(col1) AS col2 + source_columns = [ + cqt + for tk in get_parameters(token) + for cqt in SqlParseColumn._extract_source_columns(tk) + ] + elif isinstance(token, Parenthesis): + if is_subquery(token): + # This is to avoid circular import + from sqllineage.runner import LineageRunner + + # (SELECT avg(col1) AS col1 FROM tab3), used after WHEN or THEN in CASE clause + src_cols = [ + lineage[0] + for lineage in LineageRunner(token.value).get_column_lineage( + exclude_subquery=False + ) + ] + source_columns = [ + ColumnQualifierTuple(src_col.raw_name, src_col.parent.raw_name) + for src_col in src_cols + ] + else: + # (col1 + col2) AS col3 + source_columns = [ + cqt + for tk in token.tokens[1:-1] + for cqt in SqlParseColumn._extract_source_columns(tk) + ] + elif isinstance(token, Operation): + # col1 + col2 AS col3 + source_columns = [ + cqt + for tk in token.get_sublists() + for cqt in SqlParseColumn._extract_source_columns(tk) + ] + elif isinstance(token, Case): + # CASE WHEN col1 = 2 THEN "V1" WHEN col1 = "2" THEN "V2" END AS col2 + source_columns = [ + cqt + for tk in token.get_sublists() + for cqt in SqlParseColumn._extract_source_columns(tk) + ] + elif isinstance(token, Comparison): + source_columns = SqlParseColumn._extract_source_columns( + token.left + ) + SqlParseColumn._extract_source_columns(token.right) + elif isinstance(token, IdentifierList): + source_columns = [ + cqt + for tk in token.get_sublists() + for cqt in SqlParseColumn._extract_source_columns(tk) + ] + elif isinstance(token, Identifier): + real_name = token.get_real_name() + # ignore function dtypes that don't need to check for extract column + FUNC_DTYPE = ["decimal", "numeric"] + has_function = any( + isinstance(t, Function) and t.get_real_name() not in FUNC_DTYPE + for t in token.tokens + ) + is_kw = is_keyword(real_name) if real_name is not None else False + if ( + # real name is None: col1=1 AS int + real_name is None + # real_name is decimal: case when col1 > 0 then col2 else col3 end as decimal(18, 0) + or (real_name in FUNC_DTYPE and isinstance(token.tokens[-1], Function)) + or (is_kw and has_function) + ): + source_columns = [ + cqt + for tk in token.get_sublists() + for cqt in SqlParseColumn._extract_source_columns(tk) + ] + else: + # col1 AS col2 + source_columns = [ + ColumnQualifierTuple(token.get_real_name(), token.get_parent_name()) + ] + else: + if token.ttype == T.Wildcard: + # select * + source_columns = [ColumnQualifierTuple(token.value, None)] + else: + # typically, T.Literal here + source_columns = [] + return source_columns diff --git a/sqllineage/sqlfluff_core/utils/__init__.py b/sqllineage/core/parser/sqlparse/utils/__init__.py similarity index 100% rename from sqllineage/sqlfluff_core/utils/__init__.py rename to sqllineage/core/parser/sqlparse/utils/__init__.py diff --git a/sqllineage/utils/sqlparse.py b/sqllineage/core/parser/sqlparse/utils/sqlparse.py similarity index 100% rename from sqllineage/utils/sqlparse.py rename to sqllineage/core/parser/sqlparse/utils/sqlparse.py diff --git a/sqllineage/data/tpcds/query01.sql b/sqllineage/data/tpcds/query01.sql index 74899c43..c59a0007 100644 --- a/sqllineage/data/tpcds/query01.sql +++ b/sqllineage/data/tpcds/query01.sql @@ -8,7 +8,7 @@ with customer_total_return as and d_year = 2000 group by sr_customer_sk , sr_store_sk) -insert overwrite table query01 +insert into query01 select c_customer_id from customer_total_return ctr1 , store diff --git a/sqllineage/data/tpcds/query02.sql b/sqllineage/data/tpcds/query02.sql index db4715b0..bc1113a6 100644 --- a/sqllineage/data/tpcds/query02.sql +++ b/sqllineage/data/tpcds/query02.sql @@ -21,7 +21,7 @@ with wscs as , date_dim where d_date_sk = sold_date_sk group by d_week_seq) -insert overwrite table query02 +insert into query02 select d_week_seq1 , round(sun_sales1 / sun_sales2, 2) , round(mon_sales1 / mon_sales2, 2) diff --git a/sqllineage/data/tpcds/query03.sql b/sqllineage/data/tpcds/query03.sql index b9cf5221..c0aa08d0 100644 --- a/sqllineage/data/tpcds/query03.sql +++ b/sqllineage/data/tpcds/query03.sql @@ -1,4 +1,4 @@ -insert overwrite table query03 +insert into query03 select dt.d_year , item.i_brand_id brand_id , item.i_brand brand diff --git a/sqllineage/data/tpcds/query04.sql b/sqllineage/data/tpcds/query04.sql index eb1ad55b..a03ac94d 100644 --- a/sqllineage/data/tpcds/query04.sql +++ b/sqllineage/data/tpcds/query04.sql @@ -73,7 +73,7 @@ with year_total as ( , c_email_address , d_year ) -insert overwrite table query04 +insert into query04 select t_s_secyear.customer_id , t_s_secyear.customer_first_name , t_s_secyear.customer_last_name diff --git a/sqllineage/data/tpcds/query05.sql b/sqllineage/data/tpcds/query05.sql index b8e553c7..dba1a23b 100644 --- a/sqllineage/data/tpcds/query05.sql +++ b/sqllineage/data/tpcds/query05.sql @@ -90,7 +90,7 @@ with ssr as and date_add(cast('2000-08-23' as date), 14) and wsr_web_site_sk = web_site_sk group by web_site_id) -insert overwrite table query05 +insert into query05 select channel , id , sum(sales) as sales diff --git a/sqllineage/data/tpcds/query06.sql b/sqllineage/data/tpcds/query06.sql index e0433022..3baaff14 100644 --- a/sqllineage/data/tpcds/query06.sql +++ b/sqllineage/data/tpcds/query06.sql @@ -1,4 +1,4 @@ -insert overwrite table query06 +insert into query06 select a.ca_state state, count(*) cnt from customer_address a , customer c diff --git a/sqllineage/data/tpcds/query07.sql b/sqllineage/data/tpcds/query07.sql index 12fb219e..843db2ba 100644 --- a/sqllineage/data/tpcds/query07.sql +++ b/sqllineage/data/tpcds/query07.sql @@ -1,4 +1,4 @@ -insert overwrite table query07 +insert into query07 select i_item_id, avg(ss_quantity) agg1, avg(ss_list_price) agg2, diff --git a/sqllineage/data/tpcds/query08.sql b/sqllineage/data/tpcds/query08.sql index cf2668f5..fe099724 100644 --- a/sqllineage/data/tpcds/query08.sql +++ b/sqllineage/data/tpcds/query08.sql @@ -1,4 +1,4 @@ -insert overwrite table query08 +insert into query08 select s_store_name , sum(ss_net_profit) from store_sales diff --git a/sqllineage/data/tpcds/query09.sql b/sqllineage/data/tpcds/query09.sql index 5ea35b42..f5f7ea95 100644 --- a/sqllineage/data/tpcds/query09.sql +++ b/sqllineage/data/tpcds/query09.sql @@ -1,4 +1,4 @@ -insert overwrite table query09 +insert into query09 select case when (select count(*) from store_sales diff --git a/sqllineage/data/tpcds/query10.sql b/sqllineage/data/tpcds/query10.sql index fd2a2419..f4d0c99c 100644 --- a/sqllineage/data/tpcds/query10.sql +++ b/sqllineage/data/tpcds/query10.sql @@ -1,4 +1,4 @@ -insert overwrite table query10 +insert into query10 select cd_gender, cd_marital_status, cd_education_status, diff --git a/sqllineage/data/tpcds/query11.sql b/sqllineage/data/tpcds/query11.sql index 2c861117..8e082ddd 100644 --- a/sqllineage/data/tpcds/query11.sql +++ b/sqllineage/data/tpcds/query11.sql @@ -47,7 +47,7 @@ with year_total as ( , c_email_address , d_year ) -insert overwrite table query11 +insert into query11 select t_s_secyear.customer_id , t_s_secyear.customer_first_name , t_s_secyear.customer_last_name diff --git a/sqllineage/data/tpcds/query12.sql b/sqllineage/data/tpcds/query12.sql index 0b4d48bd..6f5db38b 100644 --- a/sqllineage/data/tpcds/query12.sql +++ b/sqllineage/data/tpcds/query12.sql @@ -1,4 +1,4 @@ -insert overwrite table query12 +insert into query12 select i_item_id , i_item_desc , i_category diff --git a/sqllineage/data/tpcds/query13.sql b/sqllineage/data/tpcds/query13.sql index 778f073e..90defd37 100644 --- a/sqllineage/data/tpcds/query13.sql +++ b/sqllineage/data/tpcds/query13.sql @@ -1,4 +1,4 @@ -insert overwrite table query13 +insert into query13 select avg(ss_quantity) , avg(ss_ext_sales_price) , avg(ss_ext_wholesale_cost) diff --git a/sqllineage/data/tpcds/query14.sql b/sqllineage/data/tpcds/query14.sql index 04acfa50..db33761a 100644 --- a/sqllineage/data/tpcds/query14.sql +++ b/sqllineage/data/tpcds/query14.sql @@ -56,7 +56,7 @@ with cross_items as , date_dim where ws_sold_date_sk = d_date_sk and d_year between 1999 and 1999 + 2) x) -insert overwrite table query14 +insert into query14 select channel, i_brand_id, i_class_id, i_category_id, sum(sales), sum(number_sales) from ( select 'store' channel diff --git a/sqllineage/data/tpcds/query15.sql b/sqllineage/data/tpcds/query15.sql index 6961594b..abd01118 100644 --- a/sqllineage/data/tpcds/query15.sql +++ b/sqllineage/data/tpcds/query15.sql @@ -1,4 +1,4 @@ -insert overwrite table query15 +insert into query15 select ca_zip , sum(cs_sales_price) from catalog_sales diff --git a/sqllineage/data/tpcds/query16.sql b/sqllineage/data/tpcds/query16.sql index 160ca93d..16697d2b 100644 --- a/sqllineage/data/tpcds/query16.sql +++ b/sqllineage/data/tpcds/query16.sql @@ -1,4 +1,4 @@ -insert overwrite table query16 +insert into query16 select count(distinct cs_order_number) as order_count , sum(cs_ext_ship_cost) as total_shipping_cost , sum(cs_net_profit) as total_net_profit diff --git a/sqllineage/data/tpcds/query17.sql b/sqllineage/data/tpcds/query17.sql index fd66efaa..1684d872 100644 --- a/sqllineage/data/tpcds/query17.sql +++ b/sqllineage/data/tpcds/query17.sql @@ -1,4 +1,4 @@ -insert overwrite table query17 +insert into query17 select i_item_id , i_item_desc , s_state diff --git a/sqllineage/data/tpcds/query18.sql b/sqllineage/data/tpcds/query18.sql index 6f0f658a..2cf750ae 100644 --- a/sqllineage/data/tpcds/query18.sql +++ b/sqllineage/data/tpcds/query18.sql @@ -1,4 +1,4 @@ -insert overwrite table query18 +insert into query18 select i_item_id, ca_country, ca_state, diff --git a/sqllineage/data/tpcds/query19.sql b/sqllineage/data/tpcds/query19.sql index 65cf1961..0851f66f 100644 --- a/sqllineage/data/tpcds/query19.sql +++ b/sqllineage/data/tpcds/query19.sql @@ -1,4 +1,4 @@ -insert overwrite table query19 +insert into query19 select i_brand_id brand_id, i_brand brand, i_manufact_id, diff --git a/sqllineage/data/tpcds/query20.sql b/sqllineage/data/tpcds/query20.sql index 69161ef9..f7b41170 100644 --- a/sqllineage/data/tpcds/query20.sql +++ b/sqllineage/data/tpcds/query20.sql @@ -1,4 +1,4 @@ -insert overwrite table query20 +insert into query20 select i_item_id , i_item_desc , i_category diff --git a/sqllineage/data/tpcds/query21.sql b/sqllineage/data/tpcds/query21.sql index 8758f9c5..4b41bf89 100644 --- a/sqllineage/data/tpcds/query21.sql +++ b/sqllineage/data/tpcds/query21.sql @@ -1,4 +1,4 @@ -insert overwrite table query21 +insert into query21 select * from (select w_warehouse_name , i_item_id diff --git a/sqllineage/data/tpcds/query22.sql b/sqllineage/data/tpcds/query22.sql index 188c81d7..85a7313e 100644 --- a/sqllineage/data/tpcds/query22.sql +++ b/sqllineage/data/tpcds/query22.sql @@ -1,4 +1,4 @@ -insert overwrite table query22 +insert into query22 select i_product_name , i_brand , i_class diff --git a/sqllineage/data/tpcds/query23.sql b/sqllineage/data/tpcds/query23.sql index 2f8f5893..5e863f5c 100644 --- a/sqllineage/data/tpcds/query23.sql +++ b/sqllineage/data/tpcds/query23.sql @@ -1,4 +1,4 @@ -insert overwrite table query23 +insert into query23 with frequent_ss_items as (select substr(i_item_desc, 1, 30) itemdesc, i_item_sk item_sk, d_date solddate, count(*) cnt from store_sales diff --git a/sqllineage/data/tpcds/query24.sql b/sqllineage/data/tpcds/query24.sql index bccf52e8..ed61e4e5 100644 --- a/sqllineage/data/tpcds/query24.sql +++ b/sqllineage/data/tpcds/query24.sql @@ -35,7 +35,7 @@ with ssales as , i_manager_id , i_units , i_size) -insert overwrite table query24 +insert into query24 select c_last_name , c_first_name , s_store_name diff --git a/sqllineage/data/tpcds/query25.sql b/sqllineage/data/tpcds/query25.sql index cfc7b7f3..72f7bba4 100644 --- a/sqllineage/data/tpcds/query25.sql +++ b/sqllineage/data/tpcds/query25.sql @@ -1,4 +1,4 @@ -insert overwrite table query25 +insert into query25 select i_item_id , i_item_desc , s_store_id diff --git a/sqllineage/data/tpcds/query26.sql b/sqllineage/data/tpcds/query26.sql index b05068e0..d13d4b76 100644 --- a/sqllineage/data/tpcds/query26.sql +++ b/sqllineage/data/tpcds/query26.sql @@ -1,4 +1,4 @@ -insert overwrite table query26 +insert into query26 select i_item_id, avg(cs_quantity) agg1, avg(cs_list_price) agg2, diff --git a/sqllineage/data/tpcds/query27.sql b/sqllineage/data/tpcds/query27.sql index e8aab5fa..9aec2c6c 100644 --- a/sqllineage/data/tpcds/query27.sql +++ b/sqllineage/data/tpcds/query27.sql @@ -1,4 +1,4 @@ -insert overwrite table query27 +insert into query27 select i_item_id, s_state, grouping(s_state) g_state, diff --git a/sqllineage/data/tpcds/query28.sql b/sqllineage/data/tpcds/query28.sql index 5c1f32af..1515fa41 100644 --- a/sqllineage/data/tpcds/query28.sql +++ b/sqllineage/data/tpcds/query28.sql @@ -1,4 +1,4 @@ -insert overwrite table query28 +insert into query28 select * from (select avg(ss_list_price) B1_LP , count(ss_list_price) B1_CNT diff --git a/sqllineage/data/tpcds/query29.sql b/sqllineage/data/tpcds/query29.sql index abc0baa2..59e03e70 100644 --- a/sqllineage/data/tpcds/query29.sql +++ b/sqllineage/data/tpcds/query29.sql @@ -1,4 +1,4 @@ -insert overwrite table query29 +insert into query29 select i_item_id , i_item_desc , s_store_id diff --git a/sqllineage/data/tpcds/query30.sql b/sqllineage/data/tpcds/query30.sql index f6cee075..2d5b804b 100644 --- a/sqllineage/data/tpcds/query30.sql +++ b/sqllineage/data/tpcds/query30.sql @@ -11,7 +11,7 @@ with customer_total_return as and wr_returning_addr_sk = ca_address_sk group by wr_returning_customer_sk , ca_state) -insert overwrite table query30 +insert into query30 select c_customer_id , c_salutation , c_first_name diff --git a/sqllineage/data/tpcds/query31.sql b/sqllineage/data/tpcds/query31.sql index 5fea06cb..75dc9232 100644 --- a/sqllineage/data/tpcds/query31.sql +++ b/sqllineage/data/tpcds/query31.sql @@ -14,7 +14,7 @@ with ss as where ws_sold_date_sk = d_date_sk and ws_bill_addr_sk = ca_address_sk group by ca_county, d_qoy, d_year) -insert overwrite table query31 +insert into query31 select ss1.ca_county , ss1.d_year , ws2.web_sales / ws1.web_sales web_q1_q2_increase diff --git a/sqllineage/data/tpcds/query32.sql b/sqllineage/data/tpcds/query32.sql index b75ea814..c19e7a49 100644 --- a/sqllineage/data/tpcds/query32.sql +++ b/sqllineage/data/tpcds/query32.sql @@ -1,4 +1,4 @@ -insert overwrite table query32 +insert into query32 select sum(cs_ext_discount_amt) as excess_discount_amount from catalog_sales , item diff --git a/sqllineage/data/tpcds/query33.sql b/sqllineage/data/tpcds/query33.sql index efa0d8f7..f2ca4b74 100644 --- a/sqllineage/data/tpcds/query33.sql +++ b/sqllineage/data/tpcds/query33.sql @@ -49,7 +49,7 @@ with ss as ( and ws_bill_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_manufact_id) -insert overwrite table query33 +insert into query33 select i_manufact_id, sum(total_sales) total_sales from (select * from ss diff --git a/sqllineage/data/tpcds/query34.sql b/sqllineage/data/tpcds/query34.sql index eac71a2f..693166fc 100644 --- a/sqllineage/data/tpcds/query34.sql +++ b/sqllineage/data/tpcds/query34.sql @@ -1,4 +1,4 @@ -insert overwrite table query34 +insert into query34 select c_last_name , c_first_name , c_salutation diff --git a/sqllineage/data/tpcds/query35.sql b/sqllineage/data/tpcds/query35.sql index b0479b0e..47c0793b 100644 --- a/sqllineage/data/tpcds/query35.sql +++ b/sqllineage/data/tpcds/query35.sql @@ -1,4 +1,4 @@ -insert overwrite table query35 +insert into query35 select ca_state, cd_gender, cd_marital_status, diff --git a/sqllineage/data/tpcds/query36.sql b/sqllineage/data/tpcds/query36.sql index c47bae08..cad026bd 100644 --- a/sqllineage/data/tpcds/query36.sql +++ b/sqllineage/data/tpcds/query36.sql @@ -1,4 +1,4 @@ -insert overwrite table query36 +insert into query36 select sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin , i_category , i_class diff --git a/sqllineage/data/tpcds/query37.sql b/sqllineage/data/tpcds/query37.sql index 372ffb30..896c3f8f 100644 --- a/sqllineage/data/tpcds/query37.sql +++ b/sqllineage/data/tpcds/query37.sql @@ -1,4 +1,4 @@ -insert overwrite table query37 +insert into query37 select i_item_id , i_item_desc , i_current_price diff --git a/sqllineage/data/tpcds/query38.sql b/sqllineage/data/tpcds/query38.sql index d5e893ef..7fefbbff 100644 --- a/sqllineage/data/tpcds/query38.sql +++ b/sqllineage/data/tpcds/query38.sql @@ -1,4 +1,4 @@ -insert overwrite table query38 +insert into query38 select count(*) from ( select distinct c_last_name, c_first_name, d_date diff --git a/sqllineage/data/tpcds/query39.sql b/sqllineage/data/tpcds/query39.sql index 808b7a9e..f6327f32 100644 --- a/sqllineage/data/tpcds/query39.sql +++ b/sqllineage/data/tpcds/query39.sql @@ -22,7 +22,7 @@ with inv as and d_year = 2001 group by w_warehouse_name, w_warehouse_sk, i_item_sk, d_moy) foo where case mean when 0 then 0 else stdev / mean end > 1) -insert overwrite table query39_1 +insert into query39_1 select inv1.w_warehouse_sk , inv1.i_item_sk , inv1.d_moy @@ -66,7 +66,7 @@ with inv as and d_year = 2001 group by w_warehouse_name, w_warehouse_sk, i_item_sk, d_moy) foo where case mean when 0 then 0 else stdev / mean end > 1) -insert overwrite table query39_2 +insert into query39_2 select inv1.w_warehouse_sk , inv1.i_item_sk , inv1.d_moy diff --git a/sqllineage/data/tpcds/query40.sql b/sqllineage/data/tpcds/query40.sql index 760ec903..bee41e7e 100644 --- a/sqllineage/data/tpcds/query40.sql +++ b/sqllineage/data/tpcds/query40.sql @@ -1,4 +1,4 @@ -insert overwrite table query40 +insert into query40 select w_state , i_item_id , sum(case diff --git a/sqllineage/data/tpcds/query41.sql b/sqllineage/data/tpcds/query41.sql index f2869358..3660550c 100644 --- a/sqllineage/data/tpcds/query41.sql +++ b/sqllineage/data/tpcds/query41.sql @@ -1,4 +1,4 @@ -insert overwrite table query41 +insert into query41 select distinct(i_product_name) from item i1 where i_manufact_id between 738 and 738 + 40 diff --git a/sqllineage/data/tpcds/query42.sql b/sqllineage/data/tpcds/query42.sql index e9a27b58..5db2c6c0 100644 --- a/sqllineage/data/tpcds/query42.sql +++ b/sqllineage/data/tpcds/query42.sql @@ -1,4 +1,4 @@ -insert overwrite table query42 +insert into query42 select dt.d_year , item.i_category_id , item.i_category diff --git a/sqllineage/data/tpcds/query43.sql b/sqllineage/data/tpcds/query43.sql index 43480759..efe76739 100644 --- a/sqllineage/data/tpcds/query43.sql +++ b/sqllineage/data/tpcds/query43.sql @@ -1,4 +1,4 @@ -insert overwrite table query43 +insert into query43 select s_store_name, s_store_id, sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, diff --git a/sqllineage/data/tpcds/query44.sql b/sqllineage/data/tpcds/query44.sql index a2b6c8a1..3248626b 100644 --- a/sqllineage/data/tpcds/query44.sql +++ b/sqllineage/data/tpcds/query44.sql @@ -1,4 +1,4 @@ -insert overwrite table query44 +insert into query44 select asceding.rnk, i1.i_product_name best_performing, i2.i_product_name worst_performing from (select * from (select item_sk, rank() over (order by rank_col asc) rnk diff --git a/sqllineage/data/tpcds/query45.sql b/sqllineage/data/tpcds/query45.sql index aec22576..4460b957 100644 --- a/sqllineage/data/tpcds/query45.sql +++ b/sqllineage/data/tpcds/query45.sql @@ -1,4 +1,4 @@ -insert overwrite table query45 +insert into query45 select ca_zip, ca_city, sum(ws_sales_price) from web_sales, customer, diff --git a/sqllineage/data/tpcds/query46.sql b/sqllineage/data/tpcds/query46.sql index d16d3d9d..a89e95c9 100644 --- a/sqllineage/data/tpcds/query46.sql +++ b/sqllineage/data/tpcds/query46.sql @@ -1,4 +1,4 @@ -insert overwrite table query46 +insert into query46 select c_last_name , c_first_name , ca_city diff --git a/sqllineage/data/tpcds/query47.sql b/sqllineage/data/tpcds/query47.sql index 7d2f0b15..c5aedadc 100644 --- a/sqllineage/data/tpcds/query47.sql +++ b/sqllineage/data/tpcds/query47.sql @@ -53,7 +53,7 @@ with v1 as ( and v1.s_company_name = v1_lead.s_company_name and v1.rn = v1_lag.rn + 1 and v1.rn = v1_lead.rn - 1) -insert overwrite table query47 +insert into query47 select * from v2 where d_year = 1999 diff --git a/sqllineage/data/tpcds/query48.sql b/sqllineage/data/tpcds/query48.sql index 5629b140..94bfda4b 100644 --- a/sqllineage/data/tpcds/query48.sql +++ b/sqllineage/data/tpcds/query48.sql @@ -1,4 +1,4 @@ -insert overwrite table query48 +insert into query48 select sum(ss_quantity) from store_sales, store, diff --git a/sqllineage/data/tpcds/query49.sql b/sqllineage/data/tpcds/query49.sql index 8b9e0b34..8a28ffd0 100644 --- a/sqllineage/data/tpcds/query49.sql +++ b/sqllineage/data/tpcds/query49.sql @@ -1,4 +1,4 @@ -insert overwrite table query49 +insert into query49 select 'web' as channel , web.item , web.return_ratio diff --git a/sqllineage/data/tpcds/query50.sql b/sqllineage/data/tpcds/query50.sql index 57c2fe7b..7fb8e62d 100644 --- a/sqllineage/data/tpcds/query50.sql +++ b/sqllineage/data/tpcds/query50.sql @@ -1,4 +1,4 @@ -insert overwrite table query50 +insert into query50 select s_store_name , s_company_id , s_street_number diff --git a/sqllineage/data/tpcds/query51.sql b/sqllineage/data/tpcds/query51.sql index b7c7fa99..63341ede 100644 --- a/sqllineage/data/tpcds/query51.sql +++ b/sqllineage/data/tpcds/query51.sql @@ -20,10 +20,7 @@ WITH web_v1 as ( and d_month_seq between 1200 and 1200 + 11 and ss_item_sk is not NULL group by ss_item_sk, d_date) -insert -overwrite -table -query51 +insert into query51 select * from (select item_sk , d_date diff --git a/sqllineage/data/tpcds/query52.sql b/sqllineage/data/tpcds/query52.sql index 138830c8..aa5a2814 100644 --- a/sqllineage/data/tpcds/query52.sql +++ b/sqllineage/data/tpcds/query52.sql @@ -1,4 +1,4 @@ -insert overwrite table query52 +insert into query52 select dt.d_year , item.i_brand_id brand_id , item.i_brand brand diff --git a/sqllineage/data/tpcds/query53.sql b/sqllineage/data/tpcds/query53.sql index 3e49bc79..d13e8353 100644 --- a/sqllineage/data/tpcds/query53.sql +++ b/sqllineage/data/tpcds/query53.sql @@ -1,4 +1,4 @@ -insert overwrite table query53 +insert into query53 select * from (select i_manufact_id, sum(ss_sales_price) sum_sales, diff --git a/sqllineage/data/tpcds/query54.sql b/sqllineage/data/tpcds/query54.sql index 42029e8d..cf287b08 100644 --- a/sqllineage/data/tpcds/query54.sql +++ b/sqllineage/data/tpcds/query54.sql @@ -49,7 +49,7 @@ with my_customers as ( (select cast((revenue / 50) as int) as segment from my_revenue ) -insert overwrite table query54 +insert into query54 select segment, count(*) as num_customers, segment * 50 as segment_base from segments group by segment diff --git a/sqllineage/data/tpcds/query55.sql b/sqllineage/data/tpcds/query55.sql index da949cc1..4b08c8f4 100644 --- a/sqllineage/data/tpcds/query55.sql +++ b/sqllineage/data/tpcds/query55.sql @@ -1,4 +1,4 @@ -insert overwrite table query55 +insert into query55 select i_brand_id brand_id, i_brand brand, sum(ss_ext_sales_price) ext_price diff --git a/sqllineage/data/tpcds/query56.sql b/sqllineage/data/tpcds/query56.sql index 226dca81..606cdb24 100644 --- a/sqllineage/data/tpcds/query56.sql +++ b/sqllineage/data/tpcds/query56.sql @@ -46,7 +46,7 @@ with ss as ( and ws_bill_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id) -insert overwrite table query56 +insert into query56 select i_item_id, sum(total_sales) total_sales from (select * from ss diff --git a/sqllineage/data/tpcds/query57.sql b/sqllineage/data/tpcds/query57.sql index 4a830abc..8b5ecc9f 100644 --- a/sqllineage/data/tpcds/query57.sql +++ b/sqllineage/data/tpcds/query57.sql @@ -48,7 +48,7 @@ with v1 as ( and v1.cc_name = v1_lead.cc_name and v1.rn = v1_lag.rn + 1 and v1.rn = v1_lead.rn - 1) -insert overwrite table query57 +insert into query57 select * from v2 where d_year = 1999 diff --git a/sqllineage/data/tpcds/query58.sql b/sqllineage/data/tpcds/query58.sql index aa013ef1..f90b839c 100644 --- a/sqllineage/data/tpcds/query58.sql +++ b/sqllineage/data/tpcds/query58.sql @@ -40,7 +40,7 @@ with ss_items as where d_date = '2000-01-03')) and ws_sold_date_sk = d_date_sk group by i_item_id) -insert overwrite table query58 +insert into query58 select ss_items.item_id , ss_item_rev , ss_item_rev / ((ss_item_rev + cs_item_rev + ws_item_rev) / 3) * 100 ss_dev diff --git a/sqllineage/data/tpcds/query59.sql b/sqllineage/data/tpcds/query59.sql index 4113df70..e188d998 100644 --- a/sqllineage/data/tpcds/query59.sql +++ b/sqllineage/data/tpcds/query59.sql @@ -13,7 +13,7 @@ with wss as where d_date_sk = ss_sold_date_sk group by d_week_seq, ss_store_sk ) -insert overwrite table query59 +insert into query59 select s_store_name1 , s_store_id1 , d_week_seq1 diff --git a/sqllineage/data/tpcds/query60.sql b/sqllineage/data/tpcds/query60.sql index 72300fdf..a5ae4f8d 100644 --- a/sqllineage/data/tpcds/query60.sql +++ b/sqllineage/data/tpcds/query60.sql @@ -49,7 +49,7 @@ with ss as ( and ws_bill_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id) -insert overwrite table query60 +insert into query60 select i_item_id , sum(total_sales) total_sales from (select * diff --git a/sqllineage/data/tpcds/query61.sql b/sqllineage/data/tpcds/query61.sql index a38d19b9..7c7b59f9 100644 --- a/sqllineage/data/tpcds/query61.sql +++ b/sqllineage/data/tpcds/query61.sql @@ -1,4 +1,4 @@ -insert overwrite table query61 +insert into query61 select promotions, total, cast(promotions as decimal(15, 4)) / cast(total as decimal(15, 4)) * 100 from (select sum(ss_ext_sales_price) promotions from store_sales diff --git a/sqllineage/data/tpcds/query62.sql b/sqllineage/data/tpcds/query62.sql index b3eddb8d..c8758454 100644 --- a/sqllineage/data/tpcds/query62.sql +++ b/sqllineage/data/tpcds/query62.sql @@ -1,4 +1,4 @@ -insert overwrite table query62 +insert into query62 select substr(w_warehouse_name, 1, 20) , sm_type , web_name diff --git a/sqllineage/data/tpcds/query63.sql b/sqllineage/data/tpcds/query63.sql index 650c9417..f9c0a0d6 100644 --- a/sqllineage/data/tpcds/query63.sql +++ b/sqllineage/data/tpcds/query63.sql @@ -1,4 +1,4 @@ -insert overwrite table query63 +insert into query63 select * from (select i_manager_id , sum(ss_sales_price) sum_sales diff --git a/sqllineage/data/tpcds/query64.sql b/sqllineage/data/tpcds/query64.sql index d6880120..9308b071 100644 --- a/sqllineage/data/tpcds/query64.sql +++ b/sqllineage/data/tpcds/query64.sql @@ -84,7 +84,7 @@ with cs_ui as , d2.d_year , d3.d_year ) -insert overwrite table query64 +insert into query64 select cs1.product_name , cs1.store_name , cs1.store_zip diff --git a/sqllineage/data/tpcds/query65.sql b/sqllineage/data/tpcds/query65.sql index 18fa1e71..f42b9262 100644 --- a/sqllineage/data/tpcds/query65.sql +++ b/sqllineage/data/tpcds/query65.sql @@ -1,4 +1,4 @@ -insert overwrite table query65 +insert into query65 select s_store_name, i_item_desc, sc.revenue, diff --git a/sqllineage/data/tpcds/query66.sql b/sqllineage/data/tpcds/query66.sql index 3302649a..b9609024 100644 --- a/sqllineage/data/tpcds/query66.sql +++ b/sqllineage/data/tpcds/query66.sql @@ -1,4 +1,4 @@ -insert overwrite table query66 +insert into query66 select w_warehouse_name , w_warehouse_sq_ft , w_city diff --git a/sqllineage/data/tpcds/query67.sql b/sqllineage/data/tpcds/query67.sql index cbf0d3d5..c65b420f 100644 --- a/sqllineage/data/tpcds/query67.sql +++ b/sqllineage/data/tpcds/query67.sql @@ -1,4 +1,4 @@ -insert overwrite table query67 +insert into query67 select * from (select i_category , i_class diff --git a/sqllineage/data/tpcds/query68.sql b/sqllineage/data/tpcds/query68.sql index 297523d9..df858cc7 100644 --- a/sqllineage/data/tpcds/query68.sql +++ b/sqllineage/data/tpcds/query68.sql @@ -1,4 +1,4 @@ -insert overwrite table query68 +insert into query68 select c_last_name , c_first_name , ca_city diff --git a/sqllineage/data/tpcds/query69.sql b/sqllineage/data/tpcds/query69.sql index b4b7f0f5..7c8d7726 100644 --- a/sqllineage/data/tpcds/query69.sql +++ b/sqllineage/data/tpcds/query69.sql @@ -1,4 +1,4 @@ -insert overwrite table query69 +insert into query69 select cd_gender, cd_marital_status, cd_education_status, diff --git a/sqllineage/data/tpcds/query70.sql b/sqllineage/data/tpcds/query70.sql index 2050edca..186999ab 100644 --- a/sqllineage/data/tpcds/query70.sql +++ b/sqllineage/data/tpcds/query70.sql @@ -1,4 +1,4 @@ -insert overwrite table query70 +insert into query70 select sum(ss_net_profit) as total_sum , s_state , s_county diff --git a/sqllineage/data/tpcds/query71.sql b/sqllineage/data/tpcds/query71.sql index 0a170a74..9dbf84e4 100644 --- a/sqllineage/data/tpcds/query71.sql +++ b/sqllineage/data/tpcds/query71.sql @@ -1,4 +1,4 @@ -insert overwrite table query71 +insert into query71 select i_brand_id brand_id, i_brand brand, t_hour, diff --git a/sqllineage/data/tpcds/query72.sql b/sqllineage/data/tpcds/query72.sql index 446eaac8..3916eed8 100644 --- a/sqllineage/data/tpcds/query72.sql +++ b/sqllineage/data/tpcds/query72.sql @@ -1,4 +1,4 @@ -insert overwrite table query72 +insert into query72 select i_item_desc , w_warehouse_name , d1.d_week_seq diff --git a/sqllineage/data/tpcds/query73.sql b/sqllineage/data/tpcds/query73.sql index 6cfc0623..18eb95b2 100644 --- a/sqllineage/data/tpcds/query73.sql +++ b/sqllineage/data/tpcds/query73.sql @@ -1,4 +1,4 @@ -insert overwrite table query73 +insert into query73 select c_last_name , c_first_name , c_salutation diff --git a/sqllineage/data/tpcds/query74.sql b/sqllineage/data/tpcds/query74.sql index 298fb723..6056d9c6 100644 --- a/sqllineage/data/tpcds/query74.sql +++ b/sqllineage/data/tpcds/query74.sql @@ -33,7 +33,7 @@ with year_total as ( , c_last_name , d_year ) -insert overwrite table query74 +insert into query74 select t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name diff --git a/sqllineage/data/tpcds/query75.sql b/sqllineage/data/tpcds/query75.sql index 72deb478..e0dd5728 100644 --- a/sqllineage/data/tpcds/query75.sql +++ b/sqllineage/data/tpcds/query75.sql @@ -48,7 +48,7 @@ WITH all_sales AS ( AND ws_item_sk = wr_item_sk) WHERE i_category = 'Books') sales_detail GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) -insert overwrite table query75 +insert into query75 SELECT prev_yr.d_year AS prev_year , curr_yr.d_year AS year , curr_yr.i_brand_id diff --git a/sqllineage/data/tpcds/query76.sql b/sqllineage/data/tpcds/query76.sql index 757c7ea0..2775c5e1 100644 --- a/sqllineage/data/tpcds/query76.sql +++ b/sqllineage/data/tpcds/query76.sql @@ -1,4 +1,4 @@ -insert overwrite table query76 +insert into query76 select channel, col_name, d_year, d_qoy, i_category, COUNT(*) sales_cnt, SUM(ext_sales_price) sales_amt FROM ( SELECT 'store' as channel, diff --git a/sqllineage/data/tpcds/query77.sql b/sqllineage/data/tpcds/query77.sql index d748b1d8..b75c68fe 100644 --- a/sqllineage/data/tpcds/query77.sql +++ b/sqllineage/data/tpcds/query77.sql @@ -69,7 +69,7 @@ with ss as and date_add(cast('2000-08-23' as date), 30) and wr_web_page_sk = wp_web_page_sk group by wp_web_page_sk) -insert overwrite table query77 +insert into query77 select channel , id , sum(sales) as sales diff --git a/sqllineage/data/tpcds/query78.sql b/sqllineage/data/tpcds/query78.sql index d7f309c9..5c2e7f6c 100644 --- a/sqllineage/data/tpcds/query78.sql +++ b/sqllineage/data/tpcds/query78.sql @@ -37,7 +37,7 @@ with ws as where sr_ticket_number is null group by d_year, ss_item_sk, ss_customer_sk ) -insert overwrite table query78 +insert into query78 select ss_sold_year, ss_item_sk, ss_customer_sk, diff --git a/sqllineage/data/tpcds/query79.sql b/sqllineage/data/tpcds/query79.sql index fc50a64b..a5992fe8 100644 --- a/sqllineage/data/tpcds/query79.sql +++ b/sqllineage/data/tpcds/query79.sql @@ -1,4 +1,4 @@ -insert overwrite table query79 +insert into query79 select c_last_name,c_first_name,substr(s_city,1,30),ss_ticket_number,amt,profit from diff --git a/sqllineage/data/tpcds/query80.sql b/sqllineage/data/tpcds/query80.sql index 55f6866a..94ab27eb 100644 --- a/sqllineage/data/tpcds/query80.sql +++ b/sqllineage/data/tpcds/query80.sql @@ -63,7 +63,7 @@ with ssr as and ws_promo_sk = p_promo_sk and p_channel_tv = 'N' group by web_site_id) -insert overwrite table query80 +insert into query80 select channel , id , sum(sales) as sales diff --git a/sqllineage/data/tpcds/query81.sql b/sqllineage/data/tpcds/query81.sql index f0b57fd3..47df4ebf 100644 --- a/sqllineage/data/tpcds/query81.sql +++ b/sqllineage/data/tpcds/query81.sql @@ -11,7 +11,7 @@ with customer_total_return as and cr_returning_addr_sk = ca_address_sk group by cr_returning_customer_sk , ca_state) -insert overwrite table query81 +insert into query81 select c_customer_id , c_salutation , c_first_name diff --git a/sqllineage/data/tpcds/query82.sql b/sqllineage/data/tpcds/query82.sql index 5e8c97e9..aa7f93d4 100644 --- a/sqllineage/data/tpcds/query82.sql +++ b/sqllineage/data/tpcds/query82.sql @@ -1,4 +1,4 @@ -insert overwrite table query82 +insert into query82 select i_item_id , i_item_desc , i_current_price diff --git a/sqllineage/data/tpcds/query83.sql b/sqllineage/data/tpcds/query83.sql index 2dcb3d7c..c4de4218 100644 --- a/sqllineage/data/tpcds/query83.sql +++ b/sqllineage/data/tpcds/query83.sql @@ -46,7 +46,7 @@ with sr_items as where d_date in ('2000-06-30', '2000-09-27', '2000-11-17'))) and wr_returned_date_sk = d_date_sk group by i_item_id) -insert overwrite table query83 +insert into query83 select sr_items.item_id , sr_item_qty , sr_item_qty / (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 * 100 sr_dev diff --git a/sqllineage/data/tpcds/query84.sql b/sqllineage/data/tpcds/query84.sql index 77e4a041..5d2ab000 100644 --- a/sqllineage/data/tpcds/query84.sql +++ b/sqllineage/data/tpcds/query84.sql @@ -1,4 +1,4 @@ -insert overwrite table query84 +insert into query84 select c_customer_id as customer_id , concat(c_last_name, ', ', coalesce(c_first_name, '')) as customername from customer diff --git a/sqllineage/data/tpcds/query85.sql b/sqllineage/data/tpcds/query85.sql index 557abb0a..52fbdb3a 100644 --- a/sqllineage/data/tpcds/query85.sql +++ b/sqllineage/data/tpcds/query85.sql @@ -1,4 +1,4 @@ -insert overwrite table query85 +insert into query85 select substr(r_reason_desc, 1, 20) , avg(ws_quantity) , avg(wr_refunded_cash) diff --git a/sqllineage/data/tpcds/query86.sql b/sqllineage/data/tpcds/query86.sql index 66561bdd..73ccc70e 100644 --- a/sqllineage/data/tpcds/query86.sql +++ b/sqllineage/data/tpcds/query86.sql @@ -1,4 +1,4 @@ -insert overwrite table query86 +insert into query86 select sum(ws_net_paid) as total_sum , i_category , i_class diff --git a/sqllineage/data/tpcds/query87.sql b/sqllineage/data/tpcds/query87.sql index 0430b6f4..7bd9b7b3 100644 --- a/sqllineage/data/tpcds/query87.sql +++ b/sqllineage/data/tpcds/query87.sql @@ -1,4 +1,4 @@ -insert overwrite table query87 +insert into query87 select count(*) from ((select distinct c_last_name, c_first_name, d_date from store_sales, diff --git a/sqllineage/data/tpcds/query88.sql b/sqllineage/data/tpcds/query88.sql index fcd94c17..49d5a5a1 100644 --- a/sqllineage/data/tpcds/query88.sql +++ b/sqllineage/data/tpcds/query88.sql @@ -1,4 +1,4 @@ -insert overwrite table query88 +insert into query88 select * from (select count(*) h8_30_to_9 from store_sales, diff --git a/sqllineage/data/tpcds/query89.sql b/sqllineage/data/tpcds/query89.sql index 1182561b..3a802399 100644 --- a/sqllineage/data/tpcds/query89.sql +++ b/sqllineage/data/tpcds/query89.sql @@ -1,4 +1,4 @@ -insert overwrite table query89 +insert into query89 select * from ( select i_category, diff --git a/sqllineage/data/tpcds/query90.sql b/sqllineage/data/tpcds/query90.sql index 5ad39edc..9009bcb5 100644 --- a/sqllineage/data/tpcds/query90.sql +++ b/sqllineage/data/tpcds/query90.sql @@ -1,4 +1,4 @@ -insert overwrite table query90 +insert into query90 select cast(amc as decimal(15, 4)) / cast(pmc as decimal(15, 4)) am_pm_ratio from (select count(*) amc from web_sales, diff --git a/sqllineage/data/tpcds/query91.sql b/sqllineage/data/tpcds/query91.sql index 46bf1585..a7b48e75 100644 --- a/sqllineage/data/tpcds/query91.sql +++ b/sqllineage/data/tpcds/query91.sql @@ -1,4 +1,4 @@ -insert overwrite table query91 +insert into query91 select cc_call_center_id Call_Center, cc_name Call_Center_Name, cc_manager Manager, diff --git a/sqllineage/data/tpcds/query92.sql b/sqllineage/data/tpcds/query92.sql index b0a74f7e..c260fc3d 100644 --- a/sqllineage/data/tpcds/query92.sql +++ b/sqllineage/data/tpcds/query92.sql @@ -1,4 +1,4 @@ -insert overwrite table query92 +insert into query92 select sum(ws_ext_discount_amt) as Excess_Discount_Amount from web_sales , item diff --git a/sqllineage/data/tpcds/query93.sql b/sqllineage/data/tpcds/query93.sql index e8b1b7ba..28257326 100644 --- a/sqllineage/data/tpcds/query93.sql +++ b/sqllineage/data/tpcds/query93.sql @@ -1,4 +1,4 @@ -insert overwrite table query93 +insert into query93 select ss_customer_sk , sum(act_sales) sumsales from (select ss_item_sk diff --git a/sqllineage/data/tpcds/query94.sql b/sqllineage/data/tpcds/query94.sql index 9da7a36e..292e6bb1 100644 --- a/sqllineage/data/tpcds/query94.sql +++ b/sqllineage/data/tpcds/query94.sql @@ -1,4 +1,4 @@ -insert overwrite table query94 +insert into query94 select count(distinct ws_order_number) as order_count , sum(ws_ext_ship_cost) as total_shipping_cost , sum(ws_net_profit) as total_net_profit diff --git a/sqllineage/data/tpcds/query95.sql b/sqllineage/data/tpcds/query95.sql index ae54aec6..eba538e0 100644 --- a/sqllineage/data/tpcds/query95.sql +++ b/sqllineage/data/tpcds/query95.sql @@ -4,7 +4,7 @@ with ws_wh as web_sales ws2 where ws1.ws_order_number = ws2.ws_order_number and ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk) -insert overwrite table query95 +insert into query95 select count(distinct ws_order_number) as order_count , sum(ws_ext_ship_cost) as total_shipping_cost , sum(ws_net_profit) as total_net_profit diff --git a/sqllineage/data/tpcds/query96.sql b/sqllineage/data/tpcds/query96.sql index 83121130..51601baa 100644 --- a/sqllineage/data/tpcds/query96.sql +++ b/sqllineage/data/tpcds/query96.sql @@ -1,4 +1,4 @@ -insert overwrite table query96 +insert into query96 select count(*) from store_sales , household_demographics diff --git a/sqllineage/data/tpcds/query97.sql b/sqllineage/data/tpcds/query97.sql index 6665b8e9..30cdd806 100644 --- a/sqllineage/data/tpcds/query97.sql +++ b/sqllineage/data/tpcds/query97.sql @@ -16,7 +16,7 @@ with ssci as ( and d_month_seq between 1200 and 1200 + 11 group by cs_bill_customer_sk , cs_item_sk) -insert overwrite table query97 +insert into query97 select sum(case when ssci.customer_sk is not null and csci.customer_sk is null then 1 else 0 end) store_only , sum(case when ssci.customer_sk is null and csci.customer_sk is not null then 1 else 0 end) catalog_only , sum(case when ssci.customer_sk is not null and csci.customer_sk is not null then 1 else 0 end) store_and_catalog diff --git a/sqllineage/data/tpcds/query98.sql b/sqllineage/data/tpcds/query98.sql index 151794e4..19370fe6 100644 --- a/sqllineage/data/tpcds/query98.sql +++ b/sqllineage/data/tpcds/query98.sql @@ -1,4 +1,4 @@ -insert overwrite table query98 +insert into query98 select i_item_id , i_item_desc , i_category diff --git a/sqllineage/data/tpcds/query99.sql b/sqllineage/data/tpcds/query99.sql index ba57b6e4..a9fc5373 100644 --- a/sqllineage/data/tpcds/query99.sql +++ b/sqllineage/data/tpcds/query99.sql @@ -1,4 +1,4 @@ -insert overwrite table query99 +insert into query99 select substr(w_warehouse_name, 1, 20) , sm_type , cc_name diff --git a/sqllineage/drawing.py b/sqllineage/drawing.py index 5c03c7d1..af8dec15 100644 --- a/sqllineage/drawing.py +++ b/sqllineage/drawing.py @@ -22,7 +22,6 @@ DEFAULT_DIALECT, DEFAULT_HOST, DEFAULT_PORT, - DEFAULT_USE_SQLFLUFF, ) from sqllineage import STATIC_FOLDER from sqllineage.exceptions import SQLLineageException @@ -159,8 +158,7 @@ def lineage(payload): req_args = Namespace(**payload) sql = extract_sql_from_args(req_args) dialect = getattr(req_args, "dialect", DEFAULT_DIALECT) - use_sql_parse = bool(getattr(req_args, "use_sqlfluff", DEFAULT_USE_SQLFLUFF)) - lr = LineageRunner(sql, verbose=True, dialect=dialect, use_sqlparse=use_sql_parse) + lr = LineageRunner(sql, verbose=True, dialect=dialect) data = { "verbose": str(lr), "dag": lr.to_cytoscape(), diff --git a/sqllineage/exceptions.py b/sqllineage/exceptions.py index 8b9cc171..aa634b08 100644 --- a/sqllineage/exceptions.py +++ b/sqllineage/exceptions.py @@ -1,2 +1,10 @@ class SQLLineageException(Exception): """Base Exception for SQLLineage""" + + +class UnsupportedStatementException(SQLLineageException): + """Raised for SQL statement that SQLLineage doesn't support analyzing""" + + +class InvalidSyntaxException(SQLLineageException): + """Raised for SQL statement that parser cannot parse""" diff --git a/sqllineage/runner.py b/sqllineage/runner.py index 83018f15..db44bff0 100644 --- a/sqllineage/runner.py +++ b/sqllineage/runner.py @@ -1,29 +1,15 @@ import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple -import sqlparse -from sqlfluff.api.simple import get_simple_config -from sqlfluff.core import Linter -from sqlparse.sql import Statement - -from sqllineage.core import LineageAnalyzer -from sqllineage.core.holders import SQLLineageHolder, StatementLineageHolder +from sqllineage import SQLPARSE_DIALECT +from sqllineage.core.holders import SQLLineageHolder from sqllineage.core.models import Column, Table +from sqllineage.core.parser.sqlfluff.analyzer import SqlFluffLineageAnalyzer +from sqllineage.core.parser.sqlparse.analyzer import SqlParseLineageAnalyzer from sqllineage.drawing import draw_lineage_graph -from sqllineage.exceptions import SQLLineageException from sqllineage.io import to_cytoscape -from sqllineage.sqlfluff_core.analyzer import SqlFluffLineageAnalyzer -from sqllineage.sqlfluff_core.holders import ( - SqlFluffLineageHolder, - SqlFluffStatementLineageHolder, -) -from sqllineage.sqlfluff_core.utils.sqlfluff import get_statement_segment from sqllineage.utils.constant import LineageLevel -from sqllineage.utils.helpers import ( - clean_parentheses, - is_subquery_statement, - remove_statement_parentheses, -) +from sqllineage.utils.helpers import split, trim_comment logger = logging.getLogger(__name__) @@ -46,11 +32,10 @@ class LineageRunner(object): def __init__( self, sql: str, + dialect: str = "ansi", encoding: Optional[str] = None, verbose: bool = False, draw_options: Optional[Dict[str, str]] = None, - dialect: Optional[str] = "ansi", - use_sqlparse: bool = True, ): """ The entry point of SQLLineage after command line options are parsed. @@ -64,20 +49,15 @@ def __init__( self._verbose = verbose self._draw_options = draw_options if draw_options else {} self._evaluated = False - self._stmt: List[Statement] = [] - self._use_sqlparse = use_sqlparse - if not self._use_sqlparse: - self._sqlfluff_linter = Linter( - config=get_simple_config(dialect=dialect, config_path=None) - ) - self._dialect = dialect + self._stmt: List[str] = [] + self._dialect = dialect @lazy_method def __str__(self): """ print out the Lineage Summary. """ - statements = self.statements(strip_comments=True) + statements = self.statements() source_tables = "\n ".join(str(t) for t in self.source_tables) target_tables = "\n ".join(str(t) for t in self.target_tables) combined = f"""Statements(#): {len(statements)} @@ -115,7 +95,7 @@ def to_cytoscape(self, level=LineageLevel.TABLE) -> List[Dict[str, Dict[str, str else: return to_cytoscape(self._sql_holder.table_lineage_graph) - def draw(self, dialect: str, use_sqlfluff: bool) -> None: + def draw(self, dialect: str) -> None: """ to draw the lineage directed graph """ @@ -124,24 +104,14 @@ def draw(self, dialect: str, use_sqlfluff: bool) -> None: draw_options.pop("f", None) draw_options["e"] = self._sql draw_options["dialect"] = dialect - draw_options["use_sqlfluff"] = str(use_sqlfluff) return draw_lineage_graph(**draw_options) @lazy_method - def statements(self, **kwargs) -> List[str]: + def statements(self) -> List[str]: """ - a list of statements. - - :param kwargs: the key arguments that will be passed to `sqlparse.format` + a list of SQL statements. """ - return [sqlparse.format(s.value, **kwargs) for s in self.statements_parsed] - - @lazy_property - def statements_parsed(self) -> List[Statement]: - """ - a list of :class:`sqlparse.sql.Statement` - """ - return self._stmt + return [trim_comment(s) for s in self._stmt] @lazy_property def source_tables(self) -> List[Table]: @@ -189,51 +159,12 @@ def print_table_lineage(self) -> None: print(str(self)) def _eval(self): - self._stmt = [ - s - for s in sqlparse.parse( - # first apply sqlparser formatting just to get rid of comments, which cause - # inconsistencies in parsing output - clean_parentheses( - sqlparse.format( - self._sql.strip(), self._encoding, strip_comments=True - ) - ), - self._encoding, - ) - if s.token_first(skip_cm=True) - ] - - self._stmt_holders = [self.run_lineage_analyzer(stmt) for stmt in self._stmt] - self._sql_holder = ( - SQLLineageHolder.of(*self._stmt_holders) - if self._use_sqlparse - else SqlFluffLineageHolder.of(self._stmt_holders) + self._stmt = split(self._sql.strip()) + analyzer = ( + SqlParseLineageAnalyzer() + if self._dialect == SQLPARSE_DIALECT + else SqlFluffLineageAnalyzer(self._dialect) ) + self._stmt_holders = [analyzer.analyze(stmt) for stmt in self._stmt] + self._sql_holder = SQLLineageHolder.of(*self._stmt_holders) self._evaluated = True - - def run_lineage_analyzer( - self, stmt: Statement - ) -> Union[StatementLineageHolder, SqlFluffStatementLineageHolder]: - stmt_value = stmt.value.strip() - if not self._use_sqlparse: - is_sub_query = is_subquery_statement(stmt_value) - if is_sub_query: - stmt_value = remove_statement_parentheses(stmt_value) - parsed_string = self._sqlfluff_linter.parse_string(stmt_value) - statement_segment = get_statement_segment(parsed_string) - if statement_segment and SqlFluffLineageAnalyzer.can_analyze( - statement_segment - ): - if "unparsable" in statement_segment.descendant_type_set: - raise SQLLineageException( - f"The query [\n{stmt_value}\n] contains an unparsable segment." - ) - return SqlFluffLineageAnalyzer().analyze( - statement_segment, self._dialect or "", is_sub_query - ) - else: - raise SQLLineageException( - f"The query [\n{stmt_value}\n] contains can not be analyzed." - ) - return LineageAnalyzer().analyze(stmt) diff --git a/sqllineage/sqlfluff_core/analyzer.py b/sqllineage/sqlfluff_core/analyzer.py deleted file mode 100644 index 2727b4a8..00000000 --- a/sqllineage/sqlfluff_core/analyzer.py +++ /dev/null @@ -1,61 +0,0 @@ -from sqlfluff.core.parser import BaseSegment - -from sqllineage.sqlfluff_core.holders import ( - SqlFluffStatementLineageHolder, -) -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext -from sqllineage.sqlfluff_core.subquery.cte_extractor import DmlCteExtractor -from sqllineage.sqlfluff_core.subquery.ddl_alter_extractor import DdlAlterExtractor -from sqllineage.sqlfluff_core.subquery.ddl_drop_extractor import DdlDropExtractor -from sqllineage.sqlfluff_core.subquery.dml_insert_extractor import DmlInsertExtractor -from sqllineage.sqlfluff_core.subquery.dml_select_extractor import DmlSelectExtractor -from sqllineage.sqlfluff_core.subquery.noop_extractor import NoopExtractor - -SUPPORTED_STMT_TYPES = ( - DmlSelectExtractor.DML_SELECT_STMT_TYPES - + DmlInsertExtractor.DML_INSERT_STMT_TYPES - + DmlCteExtractor.CTE_STMT_TYPES - + DdlDropExtractor.DDL_DROP_STMT_TYPES - + DdlAlterExtractor.DDL_ALTER_STMT_TYPES - + NoopExtractor.NOOP_STMT_TYPES -) - - -class SqlFluffLineageAnalyzer: - """SQL Statement Level Lineage Analyzer for `sqlfluff`""" - - def analyze( - self, statement: BaseSegment, dialect: str, is_sub_query: bool = False - ) -> SqlFluffStatementLineageHolder: - """ - Analyze the base segment and store the result into `sqllineage.holders.StatementLineageHolder` class. - :param statement: a SQL base segment parsed by `sqlfluff` - :param dialect: dialect used to parse the statement - :param is_sub_query: the original query contained parentheses - :return: 'SqlFluffStatementLineageHolder' object - """ - subquery_extractors = [ - DmlSelectExtractor(dialect), - DmlInsertExtractor(dialect), - DmlCteExtractor(dialect), - DdlDropExtractor(dialect), - DdlAlterExtractor(dialect), - NoopExtractor(dialect), - ] - for subquery_extractor in subquery_extractors: - if subquery_extractor.can_extract(statement.type): - lineage_holder = subquery_extractor.extract( - statement, SqlFluffAnalyzerContext(), is_sub_query - ) - return SqlFluffStatementLineageHolder.of(lineage_holder) - raise NotImplementedError( - f"Can not extract lineage for dialect [{dialect}] from query: [{statement.raw}]" - ) - - @staticmethod - def can_analyze(statement: BaseSegment): - """ - Check if the current lineage analyzer can analyze the statement - :param statement: a SQL base segment parsed by `sqlfluff` - """ - return statement.type in SUPPORTED_STMT_TYPES diff --git a/sqllineage/sqlfluff_core/handlers/source.py b/sqllineage/sqlfluff_core/handlers/source.py deleted file mode 100644 index 2fe2ffb1..00000000 --- a/sqllineage/sqlfluff_core/handlers/source.py +++ /dev/null @@ -1,222 +0,0 @@ -import re -from typing import Dict, List, Tuple, Union - -from sqlfluff.core.parser import BaseSegment - -from sqllineage.exceptions import SQLLineageException -from sqllineage.sqlfluff_core.handlers.base import ConditionalSegmentBaseHandler -from sqllineage.sqlfluff_core.holders import SqlFluffSubQueryLineageHolder -from sqllineage.sqlfluff_core.models import ( - SqlFluffColumn, - SqlFluffSubQuery, -) -from sqllineage.sqlfluff_core.models import ( - SqlFluffPath, - SqlFluffTable, -) -from sqllineage.sqlfluff_core.utils.holder import retrieve_holder_data_from -from sqllineage.sqlfluff_core.utils.sqlfluff import ( - find_table_identifier, - get_grandchild, - get_inner_from_expression, - get_multiple_identifiers, - get_subqueries, - is_subquery, - is_values_clause, - retrieve_extra_segment, - retrieve_segments, -) -from sqllineage.utils.constant import EdgeType - - -class SourceHandler(ConditionalSegmentBaseHandler): - """ - Source table and column handler - """ - - def __init__(self, dialect: str): - super().__init__(dialect) - self.columns: List[SqlFluffColumn] = [] - self.tables: List[ - Union[SqlFluffPath, SqlFluffTable, SqlFluffSubQuery, SqlFluffSubQuery] - ] = [] - self.union_barriers: List[Tuple[int, int]] = [] - - def indicate(self, segment: BaseSegment) -> bool: - """ - Indicates if the handler can handle the segment - :param segment: segment to be processed - :return: True if it can be handled - """ - return self._indicate_column(segment) or self._indicate_table(segment) - - def handle( - self, segment: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> None: - """ - Handle the segment, and update the lineage result accordingly in the holder - :param segment: segment to be handled - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage - """ - if self._indicate_table(segment): - self._handle_table(segment, holder) - if self._indicate_column(segment): - self._handle_column(segment) - - def _handle_table( - self, segment: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> None: - """ - Table handler method - :param segment: segment to be handled - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage - """ - identifiers = get_multiple_identifiers(segment) - if identifiers and len(identifiers) > 1: - for identifier in identifiers: - self._add_dataset_from_expression_element(identifier, holder) - from_segment = get_inner_from_expression(segment) - if from_segment.type == "from_expression_element": - self._add_dataset_from_expression_element(from_segment, holder) - elif from_segment.type == "bracketed": - if is_subquery(from_segment): - self.tables.append(SqlFluffSubQuery.of(from_segment, None)) - else: - raise SQLLineageException( - "An 'from_expression_element' or 'bracketed' segment is expected, got %s instead." - % from_segment.type - ) - for extra_segment in retrieve_extra_segment(segment): - self._handle_table(extra_segment, holder) - - def _handle_column(self, segment: BaseSegment) -> None: - """ - Column handler method - :param segment: segment to be handled - """ - sub_segments = retrieve_segments(segment) - for sub_segment in sub_segments: - if sub_segment.type == "select_clause_element": - self.columns.append( - SqlFluffColumn.of(sub_segment, dialect=self.dialect) - ) - - def end_of_query_cleanup(self, holder: SqlFluffSubQueryLineageHolder) -> None: - """ - Optional method to be called at the end of statement or subquery - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage - """ - for i, tbl in enumerate(self.tables): - holder.add_read(tbl) - self.union_barriers.append((len(self.columns), len(self.tables))) - for i, (col_barrier, tbl_barrier) in enumerate(self.union_barriers): - prev_col_barrier, prev_tbl_barrier = ( - (0, 0) if i == 0 else self.union_barriers[i - 1] - ) - col_grp = self.columns[prev_col_barrier:col_barrier] - tbl_grp = self.tables[prev_tbl_barrier:tbl_barrier] - tgt_tbl = None - if holder.write: - if len(holder.write) > 1: - raise SQLLineageException - tgt_tbl = list(holder.write)[0] - if tgt_tbl: - for tgt_col in col_grp: - tgt_col.parent = tgt_tbl - for src_col in tgt_col.to_source_columns( - self._get_alias_mapping_from_table_group(tbl_grp, holder) - ): - holder.add_column_lineage(src_col, tgt_col) - - def _add_dataset_from_expression_element( - self, segment: BaseSegment, holder: SqlFluffSubQueryLineageHolder - ) -> None: - """ - Append tables and subqueries identified in the 'from_expression_element' type segment to the table and - holder extra subqueries sets - :param segment: 'from_expression_element' type segment - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage - """ - dataset: Union[SqlFluffPath, SqlFluffTable, SqlFluffSubQuery] - all_segments = [ - seg for seg in retrieve_segments(segment) if seg.type != "keyword" - ] - first_segment = all_segments[0] - function_as_table = get_grandchild(segment, "table_expression", "function") - if first_segment.type == "function" or function_as_table: - # function() as alias, no dataset involved - return - elif first_segment.type == "bracketed" and is_values_clause(first_segment): - # (VALUES ...) AS alias, no dataset involved - return - path_match = re.match(r"(parquet|csv|json)\.`(.*)`", segment.raw_upper) - if path_match is not None: - dataset = SqlFluffPath(path_match.groups()[1]) - else: - subqueries = get_subqueries(segment, skip_union=False) - if subqueries: - for sq in subqueries: - bracketed, alias = sq - read_sq = SqlFluffSubQuery.of(bracketed, alias) - holder.extra_subqueries.add(read_sq) - self.tables.append(read_sq) - return - else: - table_identifier = find_table_identifier(segment) - if table_identifier: - dataset = retrieve_holder_data_from( - all_segments, holder, table_identifier - ) - else: - return - self.tables.append(dataset) - - def _get_alias_mapping_from_table_group( - self, - table_group: List[ - Union[SqlFluffPath, SqlFluffTable, SqlFluffSubQuery, SqlFluffSubQuery] - ], - holder: SqlFluffSubQueryLineageHolder, - ) -> Dict[str, Union[SqlFluffTable, SqlFluffSubQuery]]: - """ - A table can be referred to as alias, table name, or database_name.table_name, create the mapping here. - For SubQuery, it's only alias then. - :param table_group: a list of objects from the table list - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage - :return: A map of tables and references - """ - return { - **{ - tgt: src - for src, tgt, attr in holder.graph.edges(data=True) - if attr.get("type") == EdgeType.HAS_ALIAS and src in table_group - }, - **{ - table.raw_name: table - for table in table_group - if isinstance(table, SqlFluffTable) - }, - **{ - str(table): table - for table in table_group - if isinstance(table, SqlFluffTable) - }, - } - - @staticmethod - def _indicate_column(segment: BaseSegment) -> bool: - """ - Check if it is a column - :param segment: segment to be checked - :return: True if type is 'select_clause' - """ - return bool(segment.type == "select_clause") - - @staticmethod - def _indicate_table(segment: BaseSegment) -> bool: - """ - Check if it is a table - :param segment: segment to be checked - :return: True if type is 'from_clause' - """ - return bool(segment.type == "from_clause") diff --git a/sqllineage/sqlfluff_core/holders.py b/sqllineage/sqlfluff_core/holders.py deleted file mode 100644 index f52d1afb..00000000 --- a/sqllineage/sqlfluff_core/holders.py +++ /dev/null @@ -1,346 +0,0 @@ -import itertools -from typing import List, Optional, Set, Tuple, Union - -import networkx as nx -from networkx import DiGraph - -from sqllineage.sqlfluff_core.models import ( - SqlFluffColumn, - SqlFluffPath, - SqlFluffSubQuery, - SqlFluffTable, -) -from sqllineage.utils.constant import EdgeType, NodeTag - -DATASET_CLASSES = (SqlFluffPath, SqlFluffTable) - - -class SqlFluffColumnLineageMixin: - """ - Mixin class with 'get_column_lineage' method - """ - - def get_column_lineage( - self, exclude_subquery=True - ) -> Set[Tuple[SqlFluffColumn, ...]]: - """ - Calculate the column lineage of a holder's graph - :param exclude_subquery: if only 'SqlFluffTable' are considered - :return: column lineage into a list of tuples - """ - self.graph: DiGraph # For mypy attribute checking - # filter all the column node in the graph - column_nodes = [n for n in self.graph.nodes if isinstance(n, SqlFluffColumn)] - column_graph = self.graph.subgraph(column_nodes) - source_columns = {column for column, deg in column_graph.in_degree if deg == 0} - # if a column lineage path ends at SubQuery, then it should be pruned - target_columns = { - node - for node, deg in column_graph.out_degree - if isinstance(node, SqlFluffColumn) and deg == 0 - } - if exclude_subquery: - target_columns = { - node - for node in target_columns - if isinstance(node.parent, SqlFluffTable) - } - columns = set() - for (source, target) in itertools.product(source_columns, target_columns): - simple_paths = list(nx.all_simple_paths(self.graph, source, target)) - for path in simple_paths: - columns.add(tuple(path)) - return columns - - -class SqlFluffSubQueryLineageHolder(SqlFluffColumnLineageMixin): - """ - SubQuery/Query Level Lineage Result. - - SqlFluffSubQueryLineageHolder will hold attributes like read, write, cte - - Each of them is a set of 'SqlFluffTable' or 'SqlFluffSubQuery'. - - This is the most atomic representation of lineage result. - """ - - def __init__(self) -> None: - self.graph = nx.DiGraph() - self.extra_subqueries: Set[SqlFluffSubQuery] = set() - - def __or__(self, other): - self.graph = nx.compose(self.graph, other.graph) - return self - - def _property_getter( - self, prop - ) -> Union[Set[SqlFluffSubQuery], Set[SqlFluffTable]]: - return {t for t, attr in self.graph.nodes(data=True) if attr.get(prop) is True} - - def _property_setter(self, value, prop) -> None: - self.graph.add_node(value, **{prop: True}) - - @property - def read(self) -> Set[Union[SqlFluffSubQuery, SqlFluffTable]]: - return self._property_getter(NodeTag.READ) # type: ignore - - def add_read(self, value) -> None: - self._property_setter(value, NodeTag.READ) - # the same table can be added (in SQL: joined) multiple times with different alias - if hasattr(value, "alias"): - self.graph.add_edge(value, value.alias, type=EdgeType.HAS_ALIAS) - - @property - def write(self) -> Set[Union[SqlFluffSubQuery, SqlFluffTable]]: - return self._property_getter(NodeTag.WRITE) # type: ignore - - def add_write(self, value) -> None: - self._property_setter(value, NodeTag.WRITE) - - @property - def cte(self) -> Set[SqlFluffSubQuery]: - return self._property_getter(NodeTag.CTE) # type: ignore - - def add_cte(self, value) -> None: - self._property_setter(value, NodeTag.CTE) - - def add_column_lineage(self, src: SqlFluffColumn, tgt: SqlFluffColumn) -> None: - """ - Add column lineage between to given 'SqlFluffColumn' - :param src: source 'SqlFluffColumn' - :param tgt: target 'SqlFluffColumn' - """ - self.graph.add_edge(src, tgt, type=EdgeType.LINEAGE) - self.graph.add_edge(tgt.parent, tgt, type=EdgeType.HAS_COLUMN) - if src.parent is not None: - self.graph.add_edge(src.parent, src, type=EdgeType.HAS_COLUMN) - - -class SqlFluffStatementLineageHolder( - SqlFluffSubQueryLineageHolder, SqlFluffColumnLineageMixin -): - """ - Statement Level Lineage Result. - - Based on 'SqlFluffSubQueryLineageHolder' and 'StatementLineageHolder' holds extra attributes like drop and rename - - For drop, it is a set of 'SqlFluffTable'. - - For rename, it is a set of tuples of 'SqlFluffTable', with the first table being original table before renaming and - the latter after renaming. - """ - - def __str__(self): - return "\n".join( - f"table {attr}: {sorted(getattr(self, attr), key=lambda x: str(x)) if getattr(self, attr) else '[]'}" - for attr in ["read", "write", "cte", "drop", "rename"] - ) - - def __repr__(self): - return str(self) - - @property - def read(self) -> Set[SqlFluffTable]: # type: ignore - return {t for t in super().read if isinstance(t, DATASET_CLASSES)} - - @property - def write(self) -> Set[SqlFluffTable]: # type: ignore - return {t for t in super().write if isinstance(t, DATASET_CLASSES)} - - @property - def drop(self) -> Set[SqlFluffTable]: - return self._property_getter(NodeTag.DROP) # type: ignore - - def add_drop(self, value) -> None: - self._property_setter(value, NodeTag.DROP) - - @property - def rename(self) -> Set[Tuple[SqlFluffTable, SqlFluffTable]]: - return { - (src, tgt) - for src, tgt, attr in self.graph.edges(data=True) - if attr.get("type") == EdgeType.RENAME - } - - def add_rename(self, src: SqlFluffTable, tgt: SqlFluffTable) -> None: - """ - Add rename of a source 'SqlFluffColumn' into a target 'SqlFluffColumn' - :param src: source 'SqlFluffTable' - :param tgt: target 'SqlFluffTable' - """ - self.graph.add_edge(src, tgt, type=EdgeType.RENAME) - - @staticmethod - def of(holder: SqlFluffSubQueryLineageHolder) -> "SqlFluffStatementLineageHolder": - """ - Build a 'SqlFluffStatementLineageHolder' object - :param holder: 'SqlFluffSubQueryLineageHolder' to hold lineage - :return: 'SqlFluffStatementLineageHolder' object - """ - stmt_holder = SqlFluffStatementLineageHolder() - stmt_holder.graph = holder.graph - return stmt_holder - - -class SqlFluffLineageHolder(SqlFluffColumnLineageMixin): - """ - Lineage Result - """ - - def __init__(self, graph: DiGraph): - """ - The combined lineage result in representation of Directed Acyclic Graph. - :param graph: the Directed Acyclic Graph holding all the combined lineage result. - """ - self.graph = graph - self._selfloop_tables = self.__retrieve_tag_tables(NodeTag.SELFLOOP) - self._sourceonly_tables = self.__retrieve_tag_tables(NodeTag.SOURCE_ONLY) - self._targetonly_tables = self.__retrieve_tag_tables(NodeTag.TARGET_ONLY) - - @property - def table_lineage_graph(self) -> DiGraph: - """ - :return the table level DiGraph held by 'SqlFluffLineageHolder' - """ - table_nodes = [n for n in self.graph.nodes if isinstance(n, DATASET_CLASSES)] - return self.graph.subgraph(table_nodes) - - @property - def column_lineage_graph(self) -> DiGraph: - """ - :return the column level DiGraph held by 'SqlFluffLineageHolder' - """ - column_nodes = [n for n in self.graph.nodes if isinstance(n, SqlFluffColumn)] - return self.graph.subgraph(column_nodes) - - @property - def source_tables(self) -> Set[SqlFluffTable]: - """ - :return a list of source 'SqlFluffTable' - """ - source_tables = { - table for table, deg in self.table_lineage_graph.in_degree if deg == 0 - }.intersection( - {table for table, deg in self.table_lineage_graph.out_degree if deg > 0} - ) - source_tables |= self._selfloop_tables - source_tables |= self._sourceonly_tables - return source_tables - - @property - def target_tables(self) -> Set[SqlFluffTable]: - """ - :return a list of target 'SqlFluffTable' - """ - target_tables = { - table for table, deg in self.table_lineage_graph.out_degree if deg == 0 - }.intersection( - {table for table, deg in self.table_lineage_graph.in_degree if deg > 0} - ) - target_tables |= self._selfloop_tables - target_tables |= self._targetonly_tables - return target_tables - - @property - def intermediate_tables(self) -> Set[SqlFluffTable]: - """ - :return a list of intermediate 'SqlFluffTable' - """ - intermediate_tables = { - table for table, deg in self.table_lineage_graph.in_degree if deg > 0 - }.intersection( - {table for table, deg in self.table_lineage_graph.out_degree if deg > 0} - ) - intermediate_tables -= self.__retrieve_tag_tables(NodeTag.SELFLOOP) - return intermediate_tables - - def __retrieve_tag_tables(self, tag) -> Set[Union[SqlFluffPath, SqlFluffTable]]: - return { - table - for table, attr in self.graph.nodes(data=True) - if attr.get(tag) is True and isinstance(table, DATASET_CLASSES) - } - - @staticmethod - def _get_column_if_related_to_parent( - g: DiGraph, - raw_name: str, - parent: Union[SqlFluffTable, SqlFluffSubQuery], - ) -> Optional[SqlFluffColumn]: - src_col = SqlFluffColumn(raw_name) - src_col.parent = parent - return src_col if g.has_edge(parent, src_col) else None - - @classmethod - def _build_digraph(cls, holders: List[SqlFluffStatementLineageHolder]) -> DiGraph: - """ - To assemble multiple 'SqlFluffStatementLineageHolder' into 'SqlFluffLineageHolder' - :param holders: a list of 'SqlFluffStatementLineageHolder' - :return: the DiGraph held - """ - g = DiGraph() - for holder in holders: - g = nx.compose(g, holder.graph) - if holder.drop: - for table in holder.drop: - if g.has_node(table) and g.degree[table] == 0: - g.remove_node(table) - elif holder.rename: - for (table_old, table_new) in holder.rename: - g = nx.relabel_nodes(g, {table_old: table_new}) - g.remove_edge(table_new, table_new) - if g.degree[table_new] == 0: - g.remove_node(table_new) - else: - read, write = holder.read, holder.write - if len(read) > 0 and len(write) == 0: - # source only table comes from SELECT statement - nx.set_node_attributes( - g, {table: True for table in read}, NodeTag.SOURCE_ONLY - ) - elif len(read) == 0 and len(write) > 0: - # target only table comes from case like: 1) INSERT/UPDATE constant values; 2) CREATE TABLE - nx.set_node_attributes( - g, {table: True for table in write}, NodeTag.TARGET_ONLY - ) - else: - for source, target in itertools.product(read, write): - g.add_edge(source, target, type=EdgeType.LINEAGE) - nx.set_node_attributes( - g, - {table: True for table in {e[0] for e in nx.selfloop_edges(g)}}, - NodeTag.SELFLOOP, - ) - # find all the columns that we can't assign accurately to a parent table (with multiple parent candidates) - unresolved_cols = [ - (s, t) - for s, t in g.edges - if isinstance(s, SqlFluffColumn) and len(s.parent_candidates) > 1 - ] - for unresolved_col, tgt_col in unresolved_cols: - # check if there's only one parent candidate contains the column with same name - src_cols = [] - for parent in unresolved_col.parent_candidates: - src_col = cls._get_column_if_related_to_parent( - g, unresolved_col.raw_name, parent - ) - if src_col: - src_cols.append(src_col) - if len(src_cols) == 1: - g.add_edge(src_cols[0], tgt_col, type=EdgeType.LINEAGE) - g.remove_edge(unresolved_col, tgt_col) - # when unresolved column got resolved, it will be orphan node, and we can remove it - for node in [n for n, deg in g.degree if deg == 0]: - if isinstance(node, SqlFluffColumn) and len(node.parent_candidates) > 1: - g.remove_node(node) - return g - - @staticmethod - def of(holders: List[SqlFluffStatementLineageHolder]): - """ - To assemble multiple 'SqlFluffStatementLineageHolder' into 'SqlFluffLineageHolder' - :param holders: a list of 'SqlFluffStatementLineageHolder' - :return: a 'SqlFluffLineageHolder' object - """ - g = SqlFluffLineageHolder._build_digraph(holders) - return SqlFluffLineageHolder(g) diff --git a/sqllineage/sqlfluff_core/models.py b/sqllineage/sqlfluff_core/models.py deleted file mode 100644 index ed8d648f..00000000 --- a/sqllineage/sqlfluff_core/models.py +++ /dev/null @@ -1,450 +0,0 @@ -import warnings -from typing import Dict, List, Set, Union -from typing import Optional, Tuple - -from sqlfluff.core.parser import BaseSegment - -from sqllineage.exceptions import SQLLineageException -from sqllineage.sqlfluff_core.utils.entities import SqlFluffColumnQualifierTuple -from sqllineage.sqlfluff_core.utils.sqlfluff import ( - get_identifier, - is_subquery, - is_wildcard, - retrieve_segments, - token_matching, -) -from sqllineage.utils.helpers import escape_identifier_name - -NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE = [ - "function", - "over_clause", - "partitionby_clause", - "orderby_clause", - "expression", - "case_expression", - "when_clause", - "else_clause", - "select_clause_element", -] - -SOURCE_COLUMN_SEGMENT_TYPE = NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + [ - "identifier", - "column_reference", -] - - -class SqlFluffSchema: - """ - Data Class for Schema - """ - - unknown = "" - - def __init__(self, name: str = unknown): - """ - :param name: schema name - """ - self.raw_name = escape_identifier_name(name) - - def __str__(self): - return self.raw_name.lower() - - def __repr__(self): - return "Schema: " + str(self) - - def __eq__(self, other): - return type(self) is type(other) and str(self) == str(other) - - def __hash__(self): - return hash(str(self)) - - def __bool__(self): - return str(self) != self.unknown - - -class SqlFluffPath: - """ - Data Class for SqlFluffPath - """ - - def __init__(self, uri: str): - """ - :param uri: uri of the path - """ - self.uri = escape_identifier_name(uri) - - def __str__(self): - return self.uri - - def __repr__(self): - return "Path: " + str(self) - - def __eq__(self, other): - return type(self) is type(other) and self.uri == other.uri - - def __hash__(self): - return hash(self.uri) - - -class SqlFluffTable: - """ - Data Class for SqlFluffTable - """ - - def __init__(self, name: str, schema: SqlFluffSchema = SqlFluffSchema(), **kwargs): - """ - :param name: table name - :param schema: schema as defined by 'SqlFluffTable' - """ - if "." not in name: - self.schema = schema - self.raw_name = escape_identifier_name(name) - else: - schema_name, table_name = name.rsplit(".", 1) - if len(schema_name.split(".")) > 2: - # allow db.schema as schema_name, but a.b.c as schema_name is forbidden - raise SQLLineageException("Invalid format for table name: %s.", name) - self.schema = SqlFluffSchema(schema_name) - self.raw_name = escape_identifier_name(table_name) - if schema: - warnings.warn("Name is in schema.table format, schema param is ignored") - self.alias = kwargs.pop("alias", self.raw_name) - - def __str__(self): - return f"{self.schema}.{self.raw_name.lower()}" - - def __repr__(self): - return "Table: " + str(self) - - def __eq__(self, other): - return type(self) is type(other) and str(self) == str(other) - - def __hash__(self): - return hash(str(self)) - - @staticmethod - def of(table_segment: BaseSegment, alias: Optional[str] = None) -> "SqlFluffTable": - """ - Build an object of type 'SqlFluffTable' - :param table_segment: table segment to be processed - :param alias: alias of the table segment - :return: 'SqlFluffTable' object - """ - # rewrite identifier's get_real_name method, by matching the last dot instead of the first dot, so that the - # real name for a.b.c will be c instead of b - dot_idx, _ = token_matching( - table_segment, - (lambda s: bool(s.type == "symbol"),), - start=len(table_segment.segments), - reverse=True, - ) - real_name = ( - table_segment.segments[dot_idx + 1].raw - if dot_idx - else ( - table_segment.raw - if table_segment.type == "identifier" - else table_segment.segments[0].raw - ) - ) - # rewrite identifier's get_parent_name accordingly - parent_name = ( - "".join( - [ - escape_identifier_name(segment.raw) - for segment in table_segment.segments[:dot_idx] - ] - ) - if dot_idx - else None - ) - schema = ( - SqlFluffSchema(parent_name) if parent_name is not None else SqlFluffSchema() - ) - kwargs = {"alias": alias} if alias else {} - return SqlFluffTable(real_name, schema, **kwargs) - - -class SqlFluffSubQuery: - """ - Data Class for SqlFluffSubQuery - """ - - def __init__(self, subquery: BaseSegment, alias: Optional[str]): - """ - :param subquery: subquery segment - :param alias: subquery alias - """ - self.segment = subquery - self._query = subquery.raw - self.alias = alias if alias is not None else f"subquery_{hash(self)}" - - def __str__(self): - return self.alias - - def __repr__(self): - return "SubQuery: " + str(self) - - def __eq__(self, other): - return type(self) is type(other) and self._query == other._query - - def __hash__(self): - return hash(self._query) - - @staticmethod - def of(subquery: BaseSegment, alias: Optional[str]) -> "SqlFluffSubQuery": - """ - Build a 'SqlFluffSubQuery' object - :param subquery: subquery segment - :param alias: subquery alias - :return: 'SqlFluffSubQuery' object - """ - return SqlFluffSubQuery(subquery, alias) - - -class SqlFluffColumn: - """ - Data Class for SqlFluffColumn - """ - - def __init__(self, name: str, **kwargs): - """ - :param name: column name - :param parent: 'SqlFluffSubQuery' or 'SqlFluffTable' object - :param kwargs: - """ - self._parent: Set[Union[SqlFluffTable, SqlFluffSubQuery]] = set() - self.raw_name = escape_identifier_name(name) - self.source_columns = kwargs.pop("source_columns", ((self.raw_name, None),)) - - def __str__(self): - return ( - f"{self.parent}.{self.raw_name.lower()}" - if self.parent is not None and not isinstance(self.parent, SqlFluffPath) - else f"{self.raw_name.lower()}" - ) - - def __repr__(self): - return "Column: " + str(self) - - def __eq__(self, other): - return type(self) is type(other) and str(self) == str(other) - - def __hash__(self): - return hash(str(self)) - - @property - def parent(self) -> Optional[Union[SqlFluffTable, SqlFluffSubQuery]]: - """ - :return: parent of the table - """ - return list(self._parent)[0] if len(self._parent) == 1 else None - - @parent.setter - def parent(self, value: Union[SqlFluffTable, SqlFluffSubQuery]): - self._parent.add(value) - - @property - def parent_candidates(self) -> List[Union[SqlFluffTable, SqlFluffSubQuery]]: - """ - :return: parent candidate list - """ - return sorted(self._parent, key=lambda p: str(p)) - - @staticmethod - def of(column_segment: BaseSegment, dialect: str) -> "SqlFluffColumn": - """ - Build a 'SqlFluffSubQuery' object - :param column_segment: column segment - :param dialect: dialect to be used in case of running the 'LineageRunner' - :return: - """ - if column_segment.type == "select_clause_element": - source_columns, alias = SqlFluffColumn._get_column_and_alias( - column_segment, dialect - ) - if alias: - return SqlFluffColumn( - alias, - source_columns=source_columns, - ) - if source_columns: - sub_segments = retrieve_segments(column_segment) - column_name = None - for sub_segment in sub_segments: - if sub_segment.type == "column_reference": - column_name = get_identifier(sub_segment) - - return SqlFluffColumn( - column_segment.raw if column_name is None else column_name, - source_columns=source_columns, - ) - - # Wildcard, Case, Function without alias (thus not recognized as an Identifier) - source_columns = SqlFluffColumn._extract_source_columns(column_segment, dialect) - return SqlFluffColumn( - column_segment.raw, - source_columns=source_columns, - ) - - @staticmethod - def _extract_source_columns( - segment: BaseSegment, dialect: str - ) -> List[SqlFluffColumnQualifierTuple]: - """ - - :param segment: - :param dialect: - :return: - """ - if segment.type == "identifier" or is_wildcard(segment): - return [SqlFluffColumnQualifierTuple(segment.raw, None)] - if segment.type == "column_reference": - parent, column = SqlFluffColumn._get_column_and_parent(segment) - return [SqlFluffColumnQualifierTuple(column, parent)] - if segment.type in NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE: - sub_segments = retrieve_segments(segment) - col_list = [] - for sub_segment in sub_segments: - if sub_segment.type == "bracketed": - if is_subquery(sub_segment): - col_list += SqlFluffColumn._get_column_from_subquery( - sub_segment, dialect - ) - else: - col_list += SqlFluffColumn._get_column_from_parenthesis( - sub_segment, dialect - ) - elif sub_segment.type in SOURCE_COLUMN_SEGMENT_TYPE or is_wildcard( - sub_segment - ): - res = SqlFluffColumn._extract_source_columns(sub_segment, dialect) - col_list.extend(res) - return col_list - return [] - - @staticmethod - def _get_column_from_subquery( - sub_segment: BaseSegment, dialect: str - ) -> List[SqlFluffColumnQualifierTuple]: - """ - - :param sub_segment: - :param dialect: - :return: - """ - # This is to avoid circular import - from sqllineage.runner import LineageRunner - - src_cols = [ - lineage[0] - for lineage in LineageRunner( - sub_segment.raw, dialect=dialect, use_sqlparse=False - ).get_column_lineage(exclude_subquery=False) - ] - source_columns = [ - SqlFluffColumnQualifierTuple(src_col.raw_name, src_col.parent.raw_name) - for src_col in src_cols - ] - return source_columns - - @staticmethod - def _get_column_from_parenthesis( - sub_segment: BaseSegment, - dialect: str, - ) -> List[SqlFluffColumnQualifierTuple]: - """ - - :param sub_segment: - :param dialect: - :return: - """ - col, _ = SqlFluffColumn._get_column_and_alias(sub_segment, dialect) - if col: - return col - col, _ = SqlFluffColumn._get_column_and_alias(sub_segment, dialect, False) - return col if col else [] - - @staticmethod - def _get_column_and_alias( - segment: BaseSegment, dialect: str, check_bracketed: bool = True - ) -> Tuple[List[SqlFluffColumnQualifierTuple], Optional[str]]: - alias = None - columns = [] - sub_segments = retrieve_segments(segment, check_bracketed) - for sub_segment in sub_segments: - if sub_segment.type == "alias_expression": - alias = get_identifier(sub_segment) - elif sub_segment.type in SOURCE_COLUMN_SEGMENT_TYPE or is_wildcard( - sub_segment - ): - res = SqlFluffColumn._extract_source_columns(sub_segment, dialect) - columns += res if res else [] - - return columns, alias - - @staticmethod - def _get_column_and_parent(col_segment: BaseSegment) -> Tuple[Optional[str], str]: - identifiers = retrieve_segments(col_segment) - if len(identifiers) > 1: - return identifiers[-2].raw, identifiers[-1].raw - return None, identifiers[-1].raw - - def to_source_columns( - self, alias_mapping: Dict[str, Union[SqlFluffTable, SqlFluffSubQuery]] - ): - """ - Best guess for source table given all the possible table/subquery and their alias. - """ - - def _to_src_col( - name: str, parent: Optional[Union[SqlFluffTable, SqlFluffSubQuery]] = None - ): - col = SqlFluffColumn(name) - if parent: - col.parent = parent - return col - - source_columns = set() - for (src_col, qualifier) in self.source_columns: - if qualifier is None: - if src_col == "*": - # select * - for table in set(alias_mapping.values()): - source_columns.add(_to_src_col(src_col, table)) - else: - # select unqualified column - src_col = _to_src_col(src_col, None) - for table in set(alias_mapping.values()): - # in case of only one table, we get the right answer - # in case of multiple tables, a bunch of possible tables are set - src_col.parent = table - source_columns.add(src_col) - else: - if alias_mapping.get(qualifier): - source_columns.add( - _to_src_col(src_col, alias_mapping.get(qualifier)) - ) - else: - source_columns.add(_to_src_col(src_col, SqlFluffTable(qualifier))) - return source_columns - - -class SqlFluffAnalyzerContext: - """ - Data class to hold the analyzer context - """ - - subquery: Optional[SqlFluffSubQuery] - prev_cte: Optional[Set[SqlFluffSubQuery]] - prev_write: Optional[Set[Union[SqlFluffSubQuery, SqlFluffTable]]] - - def __init__( - self, - subquery: Optional[SqlFluffSubQuery] = None, - prev_cte: Optional[Set[SqlFluffSubQuery]] = None, - prev_write: Optional[Set[Union[SqlFluffSubQuery, SqlFluffTable]]] = None, - ): - self.subquery = subquery - self.prev_cte = prev_cte - self.prev_write = prev_write diff --git a/sqllineage/sqlfluff_core/subquery/ddl_drop_extractor.py b/sqllineage/sqlfluff_core/subquery/ddl_drop_extractor.py deleted file mode 100644 index af5d07ae..00000000 --- a/sqllineage/sqlfluff_core/subquery/ddl_drop_extractor.py +++ /dev/null @@ -1,51 +0,0 @@ -from sqlfluff.core.parser import BaseSegment - -from sqllineage.sqlfluff_core.holders import ( - SqlFluffStatementLineageHolder, - SqlFluffSubQueryLineageHolder, -) -from sqllineage.sqlfluff_core.models import SqlFluffAnalyzerContext -from sqllineage.sqlfluff_core.models import SqlFluffTable -from sqllineage.sqlfluff_core.subquery.lineage_holder_extractor import ( - LineageHolderExtractor, -) - - -class DdlDropExtractor(LineageHolderExtractor): - """ - DDL Drop queries lineage extractor - """ - - DDL_DROP_STMT_TYPES = ["drop_table_statement"] - - def __init__(self, dialect: str): - super().__init__(dialect) - - def can_extract(self, statement_type: str) -> bool: - """ - Determine if the current lineage holder extractor can process the statement - :param statement_type: a sqlfluff segment type - """ - return statement_type in self.DDL_DROP_STMT_TYPES - - def extract( - self, - statement: BaseSegment, - context: SqlFluffAnalyzerContext, - is_sub_query: bool = False, - ) -> SqlFluffSubQueryLineageHolder: - """ - Extract lineage for a given statement. - :param statement: a sqlfluff segment with a statement - :param context: 'SqlFluffAnalyzerContext' - :param is_sub_query: determine if the statement is bracketed or not - :return 'SqlFluffSubQueryLineageHolder' object - """ - holder = SqlFluffStatementLineageHolder() - for table in { - SqlFluffTable.of(t) - for t in statement.segments - if t.type == "table_reference" - }: - holder.add_drop(table) - return holder diff --git a/sqllineage/sqlfluff_core/utils/entities.py b/sqllineage/sqlfluff_core/utils/entities.py deleted file mode 100644 index 43efa824..00000000 --- a/sqllineage/sqlfluff_core/utils/entities.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import NamedTuple, Optional, Union - -from sqlfluff.core.parser.segments import BaseSegment - - -class SubSqlFluffQueryTuple(NamedTuple): - """ - Tuple of segment and optional alias - """ - - bracketed: Union[BaseSegment] - alias: Optional[str] - - -class SqlFluffColumnQualifierTuple(NamedTuple): - """ - Tuple of column name and qualifier - """ - - column: str - qualifier: Optional[str] diff --git a/sqllineage/utils/entities.py b/sqllineage/utils/entities.py index 41202b5a..15363906 100644 --- a/sqllineage/utils/entities.py +++ b/sqllineage/utils/entities.py @@ -1,18 +1,11 @@ -from typing import NamedTuple, Optional - -from sqlparse.sql import Parenthesis, Token +from typing import Any, NamedTuple, Optional class SubQueryTuple(NamedTuple): - parenthesis: Parenthesis - alias: str + parenthesis: Any + alias: Optional[str] class ColumnQualifierTuple(NamedTuple): column: str qualifier: Optional[str] - - -class ColumnExpression(NamedTuple): - is_identity: bool - token: Token diff --git a/sqllineage/utils/helpers.py b/sqllineage/utils/helpers.py index 2a6f8a68..9f1f7c9d 100644 --- a/sqllineage/utils/helpers.py +++ b/sqllineage/utils/helpers.py @@ -1,6 +1,6 @@ import logging -import re from argparse import Namespace +from typing import List logger = logging.getLogger(__name__) @@ -30,27 +30,16 @@ def extract_sql_from_args(args: Namespace) -> str: return sql -def clean_parentheses(stmt: str) -> str: - """ - Clean redundant parentheses from a SQL statement e.g: - `SELECT col1 FROM (((((((SELECT col1 FROM tab1))))))) dt` - will be: - `SELECT col1 FROM (SELECT col1 FROM tab1) dt` +def split(sql: str) -> List[str]: + # TODO: we need a parser independent split function + import sqlparse - :param stmt: a SQL str to be cleaned - """ - redundant_parentheses = r"\(\(([^()]+)\)\)" - if re.findall(redundant_parentheses, stmt): - stmt = re.sub(redundant_parentheses, r"(\1)", stmt) - stmt = clean_parentheses(stmt) - return stmt + # sometimes sqlparse split out a statement that is comment only, we want to exclude that + return [s.value for s in sqlparse.parse(sql) if s.token_first(skip_cm=True)] -def is_subquery_statement(stmt: str) -> bool: - parentheses_regex = r"^\(.*\)" - return bool(re.match(parentheses_regex, stmt)) +def trim_comment(sql: str) -> str: + # TODO: we need a parser independent trim_comment function + import sqlparse - -def remove_statement_parentheses(stmt: str) -> str: - parentheses_regex = r"^\((.*)\)" - return re.sub(parentheses_regex, r"\1", stmt) + return str(sqlparse.format(sql, strip_comments=True)) diff --git a/sqllineagejs/src/App.js b/sqllineagejs/src/App.js index 413c6fe8..70c62208 100644 --- a/sqllineagejs/src/App.js +++ b/sqllineagejs/src/App.js @@ -1,15 +1,32 @@ import React, {useMemo} from 'react'; -import {Box, Drawer, FormControl, FormControlLabel, Grid, Paper, Radio, RadioGroup, Tooltip} from "@material-ui/core"; +import { + AppBar, + Box, + Button, + Drawer, + Fade, + FormControl, + FormControlLabel, + Grid, + ListSubheader, + Menu, + MenuItem, + Paper, + Radio, + RadioGroup, + Toolbar, + Tooltip, + Typography +} from "@material-ui/core"; import {DAG} from "./features/editor/DAG"; import {Editor} from "./features/editor/Editor"; import {makeStyles} from "@material-ui/core/styles"; -import AppBar from "@material-ui/core/AppBar"; -import Toolbar from "@material-ui/core/Toolbar"; +import ChevronLeftIcon from '@material-ui/icons/ChevronLeft'; +import CreateIcon from "@material-ui/icons/Create"; +import ExpandMoreIcon from '@material-ui/icons/ExpandMore'; import IconButton from "@material-ui/core/IconButton"; +import LanguageIcon from '@material-ui/icons/Language'; import MenuIcon from "@material-ui/icons/Menu"; -import CreateIcon from "@material-ui/icons/Create"; -import Typography from "@material-ui/core/Typography"; -import ChevronLeftIcon from '@material-ui/icons/ChevronLeft'; import clsx from "clsx"; import {Directory} from "./features/directory/Directory"; import {BrowserRouter as Router, Link} from "react-router-dom"; @@ -55,25 +72,60 @@ const useStyles = makeStyles((theme) => ({ left: ({drawerWidth}) => drawerWidth + "vw", backgroundColor: "transparent", zIndex: 999 + }, + dialect: { + margin: theme.spacing(0, 0.5, 0, 1), + display: 'none', + [theme.breakpoints.up('md')]: { + display: 'block', + }, } })); let isResizing = null; +const dialects = { + "sqlparse": [ + "non-validating" + ], + "sqlfluff": [ + "ansi", + "athena", + "bigquery", + "clickhouse", + "db2", + "exasol", + "hive", + "materialize", + "mysql", + "oracle", + "postgres", + "redshift", + "snowflake", + "soql", + "sparksql", + "sqlite", + "teradata", + "tsql" + ] +} + export default function App() { const editorState = useSelector(selectEditor); - const [selectedValue, setSelectedValue] = React.useState('dag'); - const [open, setOpen] = React.useState(true); + const [viewSelected, setViewSelected] = React.useState('dag'); + const [drawerOpen, setDrawerOpen] = React.useState(true); const [drawerWidth, setDrawerWidth] = React.useState(18); + const [dialectMenuAnchor, setDialectMenuAnchor] = React.useState(null); + const [dialectSelected, setDialectSelected] = React.useState("ansi"); const classes = useStyles({drawerWidth: drawerWidth}); const height = "90vh"; const width = useMemo(() => { let full_width = 100; - return (open ? full_width - drawerWidth : full_width) + "vw" - }, [open, drawerWidth]) + return (drawerOpen ? full_width - drawerWidth : full_width) + "vw" + }, [drawerOpen, drawerWidth]) const handleMouseDown = e => { e.stopPropagation(); @@ -103,6 +155,14 @@ export default function App() { document.removeEventListener("mouseup", handleMouseUp); } + const handleDialectMenuClose = (e) => { + let dialect = e.currentTarget.outerText; + if (dialect !== "") { + setDialectSelected(dialect) + } + setDialectMenuAnchor(null); + } + return (
@@ -115,14 +175,51 @@ export default function App() { color="inherit" aria-label="menu" onClick={() => { - setOpen(!open) + setDrawerOpen(!drawerOpen) }} > - {open ? : } + {drawerOpen ? : } SQLLineage + + + + + + {Object.entries(dialects).map((entry) => ( +
+ {entry[0]} + {entry[1].map((dialect) => ( + + {dialect} + + ) + )} +
+ ))} +
+ {editorState.editable ?
Composing Mode
@@ -132,8 +229,8 @@ export default function App() { { - setSelectedValue("script"); - setOpen(false); + setViewSelected("script"); + setDrawerOpen(false); }} > @@ -145,7 +242,7 @@ export default function App() {
- + - + - - + + setSelectedValue(event.target.value)}> + value={viewSelected} + onChange={(event) => setViewSelected(event.target.value)}> } diff --git a/sqllineagejs/src/features/editor/Editor.js b/sqllineagejs/src/features/editor/Editor.js index 7f89f0c7..8c4ee8ab 100644 --- a/sqllineagejs/src/features/editor/Editor.js +++ b/sqllineagejs/src/features/editor/Editor.js @@ -7,7 +7,8 @@ import { setContentComposed, setDagLevel, setEditable, - setFile + setFile, + setDialect } from "./editorSlice"; import MonacoEditor from "react-monaco-editor"; import {Loading} from "../widget/Loading"; @@ -19,6 +20,7 @@ const useQueryParam = () => { }; export function Editor(props) { + const { height, width, dialect } = props; const dispatch = useDispatch(); const editorState = useSelector(selectEditor); const queryParam = useQueryParam(); @@ -31,21 +33,20 @@ export function Editor(props) { history.push("/"); } else { let file = queryParam.get("f"); - if (editorState.file !== file) { + if (editorState.file !== file || editorState.dialect !== dialect) { dispatch(setFile(file)); + dispatch(setDialect(dialect)); dispatch(setDagLevel("table")); if (file === null) { dispatch(setEditable(true)); - dispatch(fetchDAG({"e": editorState.contentComposed})) + dispatch(fetchDAG({"e": editorState.contentComposed, "dialect": dialect})) } else { dispatch(setEditable(false)); dispatch(fetchContent({"f": file})); - dispatch(fetchDAG({"f": file})); + dispatch(fetchDAG({"f": file, "dialect": dialect})); } } } - - }) const handleEditorDidMount = (editor, monaco) => { @@ -53,7 +54,7 @@ export function Editor(props) { editor.onDidBlurEditorText(() => { if (!editor.getOption(readOnly)) { dispatch(setContentComposed(editor.getValue())); - dispatch(fetchDAG({"e": editor.getValue()})); + dispatch(fetchDAG({"e": editor.getValue(), "dialect": dialect})); } }) editor.onKeyDown(() => { @@ -66,9 +67,9 @@ export function Editor(props) { } if (editorState.editorStatus === "loading") { - return + return } else if (editorState.editorStatus === "failed") { - return + return } else { const options = { minimap: {enabled: false}, @@ -77,8 +78,8 @@ export function Editor(props) { automaticLayout: true } return state.editor; -export const {setContentComposed, setDagLevel, setEditable, setFile} = editorSlice.actions; +export const {setContentComposed, setDagLevel, setEditable, setFile, setDialect} = editorSlice.actions; export default editorSlice.reducer; diff --git a/tests/helpers.py b/tests/helpers.py index 95f95107..063acfe3 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,19 +1,12 @@ -from typing import Union - import networkx as nx +from sqllineage import SQLPARSE_DIALECT from sqllineage.core.models import Column, Table from sqllineage.runner import LineageRunner -from sqllineage.sqlfluff_core.models import SqlFluffColumn, SqlFluffTable -def assert_table_lineage( - lr: LineageRunner, - source_tables=None, - target_tables=None, - test_sqlfluff: bool = False, -): - for (_type, actual, expected) in zip( +def assert_table_lineage(lr: LineageRunner, source_tables=None, target_tables=None): + for _type, actual, expected in zip( ["Source", "Target"], [lr.source_tables, lr.target_tables], [source_tables, target_tables], @@ -22,39 +15,22 @@ def assert_table_lineage( expected = ( set() if expected is None - else { - (SqlFluffTable(t) if test_sqlfluff else Table(t)) - if isinstance(t, str) - else t - for t in expected - } + else {Table(t) if isinstance(t, str) else t for t in expected} ) assert ( actual == expected ), f"\n\tExpected {_type} Table: {expected}\n\tActual {_type} Table: {actual}" -def assert_column_lineage( - lr: LineageRunner, column_lineages=None, test_sqlfluff: bool = False -): +def assert_column_lineage(lr: LineageRunner, column_lineages=None): expected = set() if column_lineages: for src, tgt in column_lineages: - src_col: Union[SqlFluffColumn, Column] = ( - SqlFluffColumn(src.column) if test_sqlfluff else Column(src.column) - ) + src_col: Column = Column(src.column) if src.qualifier is not None: - src_col.parent = ( - SqlFluffTable(src.qualifier) - if test_sqlfluff - else Table(src.qualifier) - ) - tgt_col: Union[SqlFluffColumn, Column] = ( - SqlFluffColumn(tgt.column) if test_sqlfluff else Column(tgt.column) - ) - tgt_col.parent = ( - SqlFluffTable(tgt.qualifier) if test_sqlfluff else Table(tgt.qualifier) - ) + src_col.parent = Table(src.qualifier) + tgt_col: Column = Column(tgt.column) + tgt_col.parent = Table(tgt.qualifier) expected.add((src_col, tgt_col)) actual = {(lineage[0], lineage[-1]) for lineage in set(lr.get_column_lineage())} @@ -71,15 +47,14 @@ def assert_table_lineage_equal( test_sqlfluff: bool = True, test_sqlparse: bool = True, ): + lr = LineageRunner(sql, dialect=SQLPARSE_DIALECT) + lr_sqlfluff = LineageRunner(sql, dialect=dialect) if test_sqlparse: - lr = LineageRunner(sql) assert_table_lineage(lr, source_tables, target_tables) - if test_sqlfluff: - lr_sqlfluff = LineageRunner(sql, dialect=dialect, use_sqlparse=False) - assert_table_lineage(lr_sqlfluff, source_tables, target_tables, test_sqlfluff) - if test_sqlparse: - assert_lr_graphs_match(lr, lr_sqlfluff) + assert_table_lineage(lr_sqlfluff, source_tables, target_tables) + if test_sqlparse and test_sqlfluff: + assert_lr_graphs_match(lr, lr_sqlfluff) def assert_column_lineage_equal( @@ -89,15 +64,14 @@ def assert_column_lineage_equal( test_sqlfluff: bool = True, test_sqlparse: bool = True, ): + lr = LineageRunner(sql, dialect=SQLPARSE_DIALECT) + lr_sqlfluff = LineageRunner(sql, dialect=dialect) if test_sqlparse: - lr = LineageRunner(sql) assert_column_lineage(lr, column_lineages) - if test_sqlfluff: - lr_sqlfluff = LineageRunner(sql, dialect=dialect, use_sqlparse=False) - assert_column_lineage(lr_sqlfluff, column_lineages, test_sqlfluff) - if test_sqlparse: - assert_lr_graphs_match(lr, lr_sqlfluff) + assert_column_lineage(lr_sqlfluff, column_lineages) + if test_sqlparse and test_sqlfluff: + assert_lr_graphs_match(lr, lr_sqlfluff) def assert_lr_graphs_match(lr: LineageRunner, lr_sqlfluff: LineageRunner) -> None: diff --git a/tests/test_cli.py b/tests/test_cli.py index de9b79d1..85ab66f6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,8 +1,10 @@ -import pathlib +import os +from pathlib import Path from unittest.mock import patch import pytest +from sqllineage import DATA_FOLDER from sqllineage.cli import main @@ -10,10 +12,14 @@ def test_cli_dummy(_): main([]) main(["-e", "select * from dual"]) - main(["-f", __file__]) main(["-e", "insert into foo select * from dual", "-l", "column"]) - main(["-e", "select * from dual", "-f", __file__]) - main(["-f", __file__, "-g"]) + for dirname, _, files in os.walk(DATA_FOLDER): + if len(files) > 0: + sql_file = str(Path(dirname).joinpath(Path(files[0]))) + main(["-f", sql_file]) + main(["-e", "select * from dual", "-f", sql_file]) + main(["-f", sql_file, "-g"]) + break main(["-g"]) main( [ @@ -25,7 +31,7 @@ def test_cli_dummy(_): def test_file_exception(): - for args in (["-f", str(pathlib.Path().absolute())], ["-f", "nonexist_file"]): + for args in (["-f", str(Path().absolute())], ["-f", "nonexist_file"]): with pytest.raises(SystemExit) as e: main(args) assert e.value.code == 1 diff --git a/tests/test_columns.py b/tests/test_columns.py index 09fa8c5a..29e83f25 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,47 +1,44 @@ import pytest +from sqllineage import SQLPARSE_DIALECT from sqllineage.runner import LineageRunner from sqllineage.utils.entities import ColumnQualifierTuple from .helpers import assert_column_lineage_equal def test_select_column(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 AS col2 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col2", "tab1"))], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT tab2.col1 AS col2 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col2", "tab1"))], - "sparksql", ) def test_select_column_wildcard(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT * FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("*", "tab2"), ColumnQualifierTuple("*", "tab1"))], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT * FROM tab2 a INNER JOIN tab3 b @@ -52,12 +49,11 @@ def test_select_column_wildcard(): (ColumnQualifierTuple("*", "tab2"), ColumnQualifierTuple("*", "tab1")), (ColumnQualifierTuple("*", "tab3"), ColumnQualifierTuple("*", "tab1")), ], - "sparksql", ) def test_select_column_using_function(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT max(col1), count(*) FROM tab2""" @@ -73,9 +69,8 @@ def test_select_column_using_function(): ColumnQualifierTuple("count(*)", "tab1"), ), ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT max(col1) AS col2, count(*) AS cnt FROM tab2""" @@ -88,9 +83,8 @@ def test_select_column_using_function(): ), (ColumnQualifierTuple("*", "tab2"), ColumnQualifierTuple("cnt", "tab1")), ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT cast(col1 as timestamp) FROM tab2""" assert_column_lineage_equal( @@ -101,20 +95,18 @@ def test_select_column_using_function(): ColumnQualifierTuple("cast(col1 as timestamp)", "tab1"), ) ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT cast(col1 as timestamp) as col2 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col2", "tab1"))], - "sparksql", ) def test_select_column_using_function_with_complex_parameter(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT if(col1 = 'foo' AND col2 = 'bar', 1, 0) AS flag FROM tab2""" assert_column_lineage_equal( @@ -129,12 +121,11 @@ def test_select_column_using_function_with_complex_parameter(): ColumnQualifierTuple("flag", "tab1"), ), ], - "sparksql", ) def test_select_column_using_window_function(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT row_number() OVER (PARTITION BY col1 ORDER BY col2 DESC) AS rnum FROM tab2""" assert_column_lineage_equal( @@ -149,12 +140,11 @@ def test_select_column_using_window_function(): ColumnQualifierTuple("rnum", "tab1"), ), ], - "sparksql", ) def test_select_column_using_window_function_with_parameters(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col0, max(col3) OVER (PARTITION BY col1 ORDER BY col2 DESC) AS rnum, col4 @@ -183,12 +173,11 @@ def test_select_column_using_window_function_with_parameters(): ColumnQualifierTuple("col4", "tab1"), ), ], - "sparksql", ) def test_select_column_using_expression(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 + col2 FROM tab2""" assert_column_lineage_equal( @@ -203,9 +192,8 @@ def test_select_column_using_expression(): ColumnQualifierTuple("col1 + col2", "tab1"), ), ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 + col2 AS col3 FROM tab2""" assert_column_lineage_equal( @@ -220,12 +208,11 @@ def test_select_column_using_expression(): ColumnQualifierTuple("col3", "tab1"), ), ], - "sparksql", ) def test_select_column_using_expression_in_parenthesis(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT (col1 + col2) AS col3 FROM tab2""" assert_column_lineage_equal( @@ -240,12 +227,11 @@ def test_select_column_using_expression_in_parenthesis(): ColumnQualifierTuple("col3", "tab1"), ), ], - "sparksql", ) def test_select_column_using_boolean_expression_in_parenthesis(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT (col1 > 0 AND col2 > 0) AS col3 FROM tab2""" assert_column_lineage_equal( @@ -260,12 +246,11 @@ def test_select_column_using_boolean_expression_in_parenthesis(): ColumnQualifierTuple("col3", "tab1"), ), ], - "sparksql", ) def test_select_column_using_expression_with_table_qualifier_without_column_alias(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT a.col1 + a.col2 + a.col3 + a.col4 FROM tab2 a""" assert_column_lineage_equal( @@ -288,12 +273,11 @@ def test_select_column_using_expression_with_table_qualifier_without_column_alia ColumnQualifierTuple("a.col1 + a.col2 + a.col3 + a.col4", "tab1"), ), ], - "sparksql", ) def test_select_column_using_case_when(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT CASE WHEN col1 = 1 THEN 'V1' WHEN col1 = 2 THEN 'V2' END FROM tab2""" assert_column_lineage_equal( @@ -306,20 +290,18 @@ def test_select_column_using_case_when(): ), ), ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT CASE WHEN col1 = 1 THEN 'V1' WHEN col1 = 2 THEN 'V2' END AS col2 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col2", "tab1"))], - "sparksql", ) def test_select_column_using_case_when_with_subquery(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT CASE WHEN (SELECT avg(col1) FROM tab3) > 0 AND col2 = 1 THEN (SELECT avg(col1) FROM tab3) ELSE 0 END AS col1 FROM tab4""" assert_column_lineage_equal( @@ -334,12 +316,11 @@ def test_select_column_using_case_when_with_subquery(): ColumnQualifierTuple("col1", "tab1"), ), ], - "sparksql", ) def test_select_column_using_multiple_case_when_with_subquery(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT CASE WHEN (SELECT avg(col1) FROM tab3) > 0 AND col2 = 1 THEN (SELECT avg(col1) FROM tab3) WHEN (SELECT avg(col1) FROM tab3) > 0 AND col2 = 1 THEN (SELECT avg(col1) FROM tab5) ELSE 0 END AS col1 @@ -360,31 +341,28 @@ def test_select_column_using_multiple_case_when_with_subquery(): ColumnQualifierTuple("col1", "tab1"), ), ], - "sparksql", ) def test_select_column_with_table_qualifier(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT tab2.col1 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT t.col1 FROM tab2 AS t""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) def test_select_columns(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1, col2 FROM tab2""" @@ -400,9 +378,8 @@ def test_select_columns(): ColumnQualifierTuple("col2", "tab1"), ), ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT max(col1), max(col2) FROM tab2""" @@ -418,50 +395,45 @@ def test_select_columns(): ColumnQualifierTuple("max(col2)", "tab1"), ), ], - "sparksql", ) def test_select_column_in_subquery(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM (SELECT col1 FROM tab2) dt""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM (SELECT col1, col2 FROM tab2) dt""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM (SELECT col1 FROM tab2)""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) def test_select_column_in_subquery_with_two_parenthesis(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM ((SELECT col1 FROM tab2)) dt""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) def test_select_column_in_subquery_with_two_parenthesis_and_blank_in_between(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM ( (SELECT col1 FROM tab2) @@ -469,62 +441,57 @@ def test_select_column_in_subquery_with_two_parenthesis_and_blank_in_between(): assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) def test_select_column_in_subquery_with_two_parenthesis_and_union(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM ( (SELECT col1 FROM tab2) UNION ALL (SELECT col1 FROM tab3) ) dt""" - expected_columns_lineage = [ - ( - ColumnQualifierTuple("col1", "tab2"), - ColumnQualifierTuple("col1", "tab1"), - ), - ( - ColumnQualifierTuple("col1", "tab3"), - ColumnQualifierTuple("col1", "tab1"), - ), - ] - assert_column_lineage_equal(sql, expected_columns_lineage, test_sqlfluff=False) - # graph are not compared because UNION/UNION ALL is handled different in FROM clause assert_column_lineage_equal( - sql, expected_columns_lineage, dialect="sparksql", test_sqlparse=False + sql, + [ + ( + ColumnQualifierTuple("col1", "tab2"), + ColumnQualifierTuple("col1", "tab1"), + ), + ( + ColumnQualifierTuple("col1", "tab3"), + ColumnQualifierTuple("col1", "tab1"), + ), + ], ) def test_select_column_in_subquery_with_two_parenthesis_and_union_v2(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM ( SELECT col1 FROM tab2 UNION ALL SELECT col1 FROM tab3 ) dt""" - expected_columns_lineage = [ - ( - ColumnQualifierTuple("col1", "tab2"), - ColumnQualifierTuple("col1", "tab1"), - ), - ( - ColumnQualifierTuple("col1", "tab3"), - ColumnQualifierTuple("col1", "tab1"), - ), - ] - assert_column_lineage_equal(sql, expected_columns_lineage, test_sqlfluff=False) - # graph are not compared because UNION/UNION ALL is handled different in FROM clause assert_column_lineage_equal( - sql, expected_columns_lineage, dialect="sparksql", test_sqlparse=False + sql, + [ + ( + ColumnQualifierTuple("col1", "tab2"), + ColumnQualifierTuple("col1", "tab1"), + ), + ( + ColumnQualifierTuple("col1", "tab3"), + ColumnQualifierTuple("col1", "tab1"), + ), + ], ) def test_select_column_from_table_join(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT tab2.col1, tab3.col2 FROM tab2 @@ -542,9 +509,8 @@ def test_select_column_from_table_join(): ColumnQualifierTuple("col2", "tab1"), ), ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT tab2.col1 AS col3, tab3.col2 AS col4 FROM tab2 @@ -562,9 +528,8 @@ def test_select_column_from_table_join(): ColumnQualifierTuple("col4", "tab1"), ), ], - "sparksql", ) - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT a.col1 AS col3, b.col2 AS col4 FROM tab2 a @@ -582,12 +547,11 @@ def test_select_column_from_table_join(): ColumnQualifierTuple("col4", "tab1"), ), ], - "sparksql", ) def test_select_column_without_table_qualifier_from_table_join(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1 FROM tab2 a INNER JOIN tab3 b @@ -595,12 +559,11 @@ def test_select_column_without_table_qualifier_from_table_join(): assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", None), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) def test_select_column_from_same_table_multiple_time_using_different_alias(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT a.col1 AS col2, b.col1 AS col3 FROM tab2 a @@ -618,12 +581,11 @@ def test_select_column_from_same_table_multiple_time_using_different_alias(): ColumnQualifierTuple("col3", "tab1"), ), ], - "sparksql", ) def test_comment_after_column_comma_first(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT a.col1 --, a.col2 , a.col3 @@ -640,12 +602,11 @@ def test_comment_after_column_comma_first(): ColumnQualifierTuple("col3", "tab1"), ), ], - "sparksql", ) def test_comment_after_column_comma_last(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT a.col1, -- a.col2, a.col3 @@ -662,12 +623,11 @@ def test_comment_after_column_comma_last(): ColumnQualifierTuple("col3", "tab1"), ), ], - "sparksql", ) def test_cast_with_comparison(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT cast(col1 = 1 AS int) col1, col2 = col3 col2 FROM tab2""" assert_column_lineage_equal( @@ -686,45 +646,41 @@ def test_cast_with_comparison(): ColumnQualifierTuple("col2", "tab1"), ), ], - "sparksql", ) @pytest.mark.parametrize("dtype", ["string", "timestamp", "date", "decimal(18, 0)"]) -def test_cast_to_data_type(dtype): - sql = f"""INSERT OVERWRITE TABLE tab1 +def test_cast_to_data_type(dtype: str): + sql = f"""INSERT INTO tab1 SELECT cast(col1 as {dtype}) AS col1 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) @pytest.mark.parametrize("dtype", ["string", "timestamp", "date", "decimal(18, 0)"]) -def test_nested_cast_to_data_type(dtype): - sql = f"""INSERT OVERWRITE TABLE tab1 +def test_nested_cast_to_data_type(dtype: str): + sql = f"""INSERT INTO tab1 SELECT cast(cast(col1 AS {dtype}) AS {dtype}) AS col1 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) - sql = f"""INSERT OVERWRITE TABLE tab1 + sql = f"""INSERT INTO tab1 SELECT cast(cast(cast(cast(cast(col1 AS {dtype}) AS {dtype}) AS {dtype}) AS {dtype}) AS {dtype}) AS col1 FROM tab2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - "sparksql", ) @pytest.mark.parametrize("dtype", ["string", "timestamp", "date", "decimal(18, 0)"]) -def test_cast_to_data_type_with_case_when(dtype): - sql = f"""INSERT OVERWRITE TABLE tab1 +def test_cast_to_data_type_with_case_when(dtype: str): + sql = f"""INSERT INTO tab1 SELECT cast(case when col1 > 0 then col2 else col3 end as {dtype}) AS col1 FROM tab2""" assert_column_lineage_equal( @@ -743,14 +699,13 @@ def test_cast_to_data_type_with_case_when(dtype): ColumnQualifierTuple("col1", "tab1"), ), ], - "sparksql", ) def test_cast_using_constant(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT cast('2012-12-21' as date) AS col2""" - assert_column_lineage_equal(sql, dialect="sparksql") + assert_column_lineage_equal(sql) def test_window_function_in_subquery(): @@ -767,39 +722,62 @@ def test_window_function_in_subquery(): (ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("rn", "tab1")), (ColumnQualifierTuple("col2", "tab2"), ColumnQualifierTuple("rn", "tab1")), ], - dialect="sparksql", ) def test_invalid_syntax_as_without_alias(): - sql = """INSERT OVERWRITE TABLE tab1 + sql = """INSERT INTO tab1 SELECT col1, col2 as, col3 FROM tab2""" # just assure no exception, don't guarantee the result - LineageRunner(sql).print_column_lineage() + LineageRunner(sql, dialect=SQLPARSE_DIALECT).print_column_lineage() -def test_column_reference_from_cte_using_alias(): - sql = """WITH wtab1 AS (SELECT col1 FROM tab2) -INSERT OVERWRITE TABLE tab1 -SELECT wt.col1 FROM wtab1 wt""" +def test_column_with_ctas_and_func(): + sql = """CREATE TABLE tab2 AS +SELECT + coalesce(col1, 0) AS col1, + IF( + col1 IS NOT NULL, + 1, + NULL + ) AS col2 +FROM + tab1""" assert_column_lineage_equal( sql, - [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - dialect="sparksql", + [ + ( + ColumnQualifierTuple("col1", "tab1"), + ColumnQualifierTuple("col1", "tab2"), + ), + ( + ColumnQualifierTuple("col1", "tab1"), + ColumnQualifierTuple("col2", "tab2"), + ), + ], ) def test_column_reference_from_cte_using_qualifier(): sql = """WITH wtab1 AS (SELECT col1 FROM tab2) -INSERT OVERWRITE TABLE tab1 +INSERT INTO tab1 SELECT wtab1.col1 FROM wtab1""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], - dialect="sparksql", + ) + + +def test_column_reference_from_cte_using_alias(): + sql = """WITH wtab1 AS (SELECT col1 FROM tab2) +INSERT INTO tab1 +SELECT wt.col1 FROM wtab1 wt""" + assert_column_lineage_equal( + sql, + [(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))], ) @@ -807,12 +785,11 @@ def test_column_reference_from_previous_defined_cte(): sql = """WITH cte1 AS (SELECT a FROM tab1), cte2 AS (SELECT a FROM cte1) -INSERT OVERWRITE TABLE tab2 +INSERT INTO tab2 SELECT a FROM cte2""" assert_column_lineage_equal( sql, [(ColumnQualifierTuple("a", "tab1"), ColumnQualifierTuple("a", "tab2"))], - dialect="sparksql", ) @@ -820,7 +797,7 @@ def test_multiple_column_references_from_previous_defined_cte(): sql = """WITH cte1 AS (SELECT a, b FROM tab1), cte2 AS (SELECT a, max(b) AS b_max, count(b) AS b_cnt FROM cte1 GROUP BY a) -INSERT OVERWRITE TABLE tab2 +INSERT INTO tab2 SELECT cte1.a, cte2.b_max, cte2.b_cnt FROM cte1 JOIN cte2 WHERE cte1.a = cte2.a""" assert_column_lineage_equal( @@ -830,12 +807,11 @@ def test_multiple_column_references_from_previous_defined_cte(): (ColumnQualifierTuple("b", "tab1"), ColumnQualifierTuple("b_max", "tab2")), (ColumnQualifierTuple("b", "tab1"), ColumnQualifierTuple("b_cnt", "tab2")), ], - dialect="sparksql", ) def test_column_reference_with_ansi89_join(): - sql = """INSERT OVERWRITE TABLE tab3 + sql = """INSERT INTO tab3 SELECT a.id, a.name AS name1, b.name AS name2 @@ -857,7 +833,6 @@ def test_column_reference_with_ansi89_join(): ColumnQualifierTuple("name2", "tab3"), ), ], - dialect="sparksql", ) @@ -865,7 +840,7 @@ def test_smarter_column_resolution_using_query_context(): sql = """WITH cte1 AS (SELECT a, b FROM tab1), cte2 AS (SELECT c, d FROM tab2) -INSERT OVERWRITE TABLE tab3 +INSERT INTO tab3 SELECT b, d FROM cte1 JOIN cte2 WHERE cte1.a = cte2.c""" assert_column_lineage_equal( @@ -874,12 +849,11 @@ def test_smarter_column_resolution_using_query_context(): (ColumnQualifierTuple("b", "tab1"), ColumnQualifierTuple("b", "tab3")), (ColumnQualifierTuple("d", "tab2"), ColumnQualifierTuple("d", "tab3")), ], - dialect="sparksql", ) def test_column_reference_using_union(): - sql = """INSERT OVERWRITE TABLE tab3 + sql = """INSERT INTO tab3 SELECT col1 FROM tab1 UNION ALL @@ -897,9 +871,8 @@ def test_column_reference_using_union(): ColumnQualifierTuple("col1", "tab3"), ), ], - dialect="sparksql", ) - sql = """INSERT OVERWRITE TABLE tab3 + sql = """INSERT INTO tab3 SELECT col1 FROM tab1 UNION @@ -917,12 +890,11 @@ def test_column_reference_using_union(): ColumnQualifierTuple("col1", "tab3"), ), ], - dialect="sparksql", ) def test_column_lineage_multiple_paths_for_same_column(): - sql = """INSERT OVERWRITE TABLE tab2 + sql = """INSERT INTO tab2 SELECT tab1.id, coalesce(join_table_1.col1, join_table_2.col1, join_table_3.col1) AS col1 FROM tab1 @@ -944,7 +916,6 @@ def test_column_lineage_multiple_paths_for_same_column(): ColumnQualifierTuple("col1", "tab2"), ), ], - dialect="sparksql", ) @@ -959,8 +930,8 @@ def test_column_lineage_multiple_paths_for_same_column(): "coalesce(col1, 0) as decimal(10, 6)", ], ) -def test_column_try_cast_with_func(func): - sql = f"""INSERT OVERWRITE TABLE tab2 +def test_column_try_cast_with_func(func: str): + sql = f"""INSERT INTO tab2 SELECT try_cast({func}) AS col2 FROM tab1""" assert_column_lineage_equal( @@ -971,31 +942,4 @@ def test_column_try_cast_with_func(func): ColumnQualifierTuple("col2", "tab2"), ), ], - dialect="sparksql", - ) - - -def test_column_with_ctas_and_func(): - sql = """CREATE TABLE tab2 AS -SELECT - coalesce(col1, 0) AS col1, - IF( - col1 IS NOT NULL, - 1, - NULL - ) AS col2 -FROM - tab1""" - assert_column_lineage_equal( - sql, - [ - ( - ColumnQualifierTuple("col1", "tab1"), - ColumnQualifierTuple("col1", "tab2"), - ), - ( - ColumnQualifierTuple("col1", "tab1"), - ColumnQualifierTuple("col2", "tab2"), - ), - ], ) diff --git a/tests/test_cte.py b/tests/test_cte.py index 6c4ea539..1fd979f5 100644 --- a/tests/test_cte.py +++ b/tests/test_cte.py @@ -12,17 +12,6 @@ def test_with_select_one(): ) -def test_with_select_one_without_as(): - # AS in CTE is negligible in SparkSQL, however it is required in MySQL. See below reference - # https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-cte.html - # https://dev.mysql.com/doc/refman/8.0/en/with.html - assert_table_lineage_equal( - "WITH wtab1 (SELECT * FROM schema1.tab1) SELECT * FROM wtab1", - {"schema1.tab1"}, - dialect="mysql", - ) - - def test_with_select_many(): assert_table_lineage_equal( """WITH @@ -65,30 +54,3 @@ def test_with_insert(): {"tab2"}, {"tab3"}, ) - - -def test_with_insert_overwrite(): - assert_table_lineage_equal( - "WITH tab1 AS (SELECT * FROM tab2) INSERT OVERWRITE tab3 SELECT * FROM tab1", - {"tab2"}, - {"tab3"}, - dialect="sparksql", - ) - - -def test_with_insert_plus_keyword_table(): - assert_table_lineage_equal( - "WITH tab1 AS (SELECT * FROM tab2) INSERT INTO TABLE tab3 SELECT * FROM tab1", - {"tab2"}, - {"tab3"}, - dialect="sparksql", - ) - - -def test_with_insert_overwrite_plus_keyword_table(): - assert_table_lineage_equal( - "WITH tab1 AS (SELECT * FROM tab2) INSERT OVERWRITE TABLE tab3 SELECT * FROM tab1", - {"tab2"}, - {"tab3"}, - dialect="sparksql", - ) diff --git a/tests/test_cte_dialect_specific.py b/tests/test_cte_dialect_specific.py new file mode 100644 index 00000000..9af9675c --- /dev/null +++ b/tests/test_cte_dialect_specific.py @@ -0,0 +1,50 @@ +import pytest + +from .helpers import assert_table_lineage_equal + + +""" +This test class will contain all the tests for testing 'CTE Queries' where the dialect is not ANSI. +""" + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_with_insert_plus_table_keyword(dialect: str): + assert_table_lineage_equal( + "WITH tab1 AS (SELECT * FROM tab2) INSERT INTO TABLE tab3 SELECT * FROM tab1", + {"tab2"}, + {"tab3"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_with_insert_overwrite(dialect: str): + assert_table_lineage_equal( + "WITH tab1 AS (SELECT * FROM tab2) INSERT OVERWRITE TABLE tab3 SELECT * FROM tab1", + {"tab2"}, + {"tab3"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_with_insert_overwrite_without_table_keyword(dialect: str): + assert_table_lineage_equal( + "WITH tab1 AS (SELECT * FROM tab2) INSERT OVERWRITE tab3 SELECT * FROM tab1", + {"tab2"}, + {"tab3"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_with_select_one_without_as(dialect: str): + # AS in CTE is negligible in SparkSQL, however it is required in most other dialects + # https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-cte.html + # https://dev.mysql.com/doc/refman/8.0/en/with.html + assert_table_lineage_equal( + "WITH wtab1 (SELECT * FROM schema1.tab1) SELECT * FROM wtab1", + {"schema1.tab1"}, + dialect=dialect, + ) diff --git a/tests/test_exception.py b/tests/test_exception.py index 732beb23..09cb7e7f 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -1,9 +1,23 @@ import pytest -from sqllineage.exceptions import SQLLineageException +from sqllineage.exceptions import ( + InvalidSyntaxException, + SQLLineageException, + UnsupportedStatementException, +) from sqllineage.runner import LineageRunner def test_select_without_table(): with pytest.raises(SQLLineageException): LineageRunner("select * from where foo='bar'")._eval() + + +def test_unsupported_query_type_in_sqlfluff(): + with pytest.raises(UnsupportedStatementException): + LineageRunner("WRONG SELECT FROM tab1")._eval() + + +def test_partial_unparsable_query_in_sqlfluff(): + with pytest.raises(InvalidSyntaxException): + LineageRunner("SELECT * FROM tab1 AS FULL FULL OUTER JOIN tab2")._eval() diff --git a/tests/test_insert.py b/tests/test_insert.py index 7e5bda98..d94864ce 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -5,12 +5,6 @@ def test_insert_into(): assert_table_lineage_equal("INSERT INTO tab1 VALUES (1, 2)", set(), {"tab1"}) -def test_insert_into_with_keyword_table(): - assert_table_lineage_equal( - "INSERT INTO TABLE tab1 VALUES (1, 2)", set(), {"tab1"}, dialect="sparksql" - ) - - def test_insert_into_with_columns(): assert_table_lineage_equal( "INSERT INTO tab1 (col1, col2) SELECT * FROM tab2;", {"tab2"}, {"tab1"} @@ -29,78 +23,3 @@ def test_insert_into_with_columns_and_select_union(): {"tab2", "tab3"}, {"tab1"}, ) - - -def test_insert_into_partitions(): - assert_table_lineage_equal( - "INSERT INTO TABLE tab1 PARTITION (par1=1) SELECT * FROM tab2", - {"tab2"}, - {"tab1"}, - dialect="sparksql", - ) - - -def test_insert_overwrite(): - assert_table_lineage_equal( - "INSERT OVERWRITE tab1 SELECT * FROM tab2", - {"tab2"}, - {"tab1"}, - dialect="sparksql", - ) - - -def test_insert_overwrite_with_keyword_table(): - assert_table_lineage_equal( - "INSERT OVERWRITE TABLE tab1 SELECT col1 FROM tab2", - {"tab2"}, - {"tab1"}, - dialect="sparksql", - ) - - -def test_insert_overwrite_values(): - assert_table_lineage_equal( - "INSERT OVERWRITE tab1 VALUES ('val1', 'val2'), ('val3', 'val4')", - {}, - {"tab1"}, - dialect="sparksql", - ) - - -def test_insert_overwrite_from_self(): - assert_table_lineage_equal( - """INSERT OVERWRITE TABLE foo -SELECT col FROM foo -WHERE flag IS NOT NULL""", - {"foo"}, - {"foo"}, - dialect="sparksql", - ) - - -def test_insert_overwrite_from_self_with_join(): - assert_table_lineage_equal( - """INSERT OVERWRITE TABLE tab_1 -SELECT tab_2.col_a from tab_2 -JOIN tab_1 -ON tab_1.col_a = tab_2.cola""", - {"tab_1", "tab_2"}, - {"tab_1"}, - dialect="sparksql", - ) - - -def test_create_view(): - assert_table_lineage_equal( - """CREATE VIEW view1 -as -SELECT - col1, - col2 -FROM tab1 -GROUP BY -col1""", - {"tab1"}, - {"view1"}, - "tsql", - ) diff --git a/tests/test_insert_dialect_specific.py b/tests/test_insert_dialect_specific.py new file mode 100644 index 00000000..439f3acd --- /dev/null +++ b/tests/test_insert_dialect_specific.py @@ -0,0 +1,80 @@ +import pytest + +from .helpers import assert_table_lineage_equal + + +""" +This test class will contain all the tests for testing 'Insert Queries' where the dialect is not ANSI. +""" + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_insert_overwrite(dialect: str): + assert_table_lineage_equal( + "INSERT OVERWRITE TABLE tab1 SELECT col1 FROM tab2", + {"tab2"}, + {"tab1"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_insert_overwrite_from_self(dialect: str): + assert_table_lineage_equal( + """INSERT OVERWRITE TABLE foo +SELECT col FROM foo +WHERE flag IS NOT NULL""", + {"foo"}, + {"foo"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_insert_overwrite_from_self_with_join(dialect: str): + assert_table_lineage_equal( + """INSERT OVERWRITE TABLE tab_1 +SELECT tab_2.col_a from tab_2 +JOIN tab_1 +ON tab_1.col_a = tab_2.cola""", + {"tab_1", "tab_2"}, + {"tab_1"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_insert_overwrite_values(dialect: str): + assert_table_lineage_equal( + "INSERT OVERWRITE TABLE tab1 VALUES ('val1', 'val2'), ('val3', 'val4')", + {}, + {"tab1"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_insert_into_with_keyword_table(dialect: str): + assert_table_lineage_equal( + "INSERT INTO TABLE tab1 VALUES (1, 2)", set(), {"tab1"}, dialect=dialect + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_insert_into_partitions(dialect: str): + assert_table_lineage_equal( + "INSERT INTO TABLE tab1 PARTITION (par1=1) SELECT * FROM tab2", + {"tab2"}, + {"tab1"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_insert_overwrite_without_table_keyword(dialect: str): + assert_table_lineage_equal( + "INSERT OVERWRITE tab1 SELECT * FROM tab2", + {"tab2"}, + {"tab1"}, + dialect=dialect, + ) diff --git a/tests/test_models.py b/tests/test_models.py index cfbcddef..4ecc50ec 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ def test_repr_dummy(): assert repr(Schema()) assert repr(Table("")) assert repr(Table("a.b.c")) - assert repr(SubQuery(Parenthesis(), "")) + assert repr(SubQuery(Parenthesis(), Parenthesis().value, "")) assert repr(Column("a.b")) assert repr(Path("")) with pytest.raises(SQLLineageException): @@ -23,3 +23,12 @@ def test_hash_eq(): assert len({Schema("a"), Schema("a")}) == 1 assert Table("a") == Table("a") assert len({Table("a"), Table("a")}) == 1 + + +def test_of_dummy(): + with pytest.raises(NotImplementedError): + Column.of("") + with pytest.raises(NotImplementedError): + Table.of("") + with pytest.raises(NotImplementedError): + SubQuery.of("", None) diff --git a/tests/test_others.py b/tests/test_others.py index 750bbb44..e74696fd 100644 --- a/tests/test_others.py +++ b/tests/test_others.py @@ -1,6 +1,5 @@ -from sqllineage.core.models import Path from sqllineage.runner import LineageRunner -from sqllineage.sqlfluff_core.models import SqlFluffPath +from sqllineage.utils.helpers import split from .helpers import assert_table_lineage_equal @@ -10,16 +9,19 @@ def test_use(): def test_table_name_case(): assert_table_lineage_equal( - """insert overwrite table tab_a + """insert into tab_a select * from tab_b union all select * from TAB_B""", {"tab_b"}, {"tab_a"}, - "sparksql", ) +def test_parenthesis(): + assert_table_lineage_equal("(SELECT * FROM tab1)", {"tab1"}, None) + + def test_create(): assert_table_lineage_equal("CREATE TABLE tab1 (col1 STRING)", None, {"tab1"}) @@ -30,15 +32,6 @@ def test_create_if_not_exist(): ) -def test_create_bucket_table(): - assert_table_lineage_equal( - "CREATE TABLE tab1 USING parquet CLUSTERED BY (col1) INTO 500 BUCKETS", - None, - {"tab1"}, - "bigquery", - ) - - def test_create_as(): assert_table_lineage_equal( "CREATE TABLE tab1 AS SELECT * FROM tab2", {"tab2"}, {"tab1"} @@ -47,9 +40,7 @@ def test_create_as(): def test_create_as_with_parenthesis_around_select_statement(): sql = "CREATE TABLE tab1 AS (SELECT * FROM tab2)" - assert_table_lineage_equal(sql, {"tab2"}, {"tab1"}, test_sqlfluff=False) - # Graph generated differ but it is correct - assert_table_lineage_equal(sql, {"tab2"}, {"tab1"}, test_sqlparse=False) + assert_table_lineage_equal(sql, {"tab2"}, {"tab1"}) def test_create_as_with_parenthesis_around_table_name(): @@ -60,18 +51,25 @@ def test_create_as_with_parenthesis_around_table_name(): def test_create_as_with_parenthesis_around_both(): sql = "CREATE TABLE tab1 AS (SELECT * FROM (tab2))" - assert_table_lineage_equal(sql, {"tab2"}, {"tab1"}, test_sqlfluff=False) - # Graph generated differ but it is correct - assert_table_lineage_equal(sql, {"tab2"}, {"tab1"}, test_sqlparse=False) + assert_table_lineage_equal(sql, {"tab2"}, {"tab1"}) def test_create_like(): assert_table_lineage_equal("CREATE TABLE tab1 LIKE tab2", {"tab2"}, {"tab1"}) -def test_create_select(): +def test_create_view(): assert_table_lineage_equal( - "CREATE TABLE tab1 SELECT * FROM tab2", {"tab2"}, {"tab1"}, "sparksql" + """CREATE VIEW view1 +as +SELECT + col1, + col2 +FROM tab1 +GROUP BY +col1""", + {"tab1"}, + {"view1"}, ) @@ -83,32 +81,6 @@ def test_create_after_drop(): ) -# deactivated for sqlfluff since it can not be parsed properly -def test_create_using_serde(): - # Check https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-RowFormats&SerDe - # here with is not an indicator for CTE - assert_table_lineage_equal( - """CREATE TABLE apachelog ( - host STRING, - identity STRING, - user STRING, - time STRING, - request STRING, - status STRING, - size STRING, - referer STRING, - agent STRING) -ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.RegexSerDe' -WITH SERDEPROPERTIES ( - "input.regex" = "([^]*) ([^]*) ([^]*) (-|\\[^\\]*\\]) ([^ \"]*|\"[^\"]*\") (-|[0-9]*) (-|[0-9]*)(?: ([^ \"]*|\".*\") ([^ \"]*|\".*\"))?" -) -STORED AS TEXTFILE""", # noqa - None, - {"apachelog"}, - test_sqlfluff=False, - ) - - def test_bucket_with_using_parenthesis(): assert_table_lineage_equal( """CREATE TABLE tbl1 (col1 VARCHAR) @@ -124,35 +96,6 @@ def test_update(): ) -def test_update_with_join(): - assert_table_lineage_equal( - "UPDATE tab1 a INNER JOIN tab2 b ON a.col1=b.col1 SET a.col2=b.col2", - {"tab2"}, - {"tab1"}, - "mysql", - ) - - -# the previous query "COPY tab1 FROM tab2" was wrong -# Reference: -# https://www.postgresql.org/docs/current/sql-copy.html (Postgres) -# https://docs.aws.amazon.com/es_es/redshift/latest/dg/r_COPY.html (Redshift) -def test_copy_from_table(): - assert_table_lineage_equal( - "COPY tab1 FROM 's3://mybucket/mypath'", - {Path("s3://mybucket/mypath")}, - {"tab1"}, - test_sqlfluff=False, - ) - assert_table_lineage_equal( - "COPY tab1 FROM 's3://mybucket/mypath'", - {SqlFluffPath("s3://mybucket/mypath")}, - {"tab1"}, - "redshift", - test_sqlparse=False, - ) - - def test_drop(): assert_table_lineage_equal("DROP TABLE IF EXISTS tab1", None, None) @@ -176,9 +119,9 @@ def test_drop_after_create(): def test_drop_tmp_tab_after_create(): sql = """create table tab_a as select * from tab_b; -insert overwrite table tab_c select * from tab_a; +insert into tab_c select * from tab_a; drop table tab_a;""" - assert_table_lineage_equal(sql, {"tab_b"}, {"tab_c"}, "sparksql") + assert_table_lineage_equal(sql, {"tab_b"}, {"tab_c"}) def test_new_create_tab_as_tmp_table(): @@ -191,32 +134,6 @@ def test_alter_table_rename(): assert_table_lineage_equal("alter table tab1 rename to tab2;", None, None) -def test_rename_table(): - """ - This syntax is MySQL specific: - https://dev.mysql.com/doc/refman/8.0/en/rename-table.html - """ - assert_table_lineage_equal("rename table tab1 to tab2", None, None, "teradata") - - -def test_rename_tables(): - assert_table_lineage_equal( - "rename table tab1 to tab2, tab3 to tab4", None, None, "mysql" - ) - - -def test_alter_table_exchange_partition(): - """ - See https://cwiki.apache.org/confluence/display/Hive/Exchange+Partition for language manual - """ - assert_table_lineage_equal( - "alter table tab1 exchange partition(pt='part1') with table tab2", - {"tab2"}, - {"tab1"}, - "hive", - ) - - def test_swapping_partitions(): """ See https://www.vertica.com/docs/10.0.x/HTML/Content/Authoring/AdministratorsGuide/Partitions/SwappingPartitions.htm @@ -231,31 +148,12 @@ def test_swapping_partitions(): def test_alter_target_table_name(): assert_table_lineage_equal( - "insert overwrite tab1 select * from tab2; alter table tab1 rename to tab3;", + "insert into tab1 select * from tab2; alter table tab1 rename to tab3;", {"tab2"}, {"tab3"}, - "sparksql", - ) - - -def test_refresh_table(): - assert_table_lineage_equal("refresh table tab1", None, None, "sparksql") - - -def test_cache_table(): - assert_table_lineage_equal( - "cache table tab1 select * from tab2", None, None, "sparksql" ) -def test_uncache_table(): - assert_table_lineage_equal("uncache table tab1", None, None, "sparksql") - - -def test_uncache_table_if_exists(): - assert_table_lineage_equal("uncache table if exists tab1", None, None, "sparksql") - - def test_truncate_table(): assert_table_lineage_equal("truncate table tab1", None, None) @@ -264,58 +162,38 @@ def test_delete_from_table(): assert_table_lineage_equal("delete from table tab1", None, None) -def test_lateral_view_using_json_tuple(): - sql = """INSERT OVERWRITE TABLE foo -SELECT sc.id, q.item0, q.item1 -FROM bar sc -LATERAL VIEW json_tuple(sc.json, 'key1', 'key2') q AS item0, item1""" - assert_table_lineage_equal(sql, {"bar"}, {"foo"}, "sparksql") - - -def test_lateral_view_outer(): - sql = """INSERT OVERWRITE TABLE foo -SELECT sc.id, q.col1 -FROM bar sc -LATERAL VIEW OUTER explode(sc.json_array) q AS col1""" - assert_table_lineage_equal(sql, {"bar"}, {"foo"}, "sparksql") - - -def test_show_create_table(): - assert_table_lineage_equal("show create table tab1", None, None, "sparksql") +def test_statements_trim_comment(): + comment = "------------------\n" + sql = "select * from dual;" + assert LineageRunner(comment + sql).statements()[0] == sql def test_split_statements(): sql = "SELECT * FROM tab1; SELECT * FROM tab2;" - assert len(LineageRunner(sql).statements()) == 2 + assert len(split(sql)) == 2 def test_split_statements_with_heading_and_ending_new_line(): sql = "\nSELECT * FROM tab1;\nSELECT * FROM tab2;\n" - assert len(LineageRunner(sql).statements()) == 2 + assert len(split(sql)) == 2 def test_split_statements_with_comment(): sql = """SELECT 1; -- SELECT 2;""" - assert len(LineageRunner(sql).statements()) == 1 - - -def test_statements_trim_comment(): - comment = "------------------\n" - sql = "select * from dual;" - assert LineageRunner(comment + sql).statements(strip_comments=True)[0] == sql + assert len(split(sql)) == 1 def test_split_statements_with_show_create_table(): sql = """SELECT 1; SHOW CREATE TABLE tab1;""" - assert len(LineageRunner(sql).statements()) == 2 + assert len(split(sql)) == 2 def test_split_statements_with_desc(): sql = """SELECT 1; DESC tab1;""" - assert len(LineageRunner(sql).statements()) == 2 + assert len(split(sql)) == 2 diff --git a/tests/test_others_dialect_specific.py b/tests/test_others_dialect_specific.py new file mode 100644 index 00000000..3ca5c9b7 --- /dev/null +++ b/tests/test_others_dialect_specific.py @@ -0,0 +1,139 @@ +import pytest + +from .helpers import assert_table_lineage_equal + +""" +This test class will contain all the tests for testing 'Other Queries' where the dialect is not ANSI. +""" + + +@pytest.mark.parametrize("dialect", ["bigquery", "snowflake"]) +def test_create_bucket_table(dialect: str): + assert_table_lineage_equal( + "CREATE TABLE tab1 USING parquet CLUSTERED BY (col1) INTO 500 BUCKETS", + None, + {"tab1"}, + dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_create_select_without_as(dialect: str): + assert_table_lineage_equal( + "CREATE TABLE tab1 SELECT * FROM tab2", {"tab2"}, {"tab1"}, dialect + ) + + +def test_create_using_serde(): + """ + https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-RowFormats&SerDe + here with is not an indicator for CTE + FIXME: sqlfluff hive dialect doesn't support parsing this yet + """ + # Check + # + assert_table_lineage_equal( + """CREATE TABLE apachelog ( + host STRING, + identity STRING, + user STRING, + time STRING, + request STRING, + status STRING, + size STRING, + referer STRING, + agent STRING) +ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.RegexSerDe' +WITH SERDEPROPERTIES ( + "input.regex" = "([^]*) ([^]*) ([^]*) (-|\\[^\\]*\\]) ([^ \"]*|\"[^\"]*\") (-|[0-9]*) (-|[0-9]*)(?: ([^ \"]*|\".*\") ([^ \"]*|\".*\"))?" +) +STORED AS TEXTFILE""", # noqa + None, + {"apachelog"}, + test_sqlfluff=False, + ) + + +@pytest.mark.parametrize("dialect", ["mysql"]) +def test_update_with_join(dialect: str): + assert_table_lineage_equal( + "UPDATE tab1 a INNER JOIN tab2 b ON a.col1=b.col1 SET a.col2=b.col2", + {"tab2"}, + {"tab1"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["exasol", "mysql", "teradata"]) +def test_rename_table(dialect: str): + """ + https://docs.exasol.com/db/latest/sql/rename.htm + https://dev.mysql.com/doc/refman/8.0/en/rename-table.html + https://docs.teradata.com/r/Teradata-Database-SQL-Data-Definition-Language-Syntax-and-Examples/December-2015/Table-Statements/RENAME-TABLE + """ + assert_table_lineage_equal("rename table tab1 to tab2", None, None, dialect) + + +@pytest.mark.parametrize("dialect", ["exasol", "mysql", "teradata"]) +def test_rename_tables(dialect: str): + assert_table_lineage_equal( + "rename table tab1 to tab2, tab3 to tab4", None, None, dialect + ) + + +@pytest.mark.parametrize("dialect", ["hive"]) +def test_alter_table_exchange_partition(dialect: str): + """ + See https://cwiki.apache.org/confluence/display/Hive/Exchange+Partition for language manual + """ + assert_table_lineage_equal( + "alter table tab1 exchange partition(pt='part1') with table tab2", + {"tab2"}, + {"tab1"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_refresh_table(dialect: str): + assert_table_lineage_equal("refresh table tab1", None, None, dialect) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_cache_table(dialect: str): + assert_table_lineage_equal( + "cache table tab1 select * from tab2", None, None, dialect + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_uncache_table(dialect: str): + assert_table_lineage_equal("uncache table tab1", None, None, dialect) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_uncache_table_if_exists(dialect: str): + assert_table_lineage_equal("uncache table if exists tab1", None, None, dialect) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_lateral_view_using_json_tuple(dialect: str): + sql = """INSERT OVERWRITE TABLE foo +SELECT sc.id, q.item0, q.item1 +FROM bar sc +LATERAL VIEW json_tuple(sc.json, 'key1', 'key2') q AS item0, item1""" + assert_table_lineage_equal(sql, {"bar"}, {"foo"}, dialect) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_lateral_view_outer(dialect: str): + sql = """INSERT OVERWRITE TABLE foo +SELECT sc.id, q.col1 +FROM bar sc +LATERAL VIEW OUTER explode(sc.json_array) q AS col1""" + assert_table_lineage_equal(sql, {"bar"}, {"foo"}, dialect) + + +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_show_create_table(dialect: str): + assert_table_lineage_equal("show create table tab1", None, None, dialect) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 00000000..9c1c7702 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock + +import pytest + +from sqllineage.core.holders import SubQueryLineageHolder +from sqllineage.core.parser.sqlfluff.handlers.base import ( + ConditionalSegmentBaseHandler, + SegmentBaseHandler, +) +from sqllineage.core.parser.sqlfluff.models import SqlFluffColumn + + +def test_column_extract_source_columns(): + segment_mock = Mock() + segment_mock.type = "" + assert [] == SqlFluffColumn._extract_source_columns(segment_mock) + + +def test_handler_dummy(): + segment_mock = Mock() + holder = SubQueryLineageHolder() + c_handler = ConditionalSegmentBaseHandler() + with pytest.raises(NotImplementedError): + c_handler.handle(segment_mock, holder) + with pytest.raises(NotImplementedError): + c_handler.indicate(segment_mock) + s_handler = SegmentBaseHandler() + with pytest.raises(NotImplementedError): + s_handler.handle(segment_mock, holder) diff --git a/tests/test_path.py b/tests/test_path_dialect_specific.py similarity index 53% rename from tests/test_path.py rename to tests/test_path_dialect_specific.py index d8437c91..3e4da232 100644 --- a/tests/test_path.py +++ b/tests/test_path_dialect_specific.py @@ -1,54 +1,42 @@ import pytest from sqllineage.core.models import Path -from sqllineage.sqlfluff_core.models import SqlFluffPath from .helpers import assert_table_lineage_equal -def test_copy_from_path(): +@pytest.mark.parametrize("dialect", ["postgres", "redshift"]) +def test_copy_from_path(dialect: str): """ - check following link for syntax specs: - Redshift: https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html + https://www.postgresql.org/docs/current/sql-copy.html (Postgres) + https://docs.aws.amazon.com/es_es/redshift/latest/dg/r_COPY.html (Redshift) """ assert_table_lineage_equal( "COPY tab1 FROM 's3://mybucket/mypath'", {Path("s3://mybucket/mypath")}, {"tab1"}, - test_sqlfluff=False, - ) - assert_table_lineage_equal( - "COPY tab1 FROM 's3://mybucket/mypath'", - {SqlFluffPath("s3://mybucket/mypath")}, - {"tab1"}, - "redshift", - test_sqlparse=False, + dialect=dialect, ) -def test_copy_into_path(): +@pytest.mark.parametrize("dialect", ["snowflake"]) +def test_copy_into_path(dialect: str): """ check following link for syntax reference: Snowflake: https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html - Microsoft T-SQL: https://docs.microsoft.com/en-us/sql/t-sql/statements/copy-into-transact-sql?view=azure-sqldw-latest # noqa + Microsoft T-SQL: https://docs.microsoft.com/en-us/sql/t-sql/statements/copy-into-transact-sql?view=azure-sqldw-latest # noqa + FIXME: sqlfluff tsql dialect doesn't support parsing this yet """ assert_table_lineage_equal( "COPY INTO tab1 FROM 's3://mybucket/mypath'", {Path("s3://mybucket/mypath")}, {"tab1"}, - test_sqlfluff=False, - ) - assert_table_lineage_equal( - "COPY INTO tab1 FROM 's3://mybucket/mypath'", - {SqlFluffPath("s3://mybucket/mypath")}, - {"tab1"}, - "snowflake", - test_sqlparse=False, + dialect=dialect, ) -# deactivated for sqlfluff since it can not be parsed properly @pytest.mark.parametrize("data_source", ["parquet", "json", "csv"]) -def test_select_from_files(data_source): +@pytest.mark.parametrize("dialect", ["databricks", "sparksql"]) +def test_select_from_files(data_source: str, dialect: str): """ check following link for syntax reference: https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html#run-sql-on-files-directly @@ -56,11 +44,12 @@ def test_select_from_files(data_source): assert_table_lineage_equal( f"SELECT * FROM {data_source}.`examples/src/main/resources/`", {Path("examples/src/main/resources/")}, - test_sqlfluff=False, + dialect=dialect, ) -def test_insert_overwrite_directory(): +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_insert_overwrite_directory(dialect: str): """ check following link for syntax reference: https://spark.apache.org/docs/latest/sql-ref-syntax-dml-insert-overwrite-directory.html @@ -70,13 +59,5 @@ def test_insert_overwrite_directory(): SELECT * FROM tab1""", {"tab1"}, {Path("hdfs://path/to/folder")}, - test_sqlfluff=False, - ) - assert_table_lineage_equal( - """INSERT OVERWRITE DIRECTORY 'hdfs://path/to/folder' -SELECT * FROM tab1""", - {"tab1"}, - {SqlFluffPath("hdfs://path/to/folder")}, - "sparksql", - test_sqlparse=False, + dialect=dialect, ) diff --git a/tests/test_runner.py b/tests/test_runner.py index 392ba827..97a8c215 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -5,9 +5,8 @@ def test_runner_dummy(): runner = LineageRunner( """insert into tab2 select col1, col2, col3, col4, col5, col6 from tab1; -insert overwrite table tab3 select * from tab2""", +insert into tab3 select * from tab2""", verbose=True, - dialect="sparksql", ) assert str(runner) assert runner.to_cytoscape() is not None diff --git a/tests/test_select.py b/tests/test_select.py index e95116c7..398248ad 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -15,16 +15,6 @@ def test_select_with_schema_and_database(): ) -def test_select_with_table_name_in_backtick(): - assert_table_lineage_equal("SELECT * FROM `tab1`", {"tab1"}, dialect="bigquery") - - -def test_select_with_schema_in_backtick(): - assert_table_lineage_equal( - "SELECT col1 FROM `schema1`.`tab1`", {"schema1.tab1"}, dialect="bigquery" - ) - - def test_select_multi_line(): assert_table_lineage_equal( """SELECT col1 FROM @@ -74,10 +64,10 @@ def test_select_with_comment_after_join(): def test_select_keyword_as_column_alias(): - # here `as` is the column alias - assert_table_lineage_equal("SELECT 1 `as` FROM tab1", {"tab1"}, dialect="mysql") + # here "as" is the column alias + assert_table_lineage_equal('SELECT 1 "as" FROM tab1', {"tab1"}) # the following is hive specific, MySQL doesn't allow this syntax. As of now, we don't test against it - # helper("SELECT 1 as FROM tab1", {"tab1"}) + # assert_table_lineage_equal("SELECT 1 as FROM tab1", {"tab1"}) def test_select_with_table_alias(): @@ -149,22 +139,6 @@ def test_select_left_join_with_extra_space_in_middle(): assert_table_lineage_equal("SELECT * FROM tab1 LEFT JOIN tab2", {"tab1", "tab2"}) -# deactivated for sqlfluff since it can not be parsed properly -def test_select_left_semi_join(): - assert_table_lineage_equal( - "SELECT * FROM tab1 LEFT SEMI JOIN tab2", {"tab1", "tab2"}, test_sqlfluff=False - ) - - -# deactivated for sqlfluff since it can not be parsed properly -def test_select_left_semi_join_with_on(): - assert_table_lineage_equal( - "SELECT * FROM tab1 LEFT SEMI JOIN tab2 ON (tab1.col1 = tab2.col2)", - {"tab1", "tab2"}, - test_sqlfluff=False, - ) - - def test_select_right_join(): assert_table_lineage_equal("SELECT * FROM tab1 RIGHT JOIN tab2", {"tab1", "tab2"}) @@ -175,15 +149,6 @@ def test_select_full_outer_join(): ) -# deactivated for sqlfluff since it can not be parsed properly -def test_select_full_outer_join_with_full_as_alias(): - assert_table_lineage_equal( - "SELECT * FROM tab1 AS FULL FULL OUTER JOIN tab2", - {"tab1", "tab2"}, - test_sqlfluff=False, - ) - - def test_select_cross_join(): assert_table_lineage_equal("SELECT * FROM tab1 CROSS JOIN tab2", {"tab1", "tab2"}) @@ -250,8 +215,11 @@ def test_select_from_unnest_parsed_as_keyword(): ) -# deactivated for sqlfluff since it can not be parsed properly def test_select_from_unnest_with_ordinality(): + """ + https://prestodb.io/docs/current/sql/select.html#unnest + FIXME: sqlfluff athena dialect doesn't support parsing this yet + """ sql = """ SELECT numbers, n, a FROM ( @@ -264,9 +232,17 @@ def test_select_from_unnest_with_ordinality(): assert_table_lineage_equal(sql, test_sqlfluff=False) -def test_select_from_generator(): - # generator is Snowflake specific - sql = """SELECT seq4(), uniform(1, 10, random(12)) -FROM table(generator()) v -ORDER BY 1;""" - assert_table_lineage_equal(sql, dialect="snowflake") +def test_select_union_all(): + sql = """SELECT col1 +FROM tab1 +UNION ALL +SELECT col1 +FROM tab2 +UNION ALL +SELECT col1 +FROM tab3 +ORDER BY col1""" + assert_table_lineage_equal( + sql, + {"tab1", "tab2", "tab3"}, + ) diff --git a/tests/test_select_dialect_specific.py b/tests/test_select_dialect_specific.py new file mode 100644 index 00000000..db5ac11a --- /dev/null +++ b/tests/test_select_dialect_specific.py @@ -0,0 +1,49 @@ +import pytest + +from .helpers import assert_table_lineage_equal + + +""" +This test class will contain all the tests for testing 'Select Queries' where the dialect is not ANSI. +""" + + +@pytest.mark.parametrize( + "dialect", ["athena", "bigquery", "databricks", "hive", "mysql", "sparksql"] +) +def test_select_with_table_name_in_backtick(dialect: str): + assert_table_lineage_equal("SELECT * FROM `tab1`", {"tab1"}, dialect=dialect) + + +@pytest.mark.parametrize( + "dialect", ["athena", "bigquery", "databricks", "hive", "mysql", "sparksql"] +) +def test_select_with_schema_in_backtick(dialect: str): + assert_table_lineage_equal( + "SELECT col1 FROM `schema1`.`tab1`", {"schema1.tab1"}, dialect=dialect + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_select_left_semi_join(dialect: str): + assert_table_lineage_equal( + "SELECT * FROM tab1 LEFT SEMI JOIN tab2", {"tab1", "tab2"}, dialect=dialect + ) + + +@pytest.mark.parametrize("dialect", ["databricks", "hive", "sparksql"]) +def test_select_left_semi_join_with_on(dialect: str): + assert_table_lineage_equal( + "SELECT * FROM tab1 LEFT SEMI JOIN tab2 ON (tab1.col1 = tab2.col2)", + {"tab1", "tab2"}, + dialect=dialect, + ) + + +@pytest.mark.parametrize("dialect", ["snowflake"]) +def test_select_from_generator(dialect: str): + # generator is Snowflake specific + sql = """SELECT seq4(), uniform(1, 10, random(12)) +FROM table(generator()) v +ORDER BY 1;""" + assert_table_lineage_equal(sql, dialect=dialect) diff --git a/tox.ini b/tox.ini index 98fdb6fa..45078503 100644 --- a/tox.ini +++ b/tox.ini @@ -12,10 +12,10 @@ commands = pytest --cov [flake8] -exclude = .tox,.git,__pycache__,build,sqllineagejs,venv +exclude = .tox,.git,__pycache__,build,sqllineagejs,venv,env max-line-length = 120 # ignore = D100,D101 show-source = true enable-extensions=G application-import-names = sqllineage -import-order-style = pycharm +import-order-style = pycharm \ No newline at end of file