Skip to content

Commit

Permalink
MINOR -- Fix DQ Partition Issue (open-metadata#18641)
Browse files Browse the repository at this point in the history
* fix: renamed `random_sample` to `get_dataset` and change dunder method access for SQA Table object

* fix: removed handle_partition decorator

* fix: fixed DQ partition issue + moved to `tablesample` method

* style: ran python linting

* style: fix python format check issues

* feat: added postgres tablesample

* style: ran python linting

* fix: sampling delta

* fix: merge conflicts

* fix: resolved conflicts

* style: ran python linting

* fix: patch orm call in test case

* fix: mock build_table_orm call in tests

* fix: test case failures and errors

* fix: removed unused import

* fix: patch typo

* fix: trino table schema retrieval

* fix: remove tuple context manager for 3.8 test support
  • Loading branch information
TeddyCr authored Nov 27, 2024
1 parent 9432cb5 commit 5869906
Show file tree
Hide file tree
Showing 56 changed files with 1,417 additions and 676 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,14 @@ def __init__(
)

(
self.table_sample_query,
self.table_sample_config,
self.table_partition_config,
self.sample_query,
self.profile_sample_config,
self.partition_details,
) = self._get_table_config()

# add partition logic to test suite
self.dfs = self.sampler.table
if self.dfs and self.table_partition_config:
self.dfs = self.get_partitioned_df(self.dfs)
self.dataset = self.sampler.get_dataset()

def _get_validator_builder(
self, test_case: TestCase, entity_type: str
) -> IValidatorBuilder:
return PandasValidatorBuilder(self.dfs, test_case, entity_type)
return PandasValidatorBuilder(self.dataset, test_case, entity_type)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
ometa_client: OpenMetadata,
sampler: SamplerInterface,
table_entity: Table = None,
orm_table=None,
):
super().__init__(
service_connection_config,
Expand All @@ -60,8 +59,6 @@ def __init__(
table_entity,
)
self.create_session()
self._table = orm_table

(
self.table_sample_query,
self.table_sample_config,
Expand All @@ -76,7 +73,7 @@ def create_session(self):
)

@property
def sample(self) -> Union[DeclarativeMeta, AliasedClass]:
def dataset(self) -> Union[DeclarativeMeta, AliasedClass]:
"""_summary_
Returns:
Expand All @@ -87,7 +84,7 @@ def sample(self) -> Union[DeclarativeMeta, AliasedClass]:
"You must create a sampler first `<instance>.create_sampler(...)`."
)

return self.sampler.random_sample()
return self.sampler.get_dataset()

@property
def runner(self) -> QueryRunner:
Expand All @@ -98,23 +95,13 @@ def runner(self) -> QueryRunner:
"""
return self._runner

@property
def table(self):
"""getter method for the table object
Returns:
Table: table object
"""
return self._table

def _create_runner(self) -> None:
def _create_runner(self) -> QueryRunner:
"""Create a QueryRunner Instance"""

return cls_timeout(TEN_MIN)(
QueryRunner(
session=self.session,
table=self.table,
sample=self.sample,
dataset=self.dataset,
partition_details=self.table_partition_config,
profile_sample_query=self.table_sample_query,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
from copy import deepcopy
from typing import Optional, cast

from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeMeta

from metadata.data_quality.interface.test_suite_interface import TestSuiteInterface
from metadata.data_quality.runner.core import DataTestsRunner
from metadata.generated.schema.entity.data.table import Table
Expand All @@ -31,10 +28,8 @@
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.profiler.orm.converter.base import ometa_to_sqa_orm
from metadata.sampler.models import SampleConfig
from metadata.sampler.sampler_interface import SamplerInterface
from metadata.utils.constants import NON_SQA_DATABASE_CONNECTIONS
from metadata.utils.profiler_utils import get_context_entities
from metadata.utils.service_spec.service_spec import (
import_sampler_class,
Expand Down Expand Up @@ -96,12 +91,6 @@ def _copy_service_config(

return config_copy

def _build_table_orm(self, entity: Table) -> Optional[DeclarativeMeta]:
"""Build the ORM table if needed for the sampler and profiler interfaces"""
if self.service_conn_config.type.value not in NON_SQA_DATABASE_CONNECTIONS:
return ometa_to_sqa_orm(entity, self.ometa_client, MetaData())
return None

def create_data_quality_interface(self) -> TestSuiteInterface:
"""Create data quality interface
Expand All @@ -122,7 +111,6 @@ def create_data_quality_interface(self) -> TestSuiteInterface:
source_config_type=self.service_conn_config.type.value,
)
# This is shared between the sampler and DQ interfaces
_orm = self._build_table_orm(self.entity)
sampler_interface: SamplerInterface = sampler_class.create(
service_connection_config=self.service_conn_config,
ometa_client=self.ometa_client,
Expand All @@ -134,15 +122,13 @@ def create_data_quality_interface(self) -> TestSuiteInterface:
profile_sample_type=self.source_config.profileSampleType,
sampling_method_type=self.source_config.samplingMethodType,
),
orm_table=_orm,
)

self.interface: TestSuiteInterface = test_suite_class.create(
self.service_conn_config,
self.ometa_client,
sampler_interface,
self.entity,
orm_table=_orm,
)
return self.interface

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
TestCaseStatus,
TestResultValue,
)
from metadata.profiler.metrics.registry import Metrics
from metadata.profiler.orm.converter.base import build_orm_col
from metadata.profiler.orm.functions.md5 import MD5
from metadata.profiler.orm.functions.substr import Substr
Expand Down Expand Up @@ -430,7 +431,7 @@ def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
self.runtime_params.table_profile_config.profileSampleType
== ProfileSampleType.ROWS
):
row_count = self.get_row_count()
row_count = self.get_total_row_count()
if row_count is None:
raise ValueError("Row count is required for ROWS profile sample type")
return int(
Expand Down Expand Up @@ -634,5 +635,13 @@ def get_case_sensitive(self):
)

def get_row_count(self) -> Optional[int]:
self.runner._sample = None # pylint: disable=protected-access
return self._compute_row_count(self.runner, None)

def get_total_row_count(self) -> Optional[int]:
row_count = Metrics.ROW_COUNT()
try:
row = self.runner.select_first_from_table(row_count.fn())
return dict(row).get(Metrics.ROW_COUNT.name)
except Exception as e:
logger.error(f"Error getting row count: {e}")
return None
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _run_results(self, column_name: str, range_type: str, range_interval: int):
date_or_datetime_fn = dispatch_to_date_or_datetime(
range_interval,
text(range_type),
get_partition_col_type(column_name.name, self.runner.table.__table__.c), # type: ignore
get_partition_col_type(column_name.name, self.runner.table.c), # type: ignore
)

return dict(
Expand Down
33 changes: 2 additions & 31 deletions ingestion/src/metadata/mixins/pandas/pandas_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Interfaces with database for all database engine
supporting sqlalchemy abstraction layer
"""
import math
import random
from typing import cast

Expand All @@ -23,7 +22,6 @@
from metadata.generated.schema.entity.data.table import (
PartitionIntervalTypes,
PartitionProfilerConfig,
ProfileSampleType,
)
from metadata.readers.dataframe.models import DatalakeTableSchemaWrapper
from metadata.utils.datalake.datalake_utils import fetch_dataframe
Expand Down Expand Up @@ -81,9 +79,7 @@ def get_partitioned_df(self, dfs):
for df in dfs
]

def return_ometa_dataframes_sampled(
self, service_connection_config, client, table, profile_sample_config
):
def get_dataframes(self, service_connection_config, client, table):
"""
returns sampled ometa dataframes
"""
Expand All @@ -94,35 +90,10 @@ def return_ometa_dataframes_sampled(
key=table.name.root,
bucket_name=table.databaseSchema.name,
file_extension=table.fileFormat,
separator=None,
),
)
if data:
random.shuffle(data)
# sampling data based on profiler config (if any)
if hasattr(profile_sample_config, "profile_sample"):
if (
profile_sample_config.profile_sample_type
== ProfileSampleType.PERCENTAGE
):
return [
df.sample(
frac=profile_sample_config.profile_sample / 100,
random_state=random.randint(0, 100),
replace=True,
)
for df in data
]
if profile_sample_config.profile_sample_type == ProfileSampleType.ROWS:
sample_rows_per_chunk: int = math.floor(
profile_sample_config.profile_sample / len(data)
)
return [
df.sample(
n=sample_rows_per_chunk,
random_state=random.randint(0, 100),
replace=True,
)
for df in data
]
return data
raise TypeError(f"Couldn't fetch {table.name.root}")
23 changes: 19 additions & 4 deletions ingestion/src/metadata/mixins/sqalchemy/sqa_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""


from typing import List
from typing import List, Optional

from sqlalchemy import Column, inspect
from sqlalchemy import Column, MetaData, inspect
from sqlalchemy.orm import DeclarativeMeta

from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.connections.database.databricksConnection import (
DatabricksConnection,
)
Expand All @@ -28,11 +30,15 @@
from metadata.generated.schema.entity.services.connections.database.unityCatalogConnection import (
UnityCatalogConnection,
)
from metadata.generated.schema.tests.basic import BaseModel
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_SESSION_TAG_QUERY,
)
from metadata.profiler.orm.converter.base import ometa_to_sqa_orm
from metadata.utils.collaborative_super import Root
from metadata.utils.constants import NON_SQA_DATABASE_CONNECTIONS


class SQAInterfaceMixin(Root):
Expand All @@ -46,13 +52,13 @@ def _get_engine(self):
Returns:
sqlalchemy engine
"""
engine = get_connection(self.service_connection_config)
engine = get_connection(super().service_connection_config)

return engine

def get_columns(self) -> Column:
"""get columns from an orm object"""
return inspect(self.table).c
return inspect(super().table).c

def set_session_tag(self, session) -> None:
"""
Expand Down Expand Up @@ -100,3 +106,12 @@ def _get_sample_columns(self) -> List[str]:
for column in self.table.__table__.columns
if column.name in {col.name.root for col in self.table_entity.columns}
]

def build_table_orm(
self, table: Table, service_conn_config: BaseModel, ometa_client: OpenMetadata
) -> Optional[DeclarativeMeta]:
"""Build the ORM table if needed for the sampler and profiler interfaces"""
if service_conn_config.type.value not in NON_SQA_DATABASE_CONNECTIONS:
orm_obj = ometa_to_sqa_orm(table, ometa_client, MetaData())
return orm_obj
return None
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ def get_composed_metrics(
):
return None

def get_hybrid_metrics(
self, column: Column, metric: Metrics, column_results: Dict, **kwargs
):
def get_hybrid_metrics(self, column: Column, metric: Metrics, column_results: Dict):
return None

def get_all_metrics(
Expand Down
Loading

0 comments on commit 5869906

Please sign in to comment.