Skip to content

Commit

Permalink
Merge branch 'main' into issue-17869
Browse files Browse the repository at this point in the history
  • Loading branch information
harshsoni2024 committed Sep 20, 2024
2 parents 9299257 + 043b18e commit ae6e983
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/playwright-mysql-e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:

- name: Install dependencies
working-directory: openmetadata-ui/src/main/resources/ui/
run: yarn --frozen-lockfile
run: yarn --ignore-scripts --frozen-lockfile
- name: Install Playwright Browsers
run: npx playwright@1.44.1 install --with-deps
- name: Run Playwright tests
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/playwright-postgresql-e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:

- name: Install dependencies
working-directory: openmetadata-ui/src/main/resources/ui/
run: yarn --frozen-lockfile
run: yarn --ignore-scripts --frozen-lockfile
- name: Install Playwright Browsers
run: npx playwright@1.44.1 install --with-deps
- name: Run Playwright tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn


Expand Down Expand Up @@ -168,7 +169,7 @@ def get_data_diff_url(service_url: str, table_fqn) -> str:
table_fqn
)
# path needs to include the database AND schema in some of the connectors
if kwargs["scheme"] in ["mssql"]:
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake}:
kwargs["path"] = f"/{database}/{schema}"
return url._replace(**kwargs).geturl()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
"""
Generic source to build SQL connectors.
"""
import math
import time
import traceback
from abc import ABC
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Any, Iterable, List, Optional, Tuple, Union, cast

Expand Down Expand Up @@ -60,6 +63,7 @@
from metadata.ingestion.connections.session import create_and_bind_thread_safe_session
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.models.ometa_lineage import OMetaLineageRequest
from metadata.ingestion.models.topology import Queue
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import (
get_connection,
Expand All @@ -73,6 +77,7 @@
from metadata.utils import fqn
from metadata.utils.db_utils import get_view_lineage
from metadata.utils.execution_time_tracker import (
ExecutionTimeTrackerContextMap,
calculate_execution_time,
calculate_execution_time_generator,
)
Expand Down Expand Up @@ -583,9 +588,78 @@ def yield_table(
)
)

@calculate_execution_time_generator()
def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
logger.info("Processing Lineage for Views")
def multithread_process_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
"""Multithread Processing of a Node"""

views_list = list(self.context.get().table_views or [])
views_length = len(views_list)

if views_length != 0:
chunksize = int(math.ceil(views_length / self.source_config.threads))
chunks = [
views_list[i : i + chunksize] for i in range(0, views_length, chunksize)
]

thread_pool = ThreadPoolExecutor(max_workers=self.source_config.threads)
queue = Queue()

futures = [
thread_pool.submit(
self._process_view_def_chunk,
chunk,
queue,
self.context.get_current_thread_id(),
)
for chunk in chunks
]

while True:
if queue.has_tasks():
yield from queue.process()

else:
if not futures:
break

for i, future in enumerate(futures):
if future.done():
future.result()
futures.pop(i)

time.sleep(0.01)

def _process_view_def_chunk(
self, chunk: List[TableView], queue: Queue, thread_id: int
) -> None:
"""
Process a chunk of view definitions
"""
self.context.copy_from(thread_id)
ExecutionTimeTrackerContextMap().copy_from_parent(thread_id)
for view in [v for v in chunk if v.view_definition is not None]:
for lineage in get_view_lineage(
view=view,
metadata=self.metadata,
service_name=self.context.get().database_service,
connection_type=self.service_connection.type.value,
timeout_seconds=self.source_config.queryParsingTimeoutLimit,
):
if lineage.right is not None:
queue.put(
Either(
right=OMetaLineageRequest(
lineage_request=lineage.right,
override_lineage=self.source_config.overrideViewLineage,
)
)
)
else:
queue.put(lineage)

def _process_view_def_serial(self) -> Iterable[Either[OMetaLineageRequest]]:
"""
Process view definitions serially
"""
for view in [
v for v in self.context.get().table_views if v.view_definition is not None
]:
Expand All @@ -606,6 +680,14 @@ def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
else:
yield lineage

