Skip to content

Commit

Permalink
Merge branch 'main' into importreferenceentityupdation
Browse files Browse the repository at this point in the history
  • Loading branch information
RounakDhillon authored Nov 29, 2024
2 parents 4300392 + 3453c79 commit 93d3ed0
Show file tree
Hide file tree
Showing 59 changed files with 935 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _create_runner(self) -> QueryRunner:
QueryRunner(
session=self.session,
dataset=self.dataset,
raw_dataset=self.sampler.raw_dataset,
partition_details=self.table_partition_config,
profile_sample_query=self.table_sample_query,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ColumnValuesToBeAtExpectedLocationValidator(
def _fetch_data(self, columns: List[str]) -> Iterator:
"""Fetch data from the runner object"""
self.runner = cast(QueryRunner, self.runner)
inspection = inspect(self.runner.table)
inspection = inspect(self.runner.dataset)
table_columns: List[Column] = inspection.c if inspection is not None else []
cols = [col for col in table_columns if col.name in columns]
for col in cols:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from sqlalchemy import Column, inspect
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.util import AliasedClass

from metadata.data_quality.validations.column.base.columnValuesToBeUnique import (
BaseColumnValuesToBeUniqueValidator,
Expand All @@ -41,7 +40,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
Expand All @@ -53,12 +52,7 @@ def _run_results(self, metric: Metrics, column: Column) -> Optional[int]:
"""
count = Metrics.COUNT.value(column).fn()
unique_count = Metrics.UNIQUE_COUNT.value(column).query(
sample=self.runner._sample # pylint: disable=protected-access
if isinstance(
self.runner._sample, # pylint: disable=protected-access
AliasedClass,
)
else self.runner.table,
sample=self.runner.dataset,
session=self.runner._session, # pylint: disable=protected-access
) # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_column_name(self) -> Column:
"""
return self.get_column_name(
self.test_case.entityLink.root,
inspect(self.runner.table).c,
inspect(self.runner.dataset).c,
)

def _run_results(self, metric: Metrics, column: Column, **kwargs) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Validator for table row inserted count to be between test case
"""

from sqlalchemy import Column, text
from sqlalchemy import Column, inspect, text

from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
Expand Down Expand Up @@ -52,7 +52,7 @@ def _run_results(self, column_name: str, range_type: str, range_interval: int):
date_or_datetime_fn = dispatch_to_date_or_datetime(
range_interval,
text(range_type),
get_partition_col_type(column_name.name, self.runner.table.c), # type: ignore
get_partition_col_type(column_name.name, inspect(self.runner.dataset).c), # type: ignore
)

return dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _compute_system_metrics(
**kwargs,
) -> List[SystemProfile]:
return self.system_metrics_computer.get_system_metrics(
table=runner.table,
table=runner.dataset,
usage_location=self.service_connection_config.usageLocation,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
DBT_CATALOG_FILE_NAME = "catalog.json"
DBT_MANIFEST_FILE_NAME = "manifest.json"
DBT_RUN_RESULTS_FILE_NAME = "run_results"
DBT_SOURCES_FILE_NAME = "sources.json"


class SkipResourceTypeEnum(Enum):
Expand All @@ -91,6 +92,7 @@ class SkipResourceTypeEnum(Enum):

ANALYSIS = "analysis"
TEST = "test"
SOURCE = "source"


class CompiledQueriesEnum(Enum):
Expand Down Expand Up @@ -127,6 +129,7 @@ class DbtTestFailureEnum(Enum):

FAILURE = "failure"
FAIL = "fail"
ERROR = "error"


class DbtCommonEnum(Enum):
Expand All @@ -137,6 +140,7 @@ class DbtCommonEnum(Enum):
OWNER = "owner"
NODES = "nodes"
SOURCES = "sources"
SOURCES_FILE = "sources_file"
SOURCE = "source"
RESOURCETYPE = "resource_type"
MANIFEST_NODE = "manifest_node"
Expand Down
18 changes: 18 additions & 0 deletions ingestion/src/metadata/ingestion/source/database/dbt/dbt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
DBT_CATALOG_FILE_NAME,
DBT_MANIFEST_FILE_NAME,
DBT_RUN_RESULTS_FILE_NAME,
DBT_SOURCES_FILE_NAME,
)
from metadata.ingestion.source.database.dbt.models import DbtFiles
from metadata.readers.file.config_source_factory import get_reader
Expand Down Expand Up @@ -85,6 +86,7 @@ def _(config: DbtLocalConfig):
config.dbtManifestFilePath,
config.dbtCatalogFilePath,
config.dbtRunResultsFilePath,
config.dbtSourcesFilePath,
]
yield from download_dbt_files(
blob_grouped_by_directory=blob_grouped_by_directory,
Expand Down Expand Up @@ -123,12 +125,22 @@ def _(config: DbtHttpConfig):
dbt_catalog = requests.get( # pylint: disable=missing-timeout
config.dbtCatalogHttpPath
)

dbt_sources = None
if config.dbtSourcesHttpPath:
logger.debug(
f"Requesting [dbtSourcesHttpPath] to: {config.dbtSourcesHttpPath}"
)
dbt_sources = requests.get( # pylint: disable=missing-timeout
config.dbtSourcesHttpPath
)
if not dbt_manifest:
raise DBTConfigException("Manifest file not found in file server")
yield DbtFiles(
dbt_catalog=dbt_catalog.json() if dbt_catalog else None,
dbt_manifest=dbt_manifest.json(),
dbt_run_results=[dbt_run_results.json()] if dbt_run_results else None,
dbt_sources=dbt_sources.json() if dbt_sources else None,
)
except DBTConfigException as exc:
raise exc
Expand Down Expand Up @@ -243,6 +255,7 @@ def get_blobs_grouped_by_dir(blobs: List[str]) -> Dict[str, List[str]]:
return blob_grouped_by_directory


# pylint: disable=too-many-locals, too-many-branches
def download_dbt_files(
blob_grouped_by_directory: Dict, config, client, bucket_name: Optional[str]
) -> Iterable[DbtFiles]:
Expand All @@ -255,6 +268,7 @@ def download_dbt_files(
) in blob_grouped_by_directory.items():
dbt_catalog = None
dbt_manifest = None
dbt_sources = None
dbt_run_results = []
kwargs = {}
if bucket_name:
Expand Down Expand Up @@ -285,12 +299,16 @@ def download_dbt_files(
logger.warning(
f"{DBT_RUN_RESULTS_FILE_NAME} not found in {key}: {exc}"
)
if DBT_SOURCES_FILE_NAME == blob_file_name.lower():
logger.debug(f"{DBT_SOURCES_FILE_NAME} found in {key}")
dbt_sources = reader.read(path=blob, **kwargs)
if not dbt_manifest:
raise DBTConfigException(f"Manifest file not found at: {key}")
yield DbtFiles(
dbt_catalog=json.loads(dbt_catalog) if dbt_catalog else None,
dbt_manifest=json.loads(dbt_manifest),
dbt_run_results=dbt_run_results if dbt_run_results else None,
dbt_sources=json.loads(dbt_sources) if dbt_sources else None,
)
except DBTConfigException as exc:
logger.warning(exc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from abc import ABC, abstractmethod
from typing import Iterable, List

from dbt_artifacts_parser.parser import parse_catalog, parse_manifest, parse_run_results
from dbt_artifacts_parser.parser import (
parse_catalog,
parse_manifest,
parse_run_results,
parse_sources,
)
from pydantic import Field
from typing_extensions import Annotated

Expand Down Expand Up @@ -209,11 +214,13 @@ def get_dbt_objects(self) -> Iterable[DbtObjects]:
self.remove_run_result_non_required_keys(
run_results=self.context.get().dbt_file.dbt_run_results
)

dbt_objects = DbtObjects(
dbt_catalog=parse_catalog(self.context.get().dbt_file.dbt_catalog)
if self.context.get().dbt_file.dbt_catalog
else None,
dbt_manifest=parse_manifest(self.context.get().dbt_file.dbt_manifest),
dbt_sources=parse_sources(self.context.get().dbt_file.dbt_sources),
dbt_run_results=[
parse_run_results(run_result_file)
for run_result_file in self.context.get().dbt_file.dbt_run_results
Expand Down
29 changes: 29 additions & 0 deletions ingestion/src/metadata/ingestion/source/database/dbt/dbt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ def create_test_case_parameter_definitions(dbt_test):
}
]
return test_case_param_definition
if hasattr(dbt_test, "freshness"):
test_case_param_definition = [
{
"name": "warn_after",
"displayName": "warn_after",
"required": False,
},
{
"name": "error_after",
"displayName": "error_after",
"required": False,
},
]
return test_case_param_definition
except Exception as err: # pylint: disable=broad-except
logger.debug(traceback.format_exc())
logger.error(
Expand All @@ -67,6 +81,21 @@ def create_test_case_parameter_values(dbt_test):
{"name": manifest_node.test_metadata.name, "value": dbt_test_values}
]
return test_case_param_values
if hasattr(manifest_node, "freshness"):
warn_after = manifest_node.freshness.warn_after
error_after = manifest_node.freshness.error_after

test_case_param_values = [
{
"name": "error_after",
"value": f"{error_after.count} {error_after.period.value}",
},
{
"name": "warn_after",
"value": f"{warn_after.count} {warn_after.period.value}",
},
]
return test_case_param_values
except Exception as err: # pylint: disable=broad-except
logger.debug(traceback.format_exc())
logger.error(
Expand Down
Loading

0 comments on commit 93d3ed0

Please sign in to comment.