Skip to content

Commit

Permalink
Fix #13954: Fix ParseException for older version of databricks (#14015)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulixius9 committed Nov 20, 2023
1 parent 9fe4a2f commit c6e861b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from databricks.sdk import WorkspaceClient
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DatabaseError
from sqlalchemy.inspection import inspect

from metadata.generated.schema.entity.automations.workflow import (
Expand All @@ -33,7 +34,6 @@
from metadata.ingestion.connections.test_connections import (
test_connection_engine_step,
test_connection_steps,
test_query,
)
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.databricks.client import DatabricksClient
Expand All @@ -42,6 +42,9 @@
DATABRICKS_GET_CATALOGS,
)
from metadata.utils.db_utils import get_host_from_host_port
from metadata.utils.logger import ingestion_logger

logger = ingestion_logger()


def get_connection_url(connection: DatabricksConnection) -> str:
Expand Down Expand Up @@ -84,6 +87,18 @@ def test_connection(
"""
client = DatabricksClient(service_connection)

def test_database_query(engine: Engine, statement: str):
"""
Method used to execute the given query and fetch a result
to test if user has access to the tables specified
in the sql statement
"""
try:
connection = engine.connect()
connection.execute(statement).fetchone()
except DatabaseError as soe:
logger.debug(f"Failed to fetch catalogs due to: {soe}")

if service_connection.useUnityCatalog:
table_obj = DatabricksTable()

Expand Down Expand Up @@ -121,7 +136,7 @@ def get_tables(connection: WorkspaceClient, table_obj: DatabricksTable):
"GetTables": inspector.get_table_names,
"GetViews": inspector.get_view_names,
"GetDatabases": partial(
test_query,
test_database_query,
engine=connection,
statement=DATABRICKS_GET_CATALOGS,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import re
import traceback
from copy import deepcopy
from typing import Iterable
from typing import Iterable, Optional

from pyhive.sqlalchemy_hive import _type_map
from sqlalchemy import types, util
from sqlalchemy.engine import reflection
from sqlalchemy.exc import DatabaseError
from sqlalchemy.inspection import inspect
from sqlalchemy.sql.sqltypes import String
from sqlalchemy_databricks._dialect import DatabricksDialect
Expand All @@ -35,10 +36,13 @@
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.databricks.queries import (
DATABRICKS_GET_CATALOGS,
DATABRICKS_GET_TABLE_COMMENTS,
DATABRICKS_VIEW_DEFINITIONS,
)
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.filters import filter_by_database
from metadata.utils.logger import ingestion_logger
from metadata.utils.sqlalchemy_utils import (
Expand Down Expand Up @@ -158,7 +162,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
@reflection.cache
def get_schema_names(self, connection, **kw): # pylint: disable=unused-argument
# Equivalent to SHOW DATABASES
if kw.get("database"):
if kw.get("database") and kw.get("is_old_version") is not True:
connection.execute(f"USE CATALOG '{kw.get('database')}'")
return [row[0] for row in connection.execute("SHOW SCHEMAS")]

Expand Down Expand Up @@ -238,13 +242,26 @@ def get_view_definition(
reflection.Inspector.get_schema_names = get_schema_names_reflection


class DatabricksLegacySource(CommonDbSourceService):
class DatabricksLegacySource(CommonDbSourceService, MultiDBSource):
"""
Implements the necessary methods to extract
Database metadata from Databricks Source using
the legacy hive metastore method
"""

def __init__(self, config: WorkflowSource, metadata: OpenMetadata):
super().__init__(config, metadata)
self.is_older_version = False
self._init_version()

def _init_version(self):
try:
self.connection.execute(DATABRICKS_GET_CATALOGS).fetchone()
self.is_older_version = False
except DatabaseError as soe:
logger.debug(f"Failed to fetch catalogs due to: {soe}")
self.is_older_version = True

@classmethod
def create(cls, config_dict, metadata: OpenMetadata):
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
Expand All @@ -268,44 +285,55 @@ def set_inspector(self, database_name: str) -> None:
self.engine = get_connection(new_service_connection)
self.inspector = inspect(self.engine)

def get_configured_database(self) -> Optional[str]:
return self.service_connection.catalog

def get_database_names_raw(self) -> Iterable[str]:
if not self.is_older_version:
results = self.connection.execute(DATABRICKS_GET_CATALOGS)
for res in results:
if res:
row = list(res)
yield row[0]
else:
yield DEFAULT_DATABASE

def get_database_names(self) -> Iterable[str]:
configured_catalog = self.service_connection.__dict__.get("catalog")
configured_catalog = self.service_connection.catalog
if configured_catalog:
self.set_inspector(database_name=configured_catalog)
yield configured_catalog
else:
results = self.connection.execute("SHOW CATALOGS")
for res in results:
if res:
new_catalog = res[0]
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service.name.__root__,
database_name=new_catalog,
for new_catalog in self.get_database_names_raw():
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service.name.__root__,
database_name=new_catalog,
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else new_catalog,
):
self.status.filter(database_fqn, "Database Filtered Out")
continue
try:
self.set_inspector(database_name=new_catalog)
yield new_catalog
except Exception as exc:
logger.error(traceback.format_exc())
logger.warning(
f"Error trying to process database {new_catalog}: {exc}"
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else new_catalog,
):
self.status.filter(database_fqn, "Database Filtered Out")
continue
try:
self.set_inspector(database_name=new_catalog)
yield new_catalog
except Exception as exc:
logger.error(traceback.format_exc())
logger.warning(
f"Error trying to process database {new_catalog}: {exc}"
)

def get_raw_database_schema_names(self) -> Iterable[str]:
if self.service_connection.__dict__.get("databaseSchema"):
yield self.service_connection.databaseSchema
else:
for schema_name in self.inspector.get_schema_names(
database=self.context.database.name.__root__
database=self.context.database.name.__root__,
is_old_version=self.is_older_version,
):
yield schema_name
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ForeignConstrains,
Type,
)
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.models import TableView
from metadata.utils import fqn
Expand All @@ -84,7 +85,7 @@ def from_dict(cls, dct: Dict[str, Any]) -> "TableConstraintList":
TableConstraintList.from_dict = from_dict


class DatabricksUnityCatalogSource(DatabaseServiceSource):
class DatabricksUnityCatalogSource(DatabaseServiceSource, MultiDBSource):
"""
Implements the necessary methods to extract
Database metadata from Databricks Source using
Expand All @@ -107,6 +108,13 @@ def __init__(self, config: WorkflowSource, metadata: OpenMetadata):
self.table_constraints = []
self.test_connection()

def get_configured_database(self) -> Optional[str]:
return self.service_connection.catalog

def get_database_names_raw(self) -> Iterable[str]:
for catalog in self.client.catalogs.list():
yield catalog.name

@classmethod
def create(cls, config_dict, metadata: OpenMetadata):
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
Expand All @@ -131,31 +139,31 @@ def get_database_names(self) -> Iterable[str]:
if self.service_connection.catalog:
yield self.service_connection.catalog
else:
for catalog in self.client.catalogs.list():
for catalog_name in self.get_database_names_raw():
try:
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service.name.__root__,
database_name=catalog.name,
database_name=catalog_name,
)
if filter_by_database(
self.config.sourceConfig.config.databaseFilterPattern,
database_fqn
if self.config.sourceConfig.config.useFqnForFiltering
else catalog.name,
else catalog_name,
):
self.status.filter(
database_fqn,
"Database (Catalog ID) Filtered Out",
)
continue
yield catalog.name
yield catalog_name
except Exception as exc:
self.status.failed(
StackTraceError(
name=catalog.name,
error=f"Unexpected exception to get database name [{catalog.name}]: {exc}",
name=catalog_name,
error=f"Unexpected exception to get database name [{catalog_name}]: {exc}",
stack_trace=traceback.format_exc(),
)
)
Expand Down
26 changes: 25 additions & 1 deletion ingestion/tests/unit/topology/database/test_databricks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test databricks using the topology
"""

from unittest import TestCase
from unittest.mock import patch

Expand All @@ -20,6 +35,7 @@
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.source.database.databricks.metadata import DatabricksSource

# pylint: disable=line-too-long
mock_databricks_config = {
"source": {
"type": "databricks",
Expand Down Expand Up @@ -230,12 +246,20 @@


class DatabricksUnitTest(TestCase):
"""
Databricks unit tests
"""

@patch(
"metadata.ingestion.source.database.common_db_source.CommonDbSourceService.test_connection"
)
def __init__(self, methodName, test_connection) -> None:
@patch(
"metadata.ingestion.source.database.databricks.legacy.metadata.DatabricksLegacySource._init_version"
)
def __init__(self, methodName, test_connection, db_init_version) -> None:
super().__init__(methodName)
test_connection.return_value = False
db_init_version.return_value = None

self.config = OpenMetadataWorkflowConfig.parse_obj(mock_databricks_config)
self.databricks_source = DatabricksSource.create(
Expand Down

0 comments on commit c6e861b

Please sign in to comment.