From 6a1920dbca89772f2d73684a1c408938a6425957 Mon Sep 17 00:00:00 2001 From: Faisal Date: Mon, 25 Mar 2024 15:56:42 -0300 Subject: [PATCH 1/5] Pandas on Spark refactor (#275) * refactor SparkCompare * tweaking SparkCompare and adding back Legacy * conditional import * cleaning up tests and using pytest-spark for legacy * adding docs * caching and some typo fixes * adding in doc and pandas 2 changes * adding pandas to testing matrix * drop 3.8 * drop 3.8 * refactoring ^ * rebase fix for #277 * fixing legacy uncode column names * unicode fix for legacy * unicode test for new spark logic * typo fix * changes from PR review --- .github/workflows/test-package.yml | 13 +- README.md | 48 +- datacompy/__init__.py | 4 +- datacompy/legacy.py | 928 ++++++++ datacompy/spark.py | 1639 ++++++++------- docs/source/index.rst | 2 +- docs/source/spark_usage.rst | 449 ++-- pyproject.toml | 3 +- tests/test_legacy_spark.py | 2109 +++++++++++++++++++ tests/test_spark.py | 3134 +++++++++++----------------- 10 files changed, 5322 insertions(+), 3007 deletions(-) create mode 100644 datacompy/legacy.py create mode 100644 tests/test_legacy_spark.py diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index a8390f0e..2c06b2d2 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -19,11 +19,10 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10', '3.11'] - spark-version: [3.1.3, 3.2.4, 3.3.4, 3.4.2, 3.5.0] + python-version: [3.9, '3.10', '3.11'] + spark-version: [3.2.4, 3.3.4, 3.4.2, 3.5.1] + pandas-version: [2.2.1, 1.5.3] exclude: - - python-version: '3.11' - spark-version: 3.1.3 - python-version: '3.11' spark-version: 3.2.4 - python-version: '3.11' @@ -51,6 +50,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest pytest-spark pypandoc python -m pip install pyspark==${{ matrix.spark-version }} + python -m pip install pandas==${{ matrix.pandas-version }} python -m pip install .[dev] - name: Test with pytest run: | @@ -62,7 +62,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10', '3.11'] + python-version: [3.9, '3.10', '3.11'] + env: PYTHON_VERSION: ${{ matrix.python-version }} @@ -88,7 +89,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10', '3.11'] + python-version: [3.9, '3.10', '3.11'] env: PYTHON_VERSION: ${{ matrix.python-version }} diff --git a/README.md b/README.md index dc518c94..b9abfee5 100644 --- a/README.md +++ b/README.md @@ -38,16 +38,44 @@ pip install datacompy[ray] ``` -### In-scope Spark versions -Different versions of Spark play nicely with only certain versions of Python below is a matrix of what we test with +### Legacy Spark Deprecation + +#### Starting with version 0.12.0 + +The original ``SparkCompare`` implementation differs from all the other native implementations. To align the API better, and keep behaviour consistent we are deprecating ``SparkCompare`` into a new module ``LegacySparkCompare`` + +If you wish to use the old SparkCompare moving forward you can + +```python +import datacompy.legacy.LegacySparkCompare +``` + +#### Supported versions and dependncies + +Different versions of Spark, Pandas, and Python interact differently. Below is a matrix of what we test with. +With the move to Pandas on Spark API and compatability issues with Pandas 2+ we will for the mean time note support Pandas 2 +with the Pandas on Spark implementation. Spark plans to support Pandas 2 in [Spark 4](https://issues.apache.org/jira/browse/SPARK-44101) + +With version ``0.12.0``: +- Not support Pandas ``2.0.0`` For the native Spark implemention +- Spark ``3.1`` support will be dropped +- Python ``3.8`` support is dropped + + +| | Spark 3.2.4 | Spark 3.3.4 | Spark 3.4.2 | Spark 3.5.1 | +|-------------|-------------|-------------|-------------|-------------| +| Python 3.9 | ✅ | ✅ | ✅ | ✅ | +| Python 3.10 | ✅ | ✅ | ✅ | ✅ | +| Python 3.11 | ❌ | ❌ | ✅ | ✅ | +| Python 3.12 | ❌ | ❌ | ❌ | ❌ | + + +| | Pandas < 1.5.3 | Pandas >=2.0.0 | +|---------------|----------------|----------------| +| Native Pandas | ✅ | ✅ | +| Native Spark | ✅ | ❌ | +| Fugue | ✅ | ✅ | -| | Spark 3.1.3 | Spark 3.2.3 | Spark 3.3.4 | Spark 3.4.2 | Spark 3.5.0 | -|-------------|--------------|-------------|-------------|-------------|-------------| -| Python 3.8 | ✅ | ✅ | ✅ | ✅ | ✅ | -| Python 3.9 | ✅ | ✅ | ✅ | ✅ | ✅ | -| Python 3.10 | ✅ | ✅ | ✅ | ✅ | ✅ | -| Python 3.11 | ❌ | ❌ | ❌ | ✅ | ✅ | -| Python 3.12 | ❌ | ❌ | ❌ | ❌ | ❌ | > [!NOTE] @@ -56,7 +84,7 @@ Different versions of Spark play nicely with only certain versions of Python bel ## Supported backends - Pandas: ([See documentation](https://capitalone.github.io/datacompy/pandas_usage.html)) -- Spark: ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html)) +- Spark (Pandas on Spark API): ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html)) - Polars (Experimental): ([See documentation](https://capitalone.github.io/datacompy/polars_usage.html)) - Fugue is a Python library that provides a unified interface for data processing on Pandas, DuckDB, Polars, Arrow, Spark, Dask, Ray, and many other backends. DataComPy integrates with Fugue to provide a simple way to compare data diff --git a/datacompy/__init__.py b/datacompy/__init__.py index 6b1aab24..b47d27c3 100644 --- a/datacompy/__init__.py +++ b/datacompy/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.11.3" +__version__ = "0.12.0" from datacompy.core import * from datacompy.fugue import ( @@ -25,4 +25,4 @@ unq_columns, ) from datacompy.polars import PolarsCompare -from datacompy.spark import NUMERIC_SPARK_TYPES, SparkCompare +from datacompy.spark import SparkCompare diff --git a/datacompy/legacy.py b/datacompy/legacy.py new file mode 100644 index 00000000..b23b9cb2 --- /dev/null +++ b/datacompy/legacy.py @@ -0,0 +1,928 @@ +# +# Copyright 2024 Capital One Services, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from enum import Enum +from itertools import chain +from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union +from warnings import warn + +try: + import pyspark + from pyspark.sql import functions as F +except ImportError: + pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality + + +warn( + f"The module {__name__} is deprecated. In future versions LegacySparkCompare will be completely removed.", + DeprecationWarning, + stacklevel=2, +) + + +class MatchType(Enum): + MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3) + + +# Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal". +def decimal_comparator() -> str: + class DecimalComparator(str): + def __eq__(self, other: str) -> bool: # type: ignore[override] + return len(other) >= 7 and other[0:7] == "decimal" + + return DecimalComparator("decimal") + + +NUMERIC_SPARK_TYPES = [ + "tinyint", + "smallint", + "int", + "bigint", + "float", + "double", + decimal_comparator(), +] + + +def _is_comparable(type1: str, type2: str) -> bool: + """Checks if two Spark data types can be safely compared. + Two data types are considered comparable if any of the following apply: + 1. Both data types are the same + 2. Both data types are numeric + + Parameters + ---------- + type1 : str + A string representation of a Spark data type + type2 : str + A string representation of a Spark data type + + Returns + ------- + bool + True if both data types are comparable + """ + + return type1 == type2 or ( + type1 in NUMERIC_SPARK_TYPES and type2 in NUMERIC_SPARK_TYPES + ) + + +class LegacySparkCompare: + """Comparison class used to compare two Spark Dataframes. + + Extends the ``Compare`` functionality to the wide world of Spark and + out-of-memory data. + + Parameters + ---------- + spark_session : ``pyspark.sql.SparkSession`` + A ``SparkSession`` to be used to execute Spark commands in the + comparison. + base_df : ``pyspark.sql.DataFrame`` + The dataframe to serve as a basis for comparison. While you will + ultimately get the same results comparing A to B as you will comparing + B to A, by convention ``base_df`` should be the canonical, gold + standard reference dataframe in the comparison. + compare_df : ``pyspark.sql.DataFrame`` + The dataframe to be compared against ``base_df``. + join_columns : list + A list of columns comprising the join key(s) of the two dataframes. + If the column names are the same in the two dataframes, the names of + the columns can be given as strings. If the names differ, the + ``join_columns`` list should include tuples of the form + (base_column_name, compare_column_name). + column_mapping : list[tuple], optional + If columns to be compared have different names in the base and compare + dataframes, a list should be provided in ``columns_mapping`` consisting + of tuples of the form (base_column_name, compare_column_name) for each + set of differently-named columns to be compared against each other. + cache_intermediates : bool, optional + Whether or not ``SparkCompare`` will cache intermediate dataframes + (such as the deduplicated version of dataframes, or the joined + comparison). This will take a large amount of cache, proportional to + the size of your dataframes, but will significantly speed up + performance, as multiple steps will not have to recompute + transformations. False by default. + known_differences : list[dict], optional + A list of dictionaries that define transformations to apply to the + compare dataframe to match values when there are known differences + between base and compare. The dictionaries should contain: + + * name: A name that describes the transformation + * types: The types that the transformation should be applied to. + This prevents certain transformations from being applied to + types that don't make sense and would cause exceptions. + * transformation: A Spark SQL statement to apply to the column + in the compare dataset. The string "{input}" will be replaced + by the variable in question. + abs_tol : float, optional + Absolute tolerance between two values. + rel_tol : float, optional + Relative tolerance between two values. + show_all_columns : bool, optional + If true, all columns will be shown in the report including columns + with a 100% match rate. + match_rates : bool, optional + If true, match rates by column will be shown in the column summary. + + Returns + ------- + SparkCompare + Instance of a ``SparkCompare`` object, ready to do some comparin'. + Note that if ``cache_intermediates=True``, this instance will already + have done some work deduping the input dataframes. If + ``cache_intermediates=False``, the instantiation of this object is lazy. + """ + + def __init__( + self, + spark_session: "pyspark.sql.SparkSession", + base_df: "pyspark.sql.DataFrame", + compare_df: "pyspark.sql.DataFrame", + join_columns: List[Union[str, Tuple[str, str]]], + column_mapping: Optional[List[Tuple[str, str]]] = None, + cache_intermediates: bool = False, + known_differences: Optional[List[Dict[str, Any]]] = None, + rel_tol: float = 0, + abs_tol: float = 0, + show_all_columns: bool = False, + match_rates: bool = False, + ): + self.rel_tol = rel_tol + self.abs_tol = abs_tol + if self.rel_tol < 0 or self.abs_tol < 0: + raise ValueError("Please enter positive valued tolerances") + self.show_all_columns = show_all_columns + self.match_rates = match_rates + + self._original_base_df = base_df + self._original_compare_df = compare_df + self.cache_intermediates = cache_intermediates + + self.join_columns = self._tuplizer(input_list=join_columns) + self._join_column_names = [name[0] for name in self.join_columns] + + self._known_differences = known_differences + + if column_mapping: + for mapping in column_mapping: + compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0]) + self.column_mapping = dict(column_mapping) + else: + self.column_mapping = {} + + for mapping in self.join_columns: + if mapping[1] != mapping[0]: + compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0]) + + self.spark = spark_session + self.base_unq_rows = self.compare_unq_rows = None + self._base_row_count: Optional[int] = None + self._compare_row_count: Optional[int] = None + self._common_row_count: Optional[int] = None + self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None + self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None + self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None + self.columns_match_dict: Dict[str, Any] = {} + + # drop the duplicates before actual comparison made. + self.base_df = base_df.dropDuplicates(self._join_column_names) + self.compare_df = compare_df.dropDuplicates(self._join_column_names) + + if cache_intermediates: + self.base_df.cache() + self._base_row_count = self.base_df.count() + self.compare_df.cache() + self._compare_row_count = self.compare_df.count() + + def _tuplizer( + self, input_list: List[Union[str, Tuple[str, str]]] + ) -> List[Tuple[str, str]]: + join_columns: List[Tuple[str, str]] = [] + for val in input_list: + if isinstance(val, str): + join_columns.append((val, val)) + else: + join_columns.append(val) + + return join_columns + + @property + def columns_in_both(self) -> Set[str]: + """set[str]: Get columns in both dataframes""" + return set(self.base_df.columns) & set(self.compare_df.columns) + + @property + def columns_compared(self) -> List[str]: + """list[str]: Get columns to be compared in both dataframes (all + columns in both excluding the join key(s)""" + return [ + column + for column in list(self.columns_in_both) + if column not in self._join_column_names + ] + + @property + def columns_only_base(self) -> Set[str]: + """set[str]: Get columns that are unique to the base dataframe""" + return set(self.base_df.columns) - set(self.compare_df.columns) + + @property + def columns_only_compare(self) -> Set[str]: + """set[str]: Get columns that are unique to the compare dataframe""" + return set(self.compare_df.columns) - set(self.base_df.columns) + + @property + def base_row_count(self) -> int: + """int: Get the count of rows in the de-duped base dataframe""" + if self._base_row_count is None: + self._base_row_count = self.base_df.count() + + return self._base_row_count + + @property + def compare_row_count(self) -> int: + """int: Get the count of rows in the de-duped compare dataframe""" + if self._compare_row_count is None: + self._compare_row_count = self.compare_df.count() + + return self._compare_row_count + + @property + def common_row_count(self) -> int: + """int: Get the count of rows in common between base and compare dataframes""" + if self._common_row_count is None: + common_rows = self._get_or_create_joined_dataframe() + self._common_row_count = common_rows.count() + + return self._common_row_count + + def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame": + """Get the rows only from base data frame""" + return self.base_df.select(self._join_column_names).subtract( + self.compare_df.select(self._join_column_names) + ) + + def _get_compare_rows(self) -> "pyspark.sql.DataFrame": + """Get the rows only from compare data frame""" + return self.compare_df.select(self._join_column_names).subtract( + self.base_df.select(self._join_column_names) + ) + + def _print_columns_summary(self, myfile: TextIO) -> None: + """Prints the column summary details""" + print("\n****** Column Summary ******", file=myfile) + print( + f"Number of columns in common with matching schemas: {len(self._columns_with_matching_schema())}", + file=myfile, + ) + print( + f"Number of columns in common with schema differences: {len(self._columns_with_schemadiff())}", + file=myfile, + ) + print( + f"Number of columns in base but not compare: {len(self.columns_only_base)}", + file=myfile, + ) + print( + f"Number of columns in compare but not base: {len(self.columns_only_compare)}", + file=myfile, + ) + + def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None: + """Prints the columns and data types only in either the base or compare datasets""" + + if base_or_compare.upper() == "BASE": + columns = self.columns_only_base + df = self.base_df + elif base_or_compare.upper() == "COMPARE": + columns = self.columns_only_compare + df = self.compare_df + else: + raise ValueError( + f'base_or_compare must be BASE or COMPARE, but was "{base_or_compare}"' + ) + + # If there are no columns only in this dataframe, don't display this section + if not columns: + return + + max_length = max([len(col) for col in columns] + [11]) + format_pattern = f"{{:{max_length}s}}" + + print(f"\n****** Columns In {base_or_compare.title()} Only ******", file=myfile) + print((format_pattern + " Dtype").format("Column Name"), file=myfile) + print("-" * max_length + " -------------", file=myfile) + + for column in columns: + col_type = df.select(column).dtypes[0][1] + print((format_pattern + " {:13s}").format(column, col_type), file=myfile) + + def _columns_with_matching_schema(self) -> Dict[str, str]: + """This function will identify the columns which has matching schema""" + col_schema_match = {} + base_columns_dict = dict(self.base_df.dtypes) + compare_columns_dict = dict(self.compare_df.dtypes) + + for base_row, base_type in base_columns_dict.items(): + if base_row in compare_columns_dict: + compare_column_type = compare_columns_dict.get(base_row) + if compare_column_type is not None and base_type in compare_column_type: + col_schema_match[base_row] = compare_column_type + + return col_schema_match + + def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: + """This function will identify the columns which has different schema""" + col_schema_diff = {} + base_columns_dict = dict(self.base_df.dtypes) + compare_columns_dict = dict(self.compare_df.dtypes) + + for base_row, base_type in base_columns_dict.items(): + if base_row in compare_columns_dict: + compare_column_type = compare_columns_dict.get(base_row) + if ( + compare_column_type is not None + and base_type not in compare_column_type + ): + col_schema_diff[base_row] = dict( + base_type=base_type, + compare_type=compare_column_type, + ) + return col_schema_diff + + @property + def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]: + """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches""" + if self._all_rows_mismatched is None: + self._merge_dataframes() + + return self._all_rows_mismatched + + @property + def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]: + """pyspark.sql.DataFrame: Returns all rows in both dataframes""" + if self._all_matched_rows is None: + self._merge_dataframes() + + return self._all_matched_rows + + @property + def rows_only_base(self) -> "pyspark.sql.DataFrame": + """pyspark.sql.DataFrame: Returns rows only in the base dataframe""" + if not self._rows_only_base: + base_rows = self._get_unq_base_rows() + base_rows.createOrReplaceTempView("baseRows") + self.base_df.createOrReplaceTempView("baseTable") + join_condition = " AND ".join( + [ + "A.`" + name + "`<=>B.`" + name + "`" + for name in self._join_column_names + ] + ) + sql_query = "select A.* from baseTable as A, baseRows as B where {}".format( + join_condition + ) + self._rows_only_base = self.spark.sql(sql_query) + + if self.cache_intermediates: + self._rows_only_base.cache().count() + + return self._rows_only_base + + @property + def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: + """pyspark.sql.DataFrame: Returns rows only in the compare dataframe""" + if not self._rows_only_compare: + compare_rows = self._get_compare_rows() + compare_rows.createOrReplaceTempView("compareRows") + self.compare_df.createOrReplaceTempView("compareTable") + where_condition = " AND ".join( + [ + "A.`" + name + "`<=>B.`" + name + "`" + for name in self._join_column_names + ] + ) + sql_query = ( + "select A.* from compareTable as A, compareRows as B where {}".format( + where_condition + ) + ) + self._rows_only_compare = self.spark.sql(sql_query) + + if self.cache_intermediates: + self._rows_only_compare.cache().count() + + return self._rows_only_compare + + def _generate_select_statement(self, match_data: bool = True) -> str: + """This function is to generate the select statement to be used later in the query.""" + base_only = list(set(self.base_df.columns) - set(self.compare_df.columns)) + compare_only = list(set(self.compare_df.columns) - set(self.base_df.columns)) + sorted_list = sorted(list(chain(base_only, compare_only, self.columns_in_both))) + select_statement = "" + + for column_name in sorted_list: + if column_name in self.columns_compared: + if match_data: + select_statement = select_statement + ",".join( + [self._create_case_statement(name=column_name)] + ) + else: + select_statement = select_statement + ",".join( + [self._create_select_statement(name=column_name)] + ) + elif column_name in base_only: + select_statement = select_statement + ",".join( + ["A.`" + column_name + "`"] + ) + + elif column_name in compare_only: + if match_data: + select_statement = select_statement + ",".join( + ["B.`" + column_name + "`"] + ) + else: + select_statement = select_statement + ",".join( + ["A.`" + column_name + "`"] + ) + elif column_name in self._join_column_names: + select_statement = select_statement + ",".join( + ["A.`" + column_name + "`"] + ) + + if column_name != sorted_list[-1]: + select_statement = select_statement + " , " + + return select_statement + + def _merge_dataframes(self) -> None: + """Merges the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched.""" + full_joined_dataframe = self._get_or_create_joined_dataframe() + full_joined_dataframe.createOrReplaceTempView("full_matched_table") + + select_statement = self._generate_select_statement(False) + select_query = """SELECT {} FROM full_matched_table A""".format( + select_statement + ) + self._all_matched_rows = self.spark.sql(select_query).orderBy( + self._join_column_names # type: ignore[arg-type] + ) + self._all_matched_rows.createOrReplaceTempView("matched_table") + + where_cond = " OR ".join( + ["A.`" + name + "_match`= False" for name in self.columns_compared] + ) + mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond) + self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy( + self._join_column_names # type: ignore[arg-type] + ) + + def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame": + if self._joined_dataframe is None: + join_condition = " AND ".join( + [ + "A.`" + name + "`<=>B.`" + name + "`" + for name in self._join_column_names + ] + ) + select_statement = self._generate_select_statement(match_data=True) + + self.base_df.createOrReplaceTempView("base_table") + self.compare_df.createOrReplaceTempView("compare_table") + + join_query = r""" + SELECT {} + FROM base_table A + JOIN compare_table B + ON {}""".format( + select_statement, join_condition + ) + + self._joined_dataframe = self.spark.sql(join_query) + if self.cache_intermediates: + self._joined_dataframe.cache() + self._common_row_count = self._joined_dataframe.count() + + return self._joined_dataframe + + def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None: + # match_dataframe contains columns from both dataframes with flag to indicate if columns matched + match_dataframe = self._get_or_create_joined_dataframe().select( + *self.columns_compared + ) + match_dataframe.createOrReplaceTempView("matched_df") + + where_cond = " AND ".join( + [ + "A.`" + name + "`=" + str(MatchType.MATCH.value) + for name in self.columns_compared + ] + ) + match_query = ( + r"""SELECT count(*) AS row_count FROM matched_df A WHERE {}""".format( + where_cond + ) + ) + all_rows_matched = self.spark.sql(match_query) + all_rows_matched_head = all_rows_matched.head() + matched_rows = ( + all_rows_matched_head[0] if all_rows_matched_head is not None else 0 + ) + + print("\n****** Row Comparison ******", file=myfile) + print( + f"Number of rows with some columns unequal: {self.common_row_count - matched_rows}", + file=myfile, + ) + print(f"Number of rows with all columns equal: {matched_rows}", file=myfile) + + def _populate_columns_match_dict(self) -> None: + """ + side effects: + columns_match_dict assigned to { column -> match_type_counts } + where: + column (string): Name of a column that exists in both the base and comparison columns + match_type_counts (list of int with size = len(MatchType)): The number of each match type seen for this column (in order of the MatchType enum values) + + returns: None + """ + + match_dataframe = self._get_or_create_joined_dataframe().select( + *self.columns_compared + ) + + def helper(c: str) -> "pyspark.sql.Column": + # Create a predicate for each match type, comparing column values to the match type value + predicates = [F.col(c) == k.value for k in MatchType] + # Create a tuple(number of match types found for each match type in this column) + return F.struct( + [F.lit(F.sum(pred.cast("integer"))) for pred in predicates] + ).alias(c) + + # For each column, create a single tuple. This tuple's values correspond to the number of times + # each match type appears in that column + match_data_agg = match_dataframe.agg( + *[helper(col) for col in self.columns_compared] + ).collect() + match_data = match_data_agg[0] + + for c in self.columns_compared: + self.columns_match_dict[c] = match_data[c] + + def _create_select_statement(self, name: str) -> str: + if self._known_differences: + match_type_comparison = "" + for k in MatchType: + match_type_comparison += ( + " WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format( + name=name, match_value=str(k.value), match_name=k.name + ) + ) + return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format( + name=name, + match_failure=MatchType.MISMATCH.value, + match_type_comparison=match_type_comparison, + ) + else: + return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format( + name=name, match_failure=MatchType.MISMATCH.value + ) + + def _create_case_statement(self, name: str) -> str: + equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"] + known_diff_comparisons = ["(FALSE)"] + + base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0] + compare_dtype = [d[1] for d in self.compare_df.dtypes if d[0] == name][0] + + if _is_comparable(base_dtype, compare_dtype): + if (base_dtype in NUMERIC_SPARK_TYPES) and ( + compare_dtype in NUMERIC_SPARK_TYPES + ): # numeric tolerance comparison + equal_comparisons.append( + "((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=(" + + str(self.abs_tol) + + "+(" + + str(self.rel_tol) + + "*abs(A.`{name}`)))))" + ) + else: # non-numeric comparison + equal_comparisons.append("((A.`{name}`=B.`{name}`))") + + if self._known_differences: + new_input = "B.`{name}`" + for kd in self._known_differences: + if compare_dtype in kd["types"]: + if "flags" in kd and "nullcheck" in kd["flags"]: + known_diff_comparisons.append( + "((" + + kd["transformation"].format(new_input, input=new_input) + + ") is null AND A.`{name}` is null)" + ) + else: + known_diff_comparisons.append( + "((" + + kd["transformation"].format(new_input, input=new_input) + + ") = A.`{name}`)" + ) + + case_string = ( + "( CASE WHEN (" + + " OR ".join(equal_comparisons) + + ") THEN {match_success} WHEN (" + + " OR ".join(known_diff_comparisons) + + ") THEN {match_known_difference} ELSE {match_failure} END) " + + "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`" + ) + + return case_string.format( + name=name, + match_success=MatchType.MATCH.value, + match_known_difference=MatchType.KNOWN_DIFFERENCE.value, + match_failure=MatchType.MISMATCH.value, + ) + + def _print_row_summary(self, myfile: TextIO) -> None: + base_df_cnt = self.base_df.count() + compare_df_cnt = self.compare_df.count() + base_df_with_dup_cnt = self._original_base_df.count() + compare_df_with_dup_cnt = self._original_compare_df.count() + + print("\n****** Row Summary ******", file=myfile) + print(f"Number of rows in common: {self.common_row_count}", file=myfile) + print( + f"Number of rows in base but not compare: {base_df_cnt - self.common_row_count}", + file=myfile, + ) + print( + f"Number of rows in compare but not base: {compare_df_cnt - self.common_row_count}", + file=myfile, + ) + print( + f"Number of duplicate rows found in base: {base_df_with_dup_cnt - base_df_cnt}", + file=myfile, + ) + print( + f"Number of duplicate rows found in compare: {compare_df_with_dup_cnt - compare_df_cnt}", + file=myfile, + ) + + def _print_schema_diff_details(self, myfile: TextIO) -> None: + schema_diff_dict = self._columns_with_schemadiff() + + if not schema_diff_dict: # If there are no differences, don't print the section + return + + # For columns with mismatches, what are the longest base and compare column name lengths (with minimums)? + base_name_max = max([len(key) for key in schema_diff_dict] + [16]) + compare_name_max = max( + [len(self._base_to_compare_name(key)) for key in schema_diff_dict] + [19] + ) + + format_pattern = "{{:{base}s}} {{:{compare}s}}".format( + base=base_name_max, compare=compare_name_max + ) + + print("\n****** Schema Differences ******", file=myfile) + print( + (format_pattern + " Base Dtype Compare Dtype").format( + "Base Column Name", "Compare Column Name" + ), + file=myfile, + ) + print( + "-" * base_name_max + + " " + + "-" * compare_name_max + + " ------------- -------------", + file=myfile, + ) + + for base_column, types in schema_diff_dict.items(): + compare_column = self._base_to_compare_name(base_column) + + print( + (format_pattern + " {:13s} {:13s}").format( + base_column, + compare_column, + types["base_type"], + types["compare_type"], + ), + file=myfile, + ) + + def _base_to_compare_name(self, base_name: str) -> str: + """Translates a column name in the base dataframe to its counterpart in the + compare dataframe, if they are different.""" + + if base_name in self.column_mapping: + return self.column_mapping[base_name] + else: + for name in self.join_columns: + if base_name == name[0]: + return name[1] + return base_name + + def _print_row_matches_by_column(self, myfile: TextIO) -> None: + self._populate_columns_match_dict() + columns_with_mismatches = { + key: self.columns_match_dict[key] + for key in self.columns_match_dict + if self.columns_match_dict[key][MatchType.MISMATCH.value] + } + + # corner case: when all columns match but no rows match + # issue: #276 + try: + columns_fully_matching = { + key: self.columns_match_dict[key] + for key in self.columns_match_dict + if sum(self.columns_match_dict[key]) + == self.columns_match_dict[key][MatchType.MATCH.value] + } + except TypeError: + columns_fully_matching = {} + + try: + columns_with_any_diffs = { + key: self.columns_match_dict[key] + for key in self.columns_match_dict + if sum(self.columns_match_dict[key]) + != self.columns_match_dict[key][MatchType.MATCH.value] + } + except TypeError: + columns_with_any_diffs = {} + # + + base_types = {x[0]: x[1] for x in self.base_df.dtypes} + compare_types = {x[0]: x[1] for x in self.compare_df.dtypes} + + print("\n****** Column Comparison ******", file=myfile) + + if self._known_differences: + print( + f"Number of columns compared with unexpected differences in some values: {len(columns_with_mismatches)}", + file=myfile, + ) + print( + f"Number of columns compared with all values equal but known differences found: {len(self.columns_compared) - len(columns_with_mismatches) - len(columns_fully_matching)}", + file=myfile, + ) + print( + f"Number of columns compared with all values completely equal: {len(columns_fully_matching)}", + file=myfile, + ) + else: + print( + f"Number of columns compared with some values unequal: {len(columns_with_mismatches)}", + file=myfile, + ) + print( + f"Number of columns compared with all values equal: {len(columns_fully_matching)}", + file=myfile, + ) + + # If all columns matched, don't print columns with unequal values + if (not self.show_all_columns) and ( + len(columns_fully_matching) == len(self.columns_compared) + ): + return + + # if show_all_columns is set, set column name length maximum to max of ALL columns(with minimum) + if self.show_all_columns: + base_name_max = max([len(key) for key in self.columns_match_dict] + [16]) + compare_name_max = max( + [ + len(self._base_to_compare_name(key)) + for key in self.columns_match_dict + ] + + [19] + ) + + # For columns with any differences, what are the longest base and compare column name lengths (with minimums)? + else: + base_name_max = max([len(key) for key in columns_with_any_diffs] + [16]) + compare_name_max = max( + [len(self._base_to_compare_name(key)) for key in columns_with_any_diffs] + + [19] + ) + + """ list of (header, condition, width, align) + where + header (String) : output header for a column + condition (Bool): true if this header should be displayed + width (Int) : width of the column + align (Bool) : true if right-aligned + """ + headers_columns_unequal = [ + ("Base Column Name", True, base_name_max, False), + ("Compare Column Name", True, compare_name_max, False), + ("Base Dtype ", True, 13, False), + ("Compare Dtype", True, 13, False), + ("# Matches", True, 9, True), + ("# Known Diffs", self._known_differences is not None, 13, True), + ("# Mismatches", True, 12, True), + ] + if self.match_rates: + headers_columns_unequal.append(("Match Rate %", True, 12, True)) + headers_columns_unequal_valid = [h for h in headers_columns_unequal if h[1]] + padding = 2 # spaces to add to left and right of each column + + if self.show_all_columns: + print("\n****** Columns with Equal/Unequal Values ******", file=myfile) + else: + print("\n****** Columns with Unequal Values ******", file=myfile) + + format_pattern = (" " * padding).join( + [ + ("{:" + (">" if h[3] else "") + str(h[2]) + "}") + for h in headers_columns_unequal_valid + ] + ) + print( + format_pattern.format(*[h[0] for h in headers_columns_unequal_valid]), + file=myfile, + ) + print( + format_pattern.format( + *["-" * len(h[0]) for h in headers_columns_unequal_valid] + ), + file=myfile, + ) + + for column_name, column_values in sorted( + self.columns_match_dict.items(), key=lambda i: i[0] + ): + num_matches = column_values[MatchType.MATCH.value] + num_known_diffs = ( + None + if self._known_differences is None + else column_values[MatchType.KNOWN_DIFFERENCE.value] + ) + num_mismatches = column_values[MatchType.MISMATCH.value] + compare_column = self._base_to_compare_name(column_name) + + if num_mismatches or num_known_diffs or self.show_all_columns: + output_row = [ + column_name, + compare_column, + base_types.get(column_name), + compare_types.get(column_name), + str(num_matches), + str(num_mismatches), + ] + if self.match_rates: + match_rate = 100 * ( + 1 + - (column_values[MatchType.MISMATCH.value] + 0.0) + / self.common_row_count + + 0.0 + ) + output_row.append("{:02.5f}".format(match_rate)) + if num_known_diffs is not None: + output_row.insert(len(output_row) - 1, str(num_known_diffs)) + print(format_pattern.format(*output_row), file=myfile) + + # noinspection PyUnresolvedReferences + def report(self, file: TextIO = sys.stdout) -> None: + """Creates a comparison report and prints it to the file specified + (stdout by default). + + Parameters + ---------- + file : ``file``, optional + A filehandle to write the report to. By default, this is + sys.stdout, printing the report to stdout. You can also redirect + this to an output file, as in the example. + + Examples + -------- + >>> with open('my_report.txt', 'w') as report_file: + ... comparison.report(file=report_file) + """ + + self._print_columns_summary(file) + self._print_schema_diff_details(file) + self._print_only_columns("BASE", file) + self._print_only_columns("COMPARE", file) + self._print_row_summary(file) + self._merge_dataframes() + self._print_num_of_rows_with_column_equality(file) + self._print_row_matches_by_column(file) diff --git a/datacompy/spark.py b/datacompy/spark.py index 9fdc2093..cfa90397 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -13,916 +13,949 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys -from enum import Enum -from itertools import chain -from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union -from warnings import warn - -try: - import pyspark - from pyspark.sql import functions as F -except ImportError: - pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality - - -warn( - f"The module {__name__} is deprecated. In future versions (0.12.0 and above) SparkCompare will be refactored and the legacy logic will move to LegacySparkCompare ", - DeprecationWarning, - stacklevel=2, -) - - -class MatchType(Enum): - MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3) +""" +Compare two Pandas on Spark DataFrames +Originally this package was meant to provide similar functionality to +PROC COMPARE in SAS - i.e. human-readable reporting on the difference between +two dataframes. +""" -# Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal". -def decimal_comparator() -> str: - class DecimalComparator(str): - def __eq__(self, other: str) -> bool: # type: ignore[override] - return len(other) >= 7 and other[0:7] == "decimal" +import logging +import os - return DecimalComparator("decimal") +import pandas as pd +from ordered_set import OrderedSet +from datacompy.base import BaseCompare -NUMERIC_SPARK_TYPES = [ - "tinyint", - "smallint", - "int", - "bigint", - "float", - "double", - decimal_comparator(), -] - - -def _is_comparable(type1: str, type2: str) -> bool: - """Checks if two Spark data types can be safely compared. - Two data types are considered comparable if any of the following apply: - 1. Both data types are the same - 2. Both data types are numeric - - Parameters - ---------- - type1 : str - A string representation of a Spark data type - type2 : str - A string representation of a Spark data type +try: + import pyspark.pandas as ps + from pandas.api.types import is_numeric_dtype +except ImportError: + pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality - Returns - ------- - bool - True if both data types are comparable - """ - return type1 == type2 or ( - type1 in NUMERIC_SPARK_TYPES and type2 in NUMERIC_SPARK_TYPES - ) +LOG = logging.getLogger(__name__) -class SparkCompare: - """Comparison class used to compare two Spark Dataframes. +class SparkCompare(BaseCompare): + """Comparison class to be used to compare whether two Pandas on Spark dataframes are equal. - Extends the ``Compare`` functionality to the wide world of Spark and - out-of-memory data. + Both df1 and df2 should be dataframes containing all of the join_columns, + with unique column names. Differences between values are compared to + abs_tol + rel_tol * abs(df2['value']). Parameters ---------- - spark_session : ``pyspark.sql.SparkSession`` - A ``SparkSession`` to be used to execute Spark commands in the - comparison. - base_df : ``pyspark.sql.DataFrame`` - The dataframe to serve as a basis for comparison. While you will - ultimately get the same results comparing A to B as you will comparing - B to A, by convention ``base_df`` should be the canonical, gold - standard reference dataframe in the comparison. - compare_df : ``pyspark.sql.DataFrame`` - The dataframe to be compared against ``base_df``. - join_columns : list - A list of columns comprising the join key(s) of the two dataframes. - If the column names are the same in the two dataframes, the names of - the columns can be given as strings. If the names differ, the - ``join_columns`` list should include tuples of the form - (base_column_name, compare_column_name). - column_mapping : list[tuple], optional - If columns to be compared have different names in the base and compare - dataframes, a list should be provided in ``columns_mapping`` consisting - of tuples of the form (base_column_name, compare_column_name) for each - set of differently-named columns to be compared against each other. - cache_intermediates : bool, optional - Whether or not ``SparkCompare`` will cache intermediate dataframes - (such as the deduplicated version of dataframes, or the joined - comparison). This will take a large amount of cache, proportional to - the size of your dataframes, but will significantly speed up - performance, as multiple steps will not have to recompute - transformations. False by default. - known_differences : list[dict], optional - A list of dictionaries that define transformations to apply to the - compare dataframe to match values when there are known differences - between base and compare. The dictionaries should contain: - - * name: A name that describes the transformation - * types: The types that the transformation should be applied to. - This prevents certain transformations from being applied to - types that don't make sense and would cause exceptions. - * transformation: A Spark SQL statement to apply to the column - in the compare dataset. The string "{input}" will be replaced - by the variable in question. + df1 : pyspark.pandas.frame.DataFrame + First dataframe to check + df2 : pyspark.pandas.frame.DataFrame + Second dataframe to check + join_columns : list or str, optional + Column(s) to join dataframes on. If a string is passed in, that one + column will be used. abs_tol : float, optional Absolute tolerance between two values. rel_tol : float, optional Relative tolerance between two values. - show_all_columns : bool, optional - If true, all columns will be shown in the report including columns - with a 100% match rate. - match_rates : bool, optional - If true, match rates by column will be shown in the column summary. - - Returns - ------- - SparkCompare - Instance of a ``SparkCompare`` object, ready to do some comparin'. - Note that if ``cache_intermediates=True``, this instance will already - have done some work deduping the input dataframes. If - ``cache_intermediates=False``, the instantiation of this object is lazy. + df1_name : str, optional + A string name for the first dataframe. This allows the reporting to + print out an actual name instead of "df1", and allows human users to + more easily track the dataframes. + df2_name : str, optional + A string name for the second dataframe + ignore_spaces : bool, optional + Flag to strip whitespace (including newlines) from string columns (including any join + columns) + ignore_case : bool, optional + Flag to ignore the case of string columns + cast_column_names_lower: bool, optional + Boolean indicator that controls of column names will be cast into lower case + + Attributes + ---------- + df1_unq_rows : pyspark.pandas.frame.DataFrame + All records that are only in df1 (based on a join on join_columns) + df2_unq_rows : pyspark.pandas.frame.DataFrame + All records that are only in df2 (based on a join on join_columns) """ def __init__( self, - spark_session: "pyspark.sql.SparkSession", - base_df: "pyspark.sql.DataFrame", - compare_df: "pyspark.sql.DataFrame", - join_columns: List[Union[str, Tuple[str, str]]], - column_mapping: Optional[List[Tuple[str, str]]] = None, - cache_intermediates: bool = False, - known_differences: Optional[List[Dict[str, Any]]] = None, - rel_tol: float = 0, - abs_tol: float = 0, - show_all_columns: bool = False, - match_rates: bool = False, + df1, + df2, + join_columns, + abs_tol=0, + rel_tol=0, + df1_name="df1", + df2_name="df2", + ignore_spaces=False, + ignore_case=False, + cast_column_names_lower=True, ): - self.rel_tol = rel_tol + if pd.__version__ >= "2.0.0": + raise Exception( + "It seems like you are running Pandas 2+. Please note that Pandas 2+ will only be supported in Spark 4+. See: https://issues.apache.org/jira/browse/SPARK-44101. If you need to use Spark DataFrame with Pandas 2+ then consider using Fugue otherwise downgrade to Pandas 1.5.3" + ) + + ps.set_option("compute.ops_on_diff_frames", True) + self.cast_column_names_lower = cast_column_names_lower + if isinstance(join_columns, (str, int, float)): + self.join_columns = [ + ( + str(join_columns).lower() + if self.cast_column_names_lower + else str(join_columns) + ) + ] + else: + self.join_columns = [ + str(col).lower() if self.cast_column_names_lower else str(col) + for col in join_columns + ] + + self._any_dupes = False + self.df1 = df1 + self.df2 = df2 + self.df1_name = df1_name + self.df2_name = df2_name self.abs_tol = abs_tol - if self.rel_tol < 0 or self.abs_tol < 0: - raise ValueError("Please enter positive valued tolerances") - self.show_all_columns = show_all_columns - self.match_rates = match_rates + self.rel_tol = rel_tol + self.ignore_spaces = ignore_spaces + self.ignore_case = ignore_case + self.df1_unq_rows = self.df2_unq_rows = self.intersect_rows = None + self.column_stats = [] + self._compare(ignore_spaces, ignore_case) - self._original_base_df = base_df - self._original_compare_df = compare_df - self.cache_intermediates = cache_intermediates + @property + def df1(self): + return self._df1 + + @df1.setter + def df1(self, df1): + """Check that it is a dataframe and has the join columns""" + self._df1 = df1 + self._validate_dataframe( + "df1", cast_column_names_lower=self.cast_column_names_lower + ) - self.join_columns = self._tuplizer(input_list=join_columns) - self._join_column_names = [name[0] for name in self.join_columns] + @property + def df2(self): + return self._df2 + + @df2.setter + def df2(self, df2): + """Check that it is a dataframe and has the join columns""" + self._df2 = df2 + self._validate_dataframe( + "df2", cast_column_names_lower=self.cast_column_names_lower + ) - self._known_differences = known_differences + def _validate_dataframe(self, index, cast_column_names_lower=True): + """Check that it is a dataframe and has the join columns - if column_mapping: - for mapping in column_mapping: - compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0]) - self.column_mapping = dict(column_mapping) + Parameters + ---------- + index : str + The "index" of the dataframe - df1 or df2. + cast_column_names_lower: bool, optional + Boolean indicator that controls of column names will be cast into lower case + """ + dataframe = getattr(self, index) + if not isinstance(dataframe, (ps.DataFrame)): + raise TypeError(f"{index} must be a pyspark.pandas.frame.DataFrame") + + if cast_column_names_lower: + dataframe.columns = [str(col).lower() for col in dataframe.columns] else: - self.column_mapping = {} - - for mapping in self.join_columns: - if mapping[1] != mapping[0]: - compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0]) - - self.spark = spark_session - self.base_unq_rows = self.compare_unq_rows = None - self._base_row_count: Optional[int] = None - self._compare_row_count: Optional[int] = None - self._common_row_count: Optional[int] = None - self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None - self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None - self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None - self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None - self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None - self.columns_match_dict: Dict[str, Any] = {} - - # drop the duplicates before actual comparison made. - self.base_df = base_df.dropDuplicates(self._join_column_names) - self.compare_df = compare_df.dropDuplicates(self._join_column_names) - - if cache_intermediates: - self.base_df.cache() - self._base_row_count = self.base_df.count() - self.compare_df.cache() - self._compare_row_count = self.compare_df.count() - - def _tuplizer( - self, input_list: List[Union[str, Tuple[str, str]]] - ) -> List[Tuple[str, str]]: - join_columns: List[Tuple[str, str]] = [] - for val in input_list: - if isinstance(val, str): - join_columns.append((val, val)) - else: - join_columns.append(val) + dataframe.columns = [str(col) for col in dataframe.columns] + # Check if join_columns are present in the dataframe + if not set(self.join_columns).issubset(set(dataframe.columns)): + raise ValueError(f"{index} must have all columns from join_columns") - return join_columns + if len(set(dataframe.columns)) < len(dataframe.columns): + raise ValueError(f"{index} must have unique column names") - @property - def columns_in_both(self) -> Set[str]: - """set[str]: Get columns in both dataframes""" - return set(self.base_df.columns) & set(self.compare_df.columns) + if len(dataframe.drop_duplicates(subset=self.join_columns)) < len(dataframe): + self._any_dupes = True - @property - def columns_compared(self) -> List[str]: - """list[str]: Get columns to be compared in both dataframes (all - columns in both excluding the join key(s)""" - return [ - column - for column in list(self.columns_in_both) - if column not in self._join_column_names - ] + def _compare(self, ignore_spaces, ignore_case): + """Actually run the comparison. This tries to run df1.equals(df2) + first so that if they're truly equal we can tell. - @property - def columns_only_base(self) -> Set[str]: - """set[str]: Get columns that are unique to the base dataframe""" - return set(self.base_df.columns) - set(self.compare_df.columns) + This method will log out information about what is different between + the two dataframes, and will also return a boolean. + """ + LOG.debug("Checking equality") + if self.df1.equals(self.df2).all().all(): + LOG.info("df1 pyspark.pandas.frame.DataFrame.equals df2") + else: + LOG.info("df1 does not pyspark.pandas.frame.DataFrame.equals df2") + LOG.info(f"Number of columns in common: {len(self.intersect_columns())}") + LOG.debug("Checking column overlap") + for col in self.df1_unq_columns(): + LOG.info(f"Column in df1 and not in df2: {col}") + LOG.info( + f"Number of columns in df1 and not in df2: {len(self.df1_unq_columns())}" + ) + for col in self.df2_unq_columns(): + LOG.info(f"Column in df2 and not in df1: {col}") + LOG.info( + f"Number of columns in df2 and not in df1: {len(self.df2_unq_columns())}" + ) + # cache + self.df1.spark.cache() + self.df2.spark.cache() + + LOG.debug("Merging dataframes") + self._dataframe_merge(ignore_spaces) + self._intersect_compare(ignore_spaces, ignore_case) + if self.matches(): + LOG.info("df1 matches df2") + else: + LOG.info("df1 does not match df2") - @property - def columns_only_compare(self) -> Set[str]: - """set[str]: Get columns that are unique to the compare dataframe""" - return set(self.compare_df.columns) - set(self.base_df.columns) + def df1_unq_columns(self): + """Get columns that are unique to df1""" + return OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) - @property - def base_row_count(self) -> int: - """int: Get the count of rows in the de-duped base dataframe""" - if self._base_row_count is None: - self._base_row_count = self.base_df.count() + def df2_unq_columns(self): + """Get columns that are unique to df2""" + return OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) - return self._base_row_count + def intersect_columns(self): + """Get columns that are shared between the two dataframes""" + return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns) - @property - def compare_row_count(self) -> int: - """int: Get the count of rows in the de-duped compare dataframe""" - if self._compare_row_count is None: - self._compare_row_count = self.compare_df.count() + def _dataframe_merge(self, ignore_spaces): + """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1 + and df1 & df2 + """ - return self._compare_row_count + LOG.debug("Outer joining") - @property - def common_row_count(self) -> int: - """int: Get the count of rows in common between base and compare dataframes""" - if self._common_row_count is None: - common_rows = self._get_or_create_joined_dataframe() - self._common_row_count = common_rows.count() - - return self._common_row_count - - def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame": - """Get the rows only from base data frame""" - return self.base_df.select(self._join_column_names).subtract( - self.compare_df.select(self._join_column_names) - ) + df1 = self.df1.copy() + df2 = self.df2.copy() - def _get_compare_rows(self) -> "pyspark.sql.DataFrame": - """Get the rows only from compare data frame""" - return self.compare_df.select(self._join_column_names).subtract( - self.base_df.select(self._join_column_names) - ) + if self._any_dupes: + LOG.debug("Duplicate rows found, deduping by order of remaining fields") + temp_join_columns = list(self.join_columns) + + # Create order column for uniqueness of match + order_column = temp_column_name(df1, df2) + df1[order_column] = generate_id_within_group(df1, temp_join_columns) + df2[order_column] = generate_id_within_group(df2, temp_join_columns) + temp_join_columns.append(order_column) - def _print_columns_summary(self, myfile: TextIO) -> None: - """Prints the column summary details""" - print("\n****** Column Summary ******", file=myfile) - print( - f"Number of columns in common with matching schemas: {len(self._columns_with_matching_schema())}", - file=myfile, + params = {"on": temp_join_columns} + else: + params = {"on": self.join_columns} + + if ignore_spaces: + for column in self.join_columns: + if df1[column].dtype.kind == "O": + df1[column] = df1[column].str.strip() + if df2[column].dtype.kind == "O": + df2[column] = df2[column].str.strip() + + non_join_columns = ( + OrderedSet(df1.columns) | OrderedSet(df2.columns) + ) - OrderedSet(self.join_columns) + + for c in non_join_columns: + df1.rename(columns={c: c + "_df1"}, inplace=True) + df2.rename(columns={c: c + "_df2"}, inplace=True) + + # generate merge indicator + df1["_merge_left"] = True + df2["_merge_right"] = True + + for c in self.join_columns: + df1.rename(columns={c: c + "_df1"}, inplace=True) + df2.rename(columns={c: c + "_df2"}, inplace=True) + + # cache + df1.spark.cache() + df2.spark.cache() + + # NULL SAFE Outer join using ON + on = " and ".join([f"df1.`{c}_df1` <=> df2.`{c}_df2`" for c in params["on"]]) + outer_join = ps.sql( + """ + SELECT * FROM + {df1} df1 FULL OUTER JOIN {df2} df2 + ON + """ + + on, + df1=df1, + df2=df2, ) - print( - f"Number of columns in common with schema differences: {len(self._columns_with_schemadiff())}", - file=myfile, + + outer_join["_merge"] = None # initialize col + + # process merge indicator + outer_join["_merge"] = outer_join._merge.mask( + (outer_join["_merge_left"] == True) & (outer_join["_merge_right"] == True), + "both", ) - print( - f"Number of columns in base but not compare: {len(self.columns_only_base)}", - file=myfile, + outer_join["_merge"] = outer_join._merge.mask( + (outer_join["_merge_left"] == True) & (outer_join["_merge_right"] != True), + "left_only", ) - print( - f"Number of columns in compare but not base: {len(self.columns_only_compare)}", - file=myfile, + outer_join["_merge"] = outer_join._merge.mask( + (outer_join["_merge_left"] != True) & (outer_join["_merge_right"] == True), + "right_only", ) - def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None: - """Prints the columns and data types only in either the base or compare datasets""" - - if base_or_compare.upper() == "BASE": - columns = self.columns_only_base - df = self.base_df - elif base_or_compare.upper() == "COMPARE": - columns = self.columns_only_compare - df = self.compare_df - else: - raise ValueError( - f'base_or_compare must be BASE or COMPARE, but was "{base_or_compare}"' + # Clean up temp columns for duplicate row matching + if self._any_dupes: + outer_join = outer_join.drop( + [order_column + "_df1", order_column + "_df2"], axis=1 ) + df1 = df1.drop([order_column + "_df1", order_column + "_df2"], axis=1) + df2 = df2.drop([order_column + "_df1", order_column + "_df2"], axis=1) - # If there are no columns only in this dataframe, don't display this section - if not columns: - return - - max_length = max([len(col) for col in columns] + [11]) - format_pattern = f"{{:{max_length}s}}" - - print(f"\n****** Columns In {base_or_compare.title()} Only ******", file=myfile) - print((format_pattern + " Dtype").format("Column Name"), file=myfile) - print("-" * max_length + " -------------", file=myfile) - - for column in columns: - col_type = df.select(column).dtypes[0][1] - print((format_pattern + " {:13s}").format(column, col_type), file=myfile) - - def _columns_with_matching_schema(self) -> Dict[str, str]: - """This function will identify the columns which has matching schema""" - col_schema_match = {} - base_columns_dict = dict(self.base_df.dtypes) - compare_columns_dict = dict(self.compare_df.dtypes) - - for base_row, base_type in base_columns_dict.items(): - if base_row in compare_columns_dict: - compare_column_type = compare_columns_dict.get(base_row) - if compare_column_type is not None and base_type in compare_column_type: - col_schema_match[base_row] = compare_column_type - - return col_schema_match - - def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: - """This function will identify the columns which has different schema""" - col_schema_diff = {} - base_columns_dict = dict(self.base_df.dtypes) - compare_columns_dict = dict(self.compare_df.dtypes) - - for base_row, base_type in base_columns_dict.items(): - if base_row in compare_columns_dict: - compare_column_type = compare_columns_dict.get(base_row) - if ( - compare_column_type is not None - and base_type not in compare_column_type - ): - col_schema_diff[base_row] = dict( - base_type=base_type, - compare_type=compare_column_type, - ) - return col_schema_diff + df1_cols = get_merged_columns(df1, outer_join, "_df1") + df2_cols = get_merged_columns(df2, outer_join, "_df2") - @property - def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]: - """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches""" - if self._all_rows_mismatched is None: - self._merge_dataframes() + LOG.debug("Selecting df1 unique rows") + self.df1_unq_rows = outer_join[outer_join["_merge"] == "left_only"][ + df1_cols + ].copy() - return self._all_rows_mismatched + LOG.debug("Selecting df2 unique rows") + self.df2_unq_rows = outer_join[outer_join["_merge"] == "right_only"][ + df2_cols + ].copy() - @property - def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]: - """pyspark.sql.DataFrame: Returns all rows in both dataframes""" - if self._all_matched_rows is None: - self._merge_dataframes() + LOG.info(f"Number of rows in df1 and not in df2: {len(self.df1_unq_rows)}") + LOG.info(f"Number of rows in df2 and not in df1: {len(self.df2_unq_rows)}") - return self._all_matched_rows + LOG.debug("Selecting intersecting rows") + self.intersect_rows = outer_join[outer_join["_merge"] == "both"].copy() + LOG.info( + "Number of rows in df1 and df2 (not necessarily equal): {len(self.intersect_rows)}" + ) + # cache + self.intersect_rows.spark.cache() - @property - def rows_only_base(self) -> "pyspark.sql.DataFrame": - """pyspark.sql.DataFrame: Returns rows only in the base dataframe""" - if not self._rows_only_base: - base_rows = self._get_unq_base_rows() - base_rows.createOrReplaceTempView("baseRows") - self.base_df.createOrReplaceTempView("baseTable") - join_condition = " AND ".join( - [ - "A.`" + name + "`<=>B.`" + name + "`" - for name in self._join_column_names - ] - ) - sql_query = "select A.* from baseTable as A, baseRows as B where {}".format( - join_condition + def _intersect_compare(self, ignore_spaces, ignore_case): + """Run the comparison on the intersect dataframe + + This loops through all columns that are shared between df1 and df2, and + creates a column column_match which is True for matches, False + otherwise. + """ + LOG.debug("Comparing intersection") + row_cnt = len(self.intersect_rows) + for column in self.intersect_columns(): + if column in self.join_columns: + match_cnt = row_cnt + col_match = "" + max_diff = 0 + null_diff = 0 + else: + col_1 = column + "_df1" + col_2 = column + "_df2" + col_match = column + "_match" + self.intersect_rows[col_match] = columns_equal( + self.intersect_rows[col_1], + self.intersect_rows[col_2], + self.rel_tol, + self.abs_tol, + ignore_spaces, + ignore_case, + ) + match_cnt = self.intersect_rows[col_match].sum() + max_diff = calculate_max_diff( + self.intersect_rows[col_1], self.intersect_rows[col_2] + ) + + try: + null_diff = ( + (self.intersect_rows[col_1].isnull()) + ^ (self.intersect_rows[col_2].isnull()) + ).sum() + except TypeError: # older pyspark compatibility + temp_null_diff = self.intersect_rows[[col_1, col_2]].isnull() + null_diff = (temp_null_diff[col_1] != temp_null_diff[col_2]).sum() + + if row_cnt > 0: + match_rate = float(match_cnt) / row_cnt + else: + match_rate = 0 + LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match") + + self.column_stats.append( + { + "column": column, + "match_column": col_match, + "match_cnt": match_cnt, + "unequal_cnt": row_cnt - match_cnt, + "dtype1": str(self.df1[column].dtype), + "dtype2": str(self.df2[column].dtype), + "all_match": all( + ( + self.df1[column].dtype == self.df2[column].dtype, + row_cnt == match_cnt, + ) + ), + "max_diff": max_diff, + "null_diff": null_diff, + } ) - self._rows_only_base = self.spark.sql(sql_query) - if self.cache_intermediates: - self._rows_only_base.cache().count() + def all_columns_match(self): + """Whether the columns all match in the dataframes""" + return self.df1_unq_columns() == self.df2_unq_columns() == set() - return self._rows_only_base + def all_rows_overlap(self): + """Whether the rows are all present in both dataframes - @property - def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: - """pyspark.sql.DataFrame: Returns rows only in the compare dataframe""" - if not self._rows_only_compare: - compare_rows = self._get_compare_rows() - compare_rows.createOrReplaceTempView("compareRows") - self.compare_df.createOrReplaceTempView("compareTable") - where_condition = " AND ".join( - [ - "A.`" + name + "`<=>B.`" + name + "`" - for name in self._join_column_names - ] - ) - sql_query = ( - "select A.* from compareTable as A, compareRows as B where {}".format( - where_condition - ) + Returns + ------- + bool + True if all rows in df1 are in df2 and vice versa (based on + existence for join option) + """ + return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0 + + def count_matching_rows(self): + """Count the number of rows match (on overlapping fields) + + Returns + ------- + int + Number of matching rows + """ + conditions = [] + match_columns = [] + for column in self.intersect_columns(): + if column not in self.join_columns: + match_columns.append(column + "_match") + conditions.append(f"`{column}_match` == True") + if len(conditions) > 0: + match_columns_count = ( + self.intersect_rows[match_columns] + .query(" and ".join(conditions)) + .shape[0] ) - self._rows_only_compare = self.spark.sql(sql_query) + else: + match_columns_count = 0 + return match_columns_count - if self.cache_intermediates: - self._rows_only_compare.cache().count() + def intersect_rows_match(self): + """Check whether the intersect rows all match""" + actual_length = self.intersect_rows.shape[0] + return self.count_matching_rows() == actual_length - return self._rows_only_compare + def matches(self, ignore_extra_columns=False): + """Return True or False if the dataframes match. - def _generate_select_statement(self, match_data: bool = True) -> str: - """This function is to generate the select statement to be used later in the query.""" - base_only = list(set(self.base_df.columns) - set(self.compare_df.columns)) - compare_only = list(set(self.compare_df.columns) - set(self.base_df.columns)) - sorted_list = sorted(list(chain(base_only, compare_only, self.columns_in_both))) - select_statement = "" + Parameters + ---------- + ignore_extra_columns : bool + Ignores any columns in one dataframe and not in the other. + """ + if not ignore_extra_columns and not self.all_columns_match(): + return False + elif not self.all_rows_overlap(): + return False + elif not self.intersect_rows_match(): + return False + else: + return True - for column_name in sorted_list: - if column_name in self.columns_compared: - if match_data: - select_statement = select_statement + ",".join( - [self._create_case_statement(name=column_name)] - ) - else: - select_statement = select_statement + ",".join( - [self._create_select_statement(name=column_name)] - ) - elif column_name in base_only: - select_statement = select_statement + ",".join( - ["A.`" + column_name + "`"] - ) + def subset(self): + """Return True if dataframe 2 is a subset of dataframe 1. - elif column_name in compare_only: - if match_data: - select_statement = select_statement + ",".join( - ["B.`" + column_name + "`"] - ) - else: - select_statement = select_statement + ",".join( - ["A.`" + column_name + "`"] - ) - elif column_name in self._join_column_names: - select_statement = select_statement + ",".join( - ["A.`" + column_name + "`"] - ) + Dataframe 2 is considered a subset if all of its columns are in + dataframe 1, and all of its rows match rows in dataframe 1 for the + shared columns. + """ + if not self.df2_unq_columns() == set(): + return False + elif not len(self.df2_unq_rows) == 0: + return False + elif not self.intersect_rows_match(): + return False + else: + return True - if column_name != sorted_list[-1]: - select_statement = select_statement + " , " + def sample_mismatch(self, column, sample_count=10, for_display=False): + """Returns a sample sub-dataframe which contains the identifying + columns, and df1 and df2 versions of the column. - return select_statement + Parameters + ---------- + column : str + The raw column name (i.e. without ``_df1`` appended) + sample_count : int, optional + The number of sample records to return. Defaults to 10. + for_display : bool, optional + Whether this is just going to be used for display (overwrite the + column names) + + Returns + ------- + pyspark.pandas.frame.DataFrame + A sample of the intersection dataframe, containing only the + "pertinent" columns, for rows that don't match on the provided + column. + """ + row_cnt = self.intersect_rows.shape[0] + col_match = self.intersect_rows[column + "_match"] + match_cnt = col_match.sum() + sample_count = min(sample_count, row_cnt - match_cnt) + sample = self.intersect_rows[~col_match].head(sample_count) + + for c in self.join_columns: + sample[c] = sample[c + "_df1"] + + return_cols = self.join_columns + [column + "_df1", column + "_df2"] + to_return = sample[return_cols] + if for_display: + to_return.columns = self.join_columns + [ + column + " (" + self.df1_name + ")", + column + " (" + self.df2_name + ")", + ] + return to_return - def _merge_dataframes(self) -> None: - """Merges the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched.""" - full_joined_dataframe = self._get_or_create_joined_dataframe() - full_joined_dataframe.createOrReplaceTempView("full_matched_table") + def all_mismatch(self, ignore_matching_cols=False): + """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join + columns. - select_statement = self._generate_select_statement(False) - select_query = """SELECT {} FROM full_matched_table A""".format( - select_statement - ) - self._all_matched_rows = self.spark.sql(select_query).orderBy( - self._join_column_names # type: ignore[arg-type] - ) - self._all_matched_rows.createOrReplaceTempView("matched_table") + Parameters + ---------- + ignore_matching_cols : bool, optional + Whether showing the matching columns in the output or not. The default is False. - where_cond = " OR ".join( - ["A.`" + name + "_match`= False" for name in self.columns_compared] - ) - mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond) - self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy( - self._join_column_names # type: ignore[arg-type] - ) + Returns + ------- + pyspark.pandas.frame.DataFrame + All rows of the intersection dataframe, containing any columns, that don't match. + """ + match_list = [] + return_list = [] + for col in self.intersect_rows.columns: + if col.endswith("_match"): + orig_col_name = col[:-6] + + col_comparison = columns_equal( + self.intersect_rows[orig_col_name + "_df1"], + self.intersect_rows[orig_col_name + "_df2"], + self.rel_tol, + self.abs_tol, + self.ignore_spaces, + self.ignore_case, + ) - def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame": - if self._joined_dataframe is None: - join_condition = " AND ".join( - [ - "A.`" + name + "`<=>B.`" + name + "`" - for name in self._join_column_names - ] - ) - select_statement = self._generate_select_statement(match_data=True) + if not ignore_matching_cols or ( + ignore_matching_cols and not col_comparison.all() + ): + LOG.debug(f"Adding column {orig_col_name} to the result.") + match_list.append(col) + return_list.extend([orig_col_name + "_df1", orig_col_name + "_df2"]) + elif ignore_matching_cols: + LOG.debug( + f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result." + ) - self.base_df.createOrReplaceTempView("base_table") - self.compare_df.createOrReplaceTempView("compare_table") + mm_bool = self.intersect_rows[match_list].T.all() - join_query = r""" - SELECT {} - FROM base_table A - JOIN compare_table B - ON {}""".format( - select_statement, join_condition - ) + updated_join_columns = [] + for c in self.join_columns: + updated_join_columns.append(c + "_df1") + updated_join_columns.append(c + "_df2") - self._joined_dataframe = self.spark.sql(join_query) - if self.cache_intermediates: - self._joined_dataframe.cache() - self._common_row_count = self._joined_dataframe.count() + return self.intersect_rows[~mm_bool][updated_join_columns + return_list] - return self._joined_dataframe + def report(self, sample_count=10, column_count=10, html_file=None): + """Returns a string representation of a report. The representation can + then be printed or saved to a file. - def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None: - # match_dataframe contains columns from both dataframes with flag to indicate if columns matched - match_dataframe = self._get_or_create_joined_dataframe().select( - *self.columns_compared - ) - match_dataframe.createOrReplaceTempView("matched_df") + Parameters + ---------- + sample_count : int, optional + The number of sample records to return. Defaults to 10. - where_cond = " AND ".join( - [ - "A.`" + name + "`=" + str(MatchType.MATCH.value) - for name in self.columns_compared - ] - ) - match_query = ( - r"""SELECT count(*) AS row_count FROM matched_df A WHERE {}""".format( - where_cond - ) - ) - all_rows_matched = self.spark.sql(match_query) - all_rows_matched_head = all_rows_matched.head() - matched_rows = ( - all_rows_matched_head[0] if all_rows_matched_head is not None else 0 - ) + column_count : int, optional + The number of columns to display in the sample records output. Defaults to 10. - print("\n****** Row Comparison ******", file=myfile) - print( - f"Number of rows with some columns unequal: {self.common_row_count - matched_rows}", - file=myfile, - ) - print(f"Number of rows with all columns equal: {matched_rows}", file=myfile) + html_file : str, optional + HTML file name to save report output to. If ``None`` the file creation will be skipped. - def _populate_columns_match_dict(self) -> None: + Returns + ------- + str + The report, formatted kinda nicely. """ - side effects: - columns_match_dict assigned to { column -> match_type_counts } - where: - column (string): Name of a column that exists in both the base and comparison columns - match_type_counts (list of int with size = len(MatchType)): The number of each match type seen for this column (in order of the MatchType enum values) + # Header + report = render("header.txt") + df_header = ps.DataFrame( + { + "DataFrame": [self.df1_name, self.df2_name], + "Columns": [self.df1.shape[1], self.df2.shape[1]], + "Rows": [self.df1.shape[0], self.df2.shape[0]], + } + ) + report += df_header[["DataFrame", "Columns", "Rows"]].to_string() + report += "\n\n" + + # Column Summary + report += render( + "column_summary.txt", + len(self.intersect_columns()), + len(self.df1_unq_columns()), + len(self.df2_unq_columns()), + self.df1_name, + self.df2_name, + ) - returns: None - """ + # Row Summary + match_on = ", ".join(self.join_columns) + report += render( + "row_summary.txt", + match_on, + self.abs_tol, + self.rel_tol, + self.intersect_rows.shape[0], + self.df1_unq_rows.shape[0], + self.df2_unq_rows.shape[0], + self.intersect_rows.shape[0] - self.count_matching_rows(), + self.count_matching_rows(), + self.df1_name, + self.df2_name, + "Yes" if self._any_dupes else "No", + ) - match_dataframe = self._get_or_create_joined_dataframe().select( - *self.columns_compared + # Column Matching + report += render( + "column_comparison.txt", + len([col for col in self.column_stats if col["unequal_cnt"] > 0]), + len([col for col in self.column_stats if col["unequal_cnt"] == 0]), + sum([col["unequal_cnt"] for col in self.column_stats]), ) - def helper(c: str) -> "pyspark.sql.Column": - # Create a predicate for each match type, comparing column values to the match type value - predicates = [F.col(c) == k.value for k in MatchType] - # Create a tuple(number of match types found for each match type in this column) - return F.struct( - [F.lit(F.sum(pred.cast("integer"))) for pred in predicates] - ).alias(c) - - # For each column, create a single tuple. This tuple's values correspond to the number of times - # each match type appears in that column - match_data_agg = match_dataframe.agg( - *[helper(col) for col in self.columns_compared] - ).collect() - match_data = match_data_agg[0] - - for c in self.columns_compared: - self.columns_match_dict[c] = match_data[c] - - def _create_select_statement(self, name: str) -> str: - if self._known_differences: - match_type_comparison = "" - for k in MatchType: - match_type_comparison += ( - " WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format( - name=name, match_value=str(k.value), match_name=k.name - ) + match_stats = [] + match_sample = [] + any_mismatch = False + for column in self.column_stats: + if not column["all_match"]: + any_mismatch = True + match_stats.append( + { + "Column": column["column"], + f"{self.df1_name} dtype": column["dtype1"], + f"{self.df2_name} dtype": column["dtype2"], + "# Unequal": column["unequal_cnt"], + "Max Diff": column["max_diff"], + "# Null Diff": column["null_diff"], + } ) - return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format( - name=name, - match_failure=MatchType.MISMATCH.value, - match_type_comparison=match_type_comparison, + if column["unequal_cnt"] > 0: + match_sample.append( + self.sample_mismatch( + column["column"], sample_count, for_display=True + ) + ) + + if any_mismatch: + report += "Columns with Unequal Values or Types\n" + report += "------------------------------------\n" + report += "\n" + df_match_stats = ps.DataFrame(match_stats) + df_match_stats.sort_values("Column", inplace=True) + # Have to specify again for sorting + report += df_match_stats[ + [ + "Column", + f"{self.df1_name} dtype", + f"{self.df2_name} dtype", + "# Unequal", + "Max Diff", + "# Null Diff", + ] + ].to_string() + report += "\n\n" + + if sample_count > 0: + report += "Sample Rows with Unequal Values\n" + report += "-------------------------------\n" + report += "\n" + for sample in match_sample: + report += sample.to_string() + report += "\n\n" + + if min(sample_count, self.df1_unq_rows.shape[0]) > 0: + report += ( + f"Sample Rows Only in {self.df1_name} (First {column_count} Columns)\n" ) - else: - return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format( - name=name, match_failure=MatchType.MISMATCH.value + report += ( + f"---------------------------------------{'-' * len(self.df1_name)}\n" + ) + report += "\n" + columns = self.df1_unq_rows.columns[:column_count] + unq_count = min(sample_count, self.df1_unq_rows.shape[0]) + report += self.df1_unq_rows.head(unq_count)[columns].to_string() + report += "\n\n" + + if min(sample_count, self.df2_unq_rows.shape[0]) > 0: + report += ( + f"Sample Rows Only in {self.df2_name} (First {column_count} Columns)\n" ) + report += ( + f"---------------------------------------{'-' * len(self.df2_name)}\n" + ) + report += "\n" + columns = self.df2_unq_rows.columns[:column_count] + unq_count = min(sample_count, self.df2_unq_rows.shape[0]) + report += self.df2_unq_rows.head(unq_count)[columns].to_string() + report += "\n\n" - def _create_case_statement(self, name: str) -> str: - equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"] - known_diff_comparisons = ["(FALSE)"] - - base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0] - compare_dtype = [d[1] for d in self.compare_df.dtypes if d[0] == name][0] - - if _is_comparable(base_dtype, compare_dtype): - if (base_dtype in NUMERIC_SPARK_TYPES) and ( - compare_dtype in NUMERIC_SPARK_TYPES - ): # numeric tolerance comparison - equal_comparisons.append( - "((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=(" - + str(self.abs_tol) - + "+(" - + str(self.rel_tol) - + "*abs(A.`{name}`)))))" - ) - else: # non-numeric comparison - equal_comparisons.append("((A.`{name}`=B.`{name}`))") - - if self._known_differences: - new_input = "B.`{name}`" - for kd in self._known_differences: - if compare_dtype in kd["types"]: - if "flags" in kd and "nullcheck" in kd["flags"]: - known_diff_comparisons.append( - "((" - + kd["transformation"].format(new_input, input=new_input) - + ") is null AND A.`{name}` is null)" - ) - else: - known_diff_comparisons.append( - "((" - + kd["transformation"].format(new_input, input=new_input) - + ") = A.`{name}`)" - ) + if html_file: + html_report = report.replace("\n", "
").replace(" ", " ") + html_report = f"
{html_report}
" + with open(html_file, "w") as f: + f.write(html_report) - case_string = ( - "( CASE WHEN (" - + " OR ".join(equal_comparisons) - + ") THEN {match_success} WHEN (" - + " OR ".join(known_diff_comparisons) - + ") THEN {match_known_difference} ELSE {match_failure} END) " - + "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`" - ) + return report - return case_string.format( - name=name, - match_success=MatchType.MATCH.value, - match_known_difference=MatchType.KNOWN_DIFFERENCE.value, - match_failure=MatchType.MISMATCH.value, - ) - def _print_row_summary(self, myfile: TextIO) -> None: - base_df_cnt = self.base_df.count() - compare_df_cnt = self.compare_df.count() - base_df_with_dup_cnt = self._original_base_df.count() - compare_df_with_dup_cnt = self._original_compare_df.count() - - print("\n****** Row Summary ******", file=myfile) - print(f"Number of rows in common: {self.common_row_count}", file=myfile) - print( - f"Number of rows in base but not compare: {base_df_cnt - self.common_row_count}", - file=myfile, - ) - print( - f"Number of rows in compare but not base: {compare_df_cnt - self.common_row_count}", - file=myfile, - ) - print( - f"Number of duplicate rows found in base: {base_df_with_dup_cnt - base_df_cnt}", - file=myfile, - ) - print( - f"Number of duplicate rows found in compare: {compare_df_with_dup_cnt - compare_df_cnt}", - file=myfile, - ) +def render(filename, *fields): + """Renders out an individual template. This basically just reads in a + template file, and applies ``.format()`` on the fields. - def _print_schema_diff_details(self, myfile: TextIO) -> None: - schema_diff_dict = self._columns_with_schemadiff() + Parameters + ---------- + filename : str + The file that contains the template. Will automagically prepend the + templates directory before opening + fields : list + Fields to be rendered out in the template - if not schema_diff_dict: # If there are no differences, don't print the section - return + Returns + ------- + str + The fully rendered out file. + """ + this_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(this_dir, "templates", filename)) as file_open: + return file_open.read().format(*fields) - # For columns with mismatches, what are the longest base and compare column name lengths (with minimums)? - base_name_max = max([len(key) for key in schema_diff_dict] + [16]) - compare_name_max = max( - [len(self._base_to_compare_name(key)) for key in schema_diff_dict] + [19] - ) - format_pattern = "{{:{base}s}} {{:{compare}s}}".format( - base=base_name_max, compare=compare_name_max - ) +def columns_equal( + col_1, col_2, rel_tol=0, abs_tol=0, ignore_spaces=False, ignore_case=False +): + """Compares two columns from a dataframe, returning a True/False series, + with the same index as column 1. - print("\n****** Schema Differences ******", file=myfile) - print( - (format_pattern + " Base Dtype Compare Dtype").format( - "Base Column Name", "Compare Column Name" - ), - file=myfile, - ) - print( - "-" * base_name_max - + " " - + "-" * compare_name_max - + " ------------- -------------", - file=myfile, + - Two nulls (np.nan) will evaluate to True. + - A null and a non-null value will evaluate to False. + - Numeric values will use the relative and absolute tolerances. + - Decimal values (decimal.Decimal) will attempt to be converted to floats + before comparing + - Non-numeric values (i.e. where np.isclose can't be used) will just + trigger True on two nulls or exact matches. + + Parameters + ---------- + col_1 : pyspark.pandas.series.Series + The first column to look at + col_2 : pyspark.pandas.series.Series + The second column + rel_tol : float, optional + Relative tolerance + abs_tol : float, optional + Absolute tolerance + ignore_spaces : bool, optional + Flag to strip whitespace (including newlines) from string columns + ignore_case : bool, optional + Flag to ignore the case of string columns + + Returns + ------- + pyspark.pandas.series.Series + A series of Boolean values. True == the values match, False == the + values don't match. + """ + try: + compare = ((col_1 - col_2).abs() <= abs_tol + (rel_tol * col_2.abs())) | ( + col_1.isnull() & col_2.isnull() ) + except TypeError: + if ( + is_numeric_dtype(col_1.dtype.kind) and is_numeric_dtype(col_2.dtype.kind) + ) or ( + col_1.spark.data_type.typeName() == "decimal" + and col_2.spark.data_type.typeName() == "decimal" + ): + compare = ( + (col_1.astype(float) - col_2.astype(float)).abs() + <= abs_tol + (rel_tol * col_2.astype(float).abs()) + ) | (col_1.astype(float).isnull() & col_2.astype(float).isnull()) + else: + try: + col_1_temp = col_1.copy() + col_2_temp = col_2.copy() + if ignore_spaces: + if col_1.dtype.kind == "O": + col_1_temp = col_1_temp.str.strip() + if col_2.dtype.kind == "O": + col_2_temp = col_2_temp.str.strip() + + if ignore_case: + if col_1.dtype.kind == "O": + col_1_temp = col_1_temp.str.upper() + if col_2.dtype.kind == "O": + col_2_temp = col_2_temp.str.upper() + + if {col_1.dtype.kind, col_2.dtype.kind} == {"M", "O"}: + compare = compare_string_and_date_columns(col_1_temp, col_2_temp) + else: + compare = (col_1_temp == col_2_temp) | ( + col_1_temp.isnull() & col_2_temp.isnull() + ) - for base_column, types in schema_diff_dict.items(): - compare_column = self._base_to_compare_name(base_column) - - print( - (format_pattern + " {:13s} {:13s}").format( - base_column, - compare_column, - types["base_type"], - types["compare_type"], - ), - file=myfile, - ) + except: + # Blanket exception should just return all False + compare = ps.Series(False, index=col_1.index.to_numpy()) + return compare - def _base_to_compare_name(self, base_name: str) -> str: - """Translates a column name in the base dataframe to its counterpart in the - compare dataframe, if they are different.""" - if base_name in self.column_mapping: - return self.column_mapping[base_name] - else: - for name in self.join_columns: - if base_name == name[0]: - return name[1] - return base_name - - def _print_row_matches_by_column(self, myfile: TextIO) -> None: - self._populate_columns_match_dict() - columns_with_mismatches = { - key: self.columns_match_dict[key] - for key in self.columns_match_dict - if self.columns_match_dict[key][MatchType.MISMATCH.value] - } - - # corner case: when all columns match but no rows match - # issue: #276 - try: - columns_fully_matching = { - key: self.columns_match_dict[key] - for key in self.columns_match_dict - if sum(self.columns_match_dict[key]) - == self.columns_match_dict[key][MatchType.MATCH.value] - } - except TypeError: - columns_fully_matching = {} - - try: - columns_with_any_diffs = { - key: self.columns_match_dict[key] - for key in self.columns_match_dict - if sum(self.columns_match_dict[key]) - != self.columns_match_dict[key][MatchType.MATCH.value] - } - except TypeError: - columns_with_any_diffs = {} - # +def compare_string_and_date_columns(col_1, col_2): + """Compare a string column and date column, value-wise. This tries to + convert a string column to a date column and compare that way. - base_types = {x[0]: x[1] for x in self.base_df.dtypes} - compare_types = {x[0]: x[1] for x in self.compare_df.dtypes} + Parameters + ---------- + col_1 : pyspark.pandas.series.Series + The first column to look at + col_2 : pyspark.pandas.series.Series + The second column - print("\n****** Column Comparison ******", file=myfile) + Returns + ------- + pyspark.pandas.series.Series + A series of Boolean values. True == the values match, False == the + values don't match. + """ + if col_1.dtype.kind == "O": + obj_column = col_1 + date_column = col_2 + else: + obj_column = col_2 + date_column = col_1 + + try: + compare = ps.Series( + ( + (ps.to_datetime(obj_column) == date_column) + | (obj_column.isnull() & date_column.isnull()) + ).to_numpy() + ) # force compute + except: + compare = ps.Series(False, index=col_1.index.to_numpy()) + return compare + + +def get_merged_columns(original_df, merged_df, suffix): + """Gets the columns from an original dataframe, in the new merged dataframe - if self._known_differences: - print( - f"Number of columns compared with unexpected differences in some values: {len(columns_with_mismatches)}", - file=myfile, - ) - print( - f"Number of columns compared with all values equal but known differences found: {len(self.columns_compared) - len(columns_with_mismatches) - len(columns_fully_matching)}", - file=myfile, - ) - print( - f"Number of columns compared with all values completely equal: {len(columns_fully_matching)}", - file=myfile, - ) + Parameters + ---------- + original_df : pyspark.pandas.frame.DataFrame + The original, pre-merge dataframe + merged_df : pyspark.pandas.frame.DataFrame + Post-merge with another dataframe, with suffixes added in. + suffix : str + What suffix was used to distinguish when the original dataframe was + overlapping with the other merged dataframe. + """ + columns = [] + for col in original_df.columns: + if col in merged_df.columns: + columns.append(col) + elif col + suffix in merged_df.columns: + columns.append(col + suffix) else: - print( - f"Number of columns compared with some values unequal: {len(columns_with_mismatches)}", - file=myfile, - ) - print( - f"Number of columns compared with all values equal: {len(columns_fully_matching)}", - file=myfile, - ) + raise ValueError("Column not found: %s", col) + return columns - # If all columns matched, don't print columns with unequal values - if (not self.show_all_columns) and ( - len(columns_fully_matching) == len(self.columns_compared) - ): - return - # if show_all_columns is set, set column name length maximum to max of ALL columns(with minimum) - if self.show_all_columns: - base_name_max = max([len(key) for key in self.columns_match_dict] + [16]) - compare_name_max = max( - [ - len(self._base_to_compare_name(key)) - for key in self.columns_match_dict - ] - + [19] - ) +def temp_column_name(*dataframes): + """Gets a temp column name that isn't included in columns of any dataframes - # For columns with any differences, what are the longest base and compare column name lengths (with minimums)? - else: - base_name_max = max([len(key) for key in columns_with_any_diffs] + [16]) - compare_name_max = max( - [len(self._base_to_compare_name(key)) for key in columns_with_any_diffs] - + [19] - ) + Parameters + ---------- + dataframes : list of pyspark.pandas.frame.DataFrame + The DataFrames to create a temporary column name for - """ list of (header, condition, width, align) - where - header (String) : output header for a column - condition (Bool): true if this header should be displayed - width (Int) : width of the column - align (Bool) : true if right-aligned - """ - headers_columns_unequal = [ - ("Base Column Name", True, base_name_max, False), - ("Compare Column Name", True, compare_name_max, False), - ("Base Dtype ", True, 13, False), - ("Compare Dtype", True, 13, False), - ("# Matches", True, 9, True), - ("# Known Diffs", self._known_differences is not None, 13, True), - ("# Mismatches", True, 12, True), - ] - if self.match_rates: - headers_columns_unequal.append(("Match Rate %", True, 12, True)) - headers_columns_unequal_valid = [h for h in headers_columns_unequal if h[1]] - padding = 2 # spaces to add to left and right of each column - - if self.show_all_columns: - print("\n****** Columns with Equal/Unequal Values ******", file=myfile) - else: - print("\n****** Columns with Unequal Values ******", file=myfile) + Returns + ------- + str + String column name that looks like '_temp_x' for some integer x + """ + i = 0 + while True: + temp_column = f"_temp_{i}" + unique = True + for dataframe in dataframes: + if temp_column in dataframe.columns: + i += 1 + unique = False + if unique: + return temp_column - format_pattern = (" " * padding).join( - [ - ("{:" + (">" if h[3] else "") + str(h[2]) + "}") - for h in headers_columns_unequal_valid - ] - ) - print( - format_pattern.format(*[h[0] for h in headers_columns_unequal_valid]), - file=myfile, - ) - print( - format_pattern.format( - *["-" * len(h[0]) for h in headers_columns_unequal_valid] - ), - file=myfile, - ) - for column_name, column_values in sorted( - self.columns_match_dict.items(), key=lambda i: i[0] - ): - num_matches = column_values[MatchType.MATCH.value] - num_known_diffs = ( - None - if self._known_differences is None - else column_values[MatchType.KNOWN_DIFFERENCE.value] - ) - num_mismatches = column_values[MatchType.MISMATCH.value] - compare_column = self._base_to_compare_name(column_name) - - if num_mismatches or num_known_diffs or self.show_all_columns: - output_row = [ - column_name, - compare_column, - base_types.get(column_name), - compare_types.get(column_name), - str(num_matches), - str(num_mismatches), - ] - if self.match_rates: - match_rate = 100 * ( - 1 - - (column_values[MatchType.MISMATCH.value] + 0.0) - / self.common_row_count - + 0.0 - ) - output_row.append("{:02.5f}".format(match_rate)) - if num_known_diffs is not None: - output_row.insert(len(output_row) - 1, str(num_known_diffs)) - print(format_pattern.format(*output_row), file=myfile) +def calculate_max_diff(col_1, col_2): + """Get a maximum difference between two columns - # noinspection PyUnresolvedReferences - def report(self, file: TextIO = sys.stdout) -> None: - """Creates a comparison report and prints it to the file specified - (stdout by default). + Parameters + ---------- + col_1 : pyspark.pandas.series.Series + The first column + col_2 : pyspark.pandas.series.Series + The second column - Parameters - ---------- - file : ``file``, optional - A filehandle to write the report to. By default, this is - sys.stdout, printing the report to stdout. You can also redirect - this to an output file, as in the example. - - Examples - -------- - >>> with open('my_report.txt', 'w') as report_file: - ... comparison.report(file=report_file) - """ + Returns + ------- + Numeric + Numeric field, or zero. + """ + try: + return (col_1.astype(float) - col_2.astype(float)).abs().max() + except: + return 0 + + +def generate_id_within_group(dataframe, join_columns): + """Generate an ID column that can be used to deduplicate identical rows. The series generated + is the order within a unique group, and it handles nulls. + + Parameters + ---------- + dataframe : pyspark.pandas.frame.DataFrame + The dataframe to operate on + join_columns : list + List of strings which are the join columns - self._print_columns_summary(file) - self._print_schema_diff_details(file) - self._print_only_columns("BASE", file) - self._print_only_columns("COMPARE", file) - self._print_row_summary(file) - self._merge_dataframes() - self._print_num_of_rows_with_column_equality(file) - self._print_row_matches_by_column(file) + Returns + ------- + pyspark.pandas.series.Series + The ID column that's unique in each group. + """ + default_value = "DATACOMPY_NULL" + if dataframe[join_columns].isnull().any().any(): + if (dataframe[join_columns] == default_value).any().any(): + raise ValueError(f"{default_value} was found in your join columns") + return ( + dataframe[join_columns] + .astype(str) + .fillna(default_value) + .groupby(join_columns) + .cumcount() + ) + else: + return dataframe[join_columns].groupby(join_columns).cumcount() diff --git a/docs/source/index.rst b/docs/source/index.rst index 0ac03d6d..1d25d11c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,7 +10,7 @@ Contents Installation Pandas Usage - Spark Usage + Spark (Pandas on Spark) Usage Polars Usage Fugue Usage Developer Instructions diff --git a/docs/source/spark_usage.rst b/docs/source/spark_usage.rst index 82c62722..a532316e 100644 --- a/docs/source/spark_usage.rst +++ b/docs/source/spark_usage.rst @@ -1,243 +1,252 @@ -Spark Usage -=========== +Spark (Pandas on Spark) Usage +============================= .. important:: - With version ``v0.9.0`` SparkCompare now uses Null Safe (``<=>``) comparisons + With version ``v0.12.0`` the original ``SparkCompare`` is now replaced with a + Pandas on Spark implementation. The original ``SparkCompare`` implementation + differs from all the other native implementations. To align the API better, + and keep behaviour consistent we are deprecating the original ``SparkCompare`` + into a new module ``LegacySparkCompare`` + If you wish to use the old SparkCompare moving forward you can -DataComPy's ``SparkCompare`` class will join two dataframes either on a list of join -columns. It has the capability to map column names that may be different in each -dataframe, including in the join columns. You are responsible for creating the -dataframes from any source which Spark can handle and specifying a unique join -key. If there are duplicates in either dataframe by join key, the match process -will remove the duplicates before joining (and tell you how many duplicates were -found). + .. code-block:: python -As with the Pandas-based ``Compare`` class, comparisons will be attempted even -if dtypes don't match. Any schema differences will be reported in the output -as well as in any mismatch reports, so that you can assess whether or not a -type mismatch is a problem or not. + import datacompy.legacy.LegacySparkCompare + -The main reasons why you would choose to use ``SparkCompare`` over ``Compare`` -are that your data is too large to fit into memory, or you're comparing data -that works well in a Spark environment, like partitioned Parquet, CSV, or JSON -files, or Cerebro tables. +DataComPy's Pandas on Spark implementation ``SparkCompare`` (new in ``v0.12.0``) +is a very similar port of the Pandas version -Basic Usage ------------ +- ``on_index`` is NOT supported like in ``PandasCompare`` +- Joining is done using ``<=>`` which is the equality test that is safe for null values. +- In the backend we are using the Pandas on Spark API. This might be less optimal than + native Spark code but allows for better maintainability and readability. -.. code-block:: python - import datetime - import datacompy - from pyspark.sql import Row - - # This example assumes you have a SparkSession named "spark" in your environment, as you - # do when running `pyspark` from the terminal or in a Databricks notebook (Spark v2.0 and higher) - - data1 = [ - Row(acct_id=10000001234, dollar_amt=123.45, name='George Maharis', float_fld=14530.1555, - date_fld=datetime.date(2017, 1, 1)), - Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=1.0, - date_fld=datetime.date(2017, 1, 1)), - Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=None, - date_fld=datetime.date(2017, 1, 1)), - Row(acct_id=10000001237, dollar_amt=123456.0, name='Bob Loblaw', float_fld=345.12, - date_fld=datetime.date(2017, 1, 1)), - Row(acct_id=10000001239, dollar_amt=1.05, name='Lucille Bluth', float_fld=None, - date_fld=datetime.date(2017, 1, 1)) - ] - - data2 = [ - Row(acct_id=10000001234, dollar_amt=123.4, name='George Michael Bluth', float_fld=14530.155), - Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=None), - Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=1.0), - Row(acct_id=10000001237, dollar_amt=123456.0, name='Robert Loblaw', float_fld=345.12), - Row(acct_id=10000001238, dollar_amt=1.05, name='Loose Seal Bluth', float_fld=111.0) - ] - - base_df = spark.createDataFrame(data1) - compare_df = spark.createDataFrame(data2) - - comparison = datacompy.SparkCompare(spark, base_df, compare_df, join_columns=['acct_id']) - - # This prints out a human-readable report summarizing differences - comparison.report() - - -Using SparkCompare on EMR or standalone Spark ---------------------------------------------- - -1. Set proxy variables -2. Create a virtual environment, if desired (``virtualenv venv; source venv/bin/activate``) -3. Pip install datacompy and requirements -4. Ensure your SPARK_HOME environment variable is set (this is probably ``/usr/lib/spark`` but may - differ based on your installation) -5. Augment your PYTHONPATH environment variable with - ``export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.10.4-src.zip:$SPARK_HOME/python:$PYTHONPATH`` - (note that your version of py4j may differ depending on the version of Spark you're using) - - -Using SparkCompare on Databricks --------------------------------- - -1. Clone this repository locally -2. Create a datacompy egg by running ``python setup.py bdist_egg`` from the repo root directory. -3. From the Databricks front page, click the "Library" link under the "New" section. -4. On the New library page: - a. Change source to "Upload Python Egg or PyPi" - b. Under "Upload Egg", Library Name should be "datacompy" - c. Drag the egg file in datacompy/dist/ to the "Drop library egg here to upload" box - d. Click the "Create Library" button -5. Once the library has been created, from the library page (which you can find in your /Users/{login} workspace), - you can choose clusters to attach the library to. -6. ``import datacompy`` in a notebook attached to the cluster that the library is attached to and enjoy! - - -Performance Implications ------------------------- - -Spark scales incredibly well, so you can use ``SparkCompare`` to compare -billions of rows of data, provided you spin up a big enough cluster. Still, -joining billions of rows of data is an inherently large task, so there are a -couple of things you may want to take into consideration when getting into the -cliched realm of "big data": - -* ``SparkCompare`` will compare all columns in common in the dataframes and - report on the rest. If there are columns in the data that you don't care to - compare, use a ``select`` statement/method on the dataframe(s) to filter - those out. Particularly when reading from wide Parquet files, this can make - a huge difference when the columns you don't care about don't have to be - read into memory and included in the joined dataframe. -* For large datasets, adding ``cache_intermediates=True`` to the ``SparkCompare`` - call can help optimize performance by caching certain intermediate dataframes - in memory, like the de-duped version of each input dataset, or the joined - dataframe. Otherwise, Spark's lazy evaluation will recompute those each time - it needs the data in a report or as you access instance attributes. This may - be fine for smaller dataframes, but will be costly for larger ones. You do - need to ensure that you have enough free cache memory before you do this, so - this parameter is set to False by default. - - -Known Differences ------------------ - -For cases when two dataframes are expected to differ, it can be helpful to cluster detected -differences into three categories: matches, known differences, and true mismatches. Known -differences can be specified through an optional parameter: +Supported Version +------------------ + +.. important:: + + Spark will not offically support Pandas 2 until Spark 4: https://issues.apache.org/jira/browse/SPARK-44101 + + +Until then we will not be supporting Pandas 2 for the Pandas on Spark API implementaion. +For Fugue and the Native Pandas implementation Pandas 2 is supported. If you need to use Spark DataFrame with +Pandas 2+ then consider using Fugue otherwise downgrade to Pandas 1.5.3 + + +SparkCompare Object Setup +------------------------- + +There is currently only one supported method for joining your dataframes - by +join column(s). .. code-block:: python - SparkCompare(spark, base_df, compare_df, join_columns=[...], column_mapping=[...], - known_differences = [ - { - 'name': "My Known Difference Name", - 'types': ['int', 'bigint'], - 'flags': ['nullcheck'], - 'transformation': "case when {input}=0 then null else {input} end" - }, - ... - ] + from io import StringIO + import pandas as pd + import pyspark.pandas as ps + from datacompy import SparkCompare + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + data1 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + """ + + data2 = """acct_id,dollar_amt,name,float_fld + 10000001234,123.4,George Michael Bluth,14530.155 + 10000001235,0.45,Michael Bluth, + 10000001236,1345,George Bluth,1 + 10000001237,123456,Robert Loblaw,345.12 + 10000001238,1.05,Loose Seal Bluth,111 + """ + + df1 = ps.from_pandas(pd.read_csv(StringIO(data1))) + df2 = ps.from_pandas(pd.read_csv(StringIO(data2))) + + compare = SparkCompare( + df1, + df2, + join_columns='acct_id', # You can also specify a list of columns + abs_tol=0, # Optional, defaults to 0 + rel_tol=0, # Optional, defaults to 0 + df1_name='Original', # Optional, defaults to 'df1' + df2_name='New' # Optional, defaults to 'df2' ) + compare.matches(ignore_extra_columns=False) + # False + # This method prints out a human-readable report summarizing and sampling differences + print(compare.report()) + + +Reports +------- + +A report is generated by calling ``SparkCompare.report()``, which returns a string. +Here is a sample report generated by ``datacompy`` for the two tables above, +joined on ``acct_id`` (Note: if you don't specify ``df1_name`` and/or ``df2_name``, +then any instance of "original" or "new" in the report is replaced with "df1" +and/or "df2".):: + + DataComPy Comparison + -------------------- + + DataFrame Summary + ----------------- + + DataFrame Columns Rows + 0 Original 5 5 + 1 New 4 5 + + Column Summary + -------------- + + Number of columns in common: 4 + Number of columns in Original but not in New: 1 + Number of columns in New but not in Original: 0 + + Row Summary + ----------- -The 'known_differences' parameter is a list of Python dicts with the following fields: + Matched on: acct_id + Any duplicates on match values: No + Absolute Tolerance: 0 + Relative Tolerance: 0 + Number of rows in common: 4 + Number of rows in Original but not in New: 1 + Number of rows in New but not in Original: 1 -============== ========= ====================================================================== -Field Required? Description -============== ========= ====================================================================== -name yes A user-readable title for this known difference -types yes A list of Spark data types on which this transformation can be applied -flags no Special flags used for computing known differences -transformation yes Spark SQL function to apply, where {input} is a cell in the comparison -============== ========= ====================================================================== + Number of rows with some compared columns unequal: 4 + Number of rows with all compared columns equal: 0 -Valid flags are: + Column Comparison + ----------------- -========= ============================================================= -Flag Description -========= ============================================================= -nullcheck Must be set when the output of the transformation can be null -========= ============================================================= + Number of columns compared with some values unequal: 3 + Number of columns compared with all values equal: 1 + Total number of values which compare unequal: 6 -Transformations are applied to the compare side only. A known difference is found when transformation(compare.cell) equals base.cell. An example comparison is shown below. + Columns with Unequal Values or Types + ------------------------------------ + + Column Original dtype New dtype # Unequal Max Diff # Null Diff + 0 dollar_amt float64 float64 1 0.0500 0 + 2 float_fld float64 float64 3 0.0005 2 + 1 name object object 2 NaN 0 + + Sample Rows with Unequal Values + ------------------------------- + + acct_id dollar_amt (Original) dollar_amt (New) + 0 10000001234 123.45 123.4 + + acct_id name (Original) name (New) + 0 10000001234 George Maharis George Michael Bluth + 3 10000001237 Bob Loblaw Robert Loblaw + + acct_id float_fld (Original) float_fld (New) + 0 10000001234 14530.1555 14530.155 + 1 10000001235 1.0000 NaN + 2 10000001236 NaN 1.000 + + Sample Rows Only in Original (First 10 Columns) + ----------------------------------------------- + + acct_id_df1 dollar_amt_df1 name_df1 float_fld_df1 date_fld_df1 _merge_left + 5 10000001239 1.05 Lucille Bluth NaN 2017-01-01 True + + Sample Rows Only in New (First 10 Columns) + ------------------------------------------ + + acct_id_df2 dollar_amt_df2 name_df2 float_fld_df2 _merge_right + 4 10000001238 1.05 Loose Seal Bluth 111.0 True + + +Convenience Methods +------------------- + +There are a few convenience methods available after the comparison has been run: .. code-block:: python - import datetime - import datacompy - from pyspark.sql import Row - - base_data = [ - Row(acct_id=10000001234, acct_sfx_num=0, clsd_reas_cd='*2', open_dt=datetime.date(2017, 5, 1), tbal_cd='0001'), - Row(acct_id=10000001235, acct_sfx_num=0, clsd_reas_cd='V1', open_dt=datetime.date(2017, 5, 2), tbal_cd='0002'), - Row(acct_id=10000001236, acct_sfx_num=0, clsd_reas_cd='V2', open_dt=datetime.date(2017, 5, 3), tbal_cd='0003'), - Row(acct_id=10000001237, acct_sfx_num=0, clsd_reas_cd='*2', open_dt=datetime.date(2017, 5, 4), tbal_cd='0004'), - Row(acct_id=10000001238, acct_sfx_num=0, clsd_reas_cd='*2', open_dt=datetime.date(2017, 5, 5), tbal_cd='0005') - ] - base_df = spark.createDataFrame(base_data) - - compare_data = [ - Row(ACCOUNT_IDENTIFIER=10000001234, SUFFIX_NUMBER=0, AM00_STATC_CLOSED=None, AM00_DATE_ACCOUNT_OPEN=2017121, AM0B_FC_TBAL=1.0), - Row(ACCOUNT_IDENTIFIER=10000001235, SUFFIX_NUMBER=0, AM00_STATC_CLOSED='V1', AM00_DATE_ACCOUNT_OPEN=2017122, AM0B_FC_TBAL=2.0), - Row(ACCOUNT_IDENTIFIER=10000001236, SUFFIX_NUMBER=0, AM00_STATC_CLOSED='V2', AM00_DATE_ACCOUNT_OPEN=2017123, AM0B_FC_TBAL=3.0), - Row(ACCOUNT_IDENTIFIER=10000001237, SUFFIX_NUMBER=0, AM00_STATC_CLOSED='V3', AM00_DATE_ACCOUNT_OPEN=2017124, AM0B_FC_TBAL=4.0), - Row(ACCOUNT_IDENTIFIER=10000001238, SUFFIX_NUMBER=0, AM00_STATC_CLOSED=None, AM00_DATE_ACCOUNT_OPEN=2017125, AM0B_FC_TBAL=5.0) - ] - compare_df = spark.createDataFrame(compare_data) - - comparison = datacompy.SparkCompare(spark, base_df, compare_df, - join_columns = [('acct_id', 'ACCOUNT_IDENTIFIER'), ('acct_sfx_num', 'SUFFIX_NUMBER')], - column_mapping = [('clsd_reas_cd', 'AM00_STATC_CLOSED'), - ('open_dt', 'AM00_DATE_ACCOUNT_OPEN'), - ('tbal_cd', 'AM0B_FC_TBAL')], - known_differences= [ - {'name': 'Left-padded, four-digit numeric code', - 'types': ['tinyint', 'smallint', 'int', 'bigint', 'float', 'double', 'decimal'], - 'transformation': "lpad(cast({input} AS bigint), 4, '0')"}, - {'name': 'Null to *2', - 'types': ['string'], - 'transformation': "case when {input} is null then '*2' else {input} end"}, - {'name': 'Julian date -> date', - 'types': ['bigint'], - 'transformation': "to_date(cast(unix_timestamp(cast({input} AS string), 'yyyyDDD') AS timestamp))"} - ]) - comparison.report() - -Corresponding output:: - - ****** Column Summary ****** - Number of columns in common with matching schemas: 3 - Number of columns in common with schema differences: 2 - Number of columns in base but not compare: 0 - Number of columns in compare but not base: 0 - - ****** Schema Differences ****** - Base Column Name Compare Column Name Base Dtype Compare Dtype - ---------------- ---------------------- ------------- ------------- - open_dt AM00_DATE_ACCOUNT_OPEN date bigint - tbal_cd AM0B_FC_TBAL string double - - ****** Row Summary ****** - Number of rows in common: 5 - Number of rows in base but not compare: 0 - Number of rows in compare but not base: 0 - Number of duplicate rows found in base: 0 - Number of duplicate rows found in compare: 0 - - ****** Row Comparison ****** - Number of rows with some columns unequal: 5 - Number of rows with all columns equal: 0 - - ****** Column Comparison ****** - Number of columns compared with unexpected differences in some values: 1 - Number of columns compared with all values equal but known differences found: 2 - Number of columns compared with all values completely equal: 0 - - ****** Columns with Unequal Values ****** - Base Column Name Compare Column Name Base Dtype Compare Dtype # Matches # Known Diffs # Mismatches - ---------------- ------------------- ------------- ------------- --------- ------------- ------------ - clsd_reas_cd AM00_STATC_CLOSED string string 2 2 1 - open_dt AM00_DATE_ACCOUNT_OPEN date bigint 0 5 0 - tbal_cd AM0B_FC_TBAL string double 0 5 0 \ No newline at end of file + print(compare.intersect_rows[['name_df1', 'name_df2', 'name_match']]) + # name_df1 name_df2 name_match + # 0 George Maharis George Michael Bluth False + # 1 Michael Bluth Michael Bluth True + # 2 George Bluth George Bluth True + # 3 Bob Loblaw Robert Loblaw False + + print(compare.df1_unq_rows) + # acct_id_df1 dollar_amt_df1 name_df1 float_fld_df1 date_fld_df1 _merge_left + # 5 10000001239 1.05 Lucille Bluth NaN 2017-01-01 True + + print(compare.df2_unq_rows) + # acct_id_df2 dollar_amt_df2 name_df2 float_fld_df2 _merge_right + # 4 10000001238 1.05 Loose Seal Bluth 111.0 True + + print(compare.intersect_columns()) + # OrderedSet(['acct_id', 'dollar_amt', 'name', 'float_fld']) + + print(compare.df1_unq_columns()) + # OrderedSet(['date_fld']) + + print(compare.df2_unq_columns()) + # OrderedSet() + +Duplicate rows +-------------- + +Datacompy will try to handle rows that are duplicate in the join columns. It does this behind the +scenes by generating a unique ID within each unique group of the join columns. For example, if you +have two dataframes you're trying to join on acct_id: + +=========== ================ +acct_id name +=========== ================ +1 George Maharis +1 Michael Bluth +2 George Bluth +=========== ================ + +=========== ================ +acct_id name +=========== ================ +1 George Maharis +1 Michael Bluth +1 Tony Wonder +2 George Bluth +=========== ================ + +Datacompy will generate a unique temporary ID for joining: + +=========== ================ ======== +acct_id name temp_id +=========== ================ ======== +1 George Maharis 0 +1 Michael Bluth 1 +2 George Bluth 0 +=========== ================ ======== + +=========== ================ ======== +acct_id name temp_id +=========== ================ ======== +1 George Maharis 0 +1 Michael Bluth 1 +1 Tony Wonder 2 +2 George Bluth 0 +=========== ================ ======== + +And then merge the two dataframes on a combination of the join_columns you specified and the temporary +ID, before dropping the temp_id again. So the first two rows in the first dataframe will match the +first two rows in the second dataframe, and the third row in the second dataframe will be recognized +as uniquely in the second. diff --git a/pyproject.toml b/pyproject.toml index 8a6f73ea..8ece29a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ maintainers = [ ] license = {text = "Apache Software License"} dependencies = ["pandas<=2.2.1,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.8.7,>=0.8.7"] -requires-python = ">=3.8.0" +requires-python = ">=3.9.0" classifiers = [ "Intended Audience :: Developers", "Natural Language :: English", @@ -20,7 +20,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/test_legacy_spark.py b/tests/test_legacy_spark.py new file mode 100644 index 00000000..30ec1500 --- /dev/null +++ b/tests/test_legacy_spark.py @@ -0,0 +1,2109 @@ +# +# Copyright 2024 Capital One Services, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import io +import logging +import re +from decimal import Decimal + +import pytest + +pytest.importorskip("pyspark") + +from pyspark.sql import Row # noqa: E402 +from pyspark.sql.types import ( # noqa: E402 + DateType, + DecimalType, + DoubleType, + LongType, + StringType, + StructField, + StructType, +) + +from datacompy.legacy import ( # noqa: E402 + NUMERIC_SPARK_TYPES, + LegacySparkCompare, + _is_comparable, +) + +# Turn off py4j debug messages for all tests in this module +logging.getLogger("py4j").setLevel(logging.INFO) + +CACHE_INTERMEDIATES = True + + +@pytest.fixture(scope="module", name="base_df1") +def base_df1_fixture(spark_session): + mock_data = [ + Row( + acct=10000001234, + dollar_amt=123, + name="George Maharis", + float_fld=14530.1555, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001235, + dollar_amt=0, + name="Michael Bluth", + float_fld=1.0, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001236, + dollar_amt=1345, + name="George Bluth", + float_fld=None, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001237, + dollar_amt=123456, + name="Bob Loblaw", + float_fld=345.12, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001239, + dollar_amt=1, + name="Lucille Bluth", + float_fld=None, + date_fld=datetime.date(2017, 1, 1), + ), + ] + + return spark_session.createDataFrame(mock_data) + + +@pytest.fixture(scope="module", name="base_df2") +def base_df2_fixture(spark_session): + mock_data = [ + Row( + acct=10000001234, + dollar_amt=123, + super_duper_big_long_name="George Maharis", + float_fld=14530.1555, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001235, + dollar_amt=0, + super_duper_big_long_name="Michael Bluth", + float_fld=1.0, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001236, + dollar_amt=1345, + super_duper_big_long_name="George Bluth", + float_fld=None, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001237, + dollar_amt=123456, + super_duper_big_long_name="Bob Loblaw", + float_fld=345.12, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001239, + dollar_amt=1, + super_duper_big_long_name="Lucille Bluth", + float_fld=None, + date_fld=datetime.date(2017, 1, 1), + ), + ] + + return spark_session.createDataFrame(mock_data) + + +@pytest.fixture(scope="module", name="compare_df1") +def compare_df1_fixture(spark_session): + mock_data2 = [ + Row( + acct=10000001234, + dollar_amt=123.4, + name="George Michael Bluth", + float_fld=14530.155, + accnt_purge=False, + ), + Row( + acct=10000001235, + dollar_amt=0.45, + name="Michael Bluth", + float_fld=None, + accnt_purge=False, + ), + Row( + acct=10000001236, + dollar_amt=1345.0, + name="George Bluth", + float_fld=1.0, + accnt_purge=False, + ), + Row( + acct=10000001237, + dollar_amt=123456.0, + name="Bob Loblaw", + float_fld=345.12, + accnt_purge=False, + ), + Row( + acct=10000001238, + dollar_amt=1.05, + name="Loose Seal Bluth", + float_fld=111.0, + accnt_purge=True, + ), + Row( + acct=10000001238, + dollar_amt=1.05, + name="Loose Seal Bluth", + float_fld=111.0, + accnt_purge=True, + ), + ] + + return spark_session.createDataFrame(mock_data2) + + +@pytest.fixture(scope="module", name="compare_df2") +def compare_df2_fixture(spark_session): + mock_data = [ + Row( + acct=10000001234, + dollar_amt=123, + name="George Maharis", + float_fld=14530.1555, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001235, + dollar_amt=0, + name="Michael Bluth", + float_fld=1.0, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001236, + dollar_amt=1345, + name="George Bluth", + float_fld=None, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001237, + dollar_amt=123456, + name="Bob Loblaw", + float_fld=345.12, + date_fld=datetime.date(2017, 1, 1), + ), + Row( + acct=10000001239, + dollar_amt=1, + name="Lucille Bluth", + float_fld=None, + date_fld=datetime.date(2017, 1, 1), + ), + ] + + return spark_session.createDataFrame(mock_data) + + +@pytest.fixture(scope="module", name="compare_df3") +def compare_df3_fixture(spark_session): + mock_data2 = [ + Row( + account_identifier=10000001234, + dollar_amount=123.4, + name="George Michael Bluth", + float_field=14530.155, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001235, + dollar_amount=0.45, + name="Michael Bluth", + float_field=1.0, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001236, + dollar_amount=1345.0, + name="George Bluth", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001237, + dollar_amount=123456.0, + name="Bob Loblaw", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001239, + dollar_amount=1.05, + name="Lucille Bluth", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + ] + + return spark_session.createDataFrame(mock_data2) + + +@pytest.fixture(scope="module", name="base_tol") +def base_tol_fixture(spark_session): + tol_data1 = [ + Row( + account_identifier=10000001234, + dollar_amount=123.4, + name="Franklin Delano Bluth", + float_field=14530.155, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001235, + dollar_amount=500.0, + name="Surely Funke", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + Row( + account_identifier=10000001236, + dollar_amount=-1100.0, + name="Nichael Bluth", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + Row( + account_identifier=10000001237, + dollar_amount=0.45, + name="Mr. F", + float_field=1.0, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001238, + dollar_amount=1345.0, + name="Steve Holt!", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001239, + dollar_amount=123456.0, + name="Blue Man Group", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001240, + dollar_amount=1.1, + name="Her?", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + Row( + account_identifier=10000001241, + dollar_amount=0.0, + name="Mrs. Featherbottom", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + Row( + account_identifier=10000001242, + dollar_amount=0.0, + name="Ice", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), + Row( + account_identifier=10000001243, + dollar_amount=-10.0, + name="Frank Wrench", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + Row( + account_identifier=10000001244, + dollar_amount=None, + name="Lucille 2", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + Row( + account_identifier=10000001245, + dollar_amount=0.009999, + name="Gene Parmesan", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + Row( + account_identifier=10000001246, + dollar_amount=None, + name="Motherboy", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), + ] + + return spark_session.createDataFrame(tol_data1) + + +@pytest.fixture(scope="module", name="compare_abs_tol") +def compare_tol2_fixture(spark_session): + tol_data2 = [ + Row( + account_identifier=10000001234, + dollar_amount=123.4, + name="Franklin Delano Bluth", + float_field=14530.155, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # full match + Row( + account_identifier=10000001235, + dollar_amount=500.01, + name="Surely Funke", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by 0.01 + Row( + account_identifier=10000001236, + dollar_amount=-1100.01, + name="Nichael Bluth", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by -0.01 + Row( + account_identifier=10000001237, + dollar_amount=0.46000000001, + name="Mr. F", + float_field=1.0, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by 0.01000000001 + Row( + account_identifier=10000001238, + dollar_amount=1344.8999999999, + name="Steve Holt!", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by -0.01000000001 + Row( + account_identifier=10000001239, + dollar_amount=123456.0099999999, + name="Blue Man Group", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by 0.00999999999 + Row( + account_identifier=10000001240, + dollar_amount=1.090000001, + name="Her?", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by -0.00999999999 + Row( + account_identifier=10000001241, + dollar_amount=0.0, + name="Mrs. Featherbottom", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # both zero + Row( + account_identifier=10000001242, + dollar_amount=1.0, + name="Ice", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # base 0, compare 1 + Row( + account_identifier=10000001243, + dollar_amount=0.0, + name="Frank Wrench", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base -10, compare 0 + Row( + account_identifier=10000001244, + dollar_amount=-1.0, + name="Lucille 2", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base NULL, compare -1 + Row( + account_identifier=10000001245, + dollar_amount=None, + name="Gene Parmesan", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base 0.009999, compare NULL + Row( + account_identifier=10000001246, + dollar_amount=None, + name="Motherboy", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # both NULL + ] + + return spark_session.createDataFrame(tol_data2) + + +@pytest.fixture(scope="module", name="compare_rel_tol") +def compare_tol3_fixture(spark_session): + tol_data3 = [ + Row( + account_identifier=10000001234, + dollar_amount=123.4, + name="Franklin Delano Bluth", + float_field=14530.155, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # full match #MATCH + Row( + account_identifier=10000001235, + dollar_amount=550.0, + name="Surely Funke", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by 10% #MATCH + Row( + account_identifier=10000001236, + dollar_amount=-1000.0, + name="Nichael Bluth", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by -10% #MATCH + Row( + account_identifier=10000001237, + dollar_amount=0.49501, + name="Mr. F", + float_field=1.0, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by greater than 10% + Row( + account_identifier=10000001238, + dollar_amount=1210.001, + name="Steve Holt!", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by greater than -10% + Row( + account_identifier=10000001239, + dollar_amount=135801.59999, + name="Blue Man Group", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by just under 10% #MATCH + Row( + account_identifier=10000001240, + dollar_amount=1.000001, + name="Her?", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by just under -10% #MATCH + Row( + account_identifier=10000001241, + dollar_amount=0.0, + name="Mrs. Featherbottom", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # both zero #MATCH + Row( + account_identifier=10000001242, + dollar_amount=1.0, + name="Ice", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # base 0, compare 1 + Row( + account_identifier=10000001243, + dollar_amount=0.0, + name="Frank Wrench", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base -10, compare 0 + Row( + account_identifier=10000001244, + dollar_amount=-1.0, + name="Lucille 2", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base NULL, compare -1 + Row( + account_identifier=10000001245, + dollar_amount=None, + name="Gene Parmesan", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base 0.009999, compare NULL + Row( + account_identifier=10000001246, + dollar_amount=None, + name="Motherboy", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # both NULL #MATCH + ] + + return spark_session.createDataFrame(tol_data3) + + +@pytest.fixture(scope="module", name="compare_both_tol") +def compare_tol4_fixture(spark_session): + tol_data4 = [ + Row( + account_identifier=10000001234, + dollar_amount=123.4, + name="Franklin Delano Bluth", + float_field=14530.155, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # full match + Row( + account_identifier=10000001235, + dollar_amount=550.01, + name="Surely Funke", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by 10% and +0.01 + Row( + account_identifier=10000001236, + dollar_amount=-1000.01, + name="Nichael Bluth", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by -10% and -0.01 + Row( + account_identifier=10000001237, + dollar_amount=0.505000000001, + name="Mr. F", + float_field=1.0, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by greater than 10% and +0.01 + Row( + account_identifier=10000001238, + dollar_amount=1209.98999, + name="Steve Holt!", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by greater than -10% and -0.01 + Row( + account_identifier=10000001239, + dollar_amount=135801.609999, + name="Blue Man Group", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # off by just under 10% and just under +0.01 + Row( + account_identifier=10000001240, + dollar_amount=0.99000001, + name="Her?", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # off by just under -10% and just under -0.01 + Row( + account_identifier=10000001241, + dollar_amount=0.0, + name="Mrs. Featherbottom", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # both zero + Row( + account_identifier=10000001242, + dollar_amount=1.0, + name="Ice", + float_field=345.12, + date_field=datetime.date(2017, 1, 1), + accnt_purge=False, + ), # base 0, compare 1 + Row( + account_identifier=10000001243, + dollar_amount=0.0, + name="Frank Wrench", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base -10, compare 0 + Row( + account_identifier=10000001244, + dollar_amount=-1.0, + name="Lucille 2", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base NULL, compare -1 + Row( + account_identifier=10000001245, + dollar_amount=None, + name="Gene Parmesan", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # base 0.009999, compare NULL + Row( + account_identifier=10000001246, + dollar_amount=None, + name="Motherboy", + float_field=None, + date_field=datetime.date(2017, 1, 1), + accnt_purge=True, + ), # both NULL + ] + + return spark_session.createDataFrame(tol_data4) + + +@pytest.fixture(scope="module", name="base_td") +def base_td_fixture(spark_session): + mock_data = [ + Row( + acct=10000001234, + acct_seq=0, + stat_cd="*2", + open_dt=datetime.date(2017, 5, 1), + cd="0001", + ), + Row( + acct=10000001235, + acct_seq=0, + stat_cd="V1", + open_dt=datetime.date(2017, 5, 2), + cd="0002", + ), + Row( + acct=10000001236, + acct_seq=0, + stat_cd="V2", + open_dt=datetime.date(2017, 5, 3), + cd="0003", + ), + Row( + acct=10000001237, + acct_seq=0, + stat_cd="*2", + open_dt=datetime.date(2017, 5, 4), + cd="0004", + ), + Row( + acct=10000001238, + acct_seq=0, + stat_cd="*2", + open_dt=datetime.date(2017, 5, 5), + cd="0005", + ), + ] + + return spark_session.createDataFrame(mock_data) + + +@pytest.fixture(scope="module", name="compare_source") +def compare_source_fixture(spark_session): + mock_data = [ + Row( + ACCOUNT_IDENTIFIER=10000001234, + SEQ_NUMBER=0, + STATC=None, + ACCOUNT_OPEN=2017121, + CODE=1.0, + ), + Row( + ACCOUNT_IDENTIFIER=10000001235, + SEQ_NUMBER=0, + STATC="V1", + ACCOUNT_OPEN=2017122, + CODE=2.0, + ), + Row( + ACCOUNT_IDENTIFIER=10000001236, + SEQ_NUMBER=0, + STATC="V2", + ACCOUNT_OPEN=2017123, + CODE=3.0, + ), + Row( + ACCOUNT_IDENTIFIER=10000001237, + SEQ_NUMBER=0, + STATC="V3", + ACCOUNT_OPEN=2017124, + CODE=4.0, + ), + Row( + ACCOUNT_IDENTIFIER=10000001238, + SEQ_NUMBER=0, + STATC=None, + ACCOUNT_OPEN=2017125, + CODE=5.0, + ), + ] + + return spark_session.createDataFrame(mock_data) + + +@pytest.fixture(scope="module", name="base_decimal") +def base_decimal_fixture(spark_session): + mock_data = [ + Row(acct=10000001234, dollar_amt=Decimal(123.4)), + Row(acct=10000001235, dollar_amt=Decimal(0.45)), + ] + + return spark_session.createDataFrame( + mock_data, + schema=StructType( + [ + StructField("acct", LongType(), True), + StructField("dollar_amt", DecimalType(8, 2), True), + ] + ), + ) + + +@pytest.fixture(scope="module", name="compare_decimal") +def compare_decimal_fixture(spark_session): + mock_data = [ + Row(acct=10000001234, dollar_amt=123.4), + Row(acct=10000001235, dollar_amt=0.456), + ] + + return spark_session.createDataFrame(mock_data) + + +@pytest.fixture(scope="module", name="comparison_abs_tol") +def comparison_abs_tol_fixture(base_tol, compare_abs_tol, spark_session): + return LegacySparkCompare( + spark_session, + base_tol, + compare_abs_tol, + join_columns=["account_identifier"], + abs_tol=0.01, + ) + + +@pytest.fixture(scope="module", name="comparison_rel_tol") +def comparison_rel_tol_fixture(base_tol, compare_rel_tol, spark_session): + return LegacySparkCompare( + spark_session, + base_tol, + compare_rel_tol, + join_columns=["account_identifier"], + rel_tol=0.1, + ) + + +@pytest.fixture(scope="module", name="comparison_both_tol") +def comparison_both_tol_fixture(base_tol, compare_both_tol, spark_session): + return LegacySparkCompare( + spark_session, + base_tol, + compare_both_tol, + join_columns=["account_identifier"], + rel_tol=0.1, + abs_tol=0.01, + ) + + +@pytest.fixture(scope="module", name="comparison_neg_tol") +def comparison_neg_tol_fixture(base_tol, compare_both_tol, spark_session): + return LegacySparkCompare( + spark_session, + base_tol, + compare_both_tol, + join_columns=["account_identifier"], + rel_tol=-0.2, + abs_tol=0.01, + ) + + +@pytest.fixture(scope="module", name="show_all_columns_and_match_rate") +def show_all_columns_and_match_rate_fixture(base_tol, compare_both_tol, spark_session): + return LegacySparkCompare( + spark_session, + base_tol, + compare_both_tol, + join_columns=["account_identifier"], + show_all_columns=True, + match_rates=True, + ) + + +@pytest.fixture(scope="module", name="comparison_kd1") +def comparison_known_diffs1(base_td, compare_source, spark_session): + return LegacySparkCompare( + spark_session, + base_td, + compare_source, + join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")], + column_mapping=[ + ("stat_cd", "STATC"), + ("open_dt", "ACCOUNT_OPEN"), + ("cd", "CODE"), + ], + known_differences=[ + { + "name": "Left-padded, four-digit numeric code", + "types": NUMERIC_SPARK_TYPES, + "transformation": "lpad(cast({input} AS bigint), 4, '0')", + }, + { + "name": "Null to *2", + "types": ["string"], + "transformation": "case when {input} is null then '*2' else {input} end", + }, + { + "name": "Julian date -> date", + "types": ["bigint"], + "transformation": "to_date(cast(unix_timestamp(cast({input} AS string), 'yyyyDDD') AS timestamp))", + }, + ], + ) + + +@pytest.fixture(scope="module", name="comparison_kd2") +def comparison_known_diffs2(base_td, compare_source, spark_session): + return LegacySparkCompare( + spark_session, + base_td, + compare_source, + join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")], + column_mapping=[ + ("stat_cd", "STATC"), + ("open_dt", "ACCOUNT_OPEN"), + ("cd", "CODE"), + ], + known_differences=[ + { + "name": "Left-padded, four-digit numeric code", + "types": NUMERIC_SPARK_TYPES, + "transformation": "lpad(cast({input} AS bigint), 4, '0')", + }, + { + "name": "Null to *2", + "types": ["string"], + "transformation": "case when {input} is null then '*2' else {input} end", + }, + ], + ) + + +@pytest.fixture(scope="module", name="comparison1") +def comparison1_fixture(base_df1, compare_df1, spark_session): + return LegacySparkCompare( + spark_session, + base_df1, + compare_df1, + join_columns=["acct"], + cache_intermediates=CACHE_INTERMEDIATES, + ) + + +@pytest.fixture(scope="module", name="comparison2") +def comparison2_fixture(base_df1, compare_df2, spark_session): + return LegacySparkCompare( + spark_session, base_df1, compare_df2, join_columns=["acct"] + ) + + +@pytest.fixture(scope="module", name="comparison3") +def comparison3_fixture(base_df1, compare_df3, spark_session): + return LegacySparkCompare( + spark_session, + base_df1, + compare_df3, + join_columns=[("acct", "account_identifier")], + column_mapping=[ + ("dollar_amt", "dollar_amount"), + ("float_fld", "float_field"), + ("date_fld", "date_field"), + ], + cache_intermediates=CACHE_INTERMEDIATES, + ) + + +@pytest.fixture(scope="module", name="comparison4") +def comparison4_fixture(base_df2, compare_df1, spark_session): + return LegacySparkCompare( + spark_session, + base_df2, + compare_df1, + join_columns=["acct"], + column_mapping=[("super_duper_big_long_name", "name")], + ) + + +@pytest.fixture(scope="module", name="comparison_decimal") +def comparison_decimal_fixture(base_decimal, compare_decimal, spark_session): + return LegacySparkCompare( + spark_session, base_decimal, compare_decimal, join_columns=["acct"] + ) + + +def test_absolute_tolerances(comparison_abs_tol): + stdout = io.StringIO() + + comparison_abs_tol.report(file=stdout) + stdout.seek(0) + assert "****** Row Comparison ******" in stdout.getvalue() + assert "Number of rows with some columns unequal: 6" in stdout.getvalue() + assert "Number of rows with all columns equal: 7" in stdout.getvalue() + assert "Number of columns compared with some values unequal: 1" in stdout.getvalue() + assert "Number of columns compared with all values equal: 4" in stdout.getvalue() + + +def test_relative_tolerances(comparison_rel_tol): + stdout = io.StringIO() + + comparison_rel_tol.report(file=stdout) + stdout.seek(0) + assert "****** Row Comparison ******" in stdout.getvalue() + assert "Number of rows with some columns unequal: 6" in stdout.getvalue() + assert "Number of rows with all columns equal: 7" in stdout.getvalue() + assert "Number of columns compared with some values unequal: 1" in stdout.getvalue() + assert "Number of columns compared with all values equal: 4" in stdout.getvalue() + + +def test_both_tolerances(comparison_both_tol): + stdout = io.StringIO() + + comparison_both_tol.report(file=stdout) + stdout.seek(0) + assert "****** Row Comparison ******" in stdout.getvalue() + assert "Number of rows with some columns unequal: 6" in stdout.getvalue() + assert "Number of rows with all columns equal: 7" in stdout.getvalue() + assert "Number of columns compared with some values unequal: 1" in stdout.getvalue() + assert "Number of columns compared with all values equal: 4" in stdout.getvalue() + + +def test_negative_tolerances(spark_session, base_tol, compare_both_tol): + with pytest.raises(ValueError, match="Please enter positive valued tolerances"): + comp = LegacySparkCompare( + spark_session, + base_tol, + compare_both_tol, + join_columns=["account_identifier"], + rel_tol=-0.2, + abs_tol=0.01, + ) + comp.report() + pass + + +def test_show_all_columns_and_match_rate(show_all_columns_and_match_rate): + stdout = io.StringIO() + + show_all_columns_and_match_rate.report(file=stdout) + + assert "****** Columns with Equal/Unequal Values ******" in stdout.getvalue() + assert ( + "accnt_purge accnt_purge boolean boolean 13 0 100.00000" + in stdout.getvalue() + ) + assert ( + "date_field date_field date date 13 0 100.00000" + in stdout.getvalue() + ) + assert ( + "dollar_amount dollar_amount double double 3 10 23.07692" + in stdout.getvalue() + ) + assert ( + "float_field float_field double double 13 0 100.00000" + in stdout.getvalue() + ) + assert ( + "name name string string 13 0 100.00000" + in stdout.getvalue() + ) + + +def test_decimal_comparisons(): + true_decimals = ["decimal", "decimal()", "decimal(20, 10)"] + assert all(v in NUMERIC_SPARK_TYPES for v in true_decimals) + + +def test_decimal_comparator_acts_like_string(): + acc = False + for t in NUMERIC_SPARK_TYPES: + acc = acc or (len(t) > 2 and t[0:3] == "dec") + assert acc + + +def test_decimals_and_doubles_are_comparable(): + assert _is_comparable("double", "decimal(10, 2)") + + +def test_report_outputs_the_column_summary(comparison1): + stdout = io.StringIO() + + comparison1.report(file=stdout) + + assert "****** Column Summary ******" in stdout.getvalue() + assert "Number of columns in common with matching schemas: 3" in stdout.getvalue() + assert "Number of columns in common with schema differences: 1" in stdout.getvalue() + assert "Number of columns in base but not compare: 1" in stdout.getvalue() + assert "Number of columns in compare but not base: 1" in stdout.getvalue() + + +def test_report_outputs_the_column_summary_for_identical_schemas(comparison2): + stdout = io.StringIO() + + comparison2.report(file=stdout) + + assert "****** Column Summary ******" in stdout.getvalue() + assert "Number of columns in common with matching schemas: 5" in stdout.getvalue() + assert "Number of columns in common with schema differences: 0" in stdout.getvalue() + assert "Number of columns in base but not compare: 0" in stdout.getvalue() + assert "Number of columns in compare but not base: 0" in stdout.getvalue() + + +def test_report_outputs_the_column_summary_for_differently_named_columns(comparison3): + stdout = io.StringIO() + + comparison3.report(file=stdout) + + assert "****** Column Summary ******" in stdout.getvalue() + assert "Number of columns in common with matching schemas: 4" in stdout.getvalue() + assert "Number of columns in common with schema differences: 1" in stdout.getvalue() + assert "Number of columns in base but not compare: 0" in stdout.getvalue() + assert "Number of columns in compare but not base: 1" in stdout.getvalue() + + +def test_report_outputs_the_row_summary(comparison1): + stdout = io.StringIO() + + comparison1.report(file=stdout) + + assert "****** Row Summary ******" in stdout.getvalue() + assert "Number of rows in common: 4" in stdout.getvalue() + assert "Number of rows in base but not compare: 1" in stdout.getvalue() + assert "Number of rows in compare but not base: 1" in stdout.getvalue() + assert "Number of duplicate rows found in base: 0" in stdout.getvalue() + assert "Number of duplicate rows found in compare: 1" in stdout.getvalue() + + +def test_report_outputs_the_row_equality_comparison(comparison1): + stdout = io.StringIO() + + comparison1.report(file=stdout) + + assert "****** Row Comparison ******" in stdout.getvalue() + assert "Number of rows with some columns unequal: 3" in stdout.getvalue() + assert "Number of rows with all columns equal: 1" in stdout.getvalue() + + +def test_report_outputs_the_row_summary_for_differently_named_columns(comparison3): + stdout = io.StringIO() + + comparison3.report(file=stdout) + + assert "****** Row Summary ******" in stdout.getvalue() + assert "Number of rows in common: 5" in stdout.getvalue() + assert "Number of rows in base but not compare: 0" in stdout.getvalue() + assert "Number of rows in compare but not base: 0" in stdout.getvalue() + assert "Number of duplicate rows found in base: 0" in stdout.getvalue() + assert "Number of duplicate rows found in compare: 0" in stdout.getvalue() + + +def test_report_outputs_the_row_equality_comparison_for_differently_named_columns( + comparison3, +): + stdout = io.StringIO() + + comparison3.report(file=stdout) + + assert "****** Row Comparison ******" in stdout.getvalue() + assert "Number of rows with some columns unequal: 3" in stdout.getvalue() + assert "Number of rows with all columns equal: 2" in stdout.getvalue() + + +def test_report_outputs_column_detail_for_columns_in_only_one_dataframe(comparison1): + stdout = io.StringIO() + + comparison1.report(file=stdout) + comparison1.report() + assert "****** Columns In Base Only ******" in stdout.getvalue() + r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \ndate_fld \s+ date""" + assert re.search(r2, str(stdout.getvalue()), re.X) is not None + + +def test_report_outputs_column_detail_for_columns_in_only_compare_dataframe( + comparison1, +): + stdout = io.StringIO() + + comparison1.report(file=stdout) + comparison1.report() + assert "****** Columns In Compare Only ******" in stdout.getvalue() + r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \n accnt_purge \s+ boolean""" + assert re.search(r2, str(stdout.getvalue()), re.X) is not None + + +def test_report_outputs_schema_difference_details(comparison1): + stdout = io.StringIO() + + comparison1.report(file=stdout) + + assert "****** Schema Differences ******" in stdout.getvalue() + assert re.search( + r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n + -+ \s+ -+ \s+ -+ \s+ -+ \n + dollar_amt \s+ dollar_amt \s+ bigint \s+ double""", + stdout.getvalue(), + re.X, + ) + + +def test_report_outputs_schema_difference_details_for_differently_named_columns( + comparison3, +): + stdout = io.StringIO() + + comparison3.report(file=stdout) + + assert "****** Schema Differences ******" in stdout.getvalue() + assert re.search( + r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n + -+ \s+ -+ \s+ -+ \s+ -+ \n + dollar_amt \s+ dollar_amount \s+ bigint \s+ double""", + stdout.getvalue(), + re.X, + ) + + +def test_column_comparison_outputs_number_of_columns_with_differences(comparison1): + stdout = io.StringIO() + + comparison1.report(file=stdout) + + assert "****** Column Comparison ******" in stdout.getvalue() + assert "Number of columns compared with some values unequal: 3" in stdout.getvalue() + assert "Number of columns compared with all values equal: 0" in stdout.getvalue() + + +def test_column_comparison_outputs_all_columns_equal_for_identical_dataframes( + comparison2, +): + stdout = io.StringIO() + + comparison2.report(file=stdout) + + assert "****** Column Comparison ******" in stdout.getvalue() + assert "Number of columns compared with some values unequal: 0" in stdout.getvalue() + assert "Number of columns compared with all values equal: 4" in stdout.getvalue() + + +def test_column_comparison_outputs_number_of_columns_with_differences_for_differently_named_columns( + comparison3, +): + stdout = io.StringIO() + + comparison3.report(file=stdout) + + assert "****** Column Comparison ******" in stdout.getvalue() + assert "Number of columns compared with some values unequal: 3" in stdout.getvalue() + assert "Number of columns compared with all values equal: 1" in stdout.getvalue() + + +def test_column_comparison_outputs_number_of_columns_with_differences_for_known_diffs( + comparison_kd1, +): + stdout = io.StringIO() + + comparison_kd1.report(file=stdout) + + assert "****** Column Comparison ******" in stdout.getvalue() + assert ( + "Number of columns compared with unexpected differences in some values: 1" + in stdout.getvalue() + ) + assert ( + "Number of columns compared with all values equal but known differences found: 2" + in stdout.getvalue() + ) + assert ( + "Number of columns compared with all values completely equal: 0" + in stdout.getvalue() + ) + + +def test_column_comparison_outputs_number_of_columns_with_differences_for_custom_known_diffs( + comparison_kd2, +): + stdout = io.StringIO() + + comparison_kd2.report(file=stdout) + + assert "****** Column Comparison ******" in stdout.getvalue() + assert ( + "Number of columns compared with unexpected differences in some values: 2" + in stdout.getvalue() + ) + assert ( + "Number of columns compared with all values equal but known differences found: 1" + in stdout.getvalue() + ) + assert ( + "Number of columns compared with all values completely equal: 0" + in stdout.getvalue() + ) + + +def test_columns_with_unequal_values_show_mismatch_counts(comparison1): + stdout = io.StringIO() + + comparison1.report(file=stdout) + + assert "****** Columns with Unequal Values ******" in stdout.getvalue() + assert re.search( + r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s* + \#\sMatches \s* \#\sMismatches \n + -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""", + stdout.getvalue(), + re.X, + ) + assert re.search( + r"""dollar_amt \s+ dollar_amt \s+ bigint \s+ double \s+ 2 \s+ 2""", + stdout.getvalue(), + re.X, + ) + assert re.search( + r"""float_fld \s+ float_fld \s+ double \s+ double \s+ 1 \s+ 3""", + stdout.getvalue(), + re.X, + ) + assert re.search( + r"""name \s+ name \s+ string \s+ string \s+ 3 \s+ 1""", stdout.getvalue(), re.X + ) + + +def test_columns_with_different_names_with_unequal_values_show_mismatch_counts( + comparison3, +): + stdout = io.StringIO() + + comparison3.report(file=stdout) + + assert "****** Columns with Unequal Values ******" in stdout.getvalue() + assert re.search( + r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s* + \#\sMatches \s* \#\sMismatches \n + -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""", + stdout.getvalue(), + re.X, + ) + assert re.search( + r"""dollar_amt \s+ dollar_amount \s+ bigint \s+ double \s+ 2 \s+ 3""", + stdout.getvalue(), + re.X, + ) + assert re.search( + r"""float_fld \s+ float_field \s+ double \s+ double \s+ 4 \s+ 1""", + stdout.getvalue(), + re.X, + ) + assert re.search( + r"""name \s+ name \s+ string \s+ string \s+ 4 \s+ 1""", stdout.getvalue(), re.X + ) + + +def test_rows_only_base_returns_a_dataframe_with_rows_only_in_base( + spark_session, comparison1 +): + # require schema if contains only 1 row and contain field value as None + schema = StructType( + [ + StructField("acct", LongType(), True), + StructField("date_fld", DateType(), True), + StructField("dollar_amt", LongType(), True), + StructField("float_fld", DoubleType(), True), + StructField("name", StringType(), True), + ] + ) + expected_df = spark_session.createDataFrame( + [ + Row( + acct=10000001239, + date_fld=datetime.date(2017, 1, 1), + dollar_amt=1, + float_fld=None, + name="Lucille Bluth", + ) + ], + schema, + ) + assert comparison1.rows_only_base.count() == 1 + assert ( + expected_df.union( + comparison1.rows_only_base.select( + "acct", "date_fld", "dollar_amt", "float_fld", "name" + ) + ) + .distinct() + .count() + == 1 + ) + + +def test_rows_only_compare_returns_a_dataframe_with_rows_only_in_compare( + spark_session, comparison1 +): + expected_df = spark_session.createDataFrame( + [ + Row( + acct=10000001238, + dollar_amt=1.05, + name="Loose Seal Bluth", + float_fld=111.0, + accnt_purge=True, + ) + ] + ) + + assert comparison1.rows_only_compare.count() == 1 + assert expected_df.union(comparison1.rows_only_compare).distinct().count() == 1 + + +def test_rows_both_mismatch_returns_a_dataframe_with_rows_where_variables_mismatched( + spark_session, comparison1 +): + expected_df = spark_session.createDataFrame( + [ + Row( + accnt_purge=False, + acct=10000001234, + date_fld=datetime.date(2017, 1, 1), + dollar_amt_base=123, + dollar_amt_compare=123.4, + dollar_amt_match=False, + float_fld_base=14530.1555, + float_fld_compare=14530.155, + float_fld_match=False, + name_base="George Maharis", + name_compare="George Michael Bluth", + name_match=False, + ), + Row( + accnt_purge=False, + acct=10000001235, + date_fld=datetime.date(2017, 1, 1), + dollar_amt_base=0, + dollar_amt_compare=0.45, + dollar_amt_match=False, + float_fld_base=1.0, + float_fld_compare=None, + float_fld_match=False, + name_base="Michael Bluth", + name_compare="Michael Bluth", + name_match=True, + ), + Row( + accnt_purge=False, + acct=10000001236, + date_fld=datetime.date(2017, 1, 1), + dollar_amt_base=1345, + dollar_amt_compare=1345.0, + dollar_amt_match=True, + float_fld_base=None, + float_fld_compare=1.0, + float_fld_match=False, + name_base="George Bluth", + name_compare="George Bluth", + name_match=True, + ), + ] + ) + + assert comparison1.rows_both_mismatch.count() == 3 + assert expected_df.union(comparison1.rows_both_mismatch).distinct().count() == 3 + + +def test_rows_both_mismatch_only_includes_rows_with_true_mismatches_when_known_diffs_are_present( + spark_session, comparison_kd1 +): + expected_df = spark_session.createDataFrame( + [ + Row( + acct=10000001237, + acct_seq=0, + cd_base="0004", + cd_compare=4.0, + cd_match=True, + cd_match_type="KNOWN_DIFFERENCE", + open_dt_base=datetime.date(2017, 5, 4), + open_dt_compare=2017124, + open_dt_match=True, + open_dt_match_type="KNOWN_DIFFERENCE", + stat_cd_base="*2", + stat_cd_compare="V3", + stat_cd_match=False, + stat_cd_match_type="MISMATCH", + ) + ] + ) + assert comparison_kd1.rows_both_mismatch.count() == 1 + assert expected_df.union(comparison_kd1.rows_both_mismatch).distinct().count() == 1 + + +def test_rows_both_all_returns_a_dataframe_with_all_rows_in_both_dataframes( + spark_session, comparison1 +): + expected_df = spark_session.createDataFrame( + [ + Row( + accnt_purge=False, + acct=10000001234, + date_fld=datetime.date(2017, 1, 1), + dollar_amt_base=123, + dollar_amt_compare=123.4, + dollar_amt_match=False, + float_fld_base=14530.1555, + float_fld_compare=14530.155, + float_fld_match=False, + name_base="George Maharis", + name_compare="George Michael Bluth", + name_match=False, + ), + Row( + accnt_purge=False, + acct=10000001235, + date_fld=datetime.date(2017, 1, 1), + dollar_amt_base=0, + dollar_amt_compare=0.45, + dollar_amt_match=False, + float_fld_base=1.0, + float_fld_compare=None, + float_fld_match=False, + name_base="Michael Bluth", + name_compare="Michael Bluth", + name_match=True, + ), + Row( + accnt_purge=False, + acct=10000001236, + date_fld=datetime.date(2017, 1, 1), + dollar_amt_base=1345, + dollar_amt_compare=1345.0, + dollar_amt_match=True, + float_fld_base=None, + float_fld_compare=1.0, + float_fld_match=False, + name_base="George Bluth", + name_compare="George Bluth", + name_match=True, + ), + Row( + accnt_purge=False, + acct=10000001237, + date_fld=datetime.date(2017, 1, 1), + dollar_amt_base=123456, + dollar_amt_compare=123456.0, + dollar_amt_match=True, + float_fld_base=345.12, + float_fld_compare=345.12, + float_fld_match=True, + name_base="Bob Loblaw", + name_compare="Bob Loblaw", + name_match=True, + ), + ] + ) + + assert comparison1.rows_both_all.count() == 4 + assert expected_df.union(comparison1.rows_both_all).distinct().count() == 4 + + +def test_rows_both_all_shows_known_diffs_flag_and_known_diffs_count_as_matches( + spark_session, comparison_kd1 +): + expected_df = spark_session.createDataFrame( + [ + Row( + acct=10000001234, + acct_seq=0, + cd_base="0001", + cd_compare=1.0, + cd_match=True, + cd_match_type="KNOWN_DIFFERENCE", + open_dt_base=datetime.date(2017, 5, 1), + open_dt_compare=2017121, + open_dt_match=True, + open_dt_match_type="KNOWN_DIFFERENCE", + stat_cd_base="*2", + stat_cd_compare=None, + stat_cd_match=True, + stat_cd_match_type="KNOWN_DIFFERENCE", + ), + Row( + acct=10000001235, + acct_seq=0, + cd_base="0002", + cd_compare=2.0, + cd_match=True, + cd_match_type="KNOWN_DIFFERENCE", + open_dt_base=datetime.date(2017, 5, 2), + open_dt_compare=2017122, + open_dt_match=True, + open_dt_match_type="KNOWN_DIFFERENCE", + stat_cd_base="V1", + stat_cd_compare="V1", + stat_cd_match=True, + stat_cd_match_type="MATCH", + ), + Row( + acct=10000001236, + acct_seq=0, + cd_base="0003", + cd_compare=3.0, + cd_match=True, + cd_match_type="KNOWN_DIFFERENCE", + open_dt_base=datetime.date(2017, 5, 3), + open_dt_compare=2017123, + open_dt_match=True, + open_dt_match_type="KNOWN_DIFFERENCE", + stat_cd_base="V2", + stat_cd_compare="V2", + stat_cd_match=True, + stat_cd_match_type="MATCH", + ), + Row( + acct=10000001237, + acct_seq=0, + cd_base="0004", + cd_compare=4.0, + cd_match=True, + cd_match_type="KNOWN_DIFFERENCE", + open_dt_base=datetime.date(2017, 5, 4), + open_dt_compare=2017124, + open_dt_match=True, + open_dt_match_type="KNOWN_DIFFERENCE", + stat_cd_base="*2", + stat_cd_compare="V3", + stat_cd_match=False, + stat_cd_match_type="MISMATCH", + ), + Row( + acct=10000001238, + acct_seq=0, + cd_base="0005", + cd_compare=5.0, + cd_match=True, + cd_match_type="KNOWN_DIFFERENCE", + open_dt_base=datetime.date(2017, 5, 5), + open_dt_compare=2017125, + open_dt_match=True, + open_dt_match_type="KNOWN_DIFFERENCE", + stat_cd_base="*2", + stat_cd_compare=None, + stat_cd_match=True, + stat_cd_match_type="KNOWN_DIFFERENCE", + ), + ] + ) + + assert comparison_kd1.rows_both_all.count() == 5 + assert expected_df.union(comparison_kd1.rows_both_all).distinct().count() == 5 + + +def test_rows_both_all_returns_a_dataframe_with_all_rows_in_identical_dataframes( + spark_session, comparison2 +): + expected_df = spark_session.createDataFrame( + [ + Row( + acct=10000001234, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=123, + dollar_amt_compare=123, + dollar_amt_match=True, + float_fld_base=14530.1555, + float_fld_compare=14530.1555, + float_fld_match=True, + name_base="George Maharis", + name_compare="George Maharis", + name_match=True, + ), + Row( + acct=10000001235, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=0, + dollar_amt_compare=0, + dollar_amt_match=True, + float_fld_base=1.0, + float_fld_compare=1.0, + float_fld_match=True, + name_base="Michael Bluth", + name_compare="Michael Bluth", + name_match=True, + ), + Row( + acct=10000001236, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=1345, + dollar_amt_compare=1345, + dollar_amt_match=True, + float_fld_base=None, + float_fld_compare=None, + float_fld_match=True, + name_base="George Bluth", + name_compare="George Bluth", + name_match=True, + ), + Row( + acct=10000001237, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=123456, + dollar_amt_compare=123456, + dollar_amt_match=True, + float_fld_base=345.12, + float_fld_compare=345.12, + float_fld_match=True, + name_base="Bob Loblaw", + name_compare="Bob Loblaw", + name_match=True, + ), + Row( + acct=10000001239, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=1, + dollar_amt_compare=1, + dollar_amt_match=True, + float_fld_base=None, + float_fld_compare=None, + float_fld_match=True, + name_base="Lucille Bluth", + name_compare="Lucille Bluth", + name_match=True, + ), + ] + ) + + assert comparison2.rows_both_all.count() == 5 + assert expected_df.union(comparison2.rows_both_all).distinct().count() == 5 + + +def test_rows_both_all_returns_all_rows_in_both_dataframes_for_differently_named_columns( + spark_session, comparison3 +): + expected_df = spark_session.createDataFrame( + [ + Row( + accnt_purge=False, + acct=10000001234, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=123, + dollar_amt_compare=123.4, + dollar_amt_match=False, + float_fld_base=14530.1555, + float_fld_compare=14530.155, + float_fld_match=False, + name_base="George Maharis", + name_compare="George Michael Bluth", + name_match=False, + ), + Row( + accnt_purge=False, + acct=10000001235, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=0, + dollar_amt_compare=0.45, + dollar_amt_match=False, + float_fld_base=1.0, + float_fld_compare=1.0, + float_fld_match=True, + name_base="Michael Bluth", + name_compare="Michael Bluth", + name_match=True, + ), + Row( + accnt_purge=False, + acct=10000001236, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=1345, + dollar_amt_compare=1345.0, + dollar_amt_match=True, + float_fld_base=None, + float_fld_compare=None, + float_fld_match=True, + name_base="George Bluth", + name_compare="George Bluth", + name_match=True, + ), + Row( + accnt_purge=False, + acct=10000001237, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=123456, + dollar_amt_compare=123456.0, + dollar_amt_match=True, + float_fld_base=345.12, + float_fld_compare=345.12, + float_fld_match=True, + name_base="Bob Loblaw", + name_compare="Bob Loblaw", + name_match=True, + ), + Row( + accnt_purge=True, + acct=10000001239, + date_fld_base=datetime.date(2017, 1, 1), + date_fld_compare=datetime.date(2017, 1, 1), + date_fld_match=True, + dollar_amt_base=1, + dollar_amt_compare=1.05, + dollar_amt_match=False, + float_fld_base=None, + float_fld_compare=None, + float_fld_match=True, + name_base="Lucille Bluth", + name_compare="Lucille Bluth", + name_match=True, + ), + ] + ) + + assert comparison3.rows_both_all.count() == 5 + assert expected_df.union(comparison3.rows_both_all).distinct().count() == 5 + + +def test_columns_with_unequal_values_text_is_aligned(comparison4): + stdout = io.StringIO() + + comparison4.report(file=stdout) + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Columns with Unequal Values ******", + section_end="\n", + left_indices=(1, 2, 3, 4), + right_indices=(5, 6), + column_regexes=[ + r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ + (\#\sMatches) \s+ (\#\sMismatches)""", + r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", + r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double) \s+ (2) \s+ (2)""", + r"""(float_fld) \s+ (float_fld) \s+ (double) \s+ (double) \s+ (1) \s+ (3)""", + r"""(super_duper_big_long_name) \s+ (name) \s+ (string) \s+ (string) \s+ (3) \s+ (1)\s*""", + ], + ) + + +def test_columns_with_unequal_values_text_is_aligned_with_known_differences( + comparison_kd1, +): + stdout = io.StringIO() + + comparison_kd1.report(file=stdout) + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Columns with Unequal Values ******", + section_end="\n", + left_indices=(1, 2, 3, 4), + right_indices=(5, 6, 7), + column_regexes=[ + r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ + (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""", + r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", + r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""", + r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (5) \s+ (0)""", + r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""", + ], + ) + + +def test_columns_with_unequal_values_text_is_aligned_with_custom_known_differences( + comparison_kd2, +): + stdout = io.StringIO() + + comparison_kd2.report(file=stdout) + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Columns with Unequal Values ******", + section_end="\n", + left_indices=(1, 2, 3, 4), + right_indices=(5, 6, 7), + column_regexes=[ + r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ + (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""", + r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", + r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""", + r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (0) \s+ (5)""", + r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""", + ], + ) + + +def test_columns_with_unequal_values_text_is_aligned_for_decimals(comparison_decimal): + stdout = io.StringIO() + + comparison_decimal.report(file=stdout) + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Columns with Unequal Values ******", + section_end="\n", + left_indices=(1, 2, 3, 4), + right_indices=(5, 6), + column_regexes=[ + r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ + (\#\sMatches) \s+ (\#\sMismatches)""", + r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", + r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double) \s+ (1) \s+ (1)""", + ], + ) + + +def test_schema_differences_text_is_aligned(comparison4): + stdout = io.StringIO() + + comparison4.report(file=stdout) + comparison4.report() + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Schema Differences ******", + section_end="\n", + left_indices=(1, 2, 3, 4), + right_indices=(), + column_regexes=[ + r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""", + r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""", + r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double)""", + ], + ) + + +def test_schema_differences_text_is_aligned_for_decimals(comparison_decimal): + stdout = io.StringIO() + + comparison_decimal.report(file=stdout) + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Schema Differences ******", + section_end="\n", + left_indices=(1, 2, 3, 4), + right_indices=(), + column_regexes=[ + r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""", + r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""", + r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double)""", + ], + ) + + +def test_base_only_columns_text_is_aligned(comparison4): + stdout = io.StringIO() + + comparison4.report(file=stdout) + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Columns In Base Only ******", + section_end="\n", + left_indices=(1, 2), + right_indices=(), + column_regexes=[ + r"""(Column\sName) \s+ (Dtype)""", + r"""(-+) \s+ (-+)""", + r"""(date_fld) \s+ (date)""", + ], + ) + + +def test_compare_only_columns_text_is_aligned(comparison4): + stdout = io.StringIO() + + comparison4.report(file=stdout) + stdout.seek(0) # Back up to the beginning of the stream + + text_alignment_validator( + report=stdout, + section_start="****** Columns In Compare Only ******", + section_end="\n", + left_indices=(1, 2), + right_indices=(), + column_regexes=[ + r"""(Column\sName) \s+ (Dtype)""", + r"""(-+) \s+ (-+)""", + r"""(accnt_purge) \s+ (boolean)""", + ], + ) + + +def text_alignment_validator( + report, section_start, section_end, left_indices, right_indices, column_regexes +): + r"""Check to make sure that report output columns are vertically aligned. + + Parameters + ---------- + report: An iterable returning lines of report output to be validated. + section_start: A string that represents the beginning of the section to be validated. + section_end: A string that represents the end of the section to be validated. + left_indices: The match group indexes (starting with 1) that should be left-aligned + in the output column. + right_indices: The match group indexes (starting with 1) that should be right-aligned + in the output column. + column_regexes: A list of regular expressions representing the expected output, with + each column enclosed with parentheses to return a match. The regular expression will + use the "X" flag, so it may contain whitespace, and any whitespace to be matched + should be explicitly given with \s. The first line will represent the alignments + that are expected in the following lines. The number of match groups should cover + all of the indices given in left/right_indices. + + Runs assertions for every match group specified by left/right_indices to ensure that + all lines past the first are either left- or right-aligned with the same match group + on the first line. + """ + + at_column_section = False + processed_first_line = False + match_positions = [None] * (len(left_indices + right_indices) + 1) + + for line in report: + if at_column_section: + if line == section_end: # Detect end of section and stop + break + + if ( + not processed_first_line + ): # First line in section - capture text start/end positions + matches = re.search(column_regexes[0], line, re.X) + assert matches is not None # Make sure we found at least this... + + for n in left_indices: + match_positions[n] = matches.start(n) + for n in right_indices: + match_positions[n] = matches.end(n) + processed_first_line = True + else: # Match the stuff after the header text + match = None + for regex in column_regexes[1:]: + match = re.search(regex, line, re.X) + if match: + break + + if not match: + raise AssertionError(f'Did not find a match for line: "{line}"') + + for n in left_indices: + assert match_positions[n] == match.start(n) + for n in right_indices: + assert match_positions[n] == match.end(n) + + if not at_column_section and section_start in line: + at_column_section = True + + +def test_unicode_columns(spark_session): + df1 = spark_session.createDataFrame( + [ + (1, "foo", "test"), + (2, "bar", "test"), + ], + ["id", "例", "予測対象日"], + ) + df2 = spark_session.createDataFrame( + [ + (1, "foo", "test"), + (2, "baz", "test"), + ], + ["id", "例", "予測対象日"], + ) + compare = LegacySparkCompare(spark_session, df1, df2, join_columns=["例"]) + # Just render the report to make sure it renders. + compare.report() diff --git a/tests/test_spark.py b/tests/test_spark.py index af8aa8f3..88acc9a0 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -13,2103 +13,1311 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import io +""" +Testing out the datacompy functionality +""" + import logging import re +import sys +from datetime import datetime from decimal import Decimal +from io import StringIO +from unittest import mock +import numpy as np +import pandas as pd import pytest +from pytest import raises pytest.importorskip("pyspark") -from pyspark.sql import Row, SparkSession -from pyspark.sql.types import ( - DateType, - DecimalType, - DoubleType, - LongType, - StringType, - StructField, - StructType, -) +import pyspark.pandas as ps # noqa: E402 +from pandas.testing import assert_series_equal # noqa: E402 -import datacompy -from datacompy import SparkCompare -from datacompy.spark import _is_comparable +from datacompy.spark import ( # noqa: E402 + SparkCompare, + calculate_max_diff, + columns_equal, + generate_id_within_group, + temp_column_name, +) -# Turn off py4j debug messages for all tests in this module -logging.getLogger("py4j").setLevel(logging.INFO) +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) -CACHE_INTERMEDIATES = True +pandas_version = pytest.mark.skipif( + pd.__version__ >= "2.0.0", reason="Pandas 2 is currently not supported" +) -# Declare fixtures -# (if we need to use these in other modules, move to conftest.py) -@pytest.fixture(scope="module", name="spark") -def spark_fixture(): - spark = ( - SparkSession.builder.master("local[2]") - .config("spark.driver.bindAddress", "127.0.0.1") - .appName("pytest") - .getOrCreate() +pd.DataFrame.iteritems = pd.DataFrame.items # Pandas 2+ compatability +np.bool = np.bool_ # Numpy 1.24.3+ comptability + + +@pandas_version +def test_numeric_columns_equal_abs(): + data = """a|b|expected +1|1|True +2|2.1|True +3|4|False +4|NULL|False +NULL|4|False +NULL|NULL|True""" + + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal(df.a, df.b, abs_tol=0.2) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - yield spark - spark.stop() - - -@pytest.fixture(scope="module", name="base_df1") -def base_df1_fixture(spark): - mock_data = [ - Row( - acct=10000001234, - dollar_amt=123, - name="George Maharis", - float_fld=14530.1555, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001235, - dollar_amt=0, - name="Michael Bluth", - float_fld=1.0, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001236, - dollar_amt=1345, - name="George Bluth", - float_fld=None, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001237, - dollar_amt=123456, - name="Bob Loblaw", - float_fld=345.12, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001239, - dollar_amt=1, - name="Lucille Bluth", - float_fld=None, - date_fld=datetime.date(2017, 1, 1), - ), - ] - - return spark.createDataFrame(mock_data) - - -@pytest.fixture(scope="module", name="base_df2") -def base_df2_fixture(spark): - mock_data = [ - Row( - acct=10000001234, - dollar_amt=123, - super_duper_big_long_name="George Maharis", - float_fld=14530.1555, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001235, - dollar_amt=0, - super_duper_big_long_name="Michael Bluth", - float_fld=1.0, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001236, - dollar_amt=1345, - super_duper_big_long_name="George Bluth", - float_fld=None, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001237, - dollar_amt=123456, - super_duper_big_long_name="Bob Loblaw", - float_fld=345.12, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001239, - dollar_amt=1, - super_duper_big_long_name="Lucille Bluth", - float_fld=None, - date_fld=datetime.date(2017, 1, 1), - ), - ] - - return spark.createDataFrame(mock_data) - - -@pytest.fixture(scope="module", name="compare_df1") -def compare_df1_fixture(spark): - mock_data2 = [ - Row( - acct=10000001234, - dollar_amt=123.4, - name="George Michael Bluth", - float_fld=14530.155, - accnt_purge=False, - ), - Row( - acct=10000001235, - dollar_amt=0.45, - name="Michael Bluth", - float_fld=None, - accnt_purge=False, - ), - Row( - acct=10000001236, - dollar_amt=1345.0, - name="George Bluth", - float_fld=1.0, - accnt_purge=False, - ), - Row( - acct=10000001237, - dollar_amt=123456.0, - name="Bob Loblaw", - float_fld=345.12, - accnt_purge=False, - ), - Row( - acct=10000001238, - dollar_amt=1.05, - name="Loose Seal Bluth", - float_fld=111.0, - accnt_purge=True, - ), - Row( - acct=10000001238, - dollar_amt=1.05, - name="Loose Seal Bluth", - float_fld=111.0, - accnt_purge=True, - ), - ] - - return spark.createDataFrame(mock_data2) - - -@pytest.fixture(scope="module", name="compare_df2") -def compare_df2_fixture(spark): - mock_data = [ - Row( - acct=10000001234, - dollar_amt=123, - name="George Maharis", - float_fld=14530.1555, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001235, - dollar_amt=0, - name="Michael Bluth", - float_fld=1.0, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001236, - dollar_amt=1345, - name="George Bluth", - float_fld=None, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001237, - dollar_amt=123456, - name="Bob Loblaw", - float_fld=345.12, - date_fld=datetime.date(2017, 1, 1), - ), - Row( - acct=10000001239, - dollar_amt=1, - name="Lucille Bluth", - float_fld=None, - date_fld=datetime.date(2017, 1, 1), - ), - ] - - return spark.createDataFrame(mock_data) - - -@pytest.fixture(scope="module", name="compare_df3") -def compare_df3_fixture(spark): - mock_data2 = [ - Row( - account_identifier=10000001234, - dollar_amount=123.4, - name="George Michael Bluth", - float_field=14530.155, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001235, - dollar_amount=0.45, - name="Michael Bluth", - float_field=1.0, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001236, - dollar_amount=1345.0, - name="George Bluth", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001237, - dollar_amount=123456.0, - name="Bob Loblaw", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001239, - dollar_amount=1.05, - name="Lucille Bluth", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - ] - return spark.createDataFrame(mock_data2) - -@pytest.fixture(scope="module", name="base_tol") -def base_tol_fixture(spark): - tol_data1 = [ - Row( - account_identifier=10000001234, - dollar_amount=123.4, - name="Franklin Delano Bluth", - float_field=14530.155, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001235, - dollar_amount=500.0, - name="Surely Funke", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - Row( - account_identifier=10000001236, - dollar_amount=-1100.0, - name="Nichael Bluth", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - Row( - account_identifier=10000001237, - dollar_amount=0.45, - name="Mr. F", - float_field=1.0, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001238, - dollar_amount=1345.0, - name="Steve Holt!", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001239, - dollar_amount=123456.0, - name="Blue Man Group", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001240, - dollar_amount=1.1, - name="Her?", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - Row( - account_identifier=10000001241, - dollar_amount=0.0, - name="Mrs. Featherbottom", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - Row( - account_identifier=10000001242, - dollar_amount=0.0, - name="Ice", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), - Row( - account_identifier=10000001243, - dollar_amount=-10.0, - name="Frank Wrench", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - Row( - account_identifier=10000001244, - dollar_amount=None, - name="Lucille 2", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - Row( - account_identifier=10000001245, - dollar_amount=0.009999, - name="Gene Parmesan", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - Row( - account_identifier=10000001246, - dollar_amount=None, - name="Motherboy", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), - ] - - return spark.createDataFrame(tol_data1) - - -@pytest.fixture(scope="module", name="compare_abs_tol") -def compare_tol2_fixture(spark): - tol_data2 = [ - Row( - account_identifier=10000001234, - dollar_amount=123.4, - name="Franklin Delano Bluth", - float_field=14530.155, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # full match - Row( - account_identifier=10000001235, - dollar_amount=500.01, - name="Surely Funke", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by 0.01 - Row( - account_identifier=10000001236, - dollar_amount=-1100.01, - name="Nichael Bluth", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by -0.01 - Row( - account_identifier=10000001237, - dollar_amount=0.46000000001, - name="Mr. F", - float_field=1.0, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by 0.01000000001 - Row( - account_identifier=10000001238, - dollar_amount=1344.8999999999, - name="Steve Holt!", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by -0.01000000001 - Row( - account_identifier=10000001239, - dollar_amount=123456.0099999999, - name="Blue Man Group", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by 0.00999999999 - Row( - account_identifier=10000001240, - dollar_amount=1.090000001, - name="Her?", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by -0.00999999999 - Row( - account_identifier=10000001241, - dollar_amount=0.0, - name="Mrs. Featherbottom", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # both zero - Row( - account_identifier=10000001242, - dollar_amount=1.0, - name="Ice", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # base 0, compare 1 - Row( - account_identifier=10000001243, - dollar_amount=0.0, - name="Frank Wrench", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base -10, compare 0 - Row( - account_identifier=10000001244, - dollar_amount=-1.0, - name="Lucille 2", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base NULL, compare -1 - Row( - account_identifier=10000001245, - dollar_amount=None, - name="Gene Parmesan", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base 0.009999, compare NULL - Row( - account_identifier=10000001246, - dollar_amount=None, - name="Motherboy", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # both NULL - ] - - return spark.createDataFrame(tol_data2) - - -@pytest.fixture(scope="module", name="compare_rel_tol") -def compare_tol3_fixture(spark): - tol_data3 = [ - Row( - account_identifier=10000001234, - dollar_amount=123.4, - name="Franklin Delano Bluth", - float_field=14530.155, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # full match #MATCH - Row( - account_identifier=10000001235, - dollar_amount=550.0, - name="Surely Funke", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by 10% #MATCH - Row( - account_identifier=10000001236, - dollar_amount=-1000.0, - name="Nichael Bluth", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by -10% #MATCH - Row( - account_identifier=10000001237, - dollar_amount=0.49501, - name="Mr. F", - float_field=1.0, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by greater than 10% - Row( - account_identifier=10000001238, - dollar_amount=1210.001, - name="Steve Holt!", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by greater than -10% - Row( - account_identifier=10000001239, - dollar_amount=135801.59999, - name="Blue Man Group", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by just under 10% #MATCH - Row( - account_identifier=10000001240, - dollar_amount=1.000001, - name="Her?", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by just under -10% #MATCH - Row( - account_identifier=10000001241, - dollar_amount=0.0, - name="Mrs. Featherbottom", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # both zero #MATCH - Row( - account_identifier=10000001242, - dollar_amount=1.0, - name="Ice", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # base 0, compare 1 - Row( - account_identifier=10000001243, - dollar_amount=0.0, - name="Frank Wrench", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base -10, compare 0 - Row( - account_identifier=10000001244, - dollar_amount=-1.0, - name="Lucille 2", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base NULL, compare -1 - Row( - account_identifier=10000001245, - dollar_amount=None, - name="Gene Parmesan", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base 0.009999, compare NULL - Row( - account_identifier=10000001246, - dollar_amount=None, - name="Motherboy", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # both NULL #MATCH - ] - - return spark.createDataFrame(tol_data3) - - -@pytest.fixture(scope="module", name="compare_both_tol") -def compare_tol4_fixture(spark): - tol_data4 = [ - Row( - account_identifier=10000001234, - dollar_amount=123.4, - name="Franklin Delano Bluth", - float_field=14530.155, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # full match - Row( - account_identifier=10000001235, - dollar_amount=550.01, - name="Surely Funke", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by 10% and +0.01 - Row( - account_identifier=10000001236, - dollar_amount=-1000.01, - name="Nichael Bluth", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by -10% and -0.01 - Row( - account_identifier=10000001237, - dollar_amount=0.505000000001, - name="Mr. F", - float_field=1.0, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by greater than 10% and +0.01 - Row( - account_identifier=10000001238, - dollar_amount=1209.98999, - name="Steve Holt!", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by greater than -10% and -0.01 - Row( - account_identifier=10000001239, - dollar_amount=135801.609999, - name="Blue Man Group", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # off by just under 10% and just under +0.01 - Row( - account_identifier=10000001240, - dollar_amount=0.99000001, - name="Her?", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # off by just under -10% and just under -0.01 - Row( - account_identifier=10000001241, - dollar_amount=0.0, - name="Mrs. Featherbottom", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # both zero - Row( - account_identifier=10000001242, - dollar_amount=1.0, - name="Ice", - float_field=345.12, - date_field=datetime.date(2017, 1, 1), - accnt_purge=False, - ), # base 0, compare 1 - Row( - account_identifier=10000001243, - dollar_amount=0.0, - name="Frank Wrench", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base -10, compare 0 - Row( - account_identifier=10000001244, - dollar_amount=-1.0, - name="Lucille 2", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base NULL, compare -1 - Row( - account_identifier=10000001245, - dollar_amount=None, - name="Gene Parmesan", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # base 0.009999, compare NULL - Row( - account_identifier=10000001246, - dollar_amount=None, - name="Motherboy", - float_field=None, - date_field=datetime.date(2017, 1, 1), - accnt_purge=True, - ), # both NULL - ] - - return spark.createDataFrame(tol_data4) - - -@pytest.fixture(scope="module", name="base_td") -def base_td_fixture(spark): - mock_data = [ - Row( - acct=10000001234, - acct_seq=0, - stat_cd="*2", - open_dt=datetime.date(2017, 5, 1), - cd="0001", - ), - Row( - acct=10000001235, - acct_seq=0, - stat_cd="V1", - open_dt=datetime.date(2017, 5, 2), - cd="0002", - ), - Row( - acct=10000001236, - acct_seq=0, - stat_cd="V2", - open_dt=datetime.date(2017, 5, 3), - cd="0003", - ), - Row( - acct=10000001237, - acct_seq=0, - stat_cd="*2", - open_dt=datetime.date(2017, 5, 4), - cd="0004", - ), - Row( - acct=10000001238, - acct_seq=0, - stat_cd="*2", - open_dt=datetime.date(2017, 5, 5), - cd="0005", - ), - ] - - return spark.createDataFrame(mock_data) - - -@pytest.fixture(scope="module", name="compare_source") -def compare_source_fixture(spark): - mock_data = [ - Row( - ACCOUNT_IDENTIFIER=10000001234, - SEQ_NUMBER=0, - STATC=None, - ACCOUNT_OPEN=2017121, - CODE=1.0, - ), - Row( - ACCOUNT_IDENTIFIER=10000001235, - SEQ_NUMBER=0, - STATC="V1", - ACCOUNT_OPEN=2017122, - CODE=2.0, - ), - Row( - ACCOUNT_IDENTIFIER=10000001236, - SEQ_NUMBER=0, - STATC="V2", - ACCOUNT_OPEN=2017123, - CODE=3.0, - ), - Row( - ACCOUNT_IDENTIFIER=10000001237, - SEQ_NUMBER=0, - STATC="V3", - ACCOUNT_OPEN=2017124, - CODE=4.0, - ), - Row( - ACCOUNT_IDENTIFIER=10000001238, - SEQ_NUMBER=0, - STATC=None, - ACCOUNT_OPEN=2017125, - CODE=5.0, - ), - ] - - return spark.createDataFrame(mock_data) - - -@pytest.fixture(scope="module", name="base_decimal") -def base_decimal_fixture(spark): - mock_data = [ - Row(acct=10000001234, dollar_amt=Decimal(123.4)), - Row(acct=10000001235, dollar_amt=Decimal(0.45)), - ] - - return spark.createDataFrame( - mock_data, - schema=StructType( - [ - StructField("acct", LongType(), True), - StructField("dollar_amt", DecimalType(8, 2), True), - ] - ), +@pandas_version +def test_numeric_columns_equal_rel(): + data = """a|b|expected +1|1|True +2|2.1|True +3|4|False +4|NULL|False +NULL|4|False +NULL|NULL|True""" + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal(df.a, df.b, rel_tol=0.2) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) -@pytest.fixture(scope="module", name="compare_decimal") -def compare_decimal_fixture(spark): - mock_data = [ - Row(acct=10000001234, dollar_amt=123.4), - Row(acct=10000001235, dollar_amt=0.456), - ] - - return spark.createDataFrame(mock_data) - - -@pytest.fixture(scope="module", name="comparison_abs_tol") -def comparison_abs_tol_fixture(base_tol, compare_abs_tol, spark): - return SparkCompare( - spark, - base_tol, - compare_abs_tol, - join_columns=["account_identifier"], - abs_tol=0.01, +@pandas_version +def test_string_columns_equal(): + data = """a|b|expected +Hi|Hi|True +Yo|Yo|True +Hey|Hey |False +résumé|resume|False +résumé|résumé|True +💩|💩|True +💩|🤔|False + | |True + | |False +datacompy|DataComPy|False +something||False +|something|False +||True""" + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal(df.a, df.b, rel_tol=0.2) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) -@pytest.fixture(scope="module", name="comparison_rel_tol") -def comparison_rel_tol_fixture(base_tol, compare_rel_tol, spark): - return SparkCompare( - spark, - base_tol, - compare_rel_tol, - join_columns=["account_identifier"], - rel_tol=0.1, +@pandas_version +def test_string_columns_equal_with_ignore_spaces(): + data = """a|b|expected +Hi|Hi|True +Yo|Yo|True +Hey|Hey |True +résumé|resume|False +résumé|résumé|True +💩|💩|True +💩|🤔|False + | |True + | |True +datacompy|DataComPy|False +something||False +|something|False +||True""" + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) -@pytest.fixture(scope="module", name="comparison_both_tol") -def comparison_both_tol_fixture(base_tol, compare_both_tol, spark): - return SparkCompare( - spark, - base_tol, - compare_both_tol, - join_columns=["account_identifier"], - rel_tol=0.1, - abs_tol=0.01, +@pandas_version +def test_string_columns_equal_with_ignore_spaces_and_case(): + data = """a|b|expected +Hi|Hi|True +Yo|Yo|True +Hey|Hey |True +résumé|resume|False +résumé|résumé|True +💩|💩|True +💩|🤔|False + | |True + | |True +datacompy|DataComPy|True +something||False +|something|False +||True""" + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + actual_out = columns_equal( + df.a, df.b, rel_tol=0.2, ignore_spaces=True, ignore_case=True ) - - -@pytest.fixture(scope="module", name="comparison_neg_tol") -def comparison_neg_tol_fixture(base_tol, compare_both_tol, spark): - return SparkCompare( - spark, - base_tol, - compare_both_tol, - join_columns=["account_identifier"], - rel_tol=-0.2, - abs_tol=0.01, + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) -@pytest.fixture(scope="module", name="show_all_columns_and_match_rate") -def show_all_columns_and_match_rate_fixture(base_tol, compare_both_tol, spark): - return SparkCompare( - spark, - base_tol, - compare_both_tol, - join_columns=["account_identifier"], - show_all_columns=True, - match_rates=True, +@pandas_version +def test_date_columns_equal(tmp_path): + data = """a|b|expected +2017-01-01|2017-01-01|True +2017-01-02|2017-01-02|True +2017-10-01|2017-10-10|False +2017-01-01||False +|2017-01-01|False +||True""" + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + # First compare just the strings + actual_out = columns_equal(df.a, df.b, rel_tol=0.2) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - -@pytest.fixture(scope="module", name="comparison_kd1") -def comparison_known_diffs1(base_td, compare_source, spark): - return SparkCompare( - spark, - base_td, - compare_source, - join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")], - column_mapping=[ - ("stat_cd", "STATC"), - ("open_dt", "ACCOUNT_OPEN"), - ("cd", "CODE"), - ], - known_differences=[ - { - "name": "Left-padded, four-digit numeric code", - "types": datacompy.NUMERIC_SPARK_TYPES, - "transformation": "lpad(cast({input} AS bigint), 4, '0')", - }, - { - "name": "Null to *2", - "types": ["string"], - "transformation": "case when {input} is null then '*2' else {input} end", - }, - { - "name": "Julian date -> date", - "types": ["bigint"], - "transformation": "to_date(cast(unix_timestamp(cast({input} AS string), 'yyyyDDD') AS timestamp))", - }, - ], + # Then compare converted to datetime objects + df["a"] = ps.to_datetime(df["a"]) + df["b"] = ps.to_datetime(df["b"]) + actual_out = columns_equal(df.a, df.b, rel_tol=0.2) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - - -@pytest.fixture(scope="module", name="comparison_kd2") -def comparison_known_diffs2(base_td, compare_source, spark): - return SparkCompare( - spark, - base_td, - compare_source, - join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")], - column_mapping=[ - ("stat_cd", "STATC"), - ("open_dt", "ACCOUNT_OPEN"), - ("cd", "CODE"), - ], - known_differences=[ - { - "name": "Left-padded, four-digit numeric code", - "types": datacompy.NUMERIC_SPARK_TYPES, - "transformation": "lpad(cast({input} AS bigint), 4, '0')", - }, - { - "name": "Null to *2", - "types": ["string"], - "transformation": "case when {input} is null then '*2' else {input} end", - }, - ], + # and reverse + actual_out_rev = columns_equal(df.b, df.a, rel_tol=0.2) + assert_series_equal( + expect_out.to_pandas(), actual_out_rev.to_pandas(), check_names=False ) -@pytest.fixture(scope="module", name="comparison1") -def comparison1_fixture(base_df1, compare_df1, spark): - return SparkCompare( - spark, - base_df1, - compare_df1, - join_columns=["acct"], - cache_intermediates=CACHE_INTERMEDIATES, +@pandas_version +def test_date_columns_equal_with_ignore_spaces(tmp_path): + data = """a|b|expected +2017-01-01|2017-01-01 |True +2017-01-02 |2017-01-02|True +2017-10-01 |2017-10-10 |False +2017-01-01||False +|2017-01-01|False +||True""" + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + # First compare just the strings + actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - -@pytest.fixture(scope="module", name="comparison2") -def comparison2_fixture(base_df1, compare_df2, spark): - return SparkCompare(spark, base_df1, compare_df2, join_columns=["acct"]) - - -@pytest.fixture(scope="module", name="comparison3") -def comparison3_fixture(base_df1, compare_df3, spark): - return SparkCompare( - spark, - base_df1, - compare_df3, - join_columns=[("acct", "account_identifier")], - column_mapping=[ - ("dollar_amt", "dollar_amount"), - ("float_fld", "float_field"), - ("date_fld", "date_field"), - ], - cache_intermediates=CACHE_INTERMEDIATES, + # Then compare converted to datetime objects + df["a"] = ps.to_datetime(df["a"], errors="coerce") + df["b"] = ps.to_datetime(df["b"], errors="coerce") + actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - - -@pytest.fixture(scope="module", name="comparison4") -def comparison4_fixture(base_df2, compare_df1, spark): - return SparkCompare( - spark, - base_df2, - compare_df1, - join_columns=["acct"], - column_mapping=[("super_duper_big_long_name", "name")], + # and reverse + actual_out_rev = columns_equal(df.b, df.a, rel_tol=0.2, ignore_spaces=True) + assert_series_equal( + expect_out.to_pandas(), actual_out_rev.to_pandas(), check_names=False ) -@pytest.fixture(scope="module", name="comparison_decimal") -def comparison_decimal_fixture(base_decimal, compare_decimal, spark): - return SparkCompare(spark, base_decimal, compare_decimal, join_columns=["acct"]) - - -def test_absolute_tolerances(comparison_abs_tol): - stdout = io.StringIO() - - comparison_abs_tol.report(file=stdout) - stdout.seek(0) - assert "****** Row Comparison ******" in stdout.getvalue() - assert "Number of rows with some columns unequal: 6" in stdout.getvalue() - assert "Number of rows with all columns equal: 7" in stdout.getvalue() - assert "Number of columns compared with some values unequal: 1" in stdout.getvalue() - assert "Number of columns compared with all values equal: 4" in stdout.getvalue() - - -def test_relative_tolerances(comparison_rel_tol): - stdout = io.StringIO() - - comparison_rel_tol.report(file=stdout) - stdout.seek(0) - assert "****** Row Comparison ******" in stdout.getvalue() - assert "Number of rows with some columns unequal: 6" in stdout.getvalue() - assert "Number of rows with all columns equal: 7" in stdout.getvalue() - assert "Number of columns compared with some values unequal: 1" in stdout.getvalue() - assert "Number of columns compared with all values equal: 4" in stdout.getvalue() - - -def test_both_tolerances(comparison_both_tol): - stdout = io.StringIO() - - comparison_both_tol.report(file=stdout) - stdout.seek(0) - assert "****** Row Comparison ******" in stdout.getvalue() - assert "Number of rows with some columns unequal: 6" in stdout.getvalue() - assert "Number of rows with all columns equal: 7" in stdout.getvalue() - assert "Number of columns compared with some values unequal: 1" in stdout.getvalue() - assert "Number of columns compared with all values equal: 4" in stdout.getvalue() - - -def test_negative_tolerances(spark, base_tol, compare_both_tol): - with pytest.raises(ValueError, match="Please enter positive valued tolerances"): - comp = SparkCompare( - spark, - base_tol, - compare_both_tol, - join_columns=["account_identifier"], - rel_tol=-0.2, - abs_tol=0.01, - ) - comp.report() - pass - - -def test_show_all_columns_and_match_rate(show_all_columns_and_match_rate): - stdout = io.StringIO() - - show_all_columns_and_match_rate.report(file=stdout) - - assert "****** Columns with Equal/Unequal Values ******" in stdout.getvalue() - assert ( - "accnt_purge accnt_purge boolean boolean 13 0 100.00000" - in stdout.getvalue() +@pandas_version +def test_date_columns_equal_with_ignore_spaces_and_case(tmp_path): + data = """a|b|expected +2017-01-01|2017-01-01 |True +2017-01-02 |2017-01-02|True +2017-10-01 |2017-10-10 |False +2017-01-01||False +|2017-01-01|False +||True""" + df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|")) + # First compare just the strings + actual_out = columns_equal( + df.a, df.b, rel_tol=0.2, ignore_spaces=True, ignore_case=True ) - assert ( - "date_field date_field date date 13 0 100.00000" - in stdout.getvalue() - ) - assert ( - "dollar_amount dollar_amount double double 3 10 23.07692" - in stdout.getvalue() + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - assert ( - "float_field float_field double double 13 0 100.00000" - in stdout.getvalue() + + # Then compare converted to datetime objects + df["a"] = ps.to_datetime(df["a"], errors="coerce") + df["b"] = ps.to_datetime(df["b"], errors="coerce") + actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - assert ( - "name name string string 13 0 100.00000" - in stdout.getvalue() + # and reverse + actual_out_rev = columns_equal(df.b, df.a, rel_tol=0.2, ignore_spaces=True) + assert_series_equal( + expect_out.to_pandas(), actual_out_rev.to_pandas(), check_names=False ) -def test_decimal_comparisons(): - true_decimals = ["decimal", "decimal()", "decimal(20, 10)"] - assert all(v in datacompy.NUMERIC_SPARK_TYPES for v in true_decimals) - - -def test_decimal_comparator_acts_like_string(): - acc = False - for t in datacompy.NUMERIC_SPARK_TYPES: - acc = acc or (len(t) > 2 and t[0:3] == "dec") - assert acc - - -def test_decimals_and_doubles_are_comparable(): - assert _is_comparable("double", "decimal(10, 2)") - - -def test_report_outputs_the_column_summary(comparison1): - stdout = io.StringIO() - - comparison1.report(file=stdout) - - assert "****** Column Summary ******" in stdout.getvalue() - assert "Number of columns in common with matching schemas: 3" in stdout.getvalue() - assert "Number of columns in common with schema differences: 1" in stdout.getvalue() - assert "Number of columns in base but not compare: 1" in stdout.getvalue() - assert "Number of columns in compare but not base: 1" in stdout.getvalue() - - -def test_report_outputs_the_column_summary_for_identical_schemas(comparison2): - stdout = io.StringIO() - - comparison2.report(file=stdout) - - assert "****** Column Summary ******" in stdout.getvalue() - assert "Number of columns in common with matching schemas: 5" in stdout.getvalue() - assert "Number of columns in common with schema differences: 0" in stdout.getvalue() - assert "Number of columns in base but not compare: 0" in stdout.getvalue() - assert "Number of columns in compare but not base: 0" in stdout.getvalue() - - -def test_report_outputs_the_column_summary_for_differently_named_columns(comparison3): - stdout = io.StringIO() - - comparison3.report(file=stdout) - - assert "****** Column Summary ******" in stdout.getvalue() - assert "Number of columns in common with matching schemas: 4" in stdout.getvalue() - assert "Number of columns in common with schema differences: 1" in stdout.getvalue() - assert "Number of columns in base but not compare: 0" in stdout.getvalue() - assert "Number of columns in compare but not base: 1" in stdout.getvalue() - - -def test_report_outputs_the_row_summary(comparison1): - stdout = io.StringIO() - - comparison1.report(file=stdout) - - assert "****** Row Summary ******" in stdout.getvalue() - assert "Number of rows in common: 4" in stdout.getvalue() - assert "Number of rows in base but not compare: 1" in stdout.getvalue() - assert "Number of rows in compare but not base: 1" in stdout.getvalue() - assert "Number of duplicate rows found in base: 0" in stdout.getvalue() - assert "Number of duplicate rows found in compare: 1" in stdout.getvalue() - - -def test_report_outputs_the_row_equality_comparison(comparison1): - stdout = io.StringIO() - - comparison1.report(file=stdout) - - assert "****** Row Comparison ******" in stdout.getvalue() - assert "Number of rows with some columns unequal: 3" in stdout.getvalue() - assert "Number of rows with all columns equal: 1" in stdout.getvalue() - - -def test_report_outputs_the_row_summary_for_differently_named_columns(comparison3): - stdout = io.StringIO() - - comparison3.report(file=stdout) - - assert "****** Row Summary ******" in stdout.getvalue() - assert "Number of rows in common: 5" in stdout.getvalue() - assert "Number of rows in base but not compare: 0" in stdout.getvalue() - assert "Number of rows in compare but not base: 0" in stdout.getvalue() - assert "Number of duplicate rows found in base: 0" in stdout.getvalue() - assert "Number of duplicate rows found in compare: 0" in stdout.getvalue() - - -def test_report_outputs_the_row_equality_comparison_for_differently_named_columns( - comparison3, -): - stdout = io.StringIO() - - comparison3.report(file=stdout) - - assert "****** Row Comparison ******" in stdout.getvalue() - assert "Number of rows with some columns unequal: 3" in stdout.getvalue() - assert "Number of rows with all columns equal: 2" in stdout.getvalue() - - -def test_report_outputs_column_detail_for_columns_in_only_one_dataframe(comparison1): - stdout = io.StringIO() - - comparison1.report(file=stdout) - comparison1.report() - assert "****** Columns In Base Only ******" in stdout.getvalue() - r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \ndate_fld \s+ date""" - assert re.search(r2, str(stdout.getvalue()), re.X) is not None - - -def test_report_outputs_column_detail_for_columns_in_only_compare_dataframe( - comparison1, -): - stdout = io.StringIO() - - comparison1.report(file=stdout) - comparison1.report() - assert "****** Columns In Compare Only ******" in stdout.getvalue() - r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \n accnt_purge \s+ boolean""" - assert re.search(r2, str(stdout.getvalue()), re.X) is not None - - -def test_report_outputs_schema_difference_details(comparison1): - stdout = io.StringIO() - - comparison1.report(file=stdout) - - assert "****** Schema Differences ******" in stdout.getvalue() - assert re.search( - r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n - -+ \s+ -+ \s+ -+ \s+ -+ \n - dollar_amt \s+ dollar_amt \s+ bigint \s+ double""", - stdout.getvalue(), - re.X, +@pandas_version +def test_date_columns_unequal(): + """I want datetime fields to match with dates stored as strings""" + df = ps.DataFrame([{"a": "2017-01-01", "b": "2017-01-02"}, {"a": "2017-01-01"}]) + df["a_dt"] = ps.to_datetime(df["a"]) + df["b_dt"] = ps.to_datetime(df["b"]) + assert columns_equal(df.a, df.a_dt).all() + assert columns_equal(df.b, df.b_dt).all() + assert columns_equal(df.a_dt, df.a).all() + assert columns_equal(df.b_dt, df.b).all() + assert not columns_equal(df.b_dt, df.a).any() + assert not columns_equal(df.a_dt, df.b).any() + assert not columns_equal(df.a, df.b_dt).any() + assert not columns_equal(df.b, df.a_dt).any() + + +@pandas_version +def test_bad_date_columns(): + """If strings can't be coerced into dates then it should be false for the + whole column. + """ + df = ps.DataFrame( + [{"a": "2017-01-01", "b": "2017-01-01"}, {"a": "2017-01-01", "b": "217-01-01"}] ) + df["a_dt"] = ps.to_datetime(df["a"]) + assert not columns_equal(df.a_dt, df.b).any() -def test_report_outputs_schema_difference_details_for_differently_named_columns( - comparison3, -): - stdout = io.StringIO() - - comparison3.report(file=stdout) - - assert "****** Schema Differences ******" in stdout.getvalue() - assert re.search( - r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n - -+ \s+ -+ \s+ -+ \s+ -+ \n - dollar_amt \s+ dollar_amount \s+ bigint \s+ double""", - stdout.getvalue(), - re.X, +@pandas_version +def test_rounded_date_columns(): + """If strings can't be coerced into dates then it should be false for the + whole column. + """ + df = ps.DataFrame( + [ + {"a": "2017-01-01", "b": "2017-01-01 00:00:00.000000", "exp": True}, + {"a": "2017-01-01", "b": "2017-01-01 00:00:00.123456", "exp": False}, + {"a": "2017-01-01", "b": "2017-01-01 00:00:01.000000", "exp": False}, + {"a": "2017-01-01", "b": "2017-01-01 00:00:00", "exp": True}, + ] ) + df["a_dt"] = ps.to_datetime(df["a"]) + actual = columns_equal(df.a_dt, df.b) + expected = df["exp"] + assert_series_equal(actual.to_pandas(), expected.to_pandas(), check_names=False) -def test_column_comparison_outputs_number_of_columns_with_differences(comparison1): - stdout = io.StringIO() - - comparison1.report(file=stdout) - - assert "****** Column Comparison ******" in stdout.getvalue() - assert "Number of columns compared with some values unequal: 3" in stdout.getvalue() - assert "Number of columns compared with all values equal: 0" in stdout.getvalue() - - -def test_column_comparison_outputs_all_columns_equal_for_identical_dataframes( - comparison2, -): - stdout = io.StringIO() - - comparison2.report(file=stdout) - - assert "****** Column Comparison ******" in stdout.getvalue() - assert "Number of columns compared with some values unequal: 0" in stdout.getvalue() - assert "Number of columns compared with all values equal: 4" in stdout.getvalue() - - -def test_column_comparison_outputs_number_of_columns_with_differences_for_differently_named_columns( - comparison3, -): - stdout = io.StringIO() - - comparison3.report(file=stdout) - - assert "****** Column Comparison ******" in stdout.getvalue() - assert "Number of columns compared with some values unequal: 3" in stdout.getvalue() - assert "Number of columns compared with all values equal: 1" in stdout.getvalue() - - -def test_column_comparison_outputs_number_of_columns_with_differences_for_known_diffs( - comparison_kd1, -): - stdout = io.StringIO() - - comparison_kd1.report(file=stdout) - - assert "****** Column Comparison ******" in stdout.getvalue() - assert ( - "Number of columns compared with unexpected differences in some values: 1" - in stdout.getvalue() - ) - assert ( - "Number of columns compared with all values equal but known differences found: 2" - in stdout.getvalue() +@pandas_version +def test_decimal_float_columns_equal(): + df = ps.DataFrame( + [ + {"a": Decimal("1"), "b": 1, "expected": True}, + {"a": Decimal("1.3"), "b": 1.3, "expected": True}, + {"a": Decimal("1.000003"), "b": 1.000003, "expected": True}, + {"a": Decimal("1.000000004"), "b": 1.000000003, "expected": False}, + {"a": Decimal("1.3"), "b": 1.2, "expected": False}, + {"a": np.nan, "b": np.nan, "expected": True}, + {"a": np.nan, "b": 1, "expected": False}, + {"a": Decimal("1"), "b": np.nan, "expected": False}, + ] ) - assert ( - "Number of columns compared with all values completely equal: 0" - in stdout.getvalue() + actual_out = columns_equal(df.a, df.b) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) -def test_column_comparison_outputs_number_of_columns_with_differences_for_custom_known_diffs( - comparison_kd2, -): - stdout = io.StringIO() - - comparison_kd2.report(file=stdout) - - assert "****** Column Comparison ******" in stdout.getvalue() - assert ( - "Number of columns compared with unexpected differences in some values: 2" - in stdout.getvalue() - ) - assert ( - "Number of columns compared with all values equal but known differences found: 1" - in stdout.getvalue() +@pandas_version +def test_decimal_float_columns_equal_rel(): + df = ps.DataFrame( + [ + {"a": Decimal("1"), "b": 1, "expected": True}, + {"a": Decimal("1.3"), "b": 1.3, "expected": True}, + {"a": Decimal("1.000003"), "b": 1.000003, "expected": True}, + {"a": Decimal("1.000000004"), "b": 1.000000003, "expected": True}, + {"a": Decimal("1.3"), "b": 1.2, "expected": False}, + {"a": np.nan, "b": np.nan, "expected": True}, + {"a": np.nan, "b": 1, "expected": False}, + {"a": Decimal("1"), "b": np.nan, "expected": False}, + ] ) - assert ( - "Number of columns compared with all values completely equal: 0" - in stdout.getvalue() + actual_out = columns_equal(df.a, df.b, abs_tol=0.001) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) -def test_columns_with_unequal_values_show_mismatch_counts(comparison1): - stdout = io.StringIO() - - comparison1.report(file=stdout) - - assert "****** Columns with Unequal Values ******" in stdout.getvalue() - assert re.search( - r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s* - \#\sMatches \s* \#\sMismatches \n - -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""", - stdout.getvalue(), - re.X, - ) - assert re.search( - r"""dollar_amt \s+ dollar_amt \s+ bigint \s+ double \s+ 2 \s+ 2""", - stdout.getvalue(), - re.X, - ) - assert re.search( - r"""float_fld \s+ float_fld \s+ double \s+ double \s+ 1 \s+ 3""", - stdout.getvalue(), - re.X, +@pandas_version +def test_decimal_columns_equal(): + df = ps.DataFrame( + [ + {"a": Decimal("1"), "b": Decimal("1"), "expected": True}, + {"a": Decimal("1.3"), "b": Decimal("1.3"), "expected": True}, + {"a": Decimal("1.000003"), "b": Decimal("1.000003"), "expected": True}, + { + "a": Decimal("1.000000004"), + "b": Decimal("1.000000003"), + "expected": False, + }, + {"a": Decimal("1.3"), "b": Decimal("1.2"), "expected": False}, + {"a": np.nan, "b": np.nan, "expected": True}, + {"a": np.nan, "b": Decimal("1"), "expected": False}, + {"a": Decimal("1"), "b": np.nan, "expected": False}, + ] ) - assert re.search( - r"""name \s+ name \s+ string \s+ string \s+ 3 \s+ 1""", stdout.getvalue(), re.X + actual_out = columns_equal(df.a, df.b) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) -def test_columns_with_different_names_with_unequal_values_show_mismatch_counts( - comparison3, -): - stdout = io.StringIO() +@pandas_version +def test_decimal_columns_equal_rel(): + df = ps.DataFrame( + [ + {"a": Decimal("1"), "b": Decimal("1"), "expected": True}, + {"a": Decimal("1.3"), "b": Decimal("1.3"), "expected": True}, + {"a": Decimal("1.000003"), "b": Decimal("1.000003"), "expected": True}, + { + "a": Decimal("1.000000004"), + "b": Decimal("1.000000003"), + "expected": True, + }, + {"a": Decimal("1.3"), "b": Decimal("1.2"), "expected": False}, + {"a": np.nan, "b": np.nan, "expected": True}, + {"a": np.nan, "b": Decimal("1"), "expected": False}, + {"a": Decimal("1"), "b": np.nan, "expected": False}, + ] + ) + actual_out = columns_equal(df.a, df.b, abs_tol=0.001) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False + ) - comparison3.report(file=stdout) - assert "****** Columns with Unequal Values ******" in stdout.getvalue() - assert re.search( - r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s* - \#\sMatches \s* \#\sMismatches \n - -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""", - stdout.getvalue(), - re.X, - ) - assert re.search( - r"""dollar_amt \s+ dollar_amount \s+ bigint \s+ double \s+ 2 \s+ 3""", - stdout.getvalue(), - re.X, +@pandas_version +def test_infinity_and_beyond(): + # https://spark.apache.org/docs/latest/sql-ref-datatypes.html#positivenegative-infinity-semantics + # Positive/negative infinity multiplied by 0 returns NaN. + # Positive infinity sorts lower than NaN and higher than any other values. + # Negative infinity sorts lower than any other values. + df = ps.DataFrame( + [ + {"a": np.inf, "b": np.inf, "expected": True}, + {"a": -np.inf, "b": -np.inf, "expected": True}, + {"a": -np.inf, "b": np.inf, "expected": True}, + {"a": np.inf, "b": -np.inf, "expected": True}, + {"a": 1, "b": 1, "expected": True}, + {"a": 1, "b": 0, "expected": False}, + ] ) - assert re.search( - r"""float_fld \s+ float_field \s+ double \s+ double \s+ 4 \s+ 1""", - stdout.getvalue(), - re.X, + actual_out = columns_equal(df.a, df.b) + expect_out = df["expected"] + assert_series_equal( + expect_out.to_pandas(), actual_out.to_pandas(), check_names=False ) - assert re.search( - r"""name \s+ name \s+ string \s+ string \s+ 4 \s+ 1""", stdout.getvalue(), re.X + + +@pandas_version +def test_compare_df_setter_bad(): + df = ps.DataFrame([{"a": 1, "c": 2}, {"a": 2, "c": 2}]) + with raises(TypeError, match="df1 must be a pyspark.pandas.frame.DataFrame"): + compare = SparkCompare("a", "a", ["a"]) + with raises(ValueError, match="df1 must have all columns from join_columns"): + compare = SparkCompare(df, df.copy(), ["b"]) + df_dupe = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 3}]) + assert ( + SparkCompare(df_dupe, df_dupe.copy(), ["a", "b"]) + .df1.equals(df_dupe) + .all() + .all() ) -def test_rows_only_base_returns_a_dataframe_with_rows_only_in_base(spark, comparison1): - # require schema if contains only 1 row and contain field value as None - schema = StructType( +@pandas_version +def test_compare_df_setter_good(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}]) + df2 = ps.DataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}]) + compare = SparkCompare(df1, df2, ["a"]) + assert compare.df1.equals(df1).all().all() + assert compare.df2.equals(df2).all().all() + assert compare.join_columns == ["a"] + compare = SparkCompare(df1, df2, ["A", "b"]) + assert compare.df1.equals(df1).all().all() + assert compare.df2.equals(df2).all().all() + assert compare.join_columns == ["a", "b"] + + +@pandas_version +def test_compare_df_setter_different_cases(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}]) + df2 = ps.DataFrame([{"A": 1, "b": 2}, {"A": 2, "b": 3}]) + compare = SparkCompare(df1, df2, ["a"]) + assert compare.df1.equals(df1).all().all() + assert compare.df2.equals(df2).all().all() + + +@pandas_version +def test_columns_overlap(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 3}]) + compare = SparkCompare(df1, df2, ["a"]) + assert compare.df1_unq_columns() == set() + assert compare.df2_unq_columns() == set() + assert compare.intersect_columns() == {"a", "b"} + + +@pandas_version +def test_columns_no_overlap(): + df1 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "yo"}]) + df2 = ps.DataFrame([{"a": 1, "b": 2, "d": "oh"}, {"a": 2, "b": 3, "d": "ya"}]) + compare = SparkCompare(df1, df2, ["a"]) + assert compare.df1_unq_columns() == {"c"} + assert compare.df2_unq_columns() == {"d"} + assert compare.intersect_columns() == {"a", "b"} + + +@pandas_version +def test_columns_maintain_order_through_set_operations(): + df1 = ps.DataFrame( [ - StructField("acct", LongType(), True), - StructField("date_fld", DateType(), True), - StructField("dollar_amt", LongType(), True), - StructField("float_fld", DoubleType(), True), - StructField("name", StringType(), True), - ] + (("A"), (0), (1), (2), (3), (4), (-2)), + (("B"), (0), (2), (2), (3), (4), (-3)), + ], + columns=["join", "f", "g", "b", "h", "a", "c"], ) - expected_df = spark.createDataFrame( + df2 = ps.DataFrame( [ - Row( - acct=10000001239, - date_fld=datetime.date(2017, 1, 1), - dollar_amt=1, - float_fld=None, - name="Lucille Bluth", - ) + (("A"), (0), (1), (2), (-1), (4), (-3)), + (("B"), (1), (2), (3), (-1), (4), (-2)), ], - schema, + columns=["join", "e", "h", "b", "a", "g", "d"], ) - assert comparison1.rows_only_base.count() == 1 - assert ( - expected_df.union( - comparison1.rows_only_base.select( - "acct", "date_fld", "dollar_amt", "float_fld", "name" - ) - ) - .distinct() - .count() - == 1 + compare = SparkCompare(df1, df2, ["join"]) + assert list(compare.df1_unq_columns()) == ["f", "c"] + assert list(compare.df2_unq_columns()) == ["e", "d"] + assert list(compare.intersect_columns()) == ["join", "g", "b", "h", "a"] + + +@pandas_version +def test_10k_rows(): + df1 = ps.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + df1.reset_index(inplace=True) + df1.columns = ["a", "b", "c"] + df2 = df1.copy() + df2["b"] = df2["b"] + 0.1 + compare_tol = SparkCompare(df1, df2, ["a"], abs_tol=0.2) + assert compare_tol.matches() + assert len(compare_tol.df1_unq_rows) == 0 + assert len(compare_tol.df2_unq_rows) == 0 + assert compare_tol.intersect_columns() == {"a", "b", "c"} + assert compare_tol.all_columns_match() + assert compare_tol.all_rows_overlap() + assert compare_tol.intersect_rows_match() + + compare_no_tol = SparkCompare(df1, df2, ["a"]) + assert not compare_no_tol.matches() + assert len(compare_no_tol.df1_unq_rows) == 0 + assert len(compare_no_tol.df2_unq_rows) == 0 + assert compare_no_tol.intersect_columns() == {"a", "b", "c"} + assert compare_no_tol.all_columns_match() + assert compare_no_tol.all_rows_overlap() + assert not compare_no_tol.intersect_rows_match() + + +@pandas_version +def test_subset(caplog): + caplog.set_level(logging.DEBUG) + df1 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "yo"}]) + df2 = ps.DataFrame([{"a": 1, "c": "hi"}]) + comp = SparkCompare(df1, df2, ["a"]) + assert comp.subset() + assert "Checking equality" in caplog.text + + +@pandas_version +def test_not_subset(caplog): + caplog.set_level(logging.INFO) + df1 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "yo"}]) + df2 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "great"}]) + comp = SparkCompare(df1, df2, ["a"]) + assert not comp.subset() + assert "c: 1 / 2 (50.00%) match" in caplog.text + + +@pandas_version +def test_large_subset(): + df1 = ps.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + df1.reset_index(inplace=True) + df1.columns = ["a", "b", "c"] + df2 = df1[["a", "b"]].head(50).copy() + comp = SparkCompare(df1, df2, ["a"]) + assert not comp.matches() + assert comp.subset() + + +@pandas_version +def test_string_joiner(): + df1 = ps.DataFrame([{"ab": 1, "bc": 2}, {"ab": 2, "bc": 2}]) + df2 = ps.DataFrame([{"ab": 1, "bc": 2}, {"ab": 2, "bc": 2}]) + compare = SparkCompare(df1, df2, "ab") + assert compare.matches() + + +@pandas_version +def test_decimal_with_joins(): + df1 = ps.DataFrame([{"a": Decimal("1"), "b": 2}, {"a": Decimal("2"), "b": 2}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}]) + compare = SparkCompare(df1, df2, "a") + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_decimal_with_nulls(): + df1 = ps.DataFrame([{"a": 1, "b": Decimal("2")}, {"a": 2, "b": Decimal("2")}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}, {"a": 3, "b": 2}]) + compare = SparkCompare(df1, df2, "a") + assert not compare.matches() + assert compare.all_columns_match() + assert not compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_strings_with_joins(): + df1 = ps.DataFrame([{"a": "hi", "b": 2}, {"a": "bye", "b": 2}]) + df2 = ps.DataFrame([{"a": "hi", "b": 2}, {"a": "bye", "b": 2}]) + compare = SparkCompare(df1, df2, "a") + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_temp_column_name(): + df1 = ps.DataFrame([{"a": "hi", "b": 2}, {"a": "bye", "b": 2}]) + df2 = ps.DataFrame( + [{"a": "hi", "b": 2}, {"a": "bye", "b": 2}, {"a": "back fo mo", "b": 3}] ) + actual = temp_column_name(df1, df2) + assert actual == "_temp_0" -def test_rows_only_compare_returns_a_dataframe_with_rows_only_in_compare( - spark, comparison1 -): - expected_df = spark.createDataFrame( - [ - Row( - acct=10000001238, - dollar_amt=1.05, - name="Loose Seal Bluth", - float_fld=111.0, - accnt_purge=True, - ) - ] +@pandas_version +def test_temp_column_name_one_has(): + df1 = ps.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}]) + df2 = ps.DataFrame( + [{"a": "hi", "b": 2}, {"a": "bye", "b": 2}, {"a": "back fo mo", "b": 3}] ) - - assert comparison1.rows_only_compare.count() == 1 - assert expected_df.union(comparison1.rows_only_compare).distinct().count() == 1 + actual = temp_column_name(df1, df2) + assert actual == "_temp_1" -def test_rows_both_mismatch_returns_a_dataframe_with_rows_where_variables_mismatched( - spark, comparison1 -): - expected_df = spark.createDataFrame( +@pandas_version +def test_temp_column_name_both_have(): + df1 = ps.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}]) + df2 = ps.DataFrame( [ - Row( - accnt_purge=False, - acct=10000001234, - date_fld=datetime.date(2017, 1, 1), - dollar_amt_base=123, - dollar_amt_compare=123.4, - dollar_amt_match=False, - float_fld_base=14530.1555, - float_fld_compare=14530.155, - float_fld_match=False, - name_base="George Maharis", - name_compare="George Michael Bluth", - name_match=False, - ), - Row( - accnt_purge=False, - acct=10000001235, - date_fld=datetime.date(2017, 1, 1), - dollar_amt_base=0, - dollar_amt_compare=0.45, - dollar_amt_match=False, - float_fld_base=1.0, - float_fld_compare=None, - float_fld_match=False, - name_base="Michael Bluth", - name_compare="Michael Bluth", - name_match=True, - ), - Row( - accnt_purge=False, - acct=10000001236, - date_fld=datetime.date(2017, 1, 1), - dollar_amt_base=1345, - dollar_amt_compare=1345.0, - dollar_amt_match=True, - float_fld_base=None, - float_fld_compare=1.0, - float_fld_match=False, - name_base="George Bluth", - name_compare="George Bluth", - name_match=True, - ), + {"_temp_0": "hi", "b": 2}, + {"_temp_0": "bye", "b": 2}, + {"a": "back fo mo", "b": 3}, ] ) - - assert comparison1.rows_both_mismatch.count() == 3 - assert expected_df.union(comparison1.rows_both_mismatch).distinct().count() == 3 + actual = temp_column_name(df1, df2) + assert actual == "_temp_1" -def test_rows_both_mismatch_only_includes_rows_with_true_mismatches_when_known_diffs_are_present( - spark, comparison_kd1 -): - expected_df = spark.createDataFrame( +@pandas_version +def test_temp_column_name_both_have(): + df1 = ps.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}]) + df2 = ps.DataFrame( [ - Row( - acct=10000001237, - acct_seq=0, - cd_base="0004", - cd_compare=4.0, - cd_match=True, - cd_match_type="KNOWN_DIFFERENCE", - open_dt_base=datetime.date(2017, 5, 4), - open_dt_compare=2017124, - open_dt_match=True, - open_dt_match_type="KNOWN_DIFFERENCE", - stat_cd_base="*2", - stat_cd_compare="V3", - stat_cd_match=False, - stat_cd_match_type="MISMATCH", - ) + {"_temp_0": "hi", "b": 2}, + {"_temp_1": "bye", "b": 2}, + {"a": "back fo mo", "b": 3}, ] ) - assert comparison_kd1.rows_both_mismatch.count() == 1 - assert expected_df.union(comparison_kd1.rows_both_mismatch).distinct().count() == 1 + actual = temp_column_name(df1, df2) + assert actual == "_temp_2" -def test_rows_both_all_returns_a_dataframe_with_all_rows_in_both_dataframes( - spark, comparison1 -): - expected_df = spark.createDataFrame( +@pandas_version +def test_temp_column_name_one_already(): + df1 = ps.DataFrame([{"_temp_1": "hi", "b": 2}, {"_temp_1": "bye", "b": 2}]) + df2 = ps.DataFrame( [ - Row( - accnt_purge=False, - acct=10000001234, - date_fld=datetime.date(2017, 1, 1), - dollar_amt_base=123, - dollar_amt_compare=123.4, - dollar_amt_match=False, - float_fld_base=14530.1555, - float_fld_compare=14530.155, - float_fld_match=False, - name_base="George Maharis", - name_compare="George Michael Bluth", - name_match=False, - ), - Row( - accnt_purge=False, - acct=10000001235, - date_fld=datetime.date(2017, 1, 1), - dollar_amt_base=0, - dollar_amt_compare=0.45, - dollar_amt_match=False, - float_fld_base=1.0, - float_fld_compare=None, - float_fld_match=False, - name_base="Michael Bluth", - name_compare="Michael Bluth", - name_match=True, - ), - Row( - accnt_purge=False, - acct=10000001236, - date_fld=datetime.date(2017, 1, 1), - dollar_amt_base=1345, - dollar_amt_compare=1345.0, - dollar_amt_match=True, - float_fld_base=None, - float_fld_compare=1.0, - float_fld_match=False, - name_base="George Bluth", - name_compare="George Bluth", - name_match=True, - ), - Row( - accnt_purge=False, - acct=10000001237, - date_fld=datetime.date(2017, 1, 1), - dollar_amt_base=123456, - dollar_amt_compare=123456.0, - dollar_amt_match=True, - float_fld_base=345.12, - float_fld_compare=345.12, - float_fld_match=True, - name_base="Bob Loblaw", - name_compare="Bob Loblaw", - name_match=True, - ), + {"_temp_1": "hi", "b": 2}, + {"_temp_1": "bye", "b": 2}, + {"a": "back fo mo", "b": 3}, ] ) + actual = temp_column_name(df1, df2) + assert actual == "_temp_0" - assert comparison1.rows_both_all.count() == 4 - assert expected_df.union(comparison1.rows_both_all).distinct().count() == 4 +### Duplicate testing! +@pandas_version +def test_simple_dupes_one_field(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) + compare = SparkCompare(df1, df2, join_columns=["a"]) + assert compare.matches() + # Just render the report to make sure it renders. + t = compare.report() -def test_rows_both_all_shows_known_diffs_flag_and_known_diffs_count_as_matches( - spark, comparison_kd1 -): - expected_df = spark.createDataFrame( - [ - Row( - acct=10000001234, - acct_seq=0, - cd_base="0001", - cd_compare=1.0, - cd_match=True, - cd_match_type="KNOWN_DIFFERENCE", - open_dt_base=datetime.date(2017, 5, 1), - open_dt_compare=2017121, - open_dt_match=True, - open_dt_match_type="KNOWN_DIFFERENCE", - stat_cd_base="*2", - stat_cd_compare=None, - stat_cd_match=True, - stat_cd_match_type="KNOWN_DIFFERENCE", - ), - Row( - acct=10000001235, - acct_seq=0, - cd_base="0002", - cd_compare=2.0, - cd_match=True, - cd_match_type="KNOWN_DIFFERENCE", - open_dt_base=datetime.date(2017, 5, 2), - open_dt_compare=2017122, - open_dt_match=True, - open_dt_match_type="KNOWN_DIFFERENCE", - stat_cd_base="V1", - stat_cd_compare="V1", - stat_cd_match=True, - stat_cd_match_type="MATCH", - ), - Row( - acct=10000001236, - acct_seq=0, - cd_base="0003", - cd_compare=3.0, - cd_match=True, - cd_match_type="KNOWN_DIFFERENCE", - open_dt_base=datetime.date(2017, 5, 3), - open_dt_compare=2017123, - open_dt_match=True, - open_dt_match_type="KNOWN_DIFFERENCE", - stat_cd_base="V2", - stat_cd_compare="V2", - stat_cd_match=True, - stat_cd_match_type="MATCH", - ), - Row( - acct=10000001237, - acct_seq=0, - cd_base="0004", - cd_compare=4.0, - cd_match=True, - cd_match_type="KNOWN_DIFFERENCE", - open_dt_base=datetime.date(2017, 5, 4), - open_dt_compare=2017124, - open_dt_match=True, - open_dt_match_type="KNOWN_DIFFERENCE", - stat_cd_base="*2", - stat_cd_compare="V3", - stat_cd_match=False, - stat_cd_match_type="MISMATCH", - ), - Row( - acct=10000001238, - acct_seq=0, - cd_base="0005", - cd_compare=5.0, - cd_match=True, - cd_match_type="KNOWN_DIFFERENCE", - open_dt_base=datetime.date(2017, 5, 5), - open_dt_compare=2017125, - open_dt_match=True, - open_dt_match_type="KNOWN_DIFFERENCE", - stat_cd_base="*2", - stat_cd_compare=None, - stat_cd_match=True, - stat_cd_match_type="KNOWN_DIFFERENCE", - ), - ] - ) - assert comparison_kd1.rows_both_all.count() == 5 - assert expected_df.union(comparison_kd1.rows_both_all).distinct().count() == 5 +@pandas_version +def test_simple_dupes_two_fields(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 2}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 2}]) + compare = SparkCompare(df1, df2, join_columns=["a", "b"]) + assert compare.matches() + # Just render the report to make sure it renders. + t = compare.report() -def test_rows_both_all_returns_a_dataframe_with_all_rows_in_identical_dataframes( - spark, comparison2 -): - expected_df = spark.createDataFrame( - [ - Row( - acct=10000001234, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=123, - dollar_amt_compare=123, - dollar_amt_match=True, - float_fld_base=14530.1555, - float_fld_compare=14530.1555, - float_fld_match=True, - name_base="George Maharis", - name_compare="George Maharis", - name_match=True, - ), - Row( - acct=10000001235, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=0, - dollar_amt_compare=0, - dollar_amt_match=True, - float_fld_base=1.0, - float_fld_compare=1.0, - float_fld_match=True, - name_base="Michael Bluth", - name_compare="Michael Bluth", - name_match=True, - ), - Row( - acct=10000001236, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=1345, - dollar_amt_compare=1345, - dollar_amt_match=True, - float_fld_base=None, - float_fld_compare=None, - float_fld_match=True, - name_base="George Bluth", - name_compare="George Bluth", - name_match=True, - ), - Row( - acct=10000001237, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=123456, - dollar_amt_compare=123456, - dollar_amt_match=True, - float_fld_base=345.12, - float_fld_compare=345.12, - float_fld_match=True, - name_base="Bob Loblaw", - name_compare="Bob Loblaw", - name_match=True, - ), - Row( - acct=10000001239, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=1, - dollar_amt_compare=1, - dollar_amt_match=True, - float_fld_base=None, - float_fld_compare=None, - float_fld_match=True, - name_base="Lucille Bluth", - name_compare="Lucille Bluth", - name_match=True, - ), - ] +@pandas_version +def test_simple_dupes_one_field_two_vals_1(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}]) + compare = SparkCompare(df1, df2, join_columns=["a"]) + assert compare.matches() + # Just render the report to make sure it renders. + t = compare.report() + + +@pandas_version +def test_simple_dupes_one_field_two_vals_2(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 0}]) + compare = SparkCompare(df1, df2, join_columns=["a"]) + assert not compare.matches() + assert len(compare.df1_unq_rows) == 1 + assert len(compare.df2_unq_rows) == 1 + assert len(compare.intersect_rows) == 1 + # Just render the report to make sure it renders. + t = compare.report() + + +@pandas_version +def test_simple_dupes_one_field_three_to_two_vals(): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}, {"a": 1, "b": 0}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}]) + compare = SparkCompare(df1, df2, join_columns=["a"]) + assert not compare.matches() + assert len(compare.df1_unq_rows) == 1 + assert len(compare.df2_unq_rows) == 0 + assert len(compare.intersect_rows) == 2 + # Just render the report to make sure it renders. + t = compare.report() + + assert "(First 1 Columns)" in compare.report(column_count=1) + assert "(First 2 Columns)" in compare.report(column_count=2) + + +@pandas_version +def test_dupes_from_real_data(): + data = """acct_id,acct_sfx_num,trxn_post_dt,trxn_post_seq_num,trxn_amt,trxn_dt,debit_cr_cd,cash_adv_trxn_comn_cntry_cd,mrch_catg_cd,mrch_pstl_cd,visa_mail_phn_cd,visa_rqstd_pmt_svc_cd,mc_pmt_facilitator_idn_num +100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0 +200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1, +100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0 +200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A, +100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0 +200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,, +100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0 +200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G, +100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0 +200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G, +100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0 +200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,, +100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0 +200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F, +100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0 +200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F, +100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0 +200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A, +100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0 +200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,""" + df1 = ps.from_pandas(pd.read_csv(StringIO(data), sep=",")) + df2 = df1.copy() + compare_acct = SparkCompare(df1, df2, join_columns=["acct_id"]) + assert compare_acct.matches() + compare_unq = SparkCompare( + df1, + df2, + join_columns=["acct_id", "acct_sfx_num", "trxn_post_dt", "trxn_post_seq_num"], ) + assert compare_unq.matches() + # Just render the report to make sure it renders. + t = compare_acct.report() + r = compare_unq.report() + + +@pandas_version +def test_strings_with_joins_with_ignore_spaces(): + df1 = ps.DataFrame([{"a": "hi", "b": " A"}, {"a": "bye", "b": "A"}]) + df2 = ps.DataFrame([{"a": "hi", "b": "A"}, {"a": "bye", "b": "A "}]) + compare = SparkCompare(df1, df2, "a", ignore_spaces=False) + assert not compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert not compare.intersect_rows_match() + + compare = SparkCompare(df1, df2, "a", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_strings_with_joins_with_ignore_case(): + df1 = ps.DataFrame([{"a": "hi", "b": "a"}, {"a": "bye", "b": "A"}]) + df2 = ps.DataFrame([{"a": "hi", "b": "A"}, {"a": "bye", "b": "a"}]) + compare = SparkCompare(df1, df2, "a", ignore_case=False) + assert not compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert not compare.intersect_rows_match() + + compare = SparkCompare(df1, df2, "a", ignore_case=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_decimal_with_joins_with_ignore_spaces(): + df1 = ps.DataFrame([{"a": 1, "b": " A"}, {"a": 2, "b": "A"}]) + df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A "}]) + compare = SparkCompare(df1, df2, "a", ignore_spaces=False) + assert not compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert not compare.intersect_rows_match() + + compare = SparkCompare(df1, df2, "a", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_decimal_with_joins_with_ignore_case(): + df1 = ps.DataFrame([{"a": 1, "b": "a"}, {"a": 2, "b": "A"}]) + df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "a"}]) + compare = SparkCompare(df1, df2, "a", ignore_case=False) + assert not compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert not compare.intersect_rows_match() + + compare = SparkCompare(df1, df2, "a", ignore_case=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_joins_with_ignore_spaces(): + df1 = ps.DataFrame([{"a": 1, "b": " A"}, {"a": 2, "b": "A"}]) + df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A "}]) + + compare = SparkCompare(df1, df2, "a", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_joins_with_ignore_case(): + df1 = ps.DataFrame([{"a": 1, "b": "a"}, {"a": 2, "b": "A"}]) + df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "a"}]) + + compare = SparkCompare(df1, df2, "a", ignore_case=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + + +@pandas_version +def test_strings_with_ignore_spaces_and_join_columns(): + df1 = ps.DataFrame([{"a": "hi", "b": "A"}, {"a": "bye", "b": "A"}]) + df2 = ps.DataFrame([{"a": " hi ", "b": "A"}, {"a": " bye ", "b": "A"}]) + compare = SparkCompare(df1, df2, "a", ignore_spaces=False) + assert not compare.matches() + assert compare.all_columns_match() + assert not compare.all_rows_overlap() + assert compare.count_matching_rows() == 0 + + compare = SparkCompare(df1, df2, "a", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + assert compare.count_matching_rows() == 2 + + +@pandas_version +def test_integers_with_ignore_spaces_and_join_columns(): + df1 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A"}]) + df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A"}]) + compare = SparkCompare(df1, df2, "a", ignore_spaces=False) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + assert compare.count_matching_rows() == 2 + + compare = SparkCompare(df1, df2, "a", ignore_spaces=True) + assert compare.matches() + assert compare.all_columns_match() + assert compare.all_rows_overlap() + assert compare.intersect_rows_match() + assert compare.count_matching_rows() == 2 + + +@pandas_version +def test_sample_mismatch(): + data1 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ - assert comparison2.rows_both_all.count() == 5 - assert expected_df.union(comparison2.rows_both_all).distinct().count() == 5 + data2 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.4,George Michael Bluth,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Robert Loblaw,345.12, + 10000001238,1.05,Loose Seal Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=",")) + df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=",")) -def test_rows_both_all_returns_all_rows_in_both_dataframes_for_differently_named_columns( - spark, comparison3 -): - expected_df = spark.createDataFrame( - [ - Row( - accnt_purge=False, - acct=10000001234, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=123, - dollar_amt_compare=123.4, - dollar_amt_match=False, - float_fld_base=14530.1555, - float_fld_compare=14530.155, - float_fld_match=False, - name_base="George Maharis", - name_compare="George Michael Bluth", - name_match=False, - ), - Row( - accnt_purge=False, - acct=10000001235, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=0, - dollar_amt_compare=0.45, - dollar_amt_match=False, - float_fld_base=1.0, - float_fld_compare=1.0, - float_fld_match=True, - name_base="Michael Bluth", - name_compare="Michael Bluth", - name_match=True, - ), - Row( - accnt_purge=False, - acct=10000001236, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=1345, - dollar_amt_compare=1345.0, - dollar_amt_match=True, - float_fld_base=None, - float_fld_compare=None, - float_fld_match=True, - name_base="George Bluth", - name_compare="George Bluth", - name_match=True, - ), - Row( - accnt_purge=False, - acct=10000001237, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=123456, - dollar_amt_compare=123456.0, - dollar_amt_match=True, - float_fld_base=345.12, - float_fld_compare=345.12, - float_fld_match=True, - name_base="Bob Loblaw", - name_compare="Bob Loblaw", - name_match=True, - ), - Row( - accnt_purge=True, - acct=10000001239, - date_fld_base=datetime.date(2017, 1, 1), - date_fld_compare=datetime.date(2017, 1, 1), - date_fld_match=True, - dollar_amt_base=1, - dollar_amt_compare=1.05, - dollar_amt_match=False, - float_fld_base=None, - float_fld_compare=None, - float_fld_match=True, - name_base="Lucille Bluth", - name_compare="Lucille Bluth", - name_match=True, - ), - ] - ) + compare = SparkCompare(df1, df2, "acct_id") - assert comparison3.rows_both_all.count() == 5 - assert expected_df.union(comparison3.rows_both_all).distinct().count() == 5 - - -def test_columns_with_unequal_values_text_is_aligned(comparison4): - stdout = io.StringIO() - - comparison4.report(file=stdout) - stdout.seek(0) # Back up to the beginning of the stream - - text_alignment_validator( - report=stdout, - section_start="****** Columns with Unequal Values ******", - section_end="\n", - left_indices=(1, 2, 3, 4), - right_indices=(5, 6), - column_regexes=[ - r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ - (\#\sMatches) \s+ (\#\sMismatches)""", - r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", - r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double) \s+ (2) \s+ (2)""", - r"""(float_fld) \s+ (float_fld) \s+ (double) \s+ (double) \s+ (1) \s+ (3)""", - r"""(super_duper_big_long_name) \s+ (name) \s+ (string) \s+ (string) \s+ (3) \s+ (1)\s*""", - ], - ) + output = compare.sample_mismatch(column="name", sample_count=1) + assert output.shape[0] == 1 + assert (output.name_df1 != output.name_df2).all() + output = compare.sample_mismatch(column="name", sample_count=2) + assert output.shape[0] == 2 + assert (output.name_df1 != output.name_df2).all() -def test_columns_with_unequal_values_text_is_aligned_with_known_differences( - comparison_kd1, -): - stdout = io.StringIO() - - comparison_kd1.report(file=stdout) - stdout.seek(0) # Back up to the beginning of the stream - - text_alignment_validator( - report=stdout, - section_start="****** Columns with Unequal Values ******", - section_end="\n", - left_indices=(1, 2, 3, 4), - right_indices=(5, 6, 7), - column_regexes=[ - r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ - (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""", - r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", - r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""", - r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (5) \s+ (0)""", - r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""", - ], - ) + output = compare.sample_mismatch(column="name", sample_count=3) + assert output.shape[0] == 2 + assert (output.name_df1 != output.name_df2).all() -def test_columns_with_unequal_values_text_is_aligned_with_custom_known_differences( - comparison_kd2, -): - stdout = io.StringIO() - - comparison_kd2.report(file=stdout) - stdout.seek(0) # Back up to the beginning of the stream - - text_alignment_validator( - report=stdout, - section_start="****** Columns with Unequal Values ******", - section_end="\n", - left_indices=(1, 2, 3, 4), - right_indices=(5, 6, 7), - column_regexes=[ - r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ - (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""", - r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", - r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""", - r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (0) \s+ (5)""", - r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""", - ], - ) +@pandas_version +def test_all_mismatch_not_ignore_matching_cols_no_cols_matching(): + data1 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + data2 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.4,George Michael Bluth,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Robert Loblaw,345.12, + 10000001238,1.05,Loose Seal Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=",")) + df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=",")) + compare = SparkCompare(df1, df2, "acct_id") + + output = compare.all_mismatch() + assert output.shape[0] == 4 + assert output.shape[1] == 10 + + assert (output.name_df1 != output.name_df2).values.sum() == 2 + assert (~(output.name_df1 != output.name_df2)).values.sum() == 2 + + assert (output.dollar_amt_df1 != output.dollar_amt_df2).values.sum() == 1 + assert (~(output.dollar_amt_df1 != output.dollar_amt_df2)).values.sum() == 3 + + assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3 + assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1 + + assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4 + assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0 + + +@pandas_version +def test_all_mismatch_not_ignore_matching_cols_some_cols_matching(): + # Columns dollar_amt and name are matching + data1 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + + data2 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Bob Loblaw,345.12, + 10000001238,1.05,Lucille Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=",")) + df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=",")) + compare = SparkCompare(df1, df2, "acct_id") + + output = compare.all_mismatch() + assert output.shape[0] == 4 + assert output.shape[1] == 10 + + assert (output.name_df1 != output.name_df2).values.sum() == 0 + assert (~(output.name_df1 != output.name_df2)).values.sum() == 4 + + assert (output.dollar_amt_df1 != output.dollar_amt_df2).values.sum() == 0 + assert (~(output.dollar_amt_df1 != output.dollar_amt_df2)).values.sum() == 4 + + assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3 + assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1 + + assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4 + assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0 + + +@pandas_version +def test_all_mismatch_ignore_matching_cols_some_cols_matching_diff_rows(): + # Case where there are rows on either dataset which don't match up. + # Columns dollar_amt and name are matching + data1 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + 10000001241,1111.05,Lucille Bluth, + """ -def test_columns_with_unequal_values_text_is_aligned_for_decimals(comparison_decimal): - stdout = io.StringIO() + data2 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Bob Loblaw,345.12, + 10000001238,1.05,Lucille Bluth,111, + """ + df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=",")) + df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=",")) + compare = SparkCompare(df1, df2, "acct_id") - comparison_decimal.report(file=stdout) - stdout.seek(0) # Back up to the beginning of the stream + output = compare.all_mismatch(ignore_matching_cols=True) - text_alignment_validator( - report=stdout, - section_start="****** Columns with Unequal Values ******", - section_end="\n", - left_indices=(1, 2, 3, 4), - right_indices=(5, 6), - column_regexes=[ - r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+ - (\#\sMatches) \s+ (\#\sMismatches)""", - r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""", - r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double) \s+ (1) \s+ (1)""", - ], - ) + assert output.shape[0] == 4 + assert output.shape[1] == 6 + assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3 + assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1 -def test_schema_differences_text_is_aligned(comparison4): - stdout = io.StringIO() + assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4 + assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0 - comparison4.report(file=stdout) - comparison4.report() - stdout.seek(0) # Back up to the beginning of the stream + assert not ("name_df1" in output and "name_df2" in output) + assert not ("dollar_amt_df1" in output and "dollar_amt_df1" in output) - text_alignment_validator( - report=stdout, - section_start="****** Schema Differences ******", - section_end="\n", - left_indices=(1, 2, 3, 4), - right_indices=(), - column_regexes=[ - r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""", - r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""", - r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double)""", - ], - ) +@pandas_version +def test_all_mismatch_ignore_matching_cols_some_calls_matching(): + # Columns dollar_amt and name are matching + data1 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ -def test_schema_differences_text_is_aligned_for_decimals(comparison_decimal): - stdout = io.StringIO() + data2 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Bob Loblaw,345.12, + 10000001238,1.05,Lucille Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=",")) + df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=",")) + compare = SparkCompare(df1, df2, "acct_id") - comparison_decimal.report(file=stdout) - stdout.seek(0) # Back up to the beginning of the stream + output = compare.all_mismatch(ignore_matching_cols=True) - text_alignment_validator( - report=stdout, - section_start="****** Schema Differences ******", - section_end="\n", - left_indices=(1, 2, 3, 4), - right_indices=(), - column_regexes=[ - r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""", - r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""", - r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double)""", - ], - ) + assert output.shape[0] == 4 + assert output.shape[1] == 6 + assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3 + assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1 -def test_base_only_columns_text_is_aligned(comparison4): - stdout = io.StringIO() + assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4 + assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0 - comparison4.report(file=stdout) - stdout.seek(0) # Back up to the beginning of the stream + assert not ("name_df1" in output and "name_df2" in output) + assert not ("dollar_amt_df1" in output and "dollar_amt_df1" in output) - text_alignment_validator( - report=stdout, - section_start="****** Columns In Base Only ******", - section_end="\n", - left_indices=(1, 2), - right_indices=(), - column_regexes=[ - r"""(Column\sName) \s+ (Dtype)""", - r"""(-+) \s+ (-+)""", - r"""(date_fld) \s+ (date)""", - ], - ) +@pandas_version +def test_all_mismatch_ignore_matching_cols_no_cols_matching(): + data1 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.45,George Maharis,14530.1555,2017-01-01 + 10000001235,0.45,Michael Bluth,1,2017-01-01 + 10000001236,1345,George Bluth,,2017-01-01 + 10000001237,123456,Bob Loblaw,345.12,2017-01-01 + 10000001239,1.05,Lucille Bluth,,2017-01-01 + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ -def test_compare_only_columns_text_is_aligned(comparison4): - stdout = io.StringIO() + data2 = """acct_id,dollar_amt,name,float_fld,date_fld + 10000001234,123.4,George Michael Bluth,14530.155, + 10000001235,0.45,Michael Bluth,, + 10000001236,1345,George Bluth,1, + 10000001237,123456,Robert Loblaw,345.12, + 10000001238,1.05,Loose Seal Bluth,111, + 10000001240,123.45,George Maharis,14530.1555,2017-01-02 + """ + df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=",")) + df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=",")) + compare = SparkCompare(df1, df2, "acct_id") + + output = compare.all_mismatch() + assert output.shape[0] == 4 + assert output.shape[1] == 10 + + assert (output.name_df1 != output.name_df2).values.sum() == 2 + assert (~(output.name_df1 != output.name_df2)).values.sum() == 2 + + assert (output.dollar_amt_df1 != output.dollar_amt_df2).values.sum() == 1 + assert (~(output.dollar_amt_df1 != output.dollar_amt_df2)).values.sum() == 3 + + assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3 + assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1 + + assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4 + assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0 + + +@pandas_version +@pytest.mark.parametrize( + "column,expected", + [ + ("base", 0), + ("floats", 0.2), + ("decimals", 0.1), + ("null_floats", 0.1), + ("strings", 0.1), + ("mixed_strings", 1), + ("infinity", np.inf), + ], +) +def test_calculate_max_diff(column, expected): + MAX_DIFF_DF = ps.DataFrame( + { + "base": [1, 1, 1, 1, 1], + "floats": [1.1, 1.1, 1.1, 1.2, 0.9], + "decimals": [ + Decimal("1.1"), + Decimal("1.1"), + Decimal("1.1"), + Decimal("1.1"), + Decimal("1.1"), + ], + "null_floats": [np.nan, 1.1, 1, 1, 1], + "strings": ["1", "1", "1", "1.1", "1"], + "mixed_strings": ["1", "1", "1", "2", "some string"], + "infinity": [1, 1, 1, 1, np.inf], + } + ) + assert np.isclose( + calculate_max_diff(MAX_DIFF_DF["base"], MAX_DIFF_DF[column]), expected + ) - comparison4.report(file=stdout) - stdout.seek(0) # Back up to the beginning of the stream - text_alignment_validator( - report=stdout, - section_start="****** Columns In Compare Only ******", - section_end="\n", - left_indices=(1, 2), - right_indices=(), - column_regexes=[ - r"""(Column\sName) \s+ (Dtype)""", - r"""(-+) \s+ (-+)""", - r"""(accnt_purge) \s+ (boolean)""", - ], +@pandas_version +def test_dupes_with_nulls_strings(): + df1 = ps.DataFrame( + { + "fld_1": [1, 2, 2, 3, 3, 4, 5, 5], + "fld_2": ["A", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + "fld_3": [1, 2, 2, 3, 3, 4, 5, 5], + } ) + df2 = ps.DataFrame( + { + "fld_1": [1, 2, 3, 4, 5], + "fld_2": ["A", np.nan, np.nan, np.nan, np.nan], + "fld_3": [1, 2, 3, 4, 5], + } + ) + comp = SparkCompare(df1, df2, join_columns=["fld_1", "fld_2"]) + assert comp.subset() + + +@pandas_version +def test_dupes_with_nulls_ints(): + df1 = ps.DataFrame( + { + "fld_1": [1, 2, 2, 3, 3, 4, 5, 5], + "fld_2": [1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + "fld_3": [1, 2, 2, 3, 3, 4, 5, 5], + } + ) + df2 = ps.DataFrame( + { + "fld_1": [1, 2, 3, 4, 5], + "fld_2": [1, np.nan, np.nan, np.nan, np.nan], + "fld_3": [1, 2, 3, 4, 5], + } + ) + comp = SparkCompare(df1, df2, join_columns=["fld_1", "fld_2"]) + assert comp.subset() + + +@pandas_version +@pytest.mark.parametrize( + "dataframe,expected", + [ + (ps.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), ps.Series([0, 0, 0])), + ( + ps.DataFrame({"a": ["a", "a", "DATACOMPY_NULL"], "b": [1, 1, 2]}), + ps.Series([0, 1, 0]), + ), + (ps.DataFrame({"a": [-999, 2, 3], "b": [1, 2, 3]}), ps.Series([0, 0, 0])), + ( + ps.DataFrame({"a": [1, np.nan, np.nan], "b": [1, 2, 2]}), + ps.Series([0, 0, 1]), + ), + ( + ps.DataFrame({"a": ["1", np.nan, np.nan], "b": ["1", "2", "2"]}), + ps.Series([0, 0, 1]), + ), + ( + ps.DataFrame( + {"a": [datetime(2018, 1, 1), np.nan, np.nan], "b": ["1", "2", "2"]} + ), + ps.Series([0, 0, 1]), + ), + ], +) +def test_generate_id_within_group(dataframe, expected): + assert (generate_id_within_group(dataframe, ["a", "b"]) == expected).all() + + +@pandas_version +def test_lower(): + """This function tests the toggle to use lower case for column names or not""" + # should match + df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]}) + df2 = ps.DataFrame({"a": [1, 2, 3], "B": [0, 1, 2]}) + compare = SparkCompare(df1, df2, join_columns=["a"]) + assert compare.matches() + # should not match + df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]}) + df2 = ps.DataFrame({"a": [1, 2, 3], "B": [0, 1, 2]}) + compare = SparkCompare(df1, df2, join_columns=["a"], cast_column_names_lower=False) + assert not compare.matches() + + # test join column + # should match + df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]}) + df2 = ps.DataFrame({"A": [1, 2, 3], "B": [0, 1, 2]}) + compare = SparkCompare(df1, df2, join_columns=["a"]) + assert compare.matches() + # should fail because "a" is not found in df2 + df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]}) + df2 = ps.DataFrame({"A": [1, 2, 3], "B": [0, 1, 2]}) + expected_message = "df2 must have all columns from join_columns" + with raises(ValueError, match=expected_message): + compare = SparkCompare( + df1, df2, join_columns=["a"], cast_column_names_lower=False + ) -def text_alignment_validator( - report, section_start, section_end, left_indices, right_indices, column_regexes -): - r"""Check to make sure that report output columns are vertically aligned. - - Parameters - ---------- - report: An iterable returning lines of report output to be validated. - section_start: A string that represents the beginning of the section to be validated. - section_end: A string that represents the end of the section to be validated. - left_indices: The match group indexes (starting with 1) that should be left-aligned - in the output column. - right_indices: The match group indexes (starting with 1) that should be right-aligned - in the output column. - column_regexes: A list of regular expressions representing the expected output, with - each column enclosed with parentheses to return a match. The regular expression will - use the "X" flag, so it may contain whitespace, and any whitespace to be matched - should be explicitly given with \s. The first line will represent the alignments - that are expected in the following lines. The number of match groups should cover - all of the indices given in left/right_indices. - - Runs assertions for every match group specified by left/right_indices to ensure that - all lines past the first are either left- or right-aligned with the same match group - on the first line. - """ - - at_column_section = False - processed_first_line = False - match_positions = [None] * (len(left_indices + right_indices) + 1) - - for line in report: - if at_column_section: - if line == section_end: # Detect end of section and stop - break - - if ( - not processed_first_line - ): # First line in section - capture text start/end positions - matches = re.search(column_regexes[0], line, re.X) - assert matches is not None # Make sure we found at least this... - - for n in left_indices: - match_positions[n] = matches.start(n) - for n in right_indices: - match_positions[n] = matches.end(n) - processed_first_line = True - else: # Match the stuff after the header text - match = None - for regex in column_regexes[1:]: - match = re.search(regex, line, re.X) - if match: - break - - if not match: - raise AssertionError(f'Did not find a match for line: "{line}"') - - for n in left_indices: - assert match_positions[n] == match.start(n) - for n in right_indices: - assert match_positions[n] == match.end(n) - - if not at_column_section and section_start in line: - at_column_section = True - - -def test_unicode_columns(spark_session): - df1 = spark_session.createDataFrame( - [ - (1, "foo", "test"), - (2, "bar", "test"), - ], - ["id", "例", "予測対象日"], +@pandas_version +def test_integer_column_names(): + """This function tests that integer column names would also work""" + df1 = ps.DataFrame({1: [1, 2, 3], 2: [0, 1, 2]}) + df2 = ps.DataFrame({1: [1, 2, 3], 2: [0, 1, 2]}) + compare = SparkCompare(df1, df2, join_columns=[1]) + assert compare.matches() + + +@pandas_version +@mock.patch("datacompy.spark.render") +def test_save_html(mock_render): + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) + compare = SparkCompare(df1, df2, join_columns=["a"]) + + m = mock.mock_open() + with mock.patch("datacompy.spark.open", m, create=True): + # assert without HTML call + compare.report() + assert mock_render.call_count == 4 + m.assert_not_called() + + mock_render.reset_mock() + m = mock.mock_open() + with mock.patch("datacompy.spark.open", m, create=True): + # assert with HTML call + compare.report(html_file="test.html") + assert mock_render.call_count == 4 + m.assert_called_with("test.html", "w") + + +def test_pandas_version(): + expected_message = "It seems like you are running Pandas 2+. Please note that Pandas 2+ will only be supported in Spark 4+. See: https://issues.apache.org/jira/browse/SPARK-44101. If you need to use Spark DataFrame with Pandas 2+ then consider using Fugue otherwise downgrade to Pandas 1.5.3" + df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) + df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) + with mock.patch("pandas.__version__", "2.0.0"): + with raises(Exception, match=re.escape(expected_message)): + SparkCompare(df1, df2, join_columns=["a"]) + + with mock.patch("pandas.__version__", "1.5.3"): + SparkCompare(df1, df2, join_columns=["a"]) + + +@pandas_version +def test_unicode_columns(): + df1 = ps.DataFrame( + [{"a": 1, "例": 2, "予測対象日": "test"}, {"a": 1, "例": 3, "予測対象日": "test"}] ) - df2 = spark_session.createDataFrame( - [ - (1, "foo", "test"), - (2, "baz", "test"), - ], - ["id", "例", "予測対象日"], + df2 = ps.DataFrame( + [{"a": 1, "例": 2, "予測対象日": "test"}, {"a": 1, "例": 3, "予測対象日": "test"}] ) - compare = SparkCompare(spark_session, df1, df2, join_columns=["例"]) + compare = SparkCompare(df1, df2, join_columns=["例"]) + assert compare.matches() # Just render the report to make sure it renders. - compare.report() + t = compare.report() From eaacf32d96fc50f49c4c44a89b55f93410161682 Mon Sep 17 00:00:00 2001 From: Faisal Date: Mon, 29 Apr 2024 12:06:12 -0300 Subject: [PATCH 2/5] Fix suffix on all mismatches reports (#293) * Fix all_mismatch report attribute to use defined df*_name by users * Add contributor credit * adding in additional changes for df names * adding in additional changes for df names --------- Co-authored-by: enzorooo --- CONTRIBUTORS | 3 +- datacompy/core.py | 44 ++++++++++++++++--------- datacompy/polars.py | 40 ++++++++++++++--------- datacompy/spark.py | 77 ++++++++++++++++++++++++++++++-------------- tests/test_polars.py | 2 +- 5 files changed, 108 insertions(+), 58 deletions(-) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 185f3b4f..e59e6454 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -3,4 +3,5 @@ - Usman Azhar - Mark Zhou - Ian Whitestone -- Faisal Dosani \ No newline at end of file +- Faisal Dosani +- Lorenzo Mercado \ No newline at end of file diff --git a/datacompy/core.py b/datacompy/core.py index a1730768..042dffb4 100644 --- a/datacompy/core.py +++ b/datacompy/core.py @@ -20,6 +20,7 @@ PROC COMPARE in SAS - i.e. human-readable reporting on the difference between two dataframes. """ + import logging import os from typing import Any, Dict, List, Optional, Union, cast @@ -283,7 +284,11 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: self.df2[column] = self.df2[column].str.strip() outer_join = self.df1.merge( - self.df2, how="outer", suffixes=("_df1", "_df2"), indicator=True, **params + self.df2, + how="outer", + suffixes=("_" + self.df1_name, "_" + self.df2_name), + indicator=True, + **params, ) # Clean up temp columns for duplicate row matching if self._any_dupes: @@ -295,8 +300,8 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: self.df1.drop(order_column, axis=1, inplace=True) self.df2.drop(order_column, axis=1, inplace=True) - df1_cols = get_merged_columns(self.df1, outer_join, "_df1") - df2_cols = get_merged_columns(self.df2, outer_join, "_df2") + df1_cols = get_merged_columns(self.df1, outer_join, self.df1_name) + df2_cols = get_merged_columns(self.df2, outer_join, self.df2_name) LOG.debug("Selecting df1 unique rows") self.df1_unq_rows = outer_join[outer_join["_merge"] == "left_only"][ @@ -334,8 +339,8 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: max_diff = 0.0 null_diff = 0 else: - col_1 = column + "_df1" - col_2 = column + "_df2" + col_1 = column + "_" + self.df1_name + col_2 = column + "_" + self.df2_name col_match = column + "_match" self.intersect_rows[col_match] = columns_equal( self.intersect_rows[col_1], @@ -484,7 +489,10 @@ def sample_mismatch( match_cnt = col_match.sum() sample_count = min(sample_count, row_cnt - match_cnt) sample = self.intersect_rows[~col_match].sample(sample_count) - return_cols = self.join_columns + [column + "_df1", column + "_df2"] + return_cols = self.join_columns + [ + column + "_" + self.df1_name, + column + "_" + self.df2_name, + ] to_return = sample[return_cols] if for_display: to_return.columns = pd.Index( @@ -517,8 +525,8 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame: orig_col_name = col[:-6] col_comparison = columns_equal( - self.intersect_rows[orig_col_name + "_df1"], - self.intersect_rows[orig_col_name + "_df2"], + self.intersect_rows[orig_col_name + "_" + self.df1_name], + self.intersect_rows[orig_col_name + "_" + self.df2_name], self.rel_tol, self.abs_tol, self.ignore_spaces, @@ -530,7 +538,12 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame: ): LOG.debug(f"Adding column {orig_col_name} to the result.") match_list.append(col) - return_list.extend([orig_col_name + "_df1", orig_col_name + "_df2"]) + return_list.extend( + [ + orig_col_name + "_" + self.df1_name, + orig_col_name + "_" + self.df2_name, + ] + ) elif ignore_matching_cols: LOG.debug( f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result." @@ -613,7 +626,6 @@ def df_to_str(pdf: pd.DataFrame) -> str: ) # Column Matching - cnt_intersect = self.intersect_rows.shape[0] report += render( "column_comparison.txt", len([col for col in self.column_stats if col["unequal_cnt"] > 0]), @@ -804,7 +816,7 @@ def columns_equal( compare = pd.Series( (col_1 == col_2) | (col_1.isnull() & col_2.isnull()) ) - except: + except Exception: # Blanket exception should just return all False compare = pd.Series(False, index=col_1.index) compare.index = col_1.index @@ -842,13 +854,13 @@ def compare_string_and_date_columns( (pd.to_datetime(obj_column) == date_column) | (obj_column.isnull() & date_column.isnull()) ) - except: + except Exception: try: return pd.Series( (pd.to_datetime(obj_column, format="mixed") == date_column) | (obj_column.isnull() & date_column.isnull()) ) - except: + except Exception: return pd.Series(False, index=col_1.index) @@ -871,8 +883,8 @@ def get_merged_columns( for col in original_df.columns: if col in merged_df.columns: columns.append(col) - elif col + suffix in merged_df.columns: - columns.append(col + suffix) + elif col + "_" + suffix in merged_df.columns: + columns.append(col + "_" + suffix) else: raise ValueError("Column not found: %s", col) return columns @@ -920,7 +932,7 @@ def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> floa """ try: return cast(float, (col_1.astype(float) - col_2.astype(float)).abs().max()) - except: + except Exception: return 0.0 diff --git a/datacompy/polars.py b/datacompy/polars.py index 814a7cd6..aca96296 100644 --- a/datacompy/polars.py +++ b/datacompy/polars.py @@ -20,6 +20,7 @@ PROC COMPARE in SAS - i.e. human-readable reporting on the difference between two dataframes. """ + import logging import os from copy import deepcopy @@ -265,9 +266,9 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: df2_non_join_columns = OrderedSet(df2.columns) - OrderedSet(temp_join_columns) for c in df1_non_join_columns: - df1 = df1.rename({c: c + "_df1"}) + df1 = df1.rename({c: c + "_" + self.df1_name}) for c in df2_non_join_columns: - df2 = df2.rename({c: c + "_df2"}) + df2 = df2.rename({c: c + "_" + self.df2_name}) # generate merge indicator df1 = df1.with_columns(_merge_left=pl.lit(True)) @@ -290,8 +291,8 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: if self._any_dupes: outer_join = outer_join.drop(order_column) - df1_cols = get_merged_columns(self.df1, outer_join, "_df1") - df2_cols = get_merged_columns(self.df2, outer_join, "_df2") + df1_cols = get_merged_columns(self.df1, outer_join, self.df1_name) + df2_cols = get_merged_columns(self.df2, outer_join, self.df2_name) LOG.debug("Selecting df1 unique rows") self.df1_unq_rows = outer_join.filter( @@ -333,8 +334,8 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: max_diff = 0.0 null_diff = 0 else: - col_1 = column + "_df1" - col_2 = column + "_df2" + col_1 = column + "_" + self.df1_name + col_2 = column + "_" + self.df2_name col_match = column + "_match" self.intersect_rows = self.intersect_rows.with_columns( columns_equal( @@ -499,7 +500,10 @@ def sample_mismatch( sample = self.intersect_rows.filter(pl.col(column + "_match") != True).sample( sample_count ) - return_cols = self.join_columns + [column + "_df1", column + "_df2"] + return_cols = self.join_columns + [ + column + "_" + self.df1_name, + column + "_" + self.df2_name, + ] to_return = sample[return_cols] if for_display: to_return.columns = self.join_columns + [ @@ -529,8 +533,8 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame": orig_col_name = col[:-6] col_comparison = columns_equal( - self.intersect_rows[orig_col_name + "_df1"], - self.intersect_rows[orig_col_name + "_df2"], + self.intersect_rows[orig_col_name + "_" + self.df1_name], + self.intersect_rows[orig_col_name + "_" + self.df2_name], self.rel_tol, self.abs_tol, self.ignore_spaces, @@ -542,7 +546,12 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame": ): LOG.debug(f"Adding column {orig_col_name} to the result.") match_list.append(col) - return_list.extend([orig_col_name + "_df1", orig_col_name + "_df2"]) + return_list.extend( + [ + orig_col_name + "_" + self.df1_name, + orig_col_name + "_" + self.df2_name, + ] + ) elif ignore_matching_cols: LOG.debug( f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result." @@ -622,7 +631,6 @@ def df_to_str(pdf: "pl.DataFrame") -> str: ) # Column Matching - cnt_intersect = self.intersect_rows.shape[0] report += render( "column_comparison.txt", len([col for col in self.column_stats if col["unequal_cnt"] > 0]), @@ -824,7 +832,7 @@ def columns_equal( compare = pl.Series( (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null()) ) - except: + except Exception: # Blanket exception should just return all False compare = pl.Series(False * col_1.shape[0]) return compare @@ -861,7 +869,7 @@ def compare_string_and_date_columns( (str_column.str.to_datetime().eq_missing(date_column)) | (str_column.is_null() & date_column.is_null()) ) - except: + except Exception: return pl.Series([False] * col_1.shape[0]) @@ -884,8 +892,8 @@ def get_merged_columns( for col in original_df.columns: if col in merged_df.columns: columns.append(col) - elif col + suffix in merged_df.columns: - columns.append(col + suffix) + elif col + "_" + suffix in merged_df.columns: + columns.append(col + "_" + suffix) else: raise ValueError("Column not found: %s", col) return columns @@ -935,7 +943,7 @@ def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float: return cast( float, (col_1.cast(pl.Float64) - col_2.cast(pl.Float64)).abs().max() ) - except: + except Exception: return 0.0 diff --git a/datacompy/spark.py b/datacompy/spark.py index cfa90397..070a58e5 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -264,23 +264,28 @@ def _dataframe_merge(self, ignore_spaces): ) - OrderedSet(self.join_columns) for c in non_join_columns: - df1.rename(columns={c: c + "_df1"}, inplace=True) - df2.rename(columns={c: c + "_df2"}, inplace=True) + df1.rename(columns={c: c + "_" + self.df1_name}, inplace=True) + df2.rename(columns={c: c + "_" + self.df2_name}, inplace=True) # generate merge indicator df1["_merge_left"] = True df2["_merge_right"] = True for c in self.join_columns: - df1.rename(columns={c: c + "_df1"}, inplace=True) - df2.rename(columns={c: c + "_df2"}, inplace=True) + df1.rename(columns={c: c + "_" + self.df1_name}, inplace=True) + df2.rename(columns={c: c + "_" + self.df2_name}, inplace=True) # cache df1.spark.cache() df2.spark.cache() # NULL SAFE Outer join using ON - on = " and ".join([f"df1.`{c}_df1` <=> df2.`{c}_df2`" for c in params["on"]]) + on = " and ".join( + [ + f"df1.`{c}_{self.df1_name}` <=> df2.`{c}_{self.df2_name}`" + for c in params["on"] + ] + ) outer_join = ps.sql( """ SELECT * FROM @@ -311,13 +316,29 @@ def _dataframe_merge(self, ignore_spaces): # Clean up temp columns for duplicate row matching if self._any_dupes: outer_join = outer_join.drop( - [order_column + "_df1", order_column + "_df2"], axis=1 + [ + order_column + "_" + self.df1_name, + order_column + "_" + self.df2_name, + ], + axis=1, + ) + df1 = df1.drop( + [ + order_column + "_" + self.df1_name, + order_column + "_" + self.df2_name, + ], + axis=1, + ) + df2 = df2.drop( + [ + order_column + "_" + self.df1_name, + order_column + "_" + self.df2_name, + ], + axis=1, ) - df1 = df1.drop([order_column + "_df1", order_column + "_df2"], axis=1) - df2 = df2.drop([order_column + "_df1", order_column + "_df2"], axis=1) - df1_cols = get_merged_columns(df1, outer_join, "_df1") - df2_cols = get_merged_columns(df2, outer_join, "_df2") + df1_cols = get_merged_columns(df1, outer_join, self.df1_name) + df2_cols = get_merged_columns(df2, outer_join, self.df2_name) LOG.debug("Selecting df1 unique rows") self.df1_unq_rows = outer_join[outer_join["_merge"] == "left_only"][ @@ -356,8 +377,8 @@ def _intersect_compare(self, ignore_spaces, ignore_case): max_diff = 0 null_diff = 0 else: - col_1 = column + "_df1" - col_2 = column + "_df2" + col_1 = column + "_" + self.df1_name + col_2 = column + "_" + self.df2_name col_match = column + "_match" self.intersect_rows[col_match] = columns_equal( self.intersect_rows[col_1], @@ -511,9 +532,12 @@ def sample_mismatch(self, column, sample_count=10, for_display=False): sample = self.intersect_rows[~col_match].head(sample_count) for c in self.join_columns: - sample[c] = sample[c + "_df1"] + sample[c] = sample[c + "_" + self.df1_name] - return_cols = self.join_columns + [column + "_df1", column + "_df2"] + return_cols = self.join_columns + [ + column + "_" + self.df1_name, + column + "_" + self.df2_name, + ] to_return = sample[return_cols] if for_display: to_return.columns = self.join_columns + [ @@ -543,8 +567,8 @@ def all_mismatch(self, ignore_matching_cols=False): orig_col_name = col[:-6] col_comparison = columns_equal( - self.intersect_rows[orig_col_name + "_df1"], - self.intersect_rows[orig_col_name + "_df2"], + self.intersect_rows[orig_col_name + "_" + self.df1_name], + self.intersect_rows[orig_col_name + "_" + self.df2_name], self.rel_tol, self.abs_tol, self.ignore_spaces, @@ -556,7 +580,12 @@ def all_mismatch(self, ignore_matching_cols=False): ): LOG.debug(f"Adding column {orig_col_name} to the result.") match_list.append(col) - return_list.extend([orig_col_name + "_df1", orig_col_name + "_df2"]) + return_list.extend( + [ + orig_col_name + "_" + self.df1_name, + orig_col_name + "_" + self.df2_name, + ] + ) elif ignore_matching_cols: LOG.debug( f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result." @@ -566,8 +595,8 @@ def all_mismatch(self, ignore_matching_cols=False): updated_join_columns = [] for c in self.join_columns: - updated_join_columns.append(c + "_df1") - updated_join_columns.append(c + "_df2") + updated_join_columns.append(c + "_" + self.df1_name) + updated_join_columns.append(c + "_" + self.df2_name) return self.intersect_rows[~mm_bool][updated_join_columns + return_list] @@ -818,7 +847,7 @@ def columns_equal( col_1_temp.isnull() & col_2_temp.isnull() ) - except: + except Exception: # Blanket exception should just return all False compare = ps.Series(False, index=col_1.index.to_numpy()) return compare @@ -855,7 +884,7 @@ def compare_string_and_date_columns(col_1, col_2): | (obj_column.isnull() & date_column.isnull()) ).to_numpy() ) # force compute - except: + except Exception: compare = ps.Series(False, index=col_1.index.to_numpy()) return compare @@ -877,8 +906,8 @@ def get_merged_columns(original_df, merged_df, suffix): for col in original_df.columns: if col in merged_df.columns: columns.append(col) - elif col + suffix in merged_df.columns: - columns.append(col + suffix) + elif col + "_" + suffix in merged_df.columns: + columns.append(col + "_" + suffix) else: raise ValueError("Column not found: %s", col) return columns @@ -926,7 +955,7 @@ def calculate_max_diff(col_1, col_2): """ try: return (col_1.astype(float) - col_2.astype(float)).abs().max() - except: + except Exception: return 0 diff --git a/tests/test_polars.py b/tests/test_polars.py index aabbcad1..679a9ab7 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -1231,7 +1231,7 @@ def test_dupes_with_nulls(): ), ( pl.DataFrame( - {"a": [datetime(2018, 1, 1), np.nan, np.nan], "b": ["1", "2", "2"]} + {"a": [datetime(2018, 1, 1), None, None], "b": ["1", "2", "2"]} ), pl.Series([1, 1, 2]), ), From d242d493564e8e289fdf3372667eb6f6cbe3d012 Mon Sep 17 00:00:00 2001 From: Mark Elliot <123787712+mark-thm@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:45:22 -0400 Subject: [PATCH 3/5] Update pandas support to 2.2.2 (#295) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8ece29a1..3bde2d70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ maintainers = [ { name="Faisal Dosani", email="faisal.dosani@capitalone.com" } ] license = {text = "Apache Software License"} -dependencies = ["pandas<=2.2.1,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.8.7,>=0.8.7"] +dependencies = ["pandas<=2.2.2,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.8.7,>=0.8.7"] requires-python = ">=3.9.0" classifiers = [ "Intended Audience :: Developers", From bcedfdc73840c2e189a1f5b7cae84e62d21c22f1 Mon Sep 17 00:00:00 2001 From: Faisal Date: Tue, 30 Apr 2024 10:44:25 -0300 Subject: [PATCH 4/5] bump pandas actions to 2.2.2 (#296) --- .github/workflows/test-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 2c06b2d2..c7899383 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -21,7 +21,7 @@ jobs: matrix: python-version: [3.9, '3.10', '3.11'] spark-version: [3.2.4, 3.3.4, 3.4.2, 3.5.1] - pandas-version: [2.2.1, 1.5.3] + pandas-version: [2.2.2, 1.5.3] exclude: - python-version: '3.11' spark-version: 3.2.4 From 889235ce0fd398ff3b50c55ebaab8b0d01a943e9 Mon Sep 17 00:00:00 2001 From: Faisal Date: Tue, 30 Apr 2024 10:44:43 -0300 Subject: [PATCH 5/5] Fugue count matching rows (#294) * adding in count_matching_rows * linting / cleanup --- datacompy/__init__.py | 1 + datacompy/fugue.py | 96 ++++++++++++++++++++++++++- tests/test_fugue/conftest.py | 18 ++++- tests/test_fugue/test_duckdb.py | 38 +++++++++++ tests/test_fugue/test_fugue_pandas.py | 40 ++++++++++- tests/test_fugue/test_fugue_polars.py | 35 ++++++++++ tests/test_fugue/test_fugue_spark.py | 42 ++++++++++++ 7 files changed, 266 insertions(+), 4 deletions(-) diff --git a/datacompy/__init__.py b/datacompy/__init__.py index b47d27c3..b43027ae 100644 --- a/datacompy/__init__.py +++ b/datacompy/__init__.py @@ -19,6 +19,7 @@ from datacompy.fugue import ( all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, report, diff --git a/datacompy/fugue.py b/datacompy/fugue.py index 2ac4889a..8bc01d33 100644 --- a/datacompy/fugue.py +++ b/datacompy/fugue.py @@ -291,6 +291,101 @@ def all_rows_overlap( return all(overlap) +def count_matching_rows( + df1: AnyDataFrame, + df2: AnyDataFrame, + join_columns: Union[str, List[str]], + abs_tol: float = 0, + rel_tol: float = 0, + df1_name: str = "df1", + df2_name: str = "df2", + ignore_spaces: bool = False, + ignore_case: bool = False, + cast_column_names_lower: bool = True, + parallelism: Optional[int] = None, + strict_schema: bool = False, +) -> int: + """Count the number of rows match (on overlapping fields) + + Parameters + ---------- + df1 : ``AnyDataFrame`` + First dataframe to check + df2 : ``AnyDataFrame`` + Second dataframe to check + join_columns : list or str, optional + Column(s) to join dataframes on. If a string is passed in, that one + column will be used. + abs_tol : float, optional + Absolute tolerance between two values. + rel_tol : float, optional + Relative tolerance between two values. + df1_name : str, optional + A string name for the first dataframe. This allows the reporting to + print out an actual name instead of "df1", and allows human users to + more easily track the dataframes. + df2_name : str, optional + A string name for the second dataframe + ignore_spaces : bool, optional + Flag to strip whitespace (including newlines) from string columns (including any join + columns) + ignore_case : bool, optional + Flag to ignore the case of string columns + cast_column_names_lower: bool, optional + Boolean indicator that controls of column names will be cast into lower case + parallelism: int, optional + An integer representing the amount of parallelism. Entering a value for this + will force to use of Fugue over just vanilla Pandas + strict_schema: bool, optional + The schema must match exactly if set to ``True``. This includes the names and types. Allows for a fast fail. + + Returns + ------- + int + Number of matching rows + """ + if ( + isinstance(df1, pd.DataFrame) + and isinstance(df2, pd.DataFrame) + and parallelism is None # user did not specify parallelism + and fa.get_current_parallelism() == 1 # currently on a local execution engine + ): + comp = Compare( + df1=df1, + df2=df2, + join_columns=join_columns, + abs_tol=abs_tol, + rel_tol=rel_tol, + df1_name=df1_name, + df2_name=df2_name, + ignore_spaces=ignore_spaces, + ignore_case=ignore_case, + cast_column_names_lower=cast_column_names_lower, + ) + return comp.count_matching_rows() + + try: + count_matching_rows = _distributed_compare( + df1=df1, + df2=df2, + join_columns=join_columns, + return_obj_func=lambda comp: comp.count_matching_rows(), + abs_tol=abs_tol, + rel_tol=rel_tol, + df1_name=df1_name, + df2_name=df2_name, + ignore_spaces=ignore_spaces, + ignore_case=ignore_case, + cast_column_names_lower=cast_column_names_lower, + parallelism=parallelism, + strict_schema=strict_schema, + ) + except _StrictSchemaError: + return False + + return sum(count_matching_rows) + + def report( df1: AnyDataFrame, df2: AnyDataFrame, @@ -460,7 +555,6 @@ def _any(col: str) -> int: any_mismatch = len(match_sample) > 0 # Column Matching - cnt_intersect = shape0("intersect_rows_shape") rpt += render( "column_comparison.txt", len([col for col in column_stats if col["unequal_cnt"] > 0]), diff --git a/tests/test_fugue/conftest.py b/tests/test_fugue/conftest.py index 6a5683d2..a2ca99b1 100644 --- a/tests/test_fugue/conftest.py +++ b/tests/test_fugue/conftest.py @@ -1,6 +1,6 @@ -import pytest import numpy as np import pandas as pd +import pytest @pytest.fixture @@ -24,7 +24,8 @@ def ref_df(): c=np.random.choice(["aaa", "b_c", "csd"], 100), ) ) - return [df1, df1_copy, df2, df3, df4] + df5 = df1.sample(frac=0.1) + return [df1, df1_copy, df2, df3, df4, df5] @pytest.fixture @@ -87,3 +88,16 @@ def large_diff_df2(): np.random.seed(0) data = np.random.randint(6, 11, size=10000) return pd.DataFrame({"x": data, "y": np.array([9] * 10000)}).convert_dtypes() + + +@pytest.fixture +def count_matching_rows_df(): + np.random.seed(0) + df1 = pd.DataFrame( + dict( + a=np.arange(0, 100), + b=np.arange(0, 100), + ) + ) + df2 = df1.sample(frac=0.1) + return [df1, df2] diff --git a/tests/test_fugue/test_duckdb.py b/tests/test_fugue/test_duckdb.py index daed1edd..3643f22d 100644 --- a/tests/test_fugue/test_duckdb.py +++ b/tests/test_fugue/test_duckdb.py @@ -20,6 +20,7 @@ from datacompy import ( all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, unq_columns, @@ -138,3 +139,40 @@ def test_all_rows_overlap_duckdb( duckdb.sql("SELECT 'a' AS a, 'b' AS b"), join_columns="a", ) + + +def test_count_matching_rows_duckdb(count_matching_rows_df): + with duckdb.connect(): + df1 = duckdb.from_df(count_matching_rows_df[0]) + df1_copy = duckdb.from_df(count_matching_rows_df[0]) + df2 = duckdb.from_df(count_matching_rows_df[1]) + + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + ) + == 100 + ) + assert count_matching_rows(df1, df2, join_columns="a") == 10 + # Fugue + + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + df1, + df2, + join_columns="a", + parallelism=2, + ) + == 10 + ) diff --git a/tests/test_fugue/test_fugue_pandas.py b/tests/test_fugue/test_fugue_pandas.py index 77884c2c..4fd74ce7 100644 --- a/tests/test_fugue/test_fugue_pandas.py +++ b/tests/test_fugue/test_fugue_pandas.py @@ -24,6 +24,7 @@ Compare, all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, report, @@ -144,7 +145,6 @@ def test_report_pandas( def test_unique_columns_native(ref_df): df1 = ref_df[0] - df1_copy = ref_df[1] df2 = ref_df[2] df3 = ref_df[3] @@ -192,3 +192,41 @@ def test_all_rows_overlap_native( # Fugue assert all_rows_overlap(ref_df[0], shuffle_df, join_columns="a", parallelism=2) assert not all_rows_overlap(ref_df[0], ref_df[4], join_columns="a", parallelism=2) + + +def test_count_matching_rows_native(count_matching_rows_df): + # defaults to Compare class + assert ( + count_matching_rows( + count_matching_rows_df[0], + count_matching_rows_df[0].copy(), + join_columns="a", + ) + == 100 + ) + assert ( + count_matching_rows( + count_matching_rows_df[0], count_matching_rows_df[1], join_columns="a" + ) + == 10 + ) + # Fugue + + assert ( + count_matching_rows( + count_matching_rows_df[0], + count_matching_rows_df[0].copy(), + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + count_matching_rows_df[0], + count_matching_rows_df[1], + join_columns="a", + parallelism=2, + ) + == 10 + ) diff --git a/tests/test_fugue/test_fugue_polars.py b/tests/test_fugue/test_fugue_polars.py index fdb2212a..dcd19a94 100644 --- a/tests/test_fugue/test_fugue_polars.py +++ b/tests/test_fugue/test_fugue_polars.py @@ -20,6 +20,7 @@ from datacompy import ( all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, unq_columns, @@ -122,3 +123,37 @@ def test_all_rows_overlap_polars( assert all_rows_overlap(rdf, rdf_copy, join_columns="a") assert all_rows_overlap(rdf, sdf, join_columns="a") assert not all_rows_overlap(rdf, rdf4, join_columns="a") + + +def test_count_matching_rows_polars(count_matching_rows_df): + df1 = pl.from_pandas(count_matching_rows_df[0]) + df2 = pl.from_pandas(count_matching_rows_df[1]) + assert ( + count_matching_rows( + df1, + df1.clone(), + join_columns="a", + ) + == 100 + ) + assert count_matching_rows(df1, df2, join_columns="a") == 10 + # Fugue + + assert ( + count_matching_rows( + df1, + df1.clone(), + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + df1, + df2, + join_columns="a", + parallelism=2, + ) + == 10 + ) diff --git a/tests/test_fugue/test_fugue_spark.py b/tests/test_fugue/test_fugue_spark.py index 99da708b..efc895ff 100644 --- a/tests/test_fugue/test_fugue_spark.py +++ b/tests/test_fugue/test_fugue_spark.py @@ -22,6 +22,7 @@ Compare, all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, report, @@ -200,3 +201,44 @@ def test_all_rows_overlap_spark( spark_session.sql("SELECT 'a' AS a, 'b' AS b"), join_columns="a", ) + + +def test_count_matching_rows_spark(spark_session, count_matching_rows_df): + count_matching_rows_df[0].iteritems = count_matching_rows_df[ + 0 + ].items # pandas 2 compatibility + count_matching_rows_df[1].iteritems = count_matching_rows_df[ + 1 + ].items # pandas 2 compatibility + df1 = spark_session.createDataFrame(count_matching_rows_df[0]) + df1_copy = spark_session.createDataFrame(count_matching_rows_df[0]) + df2 = spark_session.createDataFrame(count_matching_rows_df[1]) + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + ) + == 100 + ) + assert count_matching_rows(df1, df2, join_columns="a") == 10 + # Fugue + + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + df1, + df2, + join_columns="a", + parallelism=2, + ) + == 10 + )