From c6e861bbe184d7af9df80a4842e31404a9206341 Mon Sep 17 00:00:00 2001 From: Mayur Singal <39544459+ulixius9@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:14:40 +0530 Subject: [PATCH] Fix #13954: Fix ParseException for older version of databricks (#14015) --- .../source/database/databricks/connection.py | 19 +++- .../database/databricks/legacy/metadata.py | 88 ++++++++++++------- .../databricks/unity_catalog/metadata.py | 22 +++-- .../unit/topology/database/test_databricks.py | 26 +++++- 4 files changed, 115 insertions(+), 40 deletions(-) diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/connection.py b/ingestion/src/metadata/ingestion/source/database/databricks/connection.py index 1953e1876827..02f90c19caa0 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/connection.py @@ -17,6 +17,7 @@ from databricks.sdk import WorkspaceClient from sqlalchemy.engine import Engine +from sqlalchemy.exc import DatabaseError from sqlalchemy.inspection import inspect from metadata.generated.schema.entity.automations.workflow import ( @@ -33,7 +34,6 @@ from metadata.ingestion.connections.test_connections import ( test_connection_engine_step, test_connection_steps, - test_query, ) from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.database.databricks.client import DatabricksClient @@ -42,6 +42,9 @@ DATABRICKS_GET_CATALOGS, ) from metadata.utils.db_utils import get_host_from_host_port +from metadata.utils.logger import ingestion_logger + +logger = ingestion_logger() def get_connection_url(connection: DatabricksConnection) -> str: @@ -84,6 +87,18 @@ def test_connection( """ client = DatabricksClient(service_connection) + def test_database_query(engine: Engine, statement: str): + """ + Method used to execute the given query and fetch a result + to test if user has access to the tables specified + in the sql statement + """ + try: + connection = engine.connect() + connection.execute(statement).fetchone() + except DatabaseError as soe: + logger.debug(f"Failed to fetch catalogs due to: {soe}") + if service_connection.useUnityCatalog: table_obj = DatabricksTable() @@ -121,7 +136,7 @@ def get_tables(connection: WorkspaceClient, table_obj: DatabricksTable): "GetTables": inspector.get_table_names, "GetViews": inspector.get_view_names, "GetDatabases": partial( - test_query, + test_database_query, engine=connection, statement=DATABRICKS_GET_CATALOGS, ), diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/legacy/metadata.py b/ingestion/src/metadata/ingestion/source/database/databricks/legacy/metadata.py index 682de21ca40a..6d50bc9a5b71 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/legacy/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/legacy/metadata.py @@ -13,11 +13,12 @@ import re import traceback from copy import deepcopy -from typing import Iterable +from typing import Iterable, Optional from pyhive.sqlalchemy_hive import _type_map from sqlalchemy import types, util from sqlalchemy.engine import reflection +from sqlalchemy.exc import DatabaseError from sqlalchemy.inspection import inspect from sqlalchemy.sql.sqltypes import String from sqlalchemy_databricks._dialect import DatabricksDialect @@ -35,10 +36,13 @@ from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type from metadata.ingestion.source.database.common_db_source import CommonDbSourceService from metadata.ingestion.source.database.databricks.queries import ( + DATABRICKS_GET_CATALOGS, DATABRICKS_GET_TABLE_COMMENTS, DATABRICKS_VIEW_DEFINITIONS, ) +from metadata.ingestion.source.database.multi_db_source import MultiDBSource from metadata.utils import fqn +from metadata.utils.constants import DEFAULT_DATABASE from metadata.utils.filters import filter_by_database from metadata.utils.logger import ingestion_logger from metadata.utils.sqlalchemy_utils import ( @@ -158,7 +162,7 @@ def get_columns(self, connection, table_name, schema=None, **kw): @reflection.cache def get_schema_names(self, connection, **kw): # pylint: disable=unused-argument # Equivalent to SHOW DATABASES - if kw.get("database"): + if kw.get("database") and kw.get("is_old_version") is not True: connection.execute(f"USE CATALOG '{kw.get('database')}'") return [row[0] for row in connection.execute("SHOW SCHEMAS")] @@ -238,13 +242,26 @@ def get_view_definition( reflection.Inspector.get_schema_names = get_schema_names_reflection -class DatabricksLegacySource(CommonDbSourceService): +class DatabricksLegacySource(CommonDbSourceService, MultiDBSource): """ Implements the necessary methods to extract Database metadata from Databricks Source using the legacy hive metastore method """ + def __init__(self, config: WorkflowSource, metadata: OpenMetadata): + super().__init__(config, metadata) + self.is_older_version = False + self._init_version() + + def _init_version(self): + try: + self.connection.execute(DATABRICKS_GET_CATALOGS).fetchone() + self.is_older_version = False + except DatabaseError as soe: + logger.debug(f"Failed to fetch catalogs due to: {soe}") + self.is_older_version = True + @classmethod def create(cls, config_dict, metadata: OpenMetadata): config: WorkflowSource = WorkflowSource.parse_obj(config_dict) @@ -268,44 +285,55 @@ def set_inspector(self, database_name: str) -> None: self.engine = get_connection(new_service_connection) self.inspector = inspect(self.engine) + def get_configured_database(self) -> Optional[str]: + return self.service_connection.catalog + + def get_database_names_raw(self) -> Iterable[str]: + if not self.is_older_version: + results = self.connection.execute(DATABRICKS_GET_CATALOGS) + for res in results: + if res: + row = list(res) + yield row[0] + else: + yield DEFAULT_DATABASE + def get_database_names(self) -> Iterable[str]: - configured_catalog = self.service_connection.__dict__.get("catalog") + configured_catalog = self.service_connection.catalog if configured_catalog: self.set_inspector(database_name=configured_catalog) yield configured_catalog else: - results = self.connection.execute("SHOW CATALOGS") - for res in results: - if res: - new_catalog = res[0] - database_fqn = fqn.build( - self.metadata, - entity_type=Database, - service_name=self.context.database_service.name.__root__, - database_name=new_catalog, + for new_catalog in self.get_database_names_raw(): + database_fqn = fqn.build( + self.metadata, + entity_type=Database, + service_name=self.context.database_service.name.__root__, + database_name=new_catalog, + ) + if filter_by_database( + self.source_config.databaseFilterPattern, + database_fqn + if self.source_config.useFqnForFiltering + else new_catalog, + ): + self.status.filter(database_fqn, "Database Filtered Out") + continue + try: + self.set_inspector(database_name=new_catalog) + yield new_catalog + except Exception as exc: + logger.error(traceback.format_exc()) + logger.warning( + f"Error trying to process database {new_catalog}: {exc}" ) - if filter_by_database( - self.source_config.databaseFilterPattern, - database_fqn - if self.source_config.useFqnForFiltering - else new_catalog, - ): - self.status.filter(database_fqn, "Database Filtered Out") - continue - try: - self.set_inspector(database_name=new_catalog) - yield new_catalog - except Exception as exc: - logger.error(traceback.format_exc()) - logger.warning( - f"Error trying to process database {new_catalog}: {exc}" - ) def get_raw_database_schema_names(self) -> Iterable[str]: if self.service_connection.__dict__.get("databaseSchema"): yield self.service_connection.databaseSchema else: for schema_name in self.inspector.get_schema_names( - database=self.context.database.name.__root__ + database=self.context.database.name.__root__, + is_old_version=self.is_older_version, ): yield schema_name diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/unity_catalog/metadata.py b/ingestion/src/metadata/ingestion/source/database/databricks/unity_catalog/metadata.py index 898fe2ec9ba2..fa8d010a95be 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/unity_catalog/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/unity_catalog/metadata.py @@ -61,6 +61,7 @@ ForeignConstrains, Type, ) +from metadata.ingestion.source.database.multi_db_source import MultiDBSource from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure from metadata.ingestion.source.models import TableView from metadata.utils import fqn @@ -84,7 +85,7 @@ def from_dict(cls, dct: Dict[str, Any]) -> "TableConstraintList": TableConstraintList.from_dict = from_dict -class DatabricksUnityCatalogSource(DatabaseServiceSource): +class DatabricksUnityCatalogSource(DatabaseServiceSource, MultiDBSource): """ Implements the necessary methods to extract Database metadata from Databricks Source using @@ -107,6 +108,13 @@ def __init__(self, config: WorkflowSource, metadata: OpenMetadata): self.table_constraints = [] self.test_connection() + def get_configured_database(self) -> Optional[str]: + return self.service_connection.catalog + + def get_database_names_raw(self) -> Iterable[str]: + for catalog in self.client.catalogs.list(): + yield catalog.name + @classmethod def create(cls, config_dict, metadata: OpenMetadata): config: WorkflowSource = WorkflowSource.parse_obj(config_dict) @@ -131,31 +139,31 @@ def get_database_names(self) -> Iterable[str]: if self.service_connection.catalog: yield self.service_connection.catalog else: - for catalog in self.client.catalogs.list(): + for catalog_name in self.get_database_names_raw(): try: database_fqn = fqn.build( self.metadata, entity_type=Database, service_name=self.context.database_service.name.__root__, - database_name=catalog.name, + database_name=catalog_name, ) if filter_by_database( self.config.sourceConfig.config.databaseFilterPattern, database_fqn if self.config.sourceConfig.config.useFqnForFiltering - else catalog.name, + else catalog_name, ): self.status.filter( database_fqn, "Database (Catalog ID) Filtered Out", ) continue - yield catalog.name + yield catalog_name except Exception as exc: self.status.failed( StackTraceError( - name=catalog.name, - error=f"Unexpected exception to get database name [{catalog.name}]: {exc}", + name=catalog_name, + error=f"Unexpected exception to get database name [{catalog_name}]: {exc}", stack_trace=traceback.format_exc(), ) ) diff --git a/ingestion/tests/unit/topology/database/test_databricks.py b/ingestion/tests/unit/topology/database/test_databricks.py index 412937912889..fb5be497ac57 100644 --- a/ingestion/tests/unit/topology/database/test_databricks.py +++ b/ingestion/tests/unit/topology/database/test_databricks.py @@ -1,3 +1,18 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test databricks using the topology +""" + from unittest import TestCase from unittest.mock import patch @@ -20,6 +35,7 @@ from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.source.database.databricks.metadata import DatabricksSource +# pylint: disable=line-too-long mock_databricks_config = { "source": { "type": "databricks", @@ -230,12 +246,20 @@ class DatabricksUnitTest(TestCase): + """ + Databricks unit tests + """ + @patch( "metadata.ingestion.source.database.common_db_source.CommonDbSourceService.test_connection" ) - def __init__(self, methodName, test_connection) -> None: + @patch( + "metadata.ingestion.source.database.databricks.legacy.metadata.DatabricksLegacySource._init_version" + ) + def __init__(self, methodName, test_connection, db_init_version) -> None: super().__init__(methodName) test_connection.return_value = False + db_init_version.return_value = None self.config = OpenMetadataWorkflowConfig.parse_obj(mock_databricks_config) self.databricks_source = DatabricksSource.create(