From a42efa0ceb5541ec24ada849f788369f086c2ff9 Mon Sep 17 00:00:00 2001 From: Faisal Date: Wed, 17 Jul 2024 22:06:40 -0300 Subject: [PATCH 1/3] adding Polars v1 tweaks for testing (#325) --- datacompy/polars.py | 2 +- pyproject.toml | 2 +- tests/test_polars.py | 54 ++++++++++++++++++++++++++++---------------- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/datacompy/polars.py b/datacompy/polars.py index 3dbf82ff..1c0f8d56 100644 --- a/datacompy/polars.py +++ b/datacompy/polars.py @@ -274,7 +274,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: df1 = df1.with_columns(_merge_left=pl.lit(True)) df2 = df2.with_columns(_merge_right=pl.lit(True)) - outer_join = df1.join(df2, how="outer_coalesce", join_nulls=True, **params) + outer_join = df1.join(df2, how="full", coalesce=True, join_nulls=True, **params) # process merge indicator outer_join = outer_join.with_columns( diff --git a/pyproject.toml b/pyproject.toml index 806c7ac2..19075892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ maintainers = [ { name="Faisal Dosani", email="faisal.dosani@capitalone.com" } ] license = {text = "Apache Software License"} -dependencies = ["pandas<=2.2.2,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.9.1,>=0.8.7", "polars<=0.20.31,>=0.20.4"] +dependencies = ["pandas<=2.2.2,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.9.1,>=0.8.7", "polars<=1.1.0,>=0.20.4"] requires-python = ">=3.9.0" classifiers = [ "Intended Audience :: Developers", diff --git a/tests/test_polars.py b/tests/test_polars.py index c878cbba..0640cd29 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -389,18 +389,14 @@ def test_compare_df_setter_bad(): PolarsCompare(df, df.clone(), ["b"]) with raises(DuplicateError, match="duplicate column names found"): PolarsCompare(df_same_col_names, df_same_col_names.clone(), ["a"]) - assert ( - PolarsCompare(df_dupe, df_dupe.clone(), ["a", "b"]) - .df1.drop("_merge_left") - .equals(df_dupe) - ) + assert PolarsCompare(df_dupe, df_dupe.clone(), ["a", "b"]).df1.equals(df_dupe) def test_compare_df_setter_good(): df1 = pl.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}]) df2 = pl.DataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}]) compare = PolarsCompare(df1, df2, ["a"]) - assert compare.df1.drop("_merge_left").equals(df1) + assert compare.df1.equals(df1) assert compare.df2.equals(df2) assert compare.join_columns == ["a"] compare = PolarsCompare(df1, df2, ["A", "b"]) @@ -1177,10 +1173,12 @@ def test_all_mismatch_ignore_matching_cols_no_cols_matching(): "strings": ["1", "1", "1", "1.1", "1"], "mixed_strings": ["1", "1", "1", "2", "some string"], "infinity": [1, 1, 1, 1, np.inf], - } + }, + strict=False, ) +@pytest.mark.skipif(pl.__version__ < "1.0.0", reason="polars breaking changes") @pytest.mark.parametrize( "column,expected", [ @@ -1204,10 +1202,12 @@ def test_dupes_with_nulls(): { "fld_1": [1, 2, 2, 3, 3, 4, 5, 5], "fld_2": ["A", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], - } + }, + strict=False, ) df2 = pl.DataFrame( - {"fld_1": [1, 2, 3, 4, 5], "fld_2": ["A", np.nan, np.nan, np.nan, np.nan]} + {"fld_1": [1, 2, 3, 4, 5], "fld_2": ["A", np.nan, np.nan, np.nan, np.nan]}, + strict=False, ) comp = PolarsCompare(df1, df2, join_columns=["fld_1", "fld_2"]) assert comp.subset() @@ -1216,25 +1216,36 @@ def test_dupes_with_nulls(): @pytest.mark.parametrize( "dataframe,expected", [ - (pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), pl.Series([1, 1, 1])), ( - pl.DataFrame({"a": ["a", "a", "DATACOMPY_NULL"], "b": [1, 1, 2]}), - pl.Series([1, 2, 1]), + pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), + pl.Series([1, 1, 1], strict=False), + ), + ( + pl.DataFrame( + {"a": ["a", "a", "DATACOMPY_NULL"], "b": [1, 1, 2]}, strict=False + ), + pl.Series([1, 2, 1], strict=False), ), - (pl.DataFrame({"a": [-999, 2, 3], "b": [1, 2, 3]}), pl.Series([1, 1, 1])), ( - pl.DataFrame({"a": [1, np.nan, np.nan], "b": [1, 2, 2]}), - pl.Series([1, 1, 2]), + pl.DataFrame({"a": [-999, 2, 3], "b": [1, 2, 3]}, strict=False), + pl.Series([1, 1, 1], strict=False), ), ( - pl.DataFrame({"a": ["1", np.nan, np.nan], "b": ["1", "2", "2"]}), - pl.Series([1, 1, 2]), + pl.DataFrame({"a": [1, np.nan, np.nan], "b": [1, 2, 2]}, strict=False), + pl.Series([1, 1, 2], strict=False), ), ( pl.DataFrame( - {"a": [datetime(2018, 1, 1), None, None], "b": ["1", "2", "2"]} + {"a": ["1", np.nan, np.nan], "b": ["1", "2", "2"]}, strict=False ), - pl.Series([1, 1, 2]), + pl.Series([1, 1, 2], strict=False), + ), + ( + pl.DataFrame( + {"a": [datetime(2018, 1, 1), None, None], "b": ["1", "2", "2"]}, + strict=False, + ), + pl.Series([1, 1, 2], strict=False), ), ], ) @@ -1242,11 +1253,14 @@ def test_generate_id_within_group(dataframe, expected): assert (generate_id_within_group(dataframe, ["a", "b"]) == expected).all() +@pytest.mark.skipif(pl.__version__ < "1.0.0", reason="polars breaking changes") @pytest.mark.parametrize( "dataframe, message", [ ( - pl.DataFrame({"a": [1, np.nan, "DATACOMPY_NULL"], "b": [1, 2, 3]}), + pl.DataFrame( + {"a": [1, None, "DATACOMPY_NULL"], "b": [1, 2, 3]}, strict=False + ), "DATACOMPY_NULL was found in your join columns", ) ], From 3c7895c3bbfa9714140030a3c28f64fb0c2d9087 Mon Sep 17 00:00:00 2001 From: Jacob Dawang Date: Sat, 10 Aug 2024 11:57:25 -0600 Subject: [PATCH 2/3] Ruff (#326) * Ruff safe fixes * Ruff unsafe fixes * Manual fixes * Update workflow * Reexport from base * Fix boolean comparison * Simplify boolean expr * Simplify subset bool expr --- .github/workflows/test-package.yml | 14 +++ .pre-commit-config.yaml | 31 +++-- datacompy/__init__.py | 46 +++++++- datacompy/base.py | 40 +++++-- datacompy/core.py | 105 ++++++++++------- datacompy/fugue.py | 49 ++++---- datacompy/polars.py | 125 +++++++++++--------- datacompy/spark/__init__.py | 1 + datacompy/spark/legacy.py | 164 ++++++++++++-------------- datacompy/spark/pandas.py | 106 +++++++++-------- datacompy/spark/sql.py | 143 ++++++++++++---------- pyproject.toml | 61 ++++++++-- tests/test_core.py | 12 +- tests/test_fugue/conftest.py | 50 ++++---- tests/test_fugue/test_duckdb.py | 6 +- tests/test_fugue/test_fugue_pandas.py | 8 +- tests/test_fugue/test_fugue_polars.py | 6 +- tests/test_fugue/test_fugue_spark.py | 8 +- tests/test_polars.py | 21 ++-- tests/test_spark/test_legacy_spark.py | 15 ++- tests/test_spark/test_pandas_spark.py | 33 ++++-- tests/test_spark/test_sql_spark.py | 14 +-- 22 files changed, 616 insertions(+), 442 deletions(-) diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index b710d6e5..1462fc91 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -13,6 +13,20 @@ permissions: contents: read jobs: + lint-and-format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: "3.9" + - name: Install dependencies + run: python -m pip install .[qa] + - name: Linting by ruff + run: ruff check + - name: Formatting by ruff + run: ruff format --check test-dev-install: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fca74087..d9fa3008 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,25 @@ repos: - - repo: https://github.com/psf/black - rev: 23.3.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.7 hooks: - - id: black - types: [file, python] - language_version: python3.10 - - repo: https://github.com/pycqa/isort - rev: 5.12.0 + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format + types_or: [ python, jupyter ] + # # Mypy: Optional static type checking + # # https://github.com/pre-commit/mirrors-mypy + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.11.1 + # hooks: + # - id: mypy + # exclude: ^(docs|tests)\/ + # language_version: python3.9 + # args: [--namespace-packages, --explicit-package-bases, --ignore-missing-imports, --non-interactive, --install-types] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 hooks: - - id: isort - name: isort (python) \ No newline at end of file + - id: trailing-whitespace + - id: debug-statements + - id: end-of-file-fixer diff --git a/datacompy/__init__.py b/datacompy/__init__.py index 6b02f0d8..fad85ae6 100644 --- a/datacompy/__init__.py +++ b/datacompy/__init__.py @@ -12,14 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""DataComPy is a package to compare two Pandas DataFrames. + +Originally started to be something of a replacement for SAS's PROC COMPARE for Pandas DataFrames with some more functionality than just Pandas.DataFrame.equals(Pandas.DataFrame) (in that it prints out some stats, and lets you tweak how accurate matches have to be). +Then extended to carry that functionality over to Spark Dataframes. +""" __version__ = "0.13.2" import platform from warnings import warn -from .core import * # noqa: F403 -from .fugue import ( # noqa: F401 +from datacompy.base import BaseCompare, temp_column_name +from datacompy.core import ( + Compare, + calculate_max_diff, + columns_equal, + compare_string_and_date_columns, + generate_id_within_group, + get_merged_columns, + render, +) +from datacompy.fugue import ( all_columns_match, all_rows_overlap, count_matching_rows, @@ -28,9 +42,31 @@ report, unq_columns, ) -from .polars import PolarsCompare # noqa: F401 -from .spark.pandas import SparkPandasCompare # noqa: F401 -from .spark.sql import SparkSQLCompare # noqa: F401 +from datacompy.polars import PolarsCompare +from datacompy.spark.pandas import SparkPandasCompare +from datacompy.spark.sql import SparkSQLCompare + +__all__ = [ + "BaseCompare", + "Compare", + "PolarsCompare", + "SparkPandasCompare", + "SparkSQLCompare", + "all_columns_match", + "all_rows_overlap", + "calculate_max_diff", + "columns_equal", + "compare_string_and_date_columns", + "count_matching_rows", + "generate_id_within_group", + "get_merged_columns", + "intersect_columns", + "is_match", + "render", + "report", + "temp_column_name", + "unq_columns", +] major = platform.python_version_tuple()[0] minor = platform.python_version_tuple()[1] diff --git a/datacompy/base.py b/datacompy/base.py index 6ac54afe..d79fa0c0 100644 --- a/datacompy/base.py +++ b/datacompy/base.py @@ -14,7 +14,7 @@ # limitations under the License. """ -Compare two Pandas DataFrames +Compare two Pandas DataFrames. Originally this package was meant to provide similar functionality to PROC COMPARE in SAS - i.e. human-readable reporting on the difference between @@ -31,36 +31,42 @@ class BaseCompare(ABC): + """Base comparison class.""" + @property def df1(self) -> Any: + """Get the first dataframe.""" return self._df1 # type: ignore @df1.setter @abstractmethod def df1(self, df1: Any) -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" pass @property def df2(self) -> Any: + """Get the second dataframe.""" return self._df2 # type: ignore @df2.setter @abstractmethod def df2(self, df2: Any) -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" pass @abstractmethod def _validate_dataframe( self, index: str, cast_column_names_lower: bool = True ) -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" pass @abstractmethod def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Actually run the comparison. This tries to run df1.equals(df2) + """Run the comparison. + + This tries to run df1.equals(df2) first so that if they're truly equal we can tell. This method will log out information about what is different between @@ -70,23 +76,25 @@ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: @abstractmethod def df1_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df1""" + """Get columns that are unique to df1.""" pass @abstractmethod def df2_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df2""" + """Get columns that are unique to df2.""" pass @abstractmethod def intersect_columns(self) -> OrderedSet[str]: - """Get columns that are shared between the two dataframes""" + """Get columns that are shared between the two dataframes.""" pass @abstractmethod def _dataframe_merge(self, ignore_spaces: bool) -> None: - """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1 - and df1 & df2 + """Merge df1 to df2 on the join columns. + + To get df1 - df2, df2 - df1 + and df1 & df2. If ``on_index`` is True, this will join on index values, otherwise it will join on the ``join_columns``. @@ -95,40 +103,49 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: @abstractmethod def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: + """Compare the intersection of the two dataframes.""" pass @abstractmethod def all_columns_match(self) -> bool: + """Check if all columns match.""" pass @abstractmethod def all_rows_overlap(self) -> bool: + """Check if all rows overlap.""" pass @abstractmethod def count_matching_rows(self) -> int: + """Count the number of matchin grows.""" pass @abstractmethod def intersect_rows_match(self) -> bool: + """Check if the intersection of rows match.""" pass @abstractmethod def matches(self, ignore_extra_columns: bool = False) -> bool: + """Check if the dataframes match.""" pass @abstractmethod def subset(self) -> bool: + """Check if one dataframe is a subset of the other.""" pass @abstractmethod def sample_mismatch( self, column: str, sample_count: int = 10, for_display: bool = False ) -> Any: + """Get a sample of rows that mismatch.""" pass @abstractmethod def all_mismatch(self, ignore_matching_cols: bool = False) -> Any: + """Get all rows that mismatch.""" pass @abstractmethod @@ -138,11 +155,12 @@ def report( column_count: int = 10, html_file: Optional[str] = None, ) -> str: + """Return a string representation of a report.""" pass def temp_column_name(*dataframes) -> str: - """Gets a temp column name that isn't included in columns of any dataframes + """Get a temp column name that isn't included in columns of any dataframes. Parameters ---------- diff --git a/datacompy/core.py b/datacompy/core.py index d07cac96..0089dc38 100644 --- a/datacompy/core.py +++ b/datacompy/core.py @@ -14,7 +14,7 @@ # limitations under the License. """ -Compare two Pandas DataFrames +Compare two Pandas DataFrames. Originally this package was meant to provide similar functionality to PROC COMPARE in SAS - i.e. human-readable reporting on the difference between @@ -29,7 +29,7 @@ import pandas as pd from ordered_set import OrderedSet -from .base import BaseCompare, temp_column_name +from datacompy.base import BaseCompare, temp_column_name LOG = logging.getLogger(__name__) @@ -131,11 +131,12 @@ def __init__( @property def df1(self) -> pd.DataFrame: + """Get the first dataframe.""" return self._df1 @df1.setter def df1(self, df1: pd.DataFrame) -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df1 = df1 self._validate_dataframe( "df1", cast_column_names_lower=self.cast_column_names_lower @@ -143,11 +144,12 @@ def df1(self, df1: pd.DataFrame) -> None: @property def df2(self) -> pd.DataFrame: + """Get the second dataframe.""" return self._df2 @df2.setter def df2(self, df2: pd.DataFrame) -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df2 = df2 self._validate_dataframe( "df2", cast_column_names_lower=self.cast_column_names_lower @@ -156,7 +158,7 @@ def df2(self, df2: pd.DataFrame) -> None: def _validate_dataframe( self, index: str, cast_column_names_lower: bool = True ) -> None: - """Check that it is a dataframe and has the join columns + """Check that it is a dataframe and has the join columns. Parameters ---------- @@ -192,7 +194,9 @@ def _validate_dataframe( self._any_dupes = True def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Actually run the comparison. This tries to run df1.equals(df2) + """Run the comparison. + + This tries to run df1.equals(df2) first so that if they're truly equal we can tell. This method will log out information about what is different between @@ -224,24 +228,26 @@ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: LOG.info("df1 does not match df2") def df1_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df1""" + """Get columns that are unique to df1.""" return cast( OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) ) def df2_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df2""" + """Get columns that are unique to df2.""" return cast( OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) ) def intersect_columns(self) -> OrderedSet[str]: - """Get columns that are shared between the two dataframes""" + """Get columns that are shared between the two dataframes.""" return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns) def _dataframe_merge(self, ignore_spaces: bool) -> None: - """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1 - and df1 & df2 + """Merge df1 to df2 on the join columns. + + To get df1 - df2, df2 - df1 + and df1 & df2. If ``on_index`` is True, this will join on index values, otherwise it will join on the ``join_columns``. @@ -324,7 +330,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: ) def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Run the comparison on the intersect dataframe + """Run the comparison on the intersect dataframe. This loops through all columns that are shared between df1 and df2, and creates a column column_match which is True for matches, False @@ -385,11 +391,11 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: ) def all_columns_match(self) -> bool: - """Whether the columns all match in the dataframes""" + """Whether the columns all match in the dataframes.""" return self.df1_unq_columns() == self.df2_unq_columns() == set() def all_rows_overlap(self) -> bool: - """Whether the rows are all present in both dataframes + """Whether the rows are all present in both dataframes. Returns ------- @@ -400,7 +406,7 @@ def all_rows_overlap(self) -> bool: return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0 def count_matching_rows(self) -> int: - """Count the number of rows match (on overlapping fields) + """Count the number of rows match (on overlapping fields). Returns ------- @@ -414,7 +420,7 @@ def count_matching_rows(self) -> int: return self.intersect_rows[match_columns].all(axis=1).sum() def intersect_rows_match(self) -> bool: - """Check whether the intersect rows all match""" + """Check whether the intersect rows all match.""" actual_length = self.intersect_rows.shape[0] return self.count_matching_rows() == actual_length @@ -431,14 +437,11 @@ def matches(self, ignore_extra_columns: bool = False) -> bool: bool True or False if the dataframes match. """ - if not ignore_extra_columns and not self.all_columns_match(): - return False - elif not self.all_rows_overlap(): - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + (ignore_extra_columns or self.all_columns_match()) + and self.all_rows_overlap() + and self.intersect_rows_match() + ) def subset(self) -> bool: """Return True if dataframe 2 is a subset of dataframe 1. @@ -452,19 +455,18 @@ def subset(self) -> bool: bool True if dataframe 2 is a subset of dataframe 1. """ - if not self.df2_unq_columns() == set(): - return False - elif not len(self.df2_unq_rows) == 0: - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + self.df2_unq_columns() == set() + and len(self.df2_unq_rows) == 0 + and self.intersect_rows_match() + ) def sample_mismatch( self, column: str, sample_count: int = 10, for_display: bool = False ) -> pd.DataFrame: - """Returns a sample sub-dataframe which contains the identifying + """Return sample mismatches. + + Gets a sub-dataframe which contains the identifying columns, and df1 and df2 versions of the column. Parameters @@ -489,15 +491,16 @@ def sample_mismatch( match_cnt = col_match.sum() sample_count = min(sample_count, row_cnt - match_cnt) sample = self.intersect_rows[~col_match].sample(sample_count) - return_cols = self.join_columns + [ + return_cols = [ + *self.join_columns, column + "_" + self.df1_name, column + "_" + self.df2_name, ] to_return = sample[return_cols] if for_display: to_return.columns = pd.Index( - self.join_columns - + [ + [ + *self.join_columns, column + " (" + self.df1_name + ")", column + " (" + self.df2_name + ")", ] @@ -505,7 +508,9 @@ def sample_mismatch( return to_return def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame: - """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join + """Get all rows with any columns that have a mismatch. + + Returns all df1 and df2 versions of the columns and join columns. Parameters @@ -558,7 +563,9 @@ def report( column_count: int = 10, html_file: Optional[str] = None, ) -> str: - """Returns a string representation of a report. The representation can + """Return a string representation of a report. + + The representation can then be printed or saved to a file. Parameters @@ -630,7 +637,7 @@ def df_to_str(pdf: pd.DataFrame) -> str: "column_comparison.txt", len([col for col in self.column_stats if col["unequal_cnt"] > 0]), len([col for col in self.column_stats if col["unequal_cnt"] == 0]), - sum([col["unequal_cnt"] for col in self.column_stats]), + sum(col["unequal_cnt"] for col in self.column_stats), ) match_stats = [] @@ -719,7 +726,9 @@ def df_to_str(pdf: pd.DataFrame) -> str: def render(filename: str, *fields: Union[int, float, str]) -> str: - """Renders out an individual template. This basically just reads in a + """Render out an individual template. + + This basically just reads in a template file, and applies ``.format()`` on the fields. Parameters @@ -748,7 +757,9 @@ def columns_equal( ignore_spaces: bool = False, ignore_case: bool = False, ) -> "pd.Series[bool]": - """Compares two columns from a dataframe, returning a True/False series, + """Compare two columns from a dataframe. + + Returns a True/False series, with the same index as column 1. - Two nulls (np.nan) will evaluate to True. @@ -826,7 +837,9 @@ def columns_equal( def compare_string_and_date_columns( col_1: "pd.Series[Any]", col_2: "pd.Series[Any]" ) -> "pd.Series[bool]": - """Compare a string column and date column, value-wise. This tries to + """Compare a string column and date column, value-wise. + + This tries to convert a string column to a date column and compare that way. Parameters @@ -867,7 +880,7 @@ def compare_string_and_date_columns( def get_merged_columns( original_df: pd.DataFrame, merged_df: pd.DataFrame, suffix: str ) -> List[str]: - """Gets the columns from an original dataframe, in the new merged dataframe + """Get the columns from an original dataframe, in the new merged dataframe. Parameters ---------- @@ -891,7 +904,7 @@ def get_merged_columns( def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> float: - """Get a maximum difference between two columns + """Get a maximum difference between two columns. Parameters ---------- @@ -914,7 +927,9 @@ def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> floa def generate_id_within_group( dataframe: pd.DataFrame, join_columns: List[str] ) -> "pd.Series[int]": - """Generate an ID column that can be used to deduplicate identical rows. The series generated + """Generate an ID column that can be used to deduplicate identical rows. + + The series generated is the order within a unique group, and it handles nulls. Parameters diff --git a/datacompy/fugue.py b/datacompy/fugue.py index 8bc01d33..f2983c49 100644 --- a/datacompy/fugue.py +++ b/datacompy/fugue.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Compare two DataFrames that are supported by Fugue -""" +"""Compare two DataFrames that are supported by Fugue.""" import logging import pickle @@ -29,14 +27,14 @@ from ordered_set import OrderedSet from triad import Schema -from .core import Compare, render +from datacompy.core import Compare, render LOG = logging.getLogger(__name__) HASH_COL = "__datacompy__hash__" def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]: - """Get columns that are unique to df1 + """Get columns that are unique to df1. Parameters ---------- @@ -57,7 +55,7 @@ def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]: def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]: - """Get columns that are shared between the two dataframes + """Get columns that are shared between the two dataframes. Parameters ---------- @@ -78,7 +76,7 @@ def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]: def all_columns_match(df1: AnyDataFrame, df2: AnyDataFrame) -> bool: - """Whether the columns all match in the dataframes + """Whether the columns all match in the dataframes. Parameters ---------- @@ -209,7 +207,7 @@ def all_rows_overlap( parallelism: Optional[int] = None, strict_schema: bool = False, ) -> bool: - """Check if the rows are all present in both dataframes + """Check if the rows are all present in both dataframes. Parameters ---------- @@ -305,7 +303,7 @@ def count_matching_rows( parallelism: Optional[int] = None, strict_schema: bool = False, ) -> int: - """Count the number of rows match (on overlapping fields) + """Count the number of rows match (on overlapping fields). Parameters ---------- @@ -402,7 +400,9 @@ def report( html_file: Optional[str] = None, parallelism: Optional[int] = None, ) -> str: - """Returns a string representation of a report. The representation can + """Return a string representation of a report. + + The representation can then be printed or saved to a file. Both df1 and df2 should be dataframes containing all of the join_columns, @@ -559,7 +559,7 @@ def _any(col: str) -> int: "column_comparison.txt", len([col for col in column_stats if col["unequal_cnt"] > 0]), len([col for col in column_stats if col["unequal_cnt"] == 0]), - sum([col["unequal_cnt"] for col in column_stats]), + sum(col["unequal_cnt"] for col in column_stats), ) match_stats = [] @@ -652,7 +652,7 @@ def _distributed_compare( parallelism: Optional[int] = None, strict_schema: bool = False, ) -> List[Any]: - """Compare the data distributively using the core Compare class + """Compare the data distributively using the core Compare class. Both df1 and df2 should be dataframes containing all of the join_columns, with unique column names. Differences between values are compared to @@ -698,7 +698,6 @@ def _distributed_compare( List[Any] Returns the list of objects returned from the return_obj_func """ - tdf1 = fa.as_fugue_df(df1) tdf2 = fa.as_fugue_df(df2) @@ -716,9 +715,8 @@ def _distributed_compare( ) hash_cols = [col.lower() for col in hash_cols] - if strict_schema: - if tdf1.schema != tdf2.schema: - raise _StrictSchemaError() + if strict_schema and tdf1.schema != tdf2.schema: + raise _StrictSchemaError() # check that hash columns exist assert hash_cols in tdf1.schema, f"{hash_cols} not found in {tdf1.schema}" @@ -726,7 +724,7 @@ def _distributed_compare( df1_schema = tdf1.schema df2_schema = tdf2.schema - str_cols = set(f.name for f in tdf1.schema.fields if pa.types.is_string(f.type)) + str_cols = {f.name for f in tdf1.schema.fields if pa.types.is_string(f.type)} bucket = ( parallelism if parallelism is not None else fa.get_current_parallelism() * 2 ) @@ -753,13 +751,13 @@ def _serialize(dfs: Iterable[pd.DataFrame], left: bool) -> Iterable[Dict[str, An tdf1, _serialize, schema="key:int,left:bool,data:binary", - params=dict(left=True), + params={"left": True}, ), fa.transform( tdf2, _serialize, schema="key:int,left:bool,data:binary", - params=dict(left=False), + params={"left": False}, ), distinct=False, ) @@ -814,7 +812,7 @@ def _comp(df: List[Dict[str, Any]]) -> List[List[Any]]: objs = fa.as_array( fa.transform( - ser, _comp, schema="obj:binary", partition=dict(by="key", num=bucket) + ser, _comp, schema="obj:binary", partition={"by": "key", "num": bucket} ) ) return [pickle.loads(row[0]) for row in objs] @@ -825,11 +823,10 @@ def _get_compare_result( ) -> Dict[str, Any]: mismatch_samples: Dict[str, pd.DataFrame] = {} for column in compare.column_stats: - if not column["all_match"]: - if column["unequal_cnt"] > 0: - mismatch_samples[column["column"]] = compare.sample_mismatch( - column["column"], sample_count, for_display=True - ) + if not column["all_match"] and column["unequal_cnt"] > 0: + mismatch_samples[column["column"]] = compare.sample_mismatch( + column["column"], sample_count, for_display=True + ) df1_unq_rows_sample: Any = None if min(sample_count, compare.df1_unq_rows.shape[0]) > 0: @@ -916,6 +913,6 @@ def _sample(df: pd.DataFrame, sample_count: int) -> pd.DataFrame: class _StrictSchemaError(Exception): - """Exception raised when strict schema is enabled and the schemas do not match""" + """Exception raised when strict schema is enabled and the schemas do not match.""" pass diff --git a/datacompy/polars.py b/datacompy/polars.py index 1c0f8d56..c9758548 100644 --- a/datacompy/polars.py +++ b/datacompy/polars.py @@ -14,7 +14,7 @@ # limitations under the License. """ -Compare two Polars DataFrames +Compare two Polars DataFrames. Originally this package was meant to provide similar functionality to PROC COMPARE in SAS - i.e. human-readable reporting on the difference between @@ -29,7 +29,7 @@ import numpy as np from ordered_set import OrderedSet -from .base import BaseCompare, temp_column_name +from datacompy.base import BaseCompare, temp_column_name try: import polars as pl @@ -123,19 +123,20 @@ def __init__( self.rel_tol = rel_tol self.ignore_spaces = ignore_spaces self.ignore_case = ignore_case - self.df1_unq_rows: "pl.DataFrame" - self.df2_unq_rows: "pl.DataFrame" - self.intersect_rows: "pl.DataFrame" + self.df1_unq_rows: pl.DataFrame + self.df2_unq_rows: pl.DataFrame + self.intersect_rows: pl.DataFrame self.column_stats: List[Dict[str, Any]] = [] self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case) @property def df1(self) -> "pl.DataFrame": + """Get the first dataframe.""" return self._df1 @df1.setter def df1(self, df1: "pl.DataFrame") -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df1 = df1 self._validate_dataframe( "df1", cast_column_names_lower=self.cast_column_names_lower @@ -143,11 +144,12 @@ def df1(self, df1: "pl.DataFrame") -> None: @property def df2(self) -> "pl.DataFrame": + """Get the second dataframe.""" return self._df2 @df2.setter def df2(self, df2: "pl.DataFrame") -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df2 = df2 self._validate_dataframe( "df2", cast_column_names_lower=self.cast_column_names_lower @@ -156,7 +158,7 @@ def df2(self, df2: "pl.DataFrame") -> None: def _validate_dataframe( self, index: str, cast_column_names_lower: bool = True ) -> None: - """Check that it is a dataframe and has the join columns + """Check that it is a dataframe and has the join columns. Parameters ---------- @@ -183,7 +185,9 @@ def _validate_dataframe( self._any_dupes = True def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Actually run the comparison. This tries to run df1.equals(df2) + """Run the comparison. + + This tries to run df1.equals(df2) first so that if they're truly equal we can tell. This method will log out information about what is different between @@ -215,24 +219,26 @@ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: LOG.info("df1 does not match df2") def df1_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df1""" + """Get columns that are unique to df1.""" return cast( OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) ) def df2_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df2""" + """Get columns that are unique to df2.""" return cast( OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) ) def intersect_columns(self) -> OrderedSet[str]: - """Get columns that are shared between the two dataframes""" + """Get columns that are shared between the two dataframes.""" return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns) def _dataframe_merge(self, ignore_spaces: bool) -> None: - """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1 - and df1 & df2 + """Merge df1 to df2 on the join columns. + + To get df1 - df2, df2 - df1 + and df1 & df2. """ params: Dict[str, Any] LOG.debug("Outer joining") @@ -278,20 +284,11 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: # process merge indicator outer_join = outer_join.with_columns( - pl.when( - (pl.col("_merge_left") == True) - & (pl.col("_merge_right") == True) # noqa: E712 - ) + pl.when(pl.col("_merge_left") & pl.col("_merge_right")) .then(pl.lit("both")) - .when( - (pl.col("_merge_left") == True) - & (pl.col("_merge_right").is_null()) # noqa: E712 - ) + .when(pl.col("_merge_left") & pl.col("_merge_right").is_null()) .then(pl.lit("left_only")) - .when( - (pl.col("_merge_left").is_null()) - & (pl.col("_merge_right") == True) # noqa: E712 - ) + .when(pl.col("_merge_left").is_null() & pl.col("_merge_right")) .then(pl.lit("right_only")) .alias("_merge") ) @@ -325,7 +322,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: ) def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Run the comparison on the intersect dataframe + """Run the comparison on the intersect dataframe. This loops through all columns that are shared between df1 and df2, and creates a column column_match which is True for matches, False @@ -390,11 +387,11 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: ) def all_columns_match(self) -> bool: - """Whether the columns all match in the dataframes""" + """Whether the columns all match in the dataframes.""" return self.df1_unq_columns() == self.df2_unq_columns() == set() def all_rows_overlap(self) -> bool: - """Whether the rows are all present in both dataframes + """Whether the rows are all present in both dataframes. Returns ------- @@ -405,7 +402,7 @@ def all_rows_overlap(self) -> bool: return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0 def count_matching_rows(self) -> int: - """Count the number of rows match (on overlapping fields) + """Count the number of rows match (on overlapping fields). Returns ------- @@ -432,7 +429,7 @@ def count_matching_rows(self) -> int: return 0 def intersect_rows_match(self) -> bool: - """Check whether the intersect rows all match""" + """Check whether the intersect rows all match.""" actual_length = self.intersect_rows.shape[0] return self.count_matching_rows() == actual_length @@ -449,14 +446,11 @@ def matches(self, ignore_extra_columns: bool = False) -> bool: bool True or False if the dataframes match. """ - if not ignore_extra_columns and not self.all_columns_match(): - return False - elif not self.all_rows_overlap(): - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + (ignore_extra_columns or self.all_columns_match()) + and self.all_rows_overlap() + and self.intersect_rows_match() + ) def subset(self) -> bool: """Return True if dataframe 2 is a subset of dataframe 1. @@ -470,19 +464,18 @@ def subset(self) -> bool: bool True if dataframe 2 is a subset of dataframe 1. """ - if not self.df2_unq_columns() == set(): - return False - elif not len(self.df2_unq_rows) == 0: - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + self.df2_unq_columns() == set() + and len(self.df2_unq_rows) == 0 + and self.intersect_rows_match() + ) def sample_mismatch( self, column: str, sample_count: int = 10, for_display: bool = False ) -> "pl.DataFrame": - """Returns a sample sub-dataframe which contains the identifying + """Return sample mismatches. + + Get a sub-dataframe which contains the identifying columns, and df1 and df2 versions of the column. Parameters @@ -509,20 +502,24 @@ def sample_mismatch( sample = self.intersect_rows.filter( pl.col(column + "_match") != True # noqa: E712 ).sample(sample_count) - return_cols = self.join_columns + [ + return_cols = [ + *self.join_columns, column + "_" + self.df1_name, column + "_" + self.df2_name, ] to_return = sample[return_cols] if for_display: - to_return.columns = self.join_columns + [ + to_return.columns = [ + *self.join_columns, column + " (" + self.df1_name + ")", column + " (" + self.df2_name + ")", ] return to_return def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame": - """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join + """Get all rows with any columns that have a mismatch. + + Returns all df1 and df2 versions of the columns and join columns. Parameters @@ -577,7 +574,9 @@ def report( column_count: int = 10, html_file: Optional[str] = None, ) -> str: - """Returns a string representation of a report. The representation can + """Return a string representation of a report. + + The representation can then be printed or saved to a file. Parameters @@ -644,7 +643,7 @@ def df_to_str(pdf: "pl.DataFrame") -> str: "column_comparison.txt", len([col for col in self.column_stats if col["unequal_cnt"] > 0]), len([col for col in self.column_stats if col["unequal_cnt"] == 0]), - sum([col["unequal_cnt"] for col in self.column_stats]), + sum(col["unequal_cnt"] for col in self.column_stats), ) match_stats = [] @@ -737,7 +736,9 @@ def df_to_str(pdf: "pl.DataFrame") -> str: def render(filename: str, *fields: Union[int, float, str]) -> str: - """Renders out an individual template. This basically just reads in a + """Render out an individual template. + + This basically just reads in a template file, and applies ``.format()`` on the fields. Parameters @@ -766,7 +767,9 @@ def columns_equal( ignore_spaces: bool = False, ignore_case: bool = False, ) -> "pl.Series": - """Compares two columns from a dataframe, returning a True/False series, + """Compare two columns from a dataframe. + + Returns a True/False series, with the same index as column 1. - Two nulls (np.nan) will evaluate to True. @@ -850,7 +853,9 @@ def columns_equal( def compare_string_and_date_columns( col_1: "pl.Series", col_2: "pl.Series" ) -> "pl.Series": - """Compare a string column and date column, value-wise. This tries to + """Compare a string column and date column, value-wise. + + This tries to convert a string column to a date column and compare that way. Parameters @@ -885,7 +890,7 @@ def compare_string_and_date_columns( def get_merged_columns( original_df: "pl.DataFrame", merged_df: "pl.DataFrame", suffix: str ) -> List[str]: - """Gets the columns from an original dataframe, in the new merged dataframe + """Get the columns from an original dataframe, in the new merged dataframe. Parameters ---------- @@ -909,7 +914,7 @@ def get_merged_columns( def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float: - """Get a maximum difference between two columns + """Get a maximum difference between two columns. Parameters ---------- @@ -934,7 +939,9 @@ def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float: def generate_id_within_group( dataframe: "pl.DataFrame", join_columns: List[str] ) -> "pl.Series": - """Generate an ID column that can be used to deduplicate identical rows. The series generated + """Generate an ID column that can be used to deduplicate identical rows. + + The series generated is the order within a unique group, and it handles nulls. Parameters diff --git a/datacompy/spark/__init__.py b/datacompy/spark/__init__.py index e69de29b..2ecc07fa 100644 --- a/datacompy/spark/__init__.py +++ b/datacompy/spark/__init__.py @@ -0,0 +1 @@ +"""Spark comparisons.""" diff --git a/datacompy/spark/legacy.py b/datacompy/spark/legacy.py index b23b9cb2..3555b8b0 100644 --- a/datacompy/spark/legacy.py +++ b/datacompy/spark/legacy.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Legacy spark comparison.""" import sys from enum import Enum @@ -34,11 +35,17 @@ class MatchType(Enum): + """Types of matches.""" + MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3) -# Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal". def decimal_comparator() -> str: + """Check equality with decimal(X, Y) types. + + Otherwise treated as the string "decimal". + """ + class DecimalComparator(str): def __eq__(self, other: str) -> bool: # type: ignore[override] return len(other) >= 7 and other[0:7] == "decimal" @@ -58,10 +65,11 @@ def __eq__(self, other: str) -> bool: # type: ignore[override] def _is_comparable(type1: str, type2: str) -> bool: - """Checks if two Spark data types can be safely compared. + """Check if two Spark data types can be safely compared. + Two data types are considered comparable if any of the following apply: 1. Both data types are the same - 2. Both data types are numeric + 2. Both data types are numeric. Parameters ---------- @@ -75,7 +83,6 @@ def _is_comparable(type1: str, type2: str) -> bool: bool True if both data types are comparable """ - return type1 == type2 or ( type1 in NUMERIC_SPARK_TYPES and type2 in NUMERIC_SPARK_TYPES ) @@ -194,11 +201,11 @@ def __init__( self._base_row_count: Optional[int] = None self._compare_row_count: Optional[int] = None self._common_row_count: Optional[int] = None - self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None - self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None - self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None - self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None - self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None + self._joined_dataframe: Optional[pyspark.sql.DataFrame] = None + self._rows_only_base: Optional[pyspark.sql.DataFrame] = None + self._rows_only_compare: Optional[pyspark.sql.DataFrame] = None + self._all_matched_rows: Optional[pyspark.sql.DataFrame] = None + self._all_rows_mismatched: Optional[pyspark.sql.DataFrame] = None self.columns_match_dict: Dict[str, Any] = {} # drop the duplicates before actual comparison made. @@ -225,13 +232,15 @@ def _tuplizer( @property def columns_in_both(self) -> Set[str]: - """set[str]: Get columns in both dataframes""" + """set[str]: Get columns in both dataframes.""" return set(self.base_df.columns) & set(self.compare_df.columns) @property def columns_compared(self) -> List[str]: - """list[str]: Get columns to be compared in both dataframes (all - columns in both excluding the join key(s)""" + """Get columns to be compared in both dataframes. + + All columns in both excluding the join key(s). + """ return [ column for column in list(self.columns_in_both) @@ -240,17 +249,17 @@ def columns_compared(self) -> List[str]: @property def columns_only_base(self) -> Set[str]: - """set[str]: Get columns that are unique to the base dataframe""" + """set[str]: Get columns that are unique to the base dataframe.""" return set(self.base_df.columns) - set(self.compare_df.columns) @property def columns_only_compare(self) -> Set[str]: - """set[str]: Get columns that are unique to the compare dataframe""" + """set[str]: Get columns that are unique to the compare dataframe.""" return set(self.compare_df.columns) - set(self.base_df.columns) @property def base_row_count(self) -> int: - """int: Get the count of rows in the de-duped base dataframe""" + """int: Get the count of rows in the de-duped base dataframe.""" if self._base_row_count is None: self._base_row_count = self.base_df.count() @@ -258,7 +267,7 @@ def base_row_count(self) -> int: @property def compare_row_count(self) -> int: - """int: Get the count of rows in the de-duped compare dataframe""" + """int: Get the count of rows in the de-duped compare dataframe.""" if self._compare_row_count is None: self._compare_row_count = self.compare_df.count() @@ -266,7 +275,7 @@ def compare_row_count(self) -> int: @property def common_row_count(self) -> int: - """int: Get the count of rows in common between base and compare dataframes""" + """int: Get the count of rows in common between base and compare dataframes.""" if self._common_row_count is None: common_rows = self._get_or_create_joined_dataframe() self._common_row_count = common_rows.count() @@ -274,19 +283,19 @@ def common_row_count(self) -> int: return self._common_row_count def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame": - """Get the rows only from base data frame""" + """Get the rows only from base data frame.""" return self.base_df.select(self._join_column_names).subtract( self.compare_df.select(self._join_column_names) ) def _get_compare_rows(self) -> "pyspark.sql.DataFrame": - """Get the rows only from compare data frame""" + """Get the rows only from compare data frame.""" return self.compare_df.select(self._join_column_names).subtract( self.base_df.select(self._join_column_names) ) def _print_columns_summary(self, myfile: TextIO) -> None: - """Prints the column summary details""" + """Print the column summary details.""" print("\n****** Column Summary ******", file=myfile) print( f"Number of columns in common with matching schemas: {len(self._columns_with_matching_schema())}", @@ -306,8 +315,7 @@ def _print_columns_summary(self, myfile: TextIO) -> None: ) def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None: - """Prints the columns and data types only in either the base or compare datasets""" - + """Print the columns and data types only in either the base or compare datasets.""" if base_or_compare.upper() == "BASE": columns = self.columns_only_base df = self.base_df @@ -335,7 +343,7 @@ def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None: print((format_pattern + " {:13s}").format(column, col_type), file=myfile) def _columns_with_matching_schema(self) -> Dict[str, str]: - """This function will identify the columns which has matching schema""" + """Identify the columns which has matching schema.""" col_schema_match = {} base_columns_dict = dict(self.base_df.dtypes) compare_columns_dict = dict(self.compare_df.dtypes) @@ -349,7 +357,7 @@ def _columns_with_matching_schema(self) -> Dict[str, str]: return col_schema_match def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: - """This function will identify the columns which has different schema""" + """Identify the columns which has different schema.""" col_schema_diff = {} base_columns_dict = dict(self.base_df.dtypes) compare_columns_dict = dict(self.compare_df.dtypes) @@ -361,15 +369,15 @@ def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: compare_column_type is not None and base_type not in compare_column_type ): - col_schema_diff[base_row] = dict( - base_type=base_type, - compare_type=compare_column_type, - ) + col_schema_diff[base_row] = { + "base_type": base_type, + "compare_type": compare_column_type, + } return col_schema_diff @property def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]: - """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches""" + """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches.""" if self._all_rows_mismatched is None: self._merge_dataframes() @@ -377,7 +385,7 @@ def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]: @property def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]: - """pyspark.sql.DataFrame: Returns all rows in both dataframes""" + """pyspark.sql.DataFrame: Returns all rows in both dataframes.""" if self._all_matched_rows is None: self._merge_dataframes() @@ -385,7 +393,7 @@ def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]: @property def rows_only_base(self) -> "pyspark.sql.DataFrame": - """pyspark.sql.DataFrame: Returns rows only in the base dataframe""" + """pyspark.sql.DataFrame: Returns rows only in the base dataframe.""" if not self._rows_only_base: base_rows = self._get_unq_base_rows() base_rows.createOrReplaceTempView("baseRows") @@ -396,8 +404,8 @@ def rows_only_base(self) -> "pyspark.sql.DataFrame": for name in self._join_column_names ] ) - sql_query = "select A.* from baseTable as A, baseRows as B where {}".format( - join_condition + sql_query = ( + f"select A.* from baseTable as A, baseRows as B where {join_condition}" ) self._rows_only_base = self.spark.sql(sql_query) @@ -408,7 +416,7 @@ def rows_only_base(self) -> "pyspark.sql.DataFrame": @property def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: - """pyspark.sql.DataFrame: Returns rows only in the compare dataframe""" + """pyspark.sql.DataFrame: Returns rows only in the compare dataframe.""" if not self._rows_only_compare: compare_rows = self._get_compare_rows() compare_rows.createOrReplaceTempView("compareRows") @@ -419,11 +427,7 @@ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: for name in self._join_column_names ] ) - sql_query = ( - "select A.* from compareTable as A, compareRows as B where {}".format( - where_condition - ) - ) + sql_query = f"select A.* from compareTable as A, compareRows as B where {where_condition}" self._rows_only_compare = self.spark.sql(sql_query) if self.cache_intermediates: @@ -432,10 +436,10 @@ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: return self._rows_only_compare def _generate_select_statement(self, match_data: bool = True) -> str: - """This function is to generate the select statement to be used later in the query.""" + """Generate the select statement to be used later in the query.""" base_only = list(set(self.base_df.columns) - set(self.compare_df.columns)) compare_only = list(set(self.compare_df.columns) - set(self.base_df.columns)) - sorted_list = sorted(list(chain(base_only, compare_only, self.columns_in_both))) + sorted_list = sorted(chain(base_only, compare_only, self.columns_in_both)) select_statement = "" for column_name in sorted_list: @@ -473,14 +477,12 @@ def _generate_select_statement(self, match_data: bool = True) -> str: return select_statement def _merge_dataframes(self) -> None: - """Merges the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched.""" + """Merge the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched.""" full_joined_dataframe = self._get_or_create_joined_dataframe() full_joined_dataframe.createOrReplaceTempView("full_matched_table") select_statement = self._generate_select_statement(False) - select_query = """SELECT {} FROM full_matched_table A""".format( - select_statement - ) + select_query = f"""SELECT {select_statement} FROM full_matched_table A""" self._all_matched_rows = self.spark.sql(select_query).orderBy( self._join_column_names # type: ignore[arg-type] ) @@ -489,7 +491,7 @@ def _merge_dataframes(self) -> None: where_cond = " OR ".join( ["A.`" + name + "_match`= False" for name in self.columns_compared] ) - mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond) + mismatch_query = f"""SELECT * FROM matched_table A WHERE {where_cond}""" self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy( self._join_column_names # type: ignore[arg-type] ) @@ -507,13 +509,11 @@ def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame": self.base_df.createOrReplaceTempView("base_table") self.compare_df.createOrReplaceTempView("compare_table") - join_query = r""" - SELECT {} + join_query = rf""" + SELECT {select_statement} FROM base_table A JOIN compare_table B - ON {}""".format( - select_statement, join_condition - ) + ON {join_condition}""" self._joined_dataframe = self.spark.sql(join_query) if self.cache_intermediates: @@ -536,9 +536,7 @@ def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None: ] ) match_query = ( - r"""SELECT count(*) AS row_count FROM matched_df A WHERE {}""".format( - where_cond - ) + rf"""SELECT count(*) AS row_count FROM matched_df A WHERE {where_cond}""" ) all_rows_matched = self.spark.sql(match_query) all_rows_matched_head = all_rows_matched.head() @@ -554,16 +552,16 @@ def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None: print(f"Number of rows with all columns equal: {matched_rows}", file=myfile) def _populate_columns_match_dict(self) -> None: - """ + """Populate the dictionary of matches in a dataframe. + side effects: columns_match_dict assigned to { column -> match_type_counts } where: column (string): Name of a column that exists in both the base and comparison columns - match_type_counts (list of int with size = len(MatchType)): The number of each match type seen for this column (in order of the MatchType enum values) + match_type_counts (list of int with size = len(MatchType)): The number of each match type seen for this column (in order of the MatchType enum values). returns: None """ - match_dataframe = self._get_or_create_joined_dataframe().select( *self.columns_compared ) @@ -591,56 +589,48 @@ def _create_select_statement(self, name: str) -> str: match_type_comparison = "" for k in MatchType: match_type_comparison += ( - " WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format( - name=name, match_value=str(k.value), match_name=k.name - ) + f" WHEN (A.`{name}`={k.value!s}) THEN '{k.name}'" ) - return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format( - name=name, - match_failure=MatchType.MISMATCH.value, - match_type_comparison=match_type_comparison, - ) + return f"A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={MatchType.MISMATCH.value}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` " else: - return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format( - name=name, match_failure=MatchType.MISMATCH.value - ) + return f"A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={MatchType.MISMATCH.value}) THEN False ELSE True END AS `{name}_match` " def _create_case_statement(self, name: str) -> str: - equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"] + equal_comparisons = [f"(A.`{name}` IS NULL AND B.`{name}` IS NULL)"] known_diff_comparisons = ["(FALSE)"] - base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0] - compare_dtype = [d[1] for d in self.compare_df.dtypes if d[0] == name][0] + base_dtype = next(d[1] for d in self.base_df.dtypes if d[0] == name) + compare_dtype = next(d[1] for d in self.compare_df.dtypes if d[0] == name) if _is_comparable(base_dtype, compare_dtype): if (base_dtype in NUMERIC_SPARK_TYPES) and ( compare_dtype in NUMERIC_SPARK_TYPES ): # numeric tolerance comparison equal_comparisons.append( - "((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=(" + f"((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=(" + str(self.abs_tol) + "+(" + str(self.rel_tol) - + "*abs(A.`{name}`)))))" + + f"*abs(A.`{name}`)))))" ) else: # non-numeric comparison - equal_comparisons.append("((A.`{name}`=B.`{name}`))") + equal_comparisons.append(f"((A.`{name}`=B.`{name}`))") if self._known_differences: - new_input = "B.`{name}`" + new_input = f"B.`{name}`" for kd in self._known_differences: if compare_dtype in kd["types"]: if "flags" in kd and "nullcheck" in kd["flags"]: known_diff_comparisons.append( "((" + kd["transformation"].format(new_input, input=new_input) - + ") is null AND A.`{name}` is null)" + + f") is null AND A.`{name}` is null)" ) else: known_diff_comparisons.append( "((" + kd["transformation"].format(new_input, input=new_input) - + ") = A.`{name}`)" + + f") = A.`{name}`)" ) case_string = ( @@ -649,7 +639,7 @@ def _create_case_statement(self, name: str) -> str: + ") THEN {match_success} WHEN (" + " OR ".join(known_diff_comparisons) + ") THEN {match_known_difference} ELSE {match_failure} END) " - + "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`" + + f"AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`" ) return case_string.format( @@ -696,9 +686,7 @@ def _print_schema_diff_details(self, myfile: TextIO) -> None: [len(self._base_to_compare_name(key)) for key in schema_diff_dict] + [19] ) - format_pattern = "{{:{base}s}} {{:{compare}s}}".format( - base=base_name_max, compare=compare_name_max - ) + format_pattern = f"{{:{base_name_max}s}} {{:{compare_name_max}s}}" print("\n****** Schema Differences ******", file=myfile) print( @@ -729,9 +717,11 @@ def _print_schema_diff_details(self, myfile: TextIO) -> None: ) def _base_to_compare_name(self, base_name: str) -> str: - """Translates a column name in the base dataframe to its counterpart in the - compare dataframe, if they are different.""" + """Translate a column name. + Translates a column in the base dataframe to its counterpart in the + compare dataframe if they are different. + """ if base_name in self.column_mapping: return self.column_mapping[base_name] else: @@ -895,15 +885,16 @@ def _print_row_matches_by_column(self, myfile: TextIO) -> None: / self.common_row_count + 0.0 ) - output_row.append("{:02.5f}".format(match_rate)) + output_row.append(f"{match_rate:02.5f}") if num_known_diffs is not None: output_row.insert(len(output_row) - 1, str(num_known_diffs)) print(format_pattern.format(*output_row), file=myfile) # noinspection PyUnresolvedReferences def report(self, file: TextIO = sys.stdout) -> None: - """Creates a comparison report and prints it to the file specified - (stdout by default). + """Create a comparison report and print it to the file specified. + + Prints to stdout by default. Parameters ---------- @@ -917,7 +908,6 @@ def report(self, file: TextIO = sys.stdout) -> None: >>> with open('my_report.txt', 'w') as report_file: ... comparison.report(file=report_file) """ - self._print_columns_summary(file) self._print_schema_diff_details(file) self._print_only_columns("BASE", file) diff --git a/datacompy/spark/pandas.py b/datacompy/spark/pandas.py index 09c983cf..c946395d 100644 --- a/datacompy/spark/pandas.py +++ b/datacompy/spark/pandas.py @@ -14,7 +14,7 @@ # limitations under the License. """ -Compare two Pandas on Spark DataFrames +Compare two Pandas on Spark DataFrames. Originally this package was meant to provide similar functionality to PROC COMPARE in SAS - i.e. human-readable reporting on the difference between @@ -28,7 +28,7 @@ import pandas as pd from ordered_set import OrderedSet -from ..base import BaseCompare, temp_column_name +from datacompy.base import BaseCompare, temp_column_name try: import pyspark.pandas as ps @@ -131,11 +131,12 @@ def __init__( @property def df1(self) -> "ps.DataFrame": + "Get the first dataframe." return self._df1 @df1.setter def df1(self, df1: "ps.DataFrame") -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df1 = df1 self._validate_dataframe( "df1", cast_column_names_lower=self.cast_column_names_lower @@ -143,11 +144,12 @@ def df1(self, df1: "ps.DataFrame") -> None: @property def df2(self) -> "ps.DataFrame": + """Get the second dataframe.""" return self._df2 @df2.setter def df2(self, df2: "ps.DataFrame") -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df2 = df2 self._validate_dataframe( "df2", cast_column_names_lower=self.cast_column_names_lower @@ -156,7 +158,7 @@ def df2(self, df2: "ps.DataFrame") -> None: def _validate_dataframe( self, index: str, cast_column_names_lower: bool = True ) -> None: - """Check that it is a dataframe and has the join columns + """Check that it is a dataframe and has the join columns. Parameters ---------- @@ -188,7 +190,9 @@ def _validate_dataframe( self._any_dupes = True def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Actually run the comparison. This tries to run df1.equals(df2) + """Run the comparison. + + This tries to run df1.equals(df2) first so that if they're truly equal we can tell. This method will log out information about what is different between @@ -224,22 +228,22 @@ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: LOG.info("df1 does not match df2") def df1_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df1""" + """Get columns that are unique to df1.""" return OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) def df2_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df2""" + """Get columns that are unique to df2.""" return OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) def intersect_columns(self) -> OrderedSet[str]: - """Get columns that are shared between the two dataframes""" + """Get columns that are shared between the two dataframes.""" return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns) def _dataframe_merge(self, ignore_spaces: bool) -> None: - """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1 - and df1 & df2 - """ + """Merge df1 to df2 on the join columns. + Get df1 - df2, df2 - df1 and df1 & df2. + """ LOG.debug("Outer joining") df1 = self.df1.copy() @@ -297,7 +301,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: """ SELECT * FROM {df1} df1 FULL OUTER JOIN {df2} df2 - ON + ON """ + on, df1=df1, @@ -372,7 +376,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: self.intersect_rows.spark.cache() def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Run the comparison on the intersect dataframe + """Run the comparison on the intersect dataframe. This loops through all columns that are shared between df1 and df2, and creates a column column_match which is True for matches, False @@ -440,11 +444,11 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: ) def all_columns_match(self) -> bool: - """Whether the columns all match in the dataframes""" + """Whether the columns all match in the dataframes.""" return self.df1_unq_columns() == self.df2_unq_columns() == set() def all_rows_overlap(self) -> bool: - """Whether the rows are all present in both dataframes + """Whether the rows are all present in both dataframes. Returns ------- @@ -455,7 +459,7 @@ def all_rows_overlap(self) -> bool: return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0 def count_matching_rows(self) -> bool: - """Count the number of rows match (on overlapping fields) + """Count the number of rows match (on overlapping fields). Returns ------- @@ -479,7 +483,7 @@ def count_matching_rows(self) -> bool: return match_columns_count def intersect_rows_match(self) -> bool: - """Check whether the intersect rows all match""" + """Check whether the intersect rows all match.""" actual_length = self.intersect_rows.shape[0] return self.count_matching_rows() == actual_length @@ -491,14 +495,11 @@ def matches(self, ignore_extra_columns: bool = False) -> bool: ignore_extra_columns : bool Ignores any columns in one dataframe and not in the other. """ - if not ignore_extra_columns and not self.all_columns_match(): - return False - elif not self.all_rows_overlap(): - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + (ignore_extra_columns or self.all_columns_match()) + and self.all_rows_overlap() + and self.intersect_rows_match() + ) def subset(self) -> bool: """Return True if dataframe 2 is a subset of dataframe 1. @@ -512,19 +513,18 @@ def subset(self) -> bool: bool True if dataframe 2 is a subset of dataframe 1. """ - if not self.df2_unq_columns() == set(): - return False - elif not len(self.df2_unq_rows) == 0: - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + self.df2_unq_columns() == set() + and len(self.df2_unq_rows) == 0 + and self.intersect_rows_match() + ) def sample_mismatch( self, column: str, sample_count: int = 10, for_display: bool = False ) -> "ps.DataFrame": - """Returns a sample sub-dataframe which contains the identifying + """Return sample mismatches. + + Gets a sub-dataframe which contains the identifying columns, and df1 and df2 versions of the column. Parameters @@ -553,20 +553,24 @@ def sample_mismatch( for c in self.join_columns: sample[c] = sample[c + "_" + self.df1_name] - return_cols = self.join_columns + [ + return_cols = [ + *self.join_columns, column + "_" + self.df1_name, column + "_" + self.df2_name, ] to_return = sample[return_cols] if for_display: - to_return.columns = self.join_columns + [ + to_return.columns = [ + *self.join_columns, column + " (" + self.df1_name + ")", column + " (" + self.df2_name + ")", ] return to_return def all_mismatch(self, ignore_matching_cols: bool = False) -> "ps.DataFrame": - """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join + """Get all rows with any columns that have a mismatch. + + Returns all df1 and df2 versions of the columns and join columns. Parameters @@ -625,7 +629,9 @@ def report( column_count: int = 10, html_file: Optional[str] = None, ) -> str: - """Returns a string representation of a report. The representation can + """Return a string representation of a report. + + The representation can then be printed or saved to a file. Parameters @@ -688,7 +694,7 @@ def report( "column_comparison.txt", len([col for col in self.column_stats if col["unequal_cnt"] > 0]), len([col for col in self.column_stats if col["unequal_cnt"] == 0]), - sum([col["unequal_cnt"] for col in self.column_stats]), + sum(col["unequal_cnt"] for col in self.column_stats), ) match_stats = [] @@ -777,7 +783,9 @@ def report( def render(filename: str, *fields: Union[int, float, str]) -> str: - """Renders out an individual template. This basically just reads in a + """Render out an individual template. + + This basically just reads in a template file, and applies ``.format()`` on the fields. Parameters @@ -806,7 +814,9 @@ def columns_equal( ignore_spaces: bool = False, ignore_case: bool = False, ) -> "ps.Series": - """Compares two columns from a dataframe, returning a True/False series, + """Compare two columns from a dataframe. + + Returns a True/False series, with the same index as column 1. - Two nulls (np.nan) will evaluate to True. @@ -885,7 +895,9 @@ def columns_equal( def compare_string_and_date_columns( col_1: "ps.Series", col_2: "ps.Series" ) -> "ps.Series": - """Compare a string column and date column, value-wise. This tries to + """Compare a string column and date column, value-wise. + + This tries to convert a string column to a date column and compare that way. Parameters @@ -925,7 +937,7 @@ def get_merged_columns( merged_df: "ps.DataFrame", suffix: str, ) -> List[str]: - """Gets the columns from an original dataframe, in the new merged dataframe + """Get the columns from an original dataframe in the new merged dataframe. Parameters ---------- @@ -954,7 +966,7 @@ def get_merged_columns( def calculate_max_diff(col_1: "ps.DataFrame", col_2: "ps.DataFrame") -> float: - """Get a maximum difference between two columns + """Get a maximum difference between two columns. Parameters ---------- @@ -977,7 +989,9 @@ def calculate_max_diff(col_1: "ps.DataFrame", col_2: "ps.DataFrame") -> float: def generate_id_within_group( dataframe: "ps.DataFrame", join_columns: List[str] ) -> "ps.Series": - """Generate an ID column that can be used to deduplicate identical rows. The series generated + """Generate an ID column that can be used to deduplicate identical rows. + + The series generated is the order within a unique group, and it handles nulls. Parameters diff --git a/datacompy/spark/sql.py b/datacompy/spark/sql.py index 1152a61e..f99b7174 100644 --- a/datacompy/spark/sql.py +++ b/datacompy/spark/sql.py @@ -14,7 +14,7 @@ # limitations under the License. """ -Compare two PySpark SQL DataFrames +Compare two PySpark SQL DataFrames. Originally this package was meant to provide similar functionality to PROC COMPARE in SAS - i.e. human-readable reporting on the difference between @@ -29,7 +29,7 @@ import pandas as pd from ordered_set import OrderedSet -from ..base import BaseCompare, temp_column_name +from datacompy.base import BaseCompare, temp_column_name try: import pyspark.sql @@ -55,8 +55,12 @@ LOG = logging.getLogger(__name__) -# Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal". def decimal_comparator(): + """Check equality with decimal(X, Y) types. + + Otherwise treated as the string "decimal". + """ + class DecimalComparator(str): def __eq__(self, other): return len(other) >= 7 and other[0:7] == "decimal" @@ -160,19 +164,20 @@ def __init__( self.rel_tol = rel_tol self.ignore_spaces = ignore_spaces self.ignore_case = ignore_case - self.df1_unq_rows: "pyspark.sql.DataFrame" - self.df2_unq_rows: "pyspark.sql.DataFrame" - self.intersect_rows: "pyspark.sql.DataFrame" + self.df1_unq_rows: pyspark.sql.DataFrame + self.df2_unq_rows: pyspark.sql.DataFrame + self.intersect_rows: pyspark.sql.DataFrame self.column_stats: List = [] self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case) @property def df1(self) -> "pyspark.sql.DataFrame": + """Get the first dataframe.""" return self._df1 @df1.setter def df1(self, df1: "pyspark.sql.DataFrame") -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df1 = df1 self._validate_dataframe( "df1", cast_column_names_lower=self.cast_column_names_lower @@ -180,11 +185,12 @@ def df1(self, df1: "pyspark.sql.DataFrame") -> None: @property def df2(self) -> "pyspark.sql.DataFrame": + """Get the second dataframe.""" return self._df2 @df2.setter def df2(self, df2: "pyspark.sql.DataFrame") -> None: - """Check that it is a dataframe and has the join columns""" + """Check that it is a dataframe and has the join columns.""" self._df2 = df2 self._validate_dataframe( "df2", cast_column_names_lower=self.cast_column_names_lower @@ -193,7 +199,7 @@ def df2(self, df2: "pyspark.sql.DataFrame") -> None: def _validate_dataframe( self, index: str, cast_column_names_lower: bool = True ) -> None: - """Check that it is a dataframe and has the join columns + """Check that it is a dataframe and has the join columns. Parameters ---------- @@ -247,7 +253,9 @@ def _validate_dataframe( self._any_dupes = True def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Actually run the comparison. This tries to run df1.equals(df2) + """Actually run the comparison. + + This tries to run df1.equals(df2) first so that if they're truly equal we can tell. This method will log out information about what is different between @@ -276,20 +284,21 @@ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None: LOG.info("df1 does not match df2") def df1_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df1""" + """Get columns that are unique to df1.""" return OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns) def df2_unq_columns(self) -> OrderedSet[str]: - """Get columns that are unique to df2""" + """Get columns that are unique to df2.""" return OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns) def intersect_columns(self) -> OrderedSet[str]: - """Get columns that are shared between the two dataframes""" + """Get columns that are shared between the two dataframes.""" return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns) def _dataframe_merge(self, ignore_spaces: bool) -> None: - """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1 - and df1 & df2 + """Merge df1 to df2 on the join columns. + + To get df1 - df2, df2 - df1 and df1 & df2. """ LOG.debug("Outer joining") @@ -327,13 +336,15 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: if ignore_spaces: for column in self.join_columns: - if [dtype for name, dtype in df1.dtypes if name == column][ - 0 - ] == "string": + if ( + next(dtype for name, dtype in df1.dtypes if name == column) + == "string" + ): df1 = df1.withColumn(column, trim(col(column))) - if [dtype for name, dtype in df2.dtypes if name == column][ - 0 - ] == "string": + if ( + next(dtype for name, dtype in df2.dtypes if name == column) + == "string" + ): df2 = df2.withColumn(column, trim(col(column))) df1_non_join_columns = OrderedSet(df1.columns) - OrderedSet(temp_join_columns) @@ -369,7 +380,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: """ SELECT * FROM df1 FULL OUTER JOIN df2 - ON + ON """ + on ) @@ -434,7 +445,7 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None: self.intersect_rows.cache() def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None: - """Run the comparison on the intersect dataframe + """Run the comparison on the intersect dataframe. This loops through all columns that are shared between df1 and df2, and creates a column column_match which is True for matches, False @@ -522,7 +533,7 @@ def all_rows_overlap(self) -> bool: return self.df1_unq_rows.count() == self.df2_unq_rows.count() == 0 def count_matching_rows(self) -> int: - """Count the number of rows match (on overlapping fields) + """Count the number of rows match (on overlapping fields). Returns ------- @@ -544,7 +555,7 @@ def count_matching_rows(self) -> int: return match_columns_count def intersect_rows_match(self) -> bool: - """Check whether the intersect rows all match""" + """Check whether the intersect rows all match.""" actual_length = self.intersect_rows.count() return self.count_matching_rows() == actual_length @@ -556,14 +567,11 @@ def matches(self, ignore_extra_columns: bool = False) -> bool: ignore_extra_columns : bool Ignores any columns in one dataframe and not in the other. """ - if not ignore_extra_columns and not self.all_columns_match(): - return False - elif not self.all_rows_overlap(): - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + (ignore_extra_columns or self.all_columns_match()) + and self.all_rows_overlap() + and self.intersect_rows_match() + ) def subset(self) -> bool: """Return True if dataframe 2 is a subset of dataframe 1. @@ -577,19 +585,18 @@ def subset(self) -> bool: bool True if dataframe 2 is a subset of dataframe 1. """ - if not self.df2_unq_columns() == set(): - return False - elif not self.df2_unq_rows.count() == 0: - return False - elif not self.intersect_rows_match(): - return False - else: - return True + return ( + self.df2_unq_columns() == set() + and self.df2_unq_rows.count() == 0 + and self.intersect_rows_match() + ) def sample_mismatch( self, column: str, sample_count: int = 10, for_display: bool = False ) -> "pyspark.sql.DataFrame": - """Returns a sample sub-dataframe which contains the identifying + """Return sample mismatches. + + Gets a sub-dataframe which contains the identifying columns, and df1 and df2 versions of the column. Parameters @@ -624,7 +631,8 @@ def sample_mismatch( for c in self.join_columns: sample = sample.withColumnRenamed(c + "_" + self.df1_name, c) - return_cols = self.join_columns + [ + return_cols = [ + *self.join_columns, column + "_" + self.df1_name, column + "_" + self.df2_name, ] @@ -632,8 +640,8 @@ def sample_mismatch( if for_display: return to_return.toDF( - *self.join_columns - + [ + *[ + *self.join_columns, column + " (" + self.df1_name + ")", column + " (" + self.df2_name + ")", ] @@ -643,7 +651,9 @@ def sample_mismatch( def all_mismatch( self, ignore_matching_cols: bool = False ) -> "pyspark.sql.DataFrame": - """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join + """Get all rows with any columns that have a mismatch. + + Returns all df1 and df2 versions of the columns and join columns. Parameters @@ -708,7 +718,9 @@ def report( column_count: int = 10, html_file: Optional[str] = None, ) -> str: - """Returns a string representation of a report. The representation can + """Return a string representation of a report. + + The representation can then be printed or saved to a file. Parameters @@ -771,7 +783,7 @@ def report( "column_comparison.txt", len([col for col in self.column_stats if col["unequal_cnt"] > 0]), len([col for col in self.column_stats if col["unequal_cnt"] == 0]), - sum([col["unequal_cnt"] for col in self.column_stats]), + sum(col["unequal_cnt"] for col in self.column_stats), ) match_stats = [] @@ -870,7 +882,9 @@ def report( def render(filename: str, *fields: Union[int, float, str]) -> str: - """Renders out an individual template. This basically just reads in a + """Render out an individual template. + + This basically just reads in a template file, and applies ``.format()`` on the fields. Parameters @@ -901,8 +915,9 @@ def columns_equal( ignore_spaces: bool = False, ignore_case: bool = False, ) -> "pyspark.sql.DataFrame": - """Compares two columns from a dataframe, returning a True/False series, - with the same index as column 1. + """Compare two columns from a dataframe. + + Returns a True/False series with the same index as column 1. - Two nulls (np.nan) will evaluate to True. - A null and a non-null value will evaluate to False. @@ -976,9 +991,7 @@ def columns_equal( ) else: LOG.debug( - "Skipping {}({}) and {}({}), columns are not comparable".format( - col_1, base_dtype, col_2, compare_dtype - ) + f"Skipping {col_1}({base_dtype}) and {col_2}({compare_dtype}), columns are not comparable" ) dataframe = dataframe.withColumn(col_match, lit(False)) return dataframe @@ -989,7 +1002,7 @@ def get_merged_columns( merged_df: "pyspark.sql.DataFrame", suffix: str, ) -> List[str]: - """Gets the columns from an original dataframe, in the new merged dataframe + """Get the columns from an original dataframe, in the new merged dataframe. Parameters ---------- @@ -1020,7 +1033,7 @@ def get_merged_columns( def calculate_max_diff( dataframe: "pyspark.sql.DataFrame", col_1: str, col_2: str ) -> float: - """Get a maximum difference between two columns + """Get a maximum difference between two columns. Parameters ---------- @@ -1055,7 +1068,7 @@ def calculate_max_diff( def calculate_null_diff( dataframe: "pyspark.sql.DataFrame", col_1: str, col_2: str ) -> int: - """Get the null differences between two columns + """Get the null differences between two columns. Parameters ---------- @@ -1099,7 +1112,9 @@ def calculate_null_diff( def _generate_id_within_group( dataframe: "pyspark.sql.DataFrame", join_columns: List[str], order_column_name: str ) -> "pyspark.sql.DataFrame": - """Generate an ID column that can be used to deduplicate identical rows. The series generated + """Generate an ID column that can be used to deduplicate identical rows. + + The series generated is the order within a unique group, and it handles nulls. Requires a ``__index`` column. Parameters @@ -1129,7 +1144,7 @@ def _generate_id_within_group( return ( dataframe.select( - *(col(c).cast("string").alias(c) for c in join_columns + ["__index"]) + *(col(c).cast("string").alias(c) for c in [*join_columns, "__index"]) ) .fillna(default_value) .withColumn( @@ -1141,7 +1156,7 @@ def _generate_id_within_group( ) else: return ( - dataframe.select(join_columns + ["__index"]) + dataframe.select([*join_columns, "__index"]) .withColumn( order_column_name, row_number().over(Window.orderBy("__index").partitionBy(join_columns)) @@ -1154,7 +1169,7 @@ def _generate_id_within_group( def _get_column_dtypes( dataframe: "pyspark.sql.DataFrame", col_1: "str", col_2: "str" ) -> Tuple[str, str]: - """Get the dtypes of two columns + """Get the dtypes of two columns. Parameters ---------- @@ -1170,13 +1185,13 @@ def _get_column_dtypes( Tuple(str, str) Tuple of base and compare datatype """ - base_dtype = [d[1] for d in dataframe.dtypes if d[0] == col_1][0] - compare_dtype = [d[1] for d in dataframe.dtypes if d[0] == col_2][0] + base_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_1) + compare_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_2) return base_dtype, compare_dtype def _is_comparable(type1: str, type2: str) -> bool: - """Checks if two Spark data types can be safely compared. + """Check if two Spark data types can be safely compared. Two data types are considered comparable if any of the following apply: 1. Both data types are the same diff --git a/pyproject.toml b/pyproject.toml index 19075892..633dc018 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,18 +62,63 @@ docs = ["sphinx", "furo", "myst-parser"] tests = ["pytest", "pytest-cov"] tests-spark = ["pytest", "pytest-cov", "pytest-spark"] -qa = ["pre-commit", "black", "isort", "mypy", "pandas-stubs"] +qa = ["pre-commit", "ruff==0.5.7", "mypy", "pandas-stubs"] build = ["build", "twine", "wheel"] edgetest = ["edgetest", "edgetest-conda"] dev = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"] -[tool.isort] -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -line_length = 88 -profile = "black" +# Linters, formatters and type checkers +[tool.ruff] +extend-include = ["*.ipynb"] +target-version = "py39" +src = ["src"] + +[tool.ruff.lint] +preview = true +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "D", # pydocstyle + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + # "A", # flake8-builtins + "C4", # flake8-comprehensions + #"C901", # mccabe complexity + # "G", # flake8-logging-format + "T20", # flake8-print + "TID252", # flake8-tidy-imports ban relative imports + # "ARG", # flake8-unused-arguments + "SIM", # flake8-simplify + "NPY", # numpy rules + "LOG", # flake8-logging + "RUF", # Ruff errors +] + +ignore = [ + "E111", # Check indentation level. Using formatter instead. + "E114", # Check indentation level. Using formatter instead. + "E117", # Check indentation level. Using formatter instead. + "E203", # Check whitespace. Using formatter instead. + "E501", # Line too long. Using formatter instead. + "D206", # Docstring indentation. Using formatter instead. + "D300", # Use triple single quotes. Using formatter instead. + "SIM108", # Use ternary operator instead of if-else blocks. + "SIM105", # Use `contextlib.suppress(FileNotFoundError)` instead of `try`-`except`-`pass` + "UP035", # `typing.x` is deprecated, use `x` instead + "UP006", # `typing.x` is deprecated, use `x` instead +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402"] +"**/{tests,docs}/*" = ["E402", "D", "F841", "ARG"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.lint.pydocstyle] +convention = "numpy" [tool.mypy] strict = true diff --git a/tests/test_core.py b/tests/test_core.py index 103b9f3b..14298e09 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,6 +16,7 @@ """ Testing out the datacompy functionality """ + import io import logging import sys @@ -23,14 +24,13 @@ from decimal import Decimal from unittest import mock +import datacompy import numpy as np import pandas as pd import pytest from pandas.testing import assert_series_equal from pytest import raises -import datacompy - logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) @@ -509,7 +509,8 @@ def test_columns_maintain_order_through_set_operations(): def test_10k_rows(): - df1 = pd.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + rng = np.random.default_rng() + df1 = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["b", "c"]) df1.reset_index(inplace=True) df1.columns = ["a", "b", "c"] df2 = df1.copy() @@ -552,7 +553,8 @@ def test_not_subset(caplog): def test_large_subset(): - df1 = pd.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + rng = np.random.default_rng() + df1 = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["b", "c"]) df1.reset_index(inplace=True) df1.columns = ["a", "b", "c"] df2 = df1[["a", "b"]].sample(50).copy() @@ -686,7 +688,7 @@ def test_temp_column_name_one_already(): assert actual == "_temp_0" -### Duplicate testing! +# Duplicate testing! def test_simple_dupes_one_field(): df1 = pd.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) df2 = pd.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) diff --git a/tests/test_fugue/conftest.py b/tests/test_fugue/conftest.py index a2ca99b1..a73db831 100644 --- a/tests/test_fugue/conftest.py +++ b/tests/test_fugue/conftest.py @@ -5,24 +5,24 @@ @pytest.fixture def ref_df(): - np.random.seed(0) + rng = np.random.default_rng(0) df1 = pd.DataFrame( - dict( - a=np.random.randint(0, 10, 100), - b=np.random.rand(100), - c=np.random.choice(["aaa", "b_c", "csd"], 100), - ) + { + "a": rng.integers(0, 10, 100), + "b": rng.uniform(size=100), + "c": rng.choice(["aaa", "b_c", "csd"], 100), + } ) df1_copy = df1.copy() df2 = df1.copy().drop(columns=["c"]) df3 = df1.copy().drop(columns=["a", "b"]) df4 = pd.DataFrame( - dict( - a=np.random.randint(1, 12, 100), # shift the join col - b=np.random.rand(100), - c=np.random.choice(["aaa", "b_c", "csd"], 100), - ) + { + "a": rng.integers(1, 12, 100), # shift the join col + "b": rng.uniform(size=100), + "c": rng.choice(["aaa", "b_c", "csd"], 100), + } ) df5 = df1.sample(frac=0.1) return [df1, df1_copy, df2, df3, df4, df5] @@ -55,49 +55,47 @@ def upper_col_df(shuffle_df): @pytest.fixture def simple_diff_df1(): - return pd.DataFrame(dict(aa=[0, 1, 0], bb=[2.1, 3.1, 4.1])).convert_dtypes() + return pd.DataFrame({"aa": [0, 1, 0], "bb": [2.1, 3.1, 4.1]}).convert_dtypes() @pytest.fixture def simple_diff_df2(): return pd.DataFrame( - dict(aa=[1, 0, 1], bb=[3.1, 4.1, 5.1], cc=["a", "b", "c"]) + {"aa": [1, 0, 1], "bb": [3.1, 4.1, 5.1], "cc": ["a", "b", "c"]} ).convert_dtypes() @pytest.fixture def no_intersection_diff_df1(): - np.random.seed(0) - return pd.DataFrame(dict(x=["a"], y=[0.1])).convert_dtypes() + return pd.DataFrame({"x": ["a"], "y": [0.1]}).convert_dtypes() @pytest.fixture def no_intersection_diff_df2(): - return pd.DataFrame(dict(x=["b"], y=[1.1])).convert_dtypes() + return pd.DataFrame({"x": ["b"], "y": [1.1]}).convert_dtypes() @pytest.fixture def large_diff_df1(): - np.random.seed(0) - data = np.random.randint(0, 7, size=10000) + rng = np.random.default_rng(0) + data = rng.integers(0, 7, size=10000) return pd.DataFrame({"x": data, "y": np.array([9] * 10000)}).convert_dtypes() @pytest.fixture def large_diff_df2(): - np.random.seed(0) - data = np.random.randint(6, 11, size=10000) + rng = np.random.default_rng(0) + data = rng.integers(6, 11, size=10000) return pd.DataFrame({"x": data, "y": np.array([9] * 10000)}).convert_dtypes() @pytest.fixture def count_matching_rows_df(): - np.random.seed(0) df1 = pd.DataFrame( - dict( - a=np.arange(0, 100), - b=np.arange(0, 100), - ) + { + "a": np.arange(0, 100), + "b": np.arange(0, 100), + } ) - df2 = df1.sample(frac=0.1) + df2 = df1.sample(frac=0.1, random_state=0) return [df1, df2] diff --git a/tests/test_fugue/test_duckdb.py b/tests/test_fugue/test_duckdb.py index 3643f22d..715f1012 100644 --- a/tests/test_fugue/test_duckdb.py +++ b/tests/test_fugue/test_duckdb.py @@ -13,10 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test fugue functionality with duckdb.""" -import pytest -from ordered_set import OrderedSet -from pytest import raises +import pytest from datacompy import ( all_columns_match, all_rows_overlap, @@ -25,6 +23,8 @@ is_match, unq_columns, ) +from ordered_set import OrderedSet +from pytest import raises duckdb = pytest.importorskip("duckdb") diff --git a/tests/test_fugue/test_fugue_pandas.py b/tests/test_fugue/test_fugue_pandas.py index 4fd74ce7..2f12a5cf 100644 --- a/tests/test_fugue/test_fugue_pandas.py +++ b/tests/test_fugue/test_fugue_pandas.py @@ -13,13 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test the fugue functionality with pandas.""" + from io import StringIO import pandas as pd -from ordered_set import OrderedSet -from pytest import raises -from test_fugue_helpers import _compare_report - from datacompy import ( Compare, all_columns_match, @@ -30,6 +27,9 @@ report, unq_columns, ) +from ordered_set import OrderedSet +from pytest import raises +from test_fugue_helpers import _compare_report def test_is_match_native( diff --git a/tests/test_fugue/test_fugue_polars.py b/tests/test_fugue/test_fugue_polars.py index dcd19a94..7f56d218 100644 --- a/tests/test_fugue/test_fugue_polars.py +++ b/tests/test_fugue/test_fugue_polars.py @@ -13,10 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test fugue and polars.""" -import pytest -from ordered_set import OrderedSet -from pytest import raises +import pytest from datacompy import ( all_columns_match, all_rows_overlap, @@ -25,6 +23,8 @@ is_match, unq_columns, ) +from ordered_set import OrderedSet +from pytest import raises pl = pytest.importorskip("polars") diff --git a/tests/test_fugue/test_fugue_spark.py b/tests/test_fugue/test_fugue_spark.py index efc895ff..ae317eb8 100644 --- a/tests/test_fugue/test_fugue_spark.py +++ b/tests/test_fugue/test_fugue_spark.py @@ -13,11 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test fugue and spark.""" -import pytest -from ordered_set import OrderedSet -from pytest import raises -from test_fugue_helpers import _compare_report +import pytest from datacompy import ( Compare, all_columns_match, @@ -28,6 +25,9 @@ report, unq_columns, ) +from ordered_set import OrderedSet +from pytest import raises +from test_fugue_helpers import _compare_report pyspark = pytest.importorskip("pyspark") diff --git a/tests/test_polars.py b/tests/test_polars.py index 0640cd29..7fb487f6 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -30,17 +30,16 @@ pytest.importorskip("polars") -import polars as pl # noqa: E402 -from polars.exceptions import ComputeError, DuplicateError # noqa: E402 -from polars.testing import assert_series_equal # noqa: E402 - -from datacompy import PolarsCompare # noqa: E402 -from datacompy.polars import ( # noqa: E402 +import polars as pl +from datacompy import PolarsCompare +from datacompy.polars import ( calculate_max_diff, columns_equal, generate_id_within_group, temp_column_name, ) +from polars.exceptions import ComputeError, DuplicateError +from polars.testing import assert_series_equal logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) @@ -469,7 +468,8 @@ def test_columns_maintain_order_through_set_operations(): def test_10k_rows(): - df1 = pl.DataFrame(np.random.randint(0, 100, size=(10000, 2)), schema=["b", "c"]) + rng = np.random.default_rng() + df1 = pl.DataFrame(rng.integers(0, 100, size=(10000, 2)), schema=["b", "c"]) df1 = df1.with_row_index() df1.columns = ["a", "b", "c"] df2 = df1.clone() @@ -512,7 +512,8 @@ def test_not_subset(caplog): def test_large_subset(): - df1 = pl.DataFrame(np.random.randint(0, 100, size=(10000, 2)), schema=["b", "c"]) + rng = np.random.default_rng() + df1 = pl.DataFrame(rng.integers(0, 100, size=(10000, 2)), schema=["b", "c"]) df1 = df1.with_row_index() df1.columns = ["a", "b", "c"] df2 = df1[["a", "b"]].sample(50).clone() @@ -612,7 +613,7 @@ def test_temp_column_name_one_already(): assert actual == "_temp_0" -### Duplicate testing! +# Duplicate testing! def test_simple_dupes_one_field(): df1 = pl.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) df2 = pl.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) @@ -855,7 +856,7 @@ def test_sample_mismatch(): output = compare.sample_mismatch(column="name", sample_count=3) assert output.shape[0] == 2 - assert (["name_df1"] != output["name_df2"]).all() + assert (output["name_df2"] != ["name_df1"]).all() def test_all_mismatch_not_ignore_matching_cols_no_cols_matching(): diff --git a/tests/test_spark/test_legacy_spark.py b/tests/test_spark/test_legacy_spark.py index 3a1cebe5..74fa5668 100644 --- a/tests/test_spark/test_legacy_spark.py +++ b/tests/test_spark/test_legacy_spark.py @@ -23,8 +23,13 @@ pytest.importorskip("pyspark") -from pyspark.sql import Row # noqa: E402 -from pyspark.sql.types import ( # noqa: E402 +from datacompy.spark.legacy import ( + NUMERIC_SPARK_TYPES, + LegacySparkCompare, + _is_comparable, +) +from pyspark.sql import Row +from pyspark.sql.types import ( DateType, DecimalType, DoubleType, @@ -34,12 +39,6 @@ StructType, ) -from datacompy.spark.legacy import ( # noqa: E402 - NUMERIC_SPARK_TYPES, - LegacySparkCompare, - _is_comparable, -) - # Turn off py4j debug messages for all tests in this module logging.getLogger("py4j").setLevel(logging.INFO) diff --git a/tests/test_spark/test_pandas_spark.py b/tests/test_spark/test_pandas_spark.py index 5517ff9a..f31f8ae8 100644 --- a/tests/test_spark/test_pandas_spark.py +++ b/tests/test_spark/test_pandas_spark.py @@ -32,16 +32,15 @@ pytest.importorskip("pyspark") -import pyspark.pandas as ps # noqa: E402 -from pandas.testing import assert_series_equal # noqa: E402 - -from datacompy.spark.pandas import ( # noqa: E402 +import pyspark.pandas as ps +from datacompy.spark.pandas import ( SparkPandasCompare, calculate_max_diff, columns_equal, generate_id_within_group, temp_column_name, ) +from pandas.testing import assert_series_equal logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) @@ -507,7 +506,8 @@ def test_columns_maintain_order_through_set_operations(): @pandas_version def test_10k_rows(): - df1 = ps.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + rng = np.random.default_rng() + df1 = ps.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["b", "c"]) df1.reset_index(inplace=True) df1.columns = ["a", "b", "c"] df2 = df1.copy() @@ -553,7 +553,8 @@ def test_not_subset(caplog): @pandas_version def test_large_subset(): - df1 = ps.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + rng = np.random.default_rng() + df1 = ps.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["b", "c"]) df1.reset_index(inplace=True) df1.columns = ["a", "b", "c"] df2 = df1[["a", "b"]].head(50).copy() @@ -665,7 +666,7 @@ def test_temp_column_name_one_already(): assert actual == "_temp_0" -### Duplicate testing! +# Duplicate testing! @pandas_version def test_simple_dupes_one_field(): df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) @@ -1304,9 +1305,11 @@ def test_pandas_version(): expected_message = "It seems like you are running Pandas 2+. Please note that Pandas 2+ will only be supported in Spark 4+. See: https://issues.apache.org/jira/browse/SPARK-44101. If you need to use Spark DataFrame with Pandas 2+ then consider using Fugue otherwise downgrade to Pandas 1.5.3" df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}]) - with mock.patch("pandas.__version__", "2.0.0"): - with raises(Exception, match=re.escape(expected_message)): - SparkPandasCompare(df1, df2, join_columns=["a"]) + with ( + mock.patch("pandas.__version__", "2.0.0"), + raises(Exception, match=re.escape(expected_message)), + ): + SparkPandasCompare(df1, df2, join_columns=["a"]) with mock.patch("pandas.__version__", "1.5.3"): SparkPandasCompare(df1, df2, join_columns=["a"]) @@ -1315,10 +1318,16 @@ def test_pandas_version(): @pandas_version def test_unicode_columns(): df1 = ps.DataFrame( - [{"a": 1, "例": 2, "予測対象日": "test"}, {"a": 1, "例": 3, "予測対象日": "test"}] + [ + {"a": 1, "例": 2, "予測対象日": "test"}, + {"a": 1, "例": 3, "予測対象日": "test"}, + ] ) df2 = ps.DataFrame( - [{"a": 1, "例": 2, "予測対象日": "test"}, {"a": 1, "例": 3, "予測対象日": "test"}] + [ + {"a": 1, "例": 2, "予測対象日": "test"}, + {"a": 1, "例": 3, "予測対象日": "test"}, + ] ) compare = SparkPandasCompare(df1, df2, join_columns=["例"]) assert compare.matches() diff --git a/tests/test_spark/test_sql_spark.py b/tests/test_spark/test_sql_spark.py index 8fb4cfcf..50fbe901 100644 --- a/tests/test_spark/test_sql_spark.py +++ b/tests/test_spark/test_sql_spark.py @@ -19,7 +19,6 @@ import io import logging -import re import sys from datetime import datetime from decimal import Decimal @@ -33,15 +32,14 @@ pytest.importorskip("pyspark") -from pandas.testing import assert_series_equal # noqa: E402 - -from datacompy.spark.sql import ( # noqa: E402 +from datacompy.spark.sql import ( SparkSQLCompare, _generate_id_within_group, calculate_max_diff, columns_equal, temp_column_name, ) +from pandas.testing import assert_series_equal logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) @@ -493,7 +491,8 @@ def test_columns_maintain_order_through_set_operations(spark_session): def test_10k_rows(spark_session): - pdf = pd.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + rng = np.random.default_rng() + pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["b", "c"]) pdf.reset_index(inplace=True) pdf.columns = ["a", "b", "c"] pdf2 = pdf.copy() @@ -543,7 +542,8 @@ def test_not_subset(spark_session, caplog): def test_large_subset(spark_session): - pdf = pd.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"]) + rng = np.random.default_rng() + pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["b", "c"]) pdf.reset_index(inplace=True) pdf.columns = ["a", "b", "c"] pdf2 = pdf[["a", "b"]].head(50).copy() @@ -662,7 +662,7 @@ def test_temp_column_name_one_already(spark_session): assert actual == "_temp_0" -### Duplicate testing! +# Duplicate testing! def test_simple_dupes_one_field(spark_session): From 735b1b3ac5e9402d6283fdd9bf902dc56eb07223 Mon Sep 17 00:00:00 2001 From: Faisal Date: Wed, 11 Sep 2024 13:13:20 -0300 Subject: [PATCH 3/3] updating polars version and bumping to 0.13.3 (#331) --- datacompy/__init__.py | 2 +- pyproject.toml | 2 +- tests/test_polars.py | 8 ++++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/datacompy/__init__.py b/datacompy/__init__.py index fad85ae6..8dfa816a 100644 --- a/datacompy/__init__.py +++ b/datacompy/__init__.py @@ -18,7 +18,7 @@ Then extended to carry that functionality over to Spark Dataframes. """ -__version__ = "0.13.2" +__version__ = "0.13.3" import platform from warnings import warn diff --git a/pyproject.toml b/pyproject.toml index 633dc018..63a5464e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ maintainers = [ { name="Faisal Dosani", email="faisal.dosani@capitalone.com" } ] license = {text = "Apache Software License"} -dependencies = ["pandas<=2.2.2,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.9.1,>=0.8.7", "polars<=1.1.0,>=0.20.4"] +dependencies = ["pandas<=2.2.2,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.9.1,>=0.8.7", "polars<=1.7.0,>=0.20.4"] requires-python = ">=3.9.0" classifiers = [ "Intended Audience :: Developers", diff --git a/tests/test_polars.py b/tests/test_polars.py index 7fb487f6..779974b0 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -386,7 +386,9 @@ def test_compare_df_setter_bad(): PolarsCompare("a", "a", ["a"]) with raises(ValueError, match="df1 must have all columns from join_columns"): PolarsCompare(df, df.clone(), ["b"]) - with raises(DuplicateError, match="duplicate column names found"): + with raises( + DuplicateError, match="column with name 'a' has more than one occurrences" + ): PolarsCompare(df_same_col_names, df_same_col_names.clone(), ["a"]) assert PolarsCompare(df_dupe, df_dupe.clone(), ["a", "b"]).df1.equals(df_dupe) @@ -416,7 +418,9 @@ def test_compare_df_setter_bad_index(): df = pl.DataFrame([{"a": 1, "A": 2}, {"a": 2, "A": 2}]) with raises(TypeError, match="df1 must be a Polars DataFrame"): PolarsCompare("a", "a", join_columns="a") - with raises(DuplicateError, match="duplicate column names found"): + with raises( + DuplicateError, match="column with name 'a' has more than one occurrences" + ): PolarsCompare(df, df.clone(), join_columns="a")