From 4e8d55d27a15da47f0eb18a42e1ae0719e49a034 Mon Sep 17 00:00:00 2001 From: sushi30 Date: Tue, 6 Aug 2024 17:42:49 +0200 Subject: [PATCH] safe iterator for table diff --- .../validations/table/sqlalchemy/tableDiff.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 1792ca10d7d4..790d3349380d 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=missing-module-docstring +import logging import traceback from itertools import islice from typing import Dict, Iterable, List, Optional, Tuple @@ -59,6 +60,13 @@ def __init__(self, param: str, dialect: str): super().__init__(f"Unsupported dialect in param {param}: {dialect}") +def masked(s: str, masked: bool = True) -> str: + """Mask a string if masked is True otherwise return the string. + Only for development purposes, do not use in production. + Change it False if you want to see the data in the logs.""" + return "***" if masked else s + + class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): """ Compare two tables and fail if the number of differences exceeds a threshold @@ -110,12 +118,14 @@ def _run(self) -> TestCaseResult: stats = table_diff_iter.get_stats_dict() if stats["total"] > 0: logger.debug("Sample of failed rows:") - for s in islice(self.get_table_diff(), 10): - # since the data can contiant sensitive information, we don't want to log it - # we can uncomment this line if we must see the data in the logs - # logger.debug(s) - # by default we will log the data masked - logger.debug([s[0], ["*" for _ in s[1]]]) + gen = self.get_table_diff() + # depending on the data, this require scanning a lot of data + # so we only log the sample in debug mode. data can be sensitive + # so it is masked by default + for s in islice( + self.safe_iterator(gen), 10 if logger.level <= logging.DEBUG else 0 + ): + logger.debug("%s", str([s[0]] + [masked(st) for st in s[1]])) test_case_result = self.get_row_diff_test_case_result( threshold, stats["total"], @@ -229,10 +239,6 @@ def get_table_diff(self) -> DiffResultWrapper: ",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()), ) ) - # this might produce an error message like: - # Exception ignored in Exception ignored in: - # this needs to be handled in the data_diff library - logger.debug("Ignore any 'Exception ignored in' log") return data_diff.diff_tables(table1, table2, **data_diff_kwargs) # type: ignore def get_where(self) -> Optional[str]: @@ -404,3 +410,18 @@ def calculate_diffs_with_limit( if len(key_set) > limit: len(key_set) return len(key_set) + + @staticmethod + def safe_iterator(gen: DiffResultWrapper) -> DiffResultWrapper: + """A safe iterator object which properly closes the diff object when the generator is exhausted. + Otherwise the data_diff library will continue to hold the connection open and eventually + raise a KeyError.""" + try: + yield from gen + finally: + try: + gen.diff.close() + except KeyError as ex: + if str(ex) == "2": + # This is a known issue in data_diff where the diff object is closed + pass