Skip to content

Commit

Permalink
safe iterator for table diff
Browse files Browse the repository at this point in the history
  • Loading branch information
sushi30 committed Aug 6, 2024
1 parent 12bcd29 commit 4e8d55d
Showing 1 changed file with 31 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-module-docstring
import logging
import traceback
from itertools import islice
from typing import Dict, Iterable, List, Optional, Tuple
Expand Down Expand Up @@ -59,6 +60,13 @@ def __init__(self, param: str, dialect: str):
super().__init__(f"Unsupported dialect in param {param}: {dialect}")


def masked(s: str, masked: bool = True) -> str:
"""Mask a string if masked is True otherwise return the string.
Only for development purposes, do not use in production.
Change it False if you want to see the data in the logs."""
return "***" if masked else s


class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
"""
Compare two tables and fail if the number of differences exceeds a threshold
Expand Down Expand Up @@ -110,12 +118,14 @@ def _run(self) -> TestCaseResult:
stats = table_diff_iter.get_stats_dict()
if stats["total"] > 0:
logger.debug("Sample of failed rows:")
for s in islice(self.get_table_diff(), 10):
# since the data can contiant sensitive information, we don't want to log it
# we can uncomment this line if we must see the data in the logs
# logger.debug(s)
# by default we will log the data masked
logger.debug([s[0], ["*" for _ in s[1]]])
gen = self.get_table_diff()
# depending on the data, this require scanning a lot of data
# so we only log the sample in debug mode. data can be sensitive
# so it is masked by default
for s in islice(
self.safe_iterator(gen), 10 if logger.level <= logging.DEBUG else 0
):
logger.debug("%s", str([s[0]] + [masked(st) for st in s[1]]))
test_case_result = self.get_row_diff_test_case_result(
threshold,
stats["total"],
Expand Down Expand Up @@ -229,10 +239,6 @@ def get_table_diff(self) -> DiffResultWrapper:
",".join(f"{k}={v}" for k, v in data_diff_kwargs.items()),
)
)
# this might produce an error message like:
# Exception ignored in Exception ignored in: <generator object TableDiffer._diff_tables_wrapper at 0x00000000>
# this needs to be handled in the data_diff library
logger.debug("Ignore any 'Exception ignored in' log")
return data_diff.diff_tables(table1, table2, **data_diff_kwargs) # type: ignore

def get_where(self) -> Optional[str]:
Expand Down Expand Up @@ -404,3 +410,18 @@ def calculate_diffs_with_limit(
if len(key_set) > limit:
len(key_set)
return len(key_set)

@staticmethod
def safe_iterator(gen: DiffResultWrapper) -> DiffResultWrapper:
"""A safe iterator object which properly closes the diff object when the generator is exhausted.
Otherwise the data_diff library will continue to hold the connection open and eventually
raise a KeyError."""
try:
yield from gen
finally:
try:
gen.diff.close()
except KeyError as ex:
if str(ex) == "2":
# This is a known issue in data_diff where the diff object is closed
pass

0 comments on commit 4e8d55d

Please sign in to comment.