@calculate_execution_time_generator()
def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
logger.info("Processing Lineage for Views")
if self.source_config.threads > 1:
yield from self.multithread_process_view_lineage()
else:
yield from self._process_view_def_serial()

def _get_foreign_constraints(self, foreign_columns) -> List[TableConstraint]:
"""
Search the referred table for foreign constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@
from metadata.ingestion.source.database.incremental_metadata_extraction import (
IncrementalConfig,
)
from metadata.ingestion.source.database.life_cycle_query_mixin import (
LifeCycleQueryMixin,
)
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
from metadata.ingestion.source.database.snowflake.models import (
STORED_PROC_LANGUAGE_MAP,
Expand Down Expand Up @@ -148,7 +145,6 @@


class SnowflakeSource(
LifeCycleQueryMixin,
StoredProcedureMixin,
ExternalTableLineageMixin,
CommonDbSourceService,
Expand Down
33 changes: 33 additions & 0 deletions ingestion/tests/cli_e2e/base/config_builders/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

from copy import deepcopy

from metadata.generated.schema.metadataIngestion.testSuitePipeline import (
TestSuiteConfigType,
)

from ..e2e_types import E2EType


Expand Down Expand Up @@ -70,6 +74,34 @@ def build(self) -> dict:
return self.config


class DataQualityConfigBuilder(BaseBuilder):
"""Builder class for the data quality config"""

# pylint: disable=invalid-name
def __init__(self, config: dict, config_args: dict) -> None:
super().__init__(config, config_args)
self.test_case_defintions = self.config_args.get("test_case_definitions", [])
self.entity_fqn = self.config_args.get("entity_fqn", [])

# pylint: enable=invalid-name

def build(self) -> dict:
"""build profiler config"""
del self.config["source"]["sourceConfig"]["config"]
self.config["source"]["sourceConfig"] = {
"config": {
"type": TestSuiteConfigType.TestSuite.value,
"entityFullyQualifiedName": self.entity_fqn,
},
}

self.config["processor"] = {
"type": "orm-test-runner",
"config": {"testCases": self.test_case_defintions},
}
return self.config


class SchemaConfigBuilder(BaseBuilder):
"""Builder for schema filter config"""

Expand Down Expand Up @@ -147,6 +179,7 @@ def builder_factory(builder, config: dict, config_args: dict):
"""Factory method to return the builder class"""
builder_classes = {
E2EType.PROFILER.value: ProfilerConfigBuilder,
E2EType.DATA_QUALITY.value: DataQualityConfigBuilder,
E2EType.INGEST_DB_FILTER_SCHEMA.value: SchemaConfigBuilder,
E2EType.INGEST_DB_FILTER_TABLE.value: TableConfigBuilder,
E2EType.INGEST_DB_FILTER_MIX.value: MixConfigBuilder,
Expand Down
1 change: 1 addition & 0 deletions ingestion/tests/cli_e2e/base/e2e_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class E2EType(Enum):
INGEST = "ingest"
PROFILER = "profiler"
PROFILER_PROCESSOR = "profiler-processor"
DATA_QUALITY = "test"
INGEST_DB_FILTER_SCHEMA = "ingest-db-filter-schema"
INGEST_DB_FILTER_TABLE = "ingest-db-filter-table"
INGEST_DB_FILTER_MIX = "ingest-db-filter-mix"
Expand Down
61 changes: 61 additions & 0 deletions ingestion/tests/cli_e2e/base/test_cli_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from unittest import TestCase

import pytest
from pydantic import TypeAdapter

from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects
from metadata.data_quality.api.models import TestCaseDefinition
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.tests.basic import TestCaseResult
from metadata.generated.schema.tests.testCase import TestCase as OMTestCase
from metadata.ingestion.api.status import Status

from .e2e_types import E2EType
Expand Down Expand Up @@ -208,6 +213,50 @@ def test_profiler_with_time_partition(self) -> None:
sink_status,
)

@pytest.mark.order(12)
def test_data_quality(self) -> None:
"""12. Test data quality for the connector"""
if self.get_data_quality_table() is None:
return
self.delete_table_and_view()
self.create_table_and_view()
table: Table = self.openmetadata.get_by_name(
Table, self.get_data_quality_table(), nullable=False
)
self.build_config_file()
self.run_command()
test_case_definitions = self.get_test_case_definitions()
self.build_config_file(
E2EType.DATA_QUALITY,
{
"entity_fqn": table.fullyQualifiedName.root,
"test_case_definitions": TypeAdapter(
List[TestCaseDefinition]
).dump_python(test_case_definitions),
},
)
result = self.run_command("test")
sink_status, source_status = self.retrieve_statuses(result)
self.assert_status_for_data_quality(source_status, sink_status)
test_case_entities = [
self.openmetadata.get_by_name(
OMTestCase,
".".join([table.fullyQualifiedName.root, tcd.name]),
fields=["*"],
nullable=False,
)
for tcd in test_case_definitions
]
expected = self.get_expected_test_case_results()
try:
for test_case, expected in zip(test_case_entities, expected):
assert_equal_pydantic_objects(expected, test_case.testCaseResult)
finally:
for tc in test_case_entities:
self.openmetadata.delete(
OMTestCase, tc.id, recursive=True, hard_delete=True
)

def retrieve_table(self, table_name_fqn: str) -> Table:
return self.openmetadata.get_by_name(entity=Table, fqn=table_name_fqn)

Expand Down Expand Up @@ -346,3 +395,15 @@ def get_profiler_processor_config(self, config: dict) -> dict:
"config": {"tableConfig": [config]},
}
}

def get_data_quality_table(self):
return None

def get_test_case_definitions(self) -> List[TestCaseDefinition]:
pass

def get_expected_test_case_results(self) -> List[TestCaseResult]:
pass

def assert_status_for_data_quality(self, source_status, sink_status):
pass
30 changes: 29 additions & 1 deletion ingestion/tests/cli_e2e/test_cli_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects
from metadata.generated.schema.entity.data.table import DmlOperationType, SystemProfile
from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus
from metadata.generated.schema.tests.testCase import TestCaseParameterValue
from metadata.generated.schema.type.basic import Timestamp
from metadata.ingestion.api.status import Status

from ...src.metadata.data_quality.api.models import TestCaseDefinition
from .base.e2e_types import E2EType
from .common.test_cli_db import CliCommonDB
from .common_e2e_sqa_mixins import SQACommonMethods
Expand Down Expand Up @@ -65,7 +68,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):

insert_data_queries: List[str] = [
"INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1,'Peter Parker');",
"INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1, 'Clark Kent');",
"INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (2, 'Clark Kent');",
"INSERT INTO e2e_test.e2e_table (varchar_column, int_column) VALUES ('e2e_test.e2e_table', 1);",
"INSERT INTO public.e2e_table (varchar_column, int_column) VALUES ('public.e2e_table', 1);",
"INSERT INTO e2e_table (varchar_column, int_column) VALUES ('e2e_table', 1);",
Expand Down Expand Up @@ -316,3 +319,28 @@ def wait_for_query_log(cls, timeout=600):
)
if (datetime.now().timestamp() - start) > timeout:
raise TimeoutError(f"Query log not updated for {timeout} seconds")

def get_data_quality_table(self):
return "e2e_snowflake.E2E_DB.E2E_TEST.PERSONS"

def get_test_case_definitions(self) -> List[TestCaseDefinition]:
return [
TestCaseDefinition(
name="snowflake_data_diff",
testDefinitionName="tableDiff",
computePassedFailedRowCount=True,
parameterValues=[
TestCaseParameterValue(
name="table2",
value=self.get_data_quality_table(),
),
TestCaseParameterValue(
name="keyColumns",
value='["PERSON_ID"]',
),
],
)
]

def get_expected_test_case_results(self):
return [TestCaseResult(testCaseStatus=TestCaseStatus.Success)]
Loading

0 comments on commit ae6e983

Please sign in to comment.