diff --git a/.github/workflows/edgetest.yml b/.github/workflows/edgetest.yml
index b53cf46c..e47ca6ff 100644
--- a/.github/workflows/edgetest.yml
+++ b/.github/workflows/edgetest.yml
@@ -5,17 +5,51 @@ name: Run edgetest
on:
schedule:
- cron: '30 17 * * 5'
+ workflow_dispatch: # allows manual dispatch
jobs:
edgetest:
runs-on: ubuntu-latest
name: running edgetest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
with:
ref: develop
- - id: run-edgetest
- uses: edgetest-dev/run-edgetest-action@v1.4
+
+ - name: Set up Python 3.10
+ uses: conda-incubator/setup-miniconda@v2
with:
- edgetest-flags: '-c pyproject.toml --export'
- base-branch: 'develop'
- skip-pr: 'false'
\ No newline at end of file
+ auto-update-conda: true
+ python-version: '3.10'
+ channels: conda-forge
+
+ - name: Setup Java JDK
+ uses: actions/setup-java@v3
+ with:
+ java-version: '8'
+ distribution: 'adopt'
+
+ - name: Install edgetest
+ shell: bash -el {0}
+ run: |
+ conda install pip
+ conda install edgetest edgetest-conda
+ python -m pip install .[dev]
+
+ - name: Run edgetest
+ shell: bash -el {0}
+ run: |
+ edgetest -c pyproject.toml --export
+
+ - name: Create Pull Request
+ uses: peter-evans/create-pull-request@v3
+ with:
+ branch: edgetest-patch
+ base: develop
+ delete-branch: true
+ title: Changes by run-edgetest action
+ commit-message: '[edgetest] automated change'
+ body: Automated changes by [run-edgetest-action](https://github.com/edgetest-dev/run-edgetest-action) GitHub action
+ add-paths: |
+ requirements.txt
+ setup.cfg
+ pyproject.toml
\ No newline at end of file
diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml
index 1e4536b3..a8390f0e 100644
--- a/.github/workflows/test-package.yml
+++ b/.github/workflows/test-package.yml
@@ -9,39 +9,101 @@ on:
pull_request:
branches: [develop, main]
+permissions:
+ contents: read
+
jobs:
- build:
+ test-dev-install:
- runs-on: ubuntu-latest
+ runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: [3.8, 3.9, '3.10']
- spark-version: [3.0.3, 3.1.3, 3.2.3, 3.3.1, 3.4.0]
+ python-version: [3.8, 3.9, '3.10', '3.11']
+ spark-version: [3.1.3, 3.2.4, 3.3.4, 3.4.2, 3.5.0]
+ exclude:
+ - python-version: '3.11'
+ spark-version: 3.1.3
+ - python-version: '3.11'
+ spark-version: 3.2.4
+ - python-version: '3.11'
+ spark-version: 3.3.4
env:
- PYTHON_VERSION: ${{ matrix.python-version }}
+ PYTHON_VERSION: ${{ matrix.python-version }}
SPARK_VERSION: ${{ matrix.spark-version }}
steps:
- - uses: actions/checkout@v2
-
+ - uses: actions/checkout@v3
+
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
-
+
- name: Setup Java JDK
uses: actions/setup-java@v3
with:
java-version: '8'
distribution: 'adopt'
-
- - name: Install Spark
+
+ - name: Install Spark and datacompy
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-spark pypandoc
python -m pip install pyspark==${{ matrix.spark-version }}
- python -m pip install .[dev,spark]
+ python -m pip install .[dev]
+ - name: Test with pytest
+ run: |
+ python -m pytest tests/
+
+ test-bare-install:
+
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: [3.8, 3.9, '3.10', '3.11']
+ env:
+ PYTHON_VERSION: ${{ matrix.python-version }}
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install datacompy
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install .[tests]
+ - name: Test with pytest
+ run: |
+ python -m pytest tests/
+
+ test-fugue-install-no-spark:
+
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: [3.8, 3.9, '3.10', '3.11']
+ env:
+ PYTHON_VERSION: ${{ matrix.python-version }}
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install datacompy
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install .[tests,duckdb,polars,dask,ray]
- name: Test with pytest
run: |
python -m pytest tests/
diff --git a/README.md b/README.md
index c80f30ad..fefe447b 100644
--- a/README.md
+++ b/README.md
@@ -38,308 +38,31 @@ pip install datacompy[ray]
```
+### In-scope Spark versions
+Different versions of Spark play nicely with only certain versions of Python below is a matrix of what we test with
-## Pandas Detail
+| | Spark 3.1.3 | Spark 3.2.3 | Spark 3.3.4 | Spark 3.4.2 | Spark 3.5.0 |
+|-------------|--------------|-------------|-------------|-------------|-------------|
+| Python 3.8 | ✅ | ✅ | ✅ | ✅ | ✅ |
+| Python 3.9 | ✅ | ✅ | ✅ | ✅ | ✅ |
+| Python 3.10 | ✅ | ✅ | ✅ | ✅ | ✅ |
+| Python 3.11 | ❌ | ❌ | ❌ | ✅ | ✅ |
+| Python 3.12 | ❌ | ❌ | ❌ | ❌ | ❌ |
-DataComPy will try to join two dataframes either on a list of join columns, or
-on indexes. If the two dataframes have duplicates based on join values, the
-match process sorts by the remaining fields and joins based on that row number.
-Column-wise comparisons attempt to match values even when dtypes don't match.
-So if, for example, you have a column with ``decimal.Decimal`` values in one
-dataframe and an identically-named column with ``float64`` dtype in another,
-it will tell you that the dtypes are different but will still try to compare the
-values.
-
-
-### Basic Usage
-
-```python
-
-from io import StringIO
-import pandas as pd
-import datacompy
-
-data1 = """acct_id,dollar_amt,name,float_fld,date_fld
-10000001234,123.45,George Maharis,14530.1555,2017-01-01
-10000001235,0.45,Michael Bluth,1,2017-01-01
-10000001236,1345,George Bluth,,2017-01-01
-10000001237,123456,Bob Loblaw,345.12,2017-01-01
-10000001239,1.05,Lucille Bluth,,2017-01-01
-"""
-
-data2 = """acct_id,dollar_amt,name,float_fld
-10000001234,123.4,George Michael Bluth,14530.155
-10000001235,0.45,Michael Bluth,
-10000001236,1345,George Bluth,1
-10000001237,123456,Robert Loblaw,345.12
-10000001238,1.05,Loose Seal Bluth,111
-"""
-
-df1 = pd.read_csv(StringIO(data1))
-df2 = pd.read_csv(StringIO(data2))
-
-compare = datacompy.Compare(
- df1,
- df2,
- join_columns='acct_id', #You can also specify a list of columns
- abs_tol=0, #Optional, defaults to 0
- rel_tol=0, #Optional, defaults to 0
- df1_name='Original', #Optional, defaults to 'df1'
- df2_name='New' #Optional, defaults to 'df2'
- )
-compare.matches(ignore_extra_columns=False)
-# False
-
-# This method prints out a human-readable report summarizing and sampling differences
-print(compare.report())
-```
-
-See docs for more detailed usage instructions and an example of the report output.
-
-
-### Things that are happening behind the scenes
-
-- You pass in two dataframes (``df1``, ``df2``) to ``datacompy.Compare`` and a
- column to join on (or list of columns) to ``join_columns``. By default the
- comparison needs to match values exactly, but you can pass in ``abs_tol``
- and/or ``rel_tol`` to apply absolute and/or relative tolerances for numeric columns.
-
- - You can pass in ``on_index=True`` instead of ``join_columns`` to join on
- the index instead.
-
-- The class validates that you passed dataframes, that they contain all of the
- columns in `join_columns` and have unique column names other than that. The
- class also lowercases all column names to disambiguate.
-- On initialization the class validates inputs, and runs the comparison.
-- ``Compare.matches()`` will return ``True`` if the dataframes match, ``False``
- otherwise.
-
- - You can pass in ``ignore_extra_columns=True`` to not return ``False`` just
- because there are non-overlapping column names (will still check on
- overlapping columns)
- - NOTE: if you only want to validate whether a dataframe matches exactly or
- not, you should look at ``pandas.testing.assert_frame_equal``. The main
- use case for ``datacompy`` is when you need to interpret the difference
- between two dataframes.
-
-- Compare also has some shortcuts like
-
- - ``intersect_rows``, ``df1_unq_rows``, ``df2_unq_rows`` for getting
- intersection, just df1 and just df2 records (DataFrames)
- - ``intersect_columns()``, ``df1_unq_columns()``, ``df2_unq_columns()`` for
- getting intersection, just df1 and just df2 columns (Sets)
-
-- You can turn on logging to see more detailed logs.
-
-
-## Fugue Detail
-
-[Fugue](https://github.com/fugue-project/fugue) is a Python library that provides a unified interface
-for data processing on Pandas, DuckDB, Polars, Arrow, Spark, Dask, Ray, and many other backends.
-DataComPy integrates with Fugue to provide a simple way to compare data across these backends.
-
-### Basic Usage
-
-The following usage example compares two Pandas dataframes, it is equivalent to the Pandas example above.
-
-```python
-from io import StringIO
-import pandas as pd
-import datacompy
-
-data1 = """acct_id,dollar_amt,name,float_fld,date_fld
-10000001234,123.45,George Maharis,14530.1555,2017-01-01
-10000001235,0.45,Michael Bluth,1,2017-01-01
-10000001236,1345,George Bluth,,2017-01-01
-10000001237,123456,Bob Loblaw,345.12,2017-01-01
-10000001239,1.05,Lucille Bluth,,2017-01-01
-"""
-
-data2 = """acct_id,dollar_amt,name,float_fld
-10000001234,123.4,George Michael Bluth,14530.155
-10000001235,0.45,Michael Bluth,
-10000001236,1345,George Bluth,1
-10000001237,123456,Robert Loblaw,345.12
-10000001238,1.05,Loose Seal Bluth,111
-"""
-
-df1 = pd.read_csv(StringIO(data1))
-df2 = pd.read_csv(StringIO(data2))
-
-datacompy.is_match(
- df1,
- df2,
- join_columns='acct_id', #You can also specify a list of columns
- abs_tol=0, #Optional, defaults to 0
- rel_tol=0, #Optional, defaults to 0
- df1_name='Original', #Optional, defaults to 'df1'
- df2_name='New' #Optional, defaults to 'df2'
-)
-# False
-
-# This method prints out a human-readable report summarizing and sampling differences
-print(datacompy.report(
- df1,
- df2,
- join_columns='acct_id', #You can also specify a list of columns
- abs_tol=0, #Optional, defaults to 0
- rel_tol=0, #Optional, defaults to 0
- df1_name='Original', #Optional, defaults to 'df1'
- df2_name='New' #Optional, defaults to 'df2'
-))
-```
-
-In order to compare dataframes of different backends, you just need to replace ``df1`` and ``df2`` with
-dataframes of different backends. Just pass in Dataframes such as Pandas dataframes, DuckDB relations,
-Polars dataframes, Arrow tables, Spark dataframes, Dask dataframes or Ray datasets. For example,
-to compare a Pandas dataframe with a Spark dataframe:
-
-```python
-from pyspark.sql import SparkSession
-
-spark = SparkSession.builder.getOrCreate()
-spark_df2 = spark.createDataFrame(df2)
-datacompy.is_match(
- df1,
- spark_df2,
- join_columns='acct_id',
-)
-```
-
-Notice that in order to use a specific backend, you need to have the corresponding library installed.
-For example, if you want compare Ray datasets, you must do
-
-```shell
-pip install datacompy[ray]
-```
-
-
-### How it works
-
-DataComPy uses Fugue to partition the two dataframes into chunks, and then compare each chunk in parallel
-using the Pandas-based ``Compare``. The comparison results are then aggregated to produce the final result.
-Different from the join operation used in ``SparkCompare``, the Fugue version uses the ``cogroup -> map``
-like semantic (not exactly the same, Fugue adopts a coarse version to achieve great performance), which
-guarantees full data comparison with consistent result compared to Pandas-based ``Compare``.
-
-
-## Spark Detail
-
-:::{important}
-With version ``v0.9.0`` SparkCompare now uses Null Safe (``<=>``) comparisons
+:::{note}
+At the current time Python ``3.12`` is not supported by Spark and also Ray within Fugue.
:::
-DataComPy's ``SparkCompare`` class will join two dataframes either on a list of join
-columns. It has the capability to map column names that may be different in each
-dataframe, including in the join columns. You are responsible for creating the
-dataframes from any source which Spark can handle and specifying a unique join
-key. If there are duplicates in either dataframe by join key, the match process
-will remove the duplicates before joining (and tell you how many duplicates were
-found).
-
-As with the Pandas-based ``Compare`` class, comparisons will be attempted even
-if dtypes don't match. Any schema differences will be reported in the output
-as well as in any mismatch reports, so that you can assess whether or not a
-type mismatch is a problem or not.
-
-The main reasons why you would choose to use ``SparkCompare`` over ``Compare``
-are that your data is too large to fit into memory, or you're comparing data
-that works well in a Spark environment, like partitioned Parquet, CSV, or JSON
-files, or Cerebro tables.
-
-### Performance Implications
-
-
-Spark scales incredibly well, so you can use ``SparkCompare`` to compare
-billions of rows of data, provided you spin up a big enough cluster. Still,
-joining billions of rows of data is an inherently large task, so there are a
-couple of things you may want to take into consideration when getting into the
-cliched realm of "big data":
-
-* ``SparkCompare`` will compare all columns in common in the dataframes and
- report on the rest. If there are columns in the data that you don't care to
- compare, use a ``select`` statement/method on the dataframe(s) to filter
- those out. Particularly when reading from wide Parquet files, this can make
- a huge difference when the columns you don't care about don't have to be
- read into memory and included in the joined dataframe.
-* For large datasets, adding ``cache_intermediates=True`` to the ``SparkCompare``
- call can help optimize performance by caching certain intermediate dataframes
- in memory, like the de-duped version of each input dataset, or the joined
- dataframe. Otherwise, Spark's lazy evaluation will recompute those each time
- it needs the data in a report or as you access instance attributes. This may
- be fine for smaller dataframes, but will be costly for larger ones. You do
- need to ensure that you have enough free cache memory before you do this, so
- this parameter is set to False by default.
-
-
-### Basic Usage
-
-```python
-
- import datetime
- import datacompy
- from pyspark.sql import Row
-
- # This example assumes you have a SparkSession named "spark" in your environment, as you
- # do when running `pyspark` from the terminal or in a Databricks notebook (Spark v2.0 and higher)
-
- data1 = [
- Row(acct_id=10000001234, dollar_amt=123.45, name='George Maharis', float_fld=14530.1555,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=1.0,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=None,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001237, dollar_amt=123456.0, name='Bob Loblaw', float_fld=345.12,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001239, dollar_amt=1.05, name='Lucille Bluth', float_fld=None,
- date_fld=datetime.date(2017, 1, 1))
- ]
-
- data2 = [
- Row(acct_id=10000001234, dollar_amt=123.4, name='George Michael Bluth', float_fld=14530.155),
- Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=None),
- Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=1.0),
- Row(acct_id=10000001237, dollar_amt=123456.0, name='Robert Loblaw', float_fld=345.12),
- Row(acct_id=10000001238, dollar_amt=1.05, name='Loose Seal Bluth', float_fld=111.0)
- ]
-
- base_df = spark.createDataFrame(data1)
- compare_df = spark.createDataFrame(data2)
-
- comparison = datacompy.SparkCompare(spark, base_df, compare_df, join_columns=['acct_id'])
-
- # This prints out a human-readable report summarizing differences
- comparison.report()
-```
-
-### Using SparkCompare on EMR or standalone Spark
-
-1. Set proxy variables
-2. Create a virtual environment, if desired (``virtualenv venv; source venv/bin/activate``)
-3. Pip install datacompy and requirements
-4. Ensure your SPARK_HOME environment variable is set (this is probably ``/usr/lib/spark`` but may
- differ based on your installation)
-5. Augment your PYTHONPATH environment variable with
- ``export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.10.4-src.zip:$SPARK_HOME/python:$PYTHONPATH``
- (note that your version of py4j may differ depending on the version of Spark you're using)
-
-
-### Using SparkCompare on Databricks
-
-1. Clone this repository locally
-2. Create a datacompy egg by running ``python setup.py bdist_egg`` from the repo root directory.
-3. From the Databricks front page, click the "Library" link under the "New" section.
-4. On the New library page:
- a. Change source to "Upload Python Egg or PyPi"
- b. Under "Upload Egg", Library Name should be "datacompy"
- c. Drag the egg file in datacompy/dist/ to the "Drop library egg here to upload" box
- d. Click the "Create Library" button
-5. Once the library has been created, from the library page (which you can find in your /Users/{login} workspace),
- you can choose clusters to attach the library to.
-6. ``import datacompy`` in a notebook attached to the cluster that the library is attached to and enjoy!
+## Supported backends
+- Pandas: ([See documentation](https://capitalone.github.io/datacompy/pandas_usage.html))
+- Spark: ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html))
+- Polars (Experimental): ([See documentation](https://capitalone.github.io/datacompy/polars_usage.html))
+- Fugue is a Python library that provides a unified interface for data processing on Pandas, DuckDB, Polars, Arrow,
+ Spark, Dask, Ray, and many other backends. DataComPy integrates with Fugue to provide a simple way to compare data
+ across these backends. Please note that Fugue will use the Pandas (Native) logic at its lowest level
+ ([See documentation](https://capitalone.github.io/datacompy/fugue_usage.html))
## Contributors
diff --git a/datacompy/__init__.py b/datacompy/__init__.py
index 7608f8cb..2231c88f 100644
--- a/datacompy/__init__.py
+++ b/datacompy/__init__.py
@@ -1,5 +1,5 @@
#
-# Copyright 2023 Capital One Services, LLC
+# Copyright 2024 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "0.10.5"
+__version__ = "0.11.0"
from datacompy.core import *
from datacompy.fugue import (
@@ -24,4 +24,5 @@
report,
unq_columns,
)
+from datacompy.polars import PolarsCompare
from datacompy.spark import NUMERIC_SPARK_TYPES, SparkCompare
diff --git a/datacompy/base.py b/datacompy/base.py
new file mode 100644
index 00000000..23a815fc
--- /dev/null
+++ b/datacompy/base.py
@@ -0,0 +1,141 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Compare two Pandas DataFrames
+
+Originally this package was meant to provide similar functionality to
+PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
+two dataframes.
+"""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Optional
+
+from ordered_set import OrderedSet
+
+LOG = logging.getLogger(__name__)
+
+
+class BaseCompare(ABC):
+ @property
+ def df1(self) -> Any:
+ return self._df1 # type: ignore
+
+ @df1.setter
+ @abstractmethod
+ def df1(self, df1: Any) -> None:
+ """Check that it is a dataframe and has the join columns"""
+ pass
+
+ @property
+ def df2(self) -> Any:
+ return self._df2 # type: ignore
+
+ @df2.setter
+ @abstractmethod
+ def df2(self, df2: Any) -> None:
+ """Check that it is a dataframe and has the join columns"""
+ pass
+
+ @abstractmethod
+ def _validate_dataframe(
+ self, index: str, cast_column_names_lower: bool = True
+ ) -> None:
+ """Check that it is a dataframe and has the join columns"""
+ pass
+
+ @abstractmethod
+ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
+ """Actually run the comparison. This tries to run df1.equals(df2)
+ first so that if they're truly equal we can tell.
+
+ This method will log out information about what is different between
+ the two dataframes, and will also return a boolean.
+ """
+ pass
+
+ @abstractmethod
+ def df1_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df1"""
+ pass
+
+ @abstractmethod
+ def df2_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df2"""
+ pass
+
+ @abstractmethod
+ def intersect_columns(self) -> OrderedSet[str]:
+ """Get columns that are shared between the two dataframes"""
+ pass
+
+ @abstractmethod
+ def _dataframe_merge(self, ignore_spaces: bool) -> None:
+ """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1
+ and df1 & df2
+
+ If ``on_index`` is True, this will join on index values, otherwise it
+ will join on the ``join_columns``.
+ """
+ pass
+
+ @abstractmethod
+ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
+ pass
+
+ @abstractmethod
+ def all_columns_match(self) -> bool:
+ pass
+
+ @abstractmethod
+ def all_rows_overlap(self) -> bool:
+ pass
+
+ @abstractmethod
+ def count_matching_rows(self) -> int:
+ pass
+
+ @abstractmethod
+ def intersect_rows_match(self) -> bool:
+ pass
+
+ @abstractmethod
+ def matches(self, ignore_extra_columns: bool = False) -> bool:
+ pass
+
+ @abstractmethod
+ def subset(self) -> bool:
+ pass
+
+ @abstractmethod
+ def sample_mismatch(
+ self, column: str, sample_count: int = 10, for_display: bool = False
+ ) -> Any:
+ pass
+
+ @abstractmethod
+ def all_mismatch(self, ignore_matching_cols: bool = False) -> Any:
+ pass
+
+ @abstractmethod
+ def report(
+ self,
+ sample_count: int = 10,
+ column_count: int = 10,
+ html_file: Optional[str] = None,
+ ) -> str:
+ pass
diff --git a/datacompy/core.py b/datacompy/core.py
index 9213c0e7..a1730768 100644
--- a/datacompy/core.py
+++ b/datacompy/core.py
@@ -1,5 +1,5 @@
#
-# Copyright 2020 Capital One Services, LLC
+# Copyright 2024 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,18 +20,20 @@
PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
two dataframes.
"""
-
import logging
import os
+from typing import Any, Dict, List, Optional, Union, cast
import numpy as np
import pandas as pd
from ordered_set import OrderedSet
+from datacompy.base import BaseCompare
+
LOG = logging.getLogger(__name__)
-class Compare:
+class Compare(BaseCompare):
"""Comparison class to be used to compare whether two dataframes as equal.
Both df1 and df2 should be dataframes containing all of the join_columns,
@@ -79,18 +81,18 @@ class Compare:
def __init__(
self,
- df1,
- df2,
- join_columns=None,
- on_index=False,
- abs_tol=0,
- rel_tol=0,
- df1_name="df1",
- df2_name="df2",
- ignore_spaces=False,
- ignore_case=False,
- cast_column_names_lower=True,
- ):
+ df1: pd.DataFrame,
+ df2: pd.DataFrame,
+ join_columns: Optional[Union[List[str], str]] = None,
+ on_index: bool = False,
+ abs_tol: float = 0,
+ rel_tol: float = 0,
+ df1_name: str = "df1",
+ df2_name: str = "df2",
+ ignore_spaces: bool = False,
+ ignore_case: bool = False,
+ cast_column_names_lower: bool = True,
+ ) -> None:
self.cast_column_names_lower = cast_column_names_lower
if on_index and join_columns is not None:
raise Exception("Only provide on_index or join_columns")
@@ -107,11 +109,11 @@ def __init__(
else:
self.join_columns = [
str(col).lower() if self.cast_column_names_lower else str(col)
- for col in join_columns
+ for col in cast(List[str], join_columns)
]
self.on_index = False
- self._any_dupes = False
+ self._any_dupes: bool = False
self.df1 = df1
self.df2 = df2
self.df1_name = df1_name
@@ -120,16 +122,18 @@ def __init__(
self.rel_tol = rel_tol
self.ignore_spaces = ignore_spaces
self.ignore_case = ignore_case
- self.df1_unq_rows = self.df2_unq_rows = self.intersect_rows = None
- self.column_stats = []
- self._compare(ignore_spaces, ignore_case)
+ self.df1_unq_rows: pd.DataFrame
+ self.df2_unq_rows: pd.DataFrame
+ self.intersect_rows: pd.DataFrame
+ self.column_stats: List[Dict[str, Any]] = []
+ self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case)
@property
- def df1(self):
+ def df1(self) -> pd.DataFrame:
return self._df1
@df1.setter
- def df1(self, df1):
+ def df1(self, df1: pd.DataFrame) -> None:
"""Check that it is a dataframe and has the join columns"""
self._df1 = df1
self._validate_dataframe(
@@ -137,18 +141,20 @@ def df1(self, df1):
)
@property
- def df2(self):
+ def df2(self) -> pd.DataFrame:
return self._df2
@df2.setter
- def df2(self, df2):
+ def df2(self, df2: pd.DataFrame) -> None:
"""Check that it is a dataframe and has the join columns"""
self._df2 = df2
self._validate_dataframe(
"df2", cast_column_names_lower=self.cast_column_names_lower
)
- def _validate_dataframe(self, index, cast_column_names_lower=True):
+ def _validate_dataframe(
+ self, index: str, cast_column_names_lower: bool = True
+ ) -> None:
"""Check that it is a dataframe and has the join columns
Parameters
@@ -163,9 +169,11 @@ def _validate_dataframe(self, index, cast_column_names_lower=True):
raise TypeError(f"{index} must be a pandas DataFrame")
if cast_column_names_lower:
- dataframe.columns = [str(col).lower() for col in dataframe.columns]
+ dataframe.columns = pd.Index(
+ [str(col).lower() for col in dataframe.columns]
+ )
else:
- dataframe.columns = [str(col) for col in dataframe.columns]
+ dataframe.columns = pd.Index([str(col) for col in dataframe.columns])
# Check if join_columns are present in the dataframe
if not set(self.join_columns).issubset(set(dataframe.columns)):
raise ValueError(f"{index} must have all columns from join_columns")
@@ -182,7 +190,7 @@ def _validate_dataframe(self, index, cast_column_names_lower=True):
):
self._any_dupes = True
- def _compare(self, ignore_spaces, ignore_case):
+ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
"""Actually run the comparison. This tries to run df1.equals(df2)
first so that if they're truly equal we can tell.
@@ -214,26 +222,31 @@ def _compare(self, ignore_spaces, ignore_case):
else:
LOG.info("df1 does not match df2")
- def df1_unq_columns(self):
+ def df1_unq_columns(self) -> OrderedSet[str]:
"""Get columns that are unique to df1"""
- return OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
+ return cast(
+ OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
+ )
- def df2_unq_columns(self):
+ def df2_unq_columns(self) -> OrderedSet[str]:
"""Get columns that are unique to df2"""
- return OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
+ return cast(
+ OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
+ )
- def intersect_columns(self):
+ def intersect_columns(self) -> OrderedSet[str]:
"""Get columns that are shared between the two dataframes"""
return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns)
- def _dataframe_merge(self, ignore_spaces):
+ def _dataframe_merge(self, ignore_spaces: bool) -> None:
"""Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1
and df1 & df2
If ``on_index`` is True, this will join on index values, otherwise it
will join on the ``join_columns``.
"""
-
+ params: Dict[str, Any]
+ index_column: str
LOG.debug("Outer joining")
if self._any_dupes:
LOG.debug("Duplicate rows found, deduping by order of remaining fields")
@@ -275,11 +288,10 @@ def _dataframe_merge(self, ignore_spaces):
# Clean up temp columns for duplicate row matching
if self._any_dupes:
if self.on_index:
- outer_join.index = outer_join[index_column]
- outer_join.drop(index_column, axis=1, inplace=True)
+ outer_join.set_index(keys=index_column, drop=True, inplace=True)
self.df1.drop(index_column, axis=1, inplace=True)
self.df2.drop(index_column, axis=1, inplace=True)
- outer_join.drop(order_column, axis=1, inplace=True)
+ outer_join.drop(labels=order_column, axis=1, inplace=True)
self.df1.drop(order_column, axis=1, inplace=True)
self.df2.drop(order_column, axis=1, inplace=True)
@@ -306,7 +318,7 @@ def _dataframe_merge(self, ignore_spaces):
f"Number of rows in df1 and df2 (not necessarily equal): {len(self.intersect_rows)}"
)
- def _intersect_compare(self, ignore_spaces, ignore_case):
+ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
"""Run the comparison on the intersect dataframe
This loops through all columns that are shared between df1 and df2, and
@@ -319,7 +331,7 @@ def _intersect_compare(self, ignore_spaces, ignore_case):
if column in self.join_columns:
match_cnt = row_cnt
col_match = ""
- max_diff = 0
+ max_diff = 0.0
null_diff = 0
else:
col_1 = column + "_df1"
@@ -367,11 +379,11 @@ def _intersect_compare(self, ignore_spaces, ignore_case):
}
)
- def all_columns_match(self):
+ def all_columns_match(self) -> bool:
"""Whether the columns all match in the dataframes"""
return self.df1_unq_columns() == self.df2_unq_columns() == set()
- def all_rows_overlap(self):
+ def all_rows_overlap(self) -> bool:
"""Whether the rows are all present in both dataframes
Returns
@@ -382,7 +394,7 @@ def all_rows_overlap(self):
"""
return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0
- def count_matching_rows(self):
+ def count_matching_rows(self) -> int:
"""Count the number of rows match (on overlapping fields)
Returns
@@ -396,18 +408,23 @@ def count_matching_rows(self):
match_columns.append(column + "_match")
return self.intersect_rows[match_columns].all(axis=1).sum()
- def intersect_rows_match(self):
+ def intersect_rows_match(self) -> bool:
"""Check whether the intersect rows all match"""
actual_length = self.intersect_rows.shape[0]
return self.count_matching_rows() == actual_length
- def matches(self, ignore_extra_columns=False):
+ def matches(self, ignore_extra_columns: bool = False) -> bool:
"""Return True or False if the dataframes match.
Parameters
----------
ignore_extra_columns : bool
Ignores any columns in one dataframe and not in the other.
+
+ Returns
+ -------
+ bool
+ True or False if the dataframes match.
"""
if not ignore_extra_columns and not self.all_columns_match():
return False
@@ -418,12 +435,17 @@ def matches(self, ignore_extra_columns=False):
else:
return True
- def subset(self):
+ def subset(self) -> bool:
"""Return True if dataframe 2 is a subset of dataframe 1.
Dataframe 2 is considered a subset if all of its columns are in
dataframe 1, and all of its rows match rows in dataframe 1 for the
shared columns.
+
+ Returns
+ -------
+ bool
+ True if dataframe 2 is a subset of dataframe 1.
"""
if not self.df2_unq_columns() == set():
return False
@@ -434,7 +456,9 @@ def subset(self):
else:
return True
- def sample_mismatch(self, column, sample_count=10, for_display=False):
+ def sample_mismatch(
+ self, column: str, sample_count: int = 10, for_display: bool = False
+ ) -> pd.DataFrame:
"""Returns a sample sub-dataframe which contains the identifying
columns, and df1 and df2 versions of the column.
@@ -463,13 +487,16 @@ def sample_mismatch(self, column, sample_count=10, for_display=False):
return_cols = self.join_columns + [column + "_df1", column + "_df2"]
to_return = sample[return_cols]
if for_display:
- to_return.columns = self.join_columns + [
- column + " (" + self.df1_name + ")",
- column + " (" + self.df2_name + ")",
- ]
+ to_return.columns = pd.Index(
+ self.join_columns
+ + [
+ column + " (" + self.df1_name + ")",
+ column + " (" + self.df2_name + ")",
+ ]
+ )
return to_return
- def all_mismatch(self, ignore_matching_cols=False):
+ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
"""All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join
columns.
@@ -512,7 +539,12 @@ def all_mismatch(self, ignore_matching_cols=False):
mm_bool = self.intersect_rows[match_list].all(axis="columns")
return self.intersect_rows[~mm_bool][self.join_columns + return_list]
- def report(self, sample_count=10, column_count=10, html_file=None):
+ def report(
+ self,
+ sample_count: int = 10,
+ column_count: int = 10,
+ html_file: Optional[str] = None,
+ ) -> str:
"""Returns a string representation of a report. The representation can
then be printed or saved to a file.
@@ -533,7 +565,7 @@ def report(self, sample_count=10, column_count=10, html_file=None):
The report, formatted kinda nicely.
"""
- def df_to_str(pdf):
+ def df_to_str(pdf: pd.DataFrame) -> str:
if not self.on_index:
pdf = pdf.reset_index(drop=True)
return pdf.to_string()
@@ -674,7 +706,7 @@ def df_to_str(pdf):
return report
-def render(filename, *fields):
+def render(filename: str, *fields: Union[int, float, str]) -> str:
"""Renders out an individual template. This basically just reads in a
template file, and applies ``.format()`` on the fields.
@@ -697,8 +729,13 @@ def render(filename, *fields):
def columns_equal(
- col_1, col_2, rel_tol=0, abs_tol=0, ignore_spaces=False, ignore_case=False
-):
+ col_1: "pd.Series[Any]",
+ col_2: "pd.Series[Any]",
+ rel_tol: float = 0,
+ abs_tol: float = 0,
+ ignore_spaces: bool = False,
+ ignore_case: bool = False,
+) -> "pd.Series[bool]":
"""Compares two columns from a dataframe, returning a True/False series,
with the same index as column 1.
@@ -731,6 +768,7 @@ def columns_equal(
A series of Boolean values. True == the values match, False == the
values don't match.
"""
+ compare: pd.Series[bool]
try:
compare = pd.Series(
np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)
@@ -773,7 +811,9 @@ def columns_equal(
return compare
-def compare_string_and_date_columns(col_1, col_2):
+def compare_string_and_date_columns(
+ col_1: "pd.Series[Any]", col_2: "pd.Series[Any]"
+) -> "pd.Series[bool]":
"""Compare a string column and date column, value-wise. This tries to
convert a string column to a date column and compare that way.
@@ -812,7 +852,9 @@ def compare_string_and_date_columns(col_1, col_2):
return pd.Series(False, index=col_1.index)
-def get_merged_columns(original_df, merged_df, suffix):
+def get_merged_columns(
+ original_df: pd.DataFrame, merged_df: pd.DataFrame, suffix: str
+) -> List[str]:
"""Gets the columns from an original dataframe, in the new merged dataframe
Parameters
@@ -836,7 +878,7 @@ def get_merged_columns(original_df, merged_df, suffix):
return columns
-def temp_column_name(*dataframes):
+def temp_column_name(*dataframes: pd.DataFrame) -> str:
"""Gets a temp column name that isn't included in columns of any dataframes
Parameters
@@ -861,7 +903,7 @@ def temp_column_name(*dataframes):
return temp_column
-def calculate_max_diff(col_1, col_2):
+def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> float:
"""Get a maximum difference between two columns
Parameters
@@ -877,12 +919,14 @@ def calculate_max_diff(col_1, col_2):
Numeric field, or zero.
"""
try:
- return (col_1.astype(float) - col_2.astype(float)).abs().max()
+ return cast(float, (col_1.astype(float) - col_2.astype(float)).abs().max())
except:
- return 0
+ return 0.0
-def generate_id_within_group(dataframe, join_columns):
+def generate_id_within_group(
+ dataframe: pd.DataFrame, join_columns: List[str]
+) -> "pd.Series[int]":
"""Generate an ID column that can be used to deduplicate identical rows. The series generated
is the order within a unique group, and it handles nulls.
diff --git a/datacompy/fugue.py b/datacompy/fugue.py
index 80038aa2..2ac4889a 100644
--- a/datacompy/fugue.py
+++ b/datacompy/fugue.py
@@ -1,5 +1,5 @@
#
-# Copyright 2023 Capital One Services, LLC
+# Copyright 2024 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@
import logging
import pickle
from collections import defaultdict
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
import fugue.api as fa
import pandas as pd
@@ -35,7 +35,7 @@
HASH_COL = "__datacompy__hash__"
-def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame):
+def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]:
"""Get columns that are unique to df1
Parameters
@@ -53,10 +53,10 @@ def unq_columns(df1: AnyDataFrame, df2: AnyDataFrame):
"""
col1 = fa.get_column_names(df1)
col2 = fa.get_column_names(df2)
- return OrderedSet(col1) - OrderedSet(col2)
+ return cast(OrderedSet[str], OrderedSet(col1) - OrderedSet(col2))
-def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame):
+def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame) -> OrderedSet[str]:
"""Get columns that are shared between the two dataframes
Parameters
@@ -77,7 +77,7 @@ def intersect_columns(df1: AnyDataFrame, df2: AnyDataFrame):
return OrderedSet(col1) & OrderedSet(col2)
-def all_columns_match(df1: AnyDataFrame, df2: AnyDataFrame):
+def all_columns_match(df1: AnyDataFrame, df2: AnyDataFrame) -> bool:
"""Whether the columns all match in the dataframes
Parameters
@@ -302,9 +302,9 @@ def report(
ignore_spaces: bool = False,
ignore_case: bool = False,
cast_column_names_lower: bool = True,
- sample_count=10,
- column_count=10,
- html_file=None,
+ sample_count: int = 10,
+ column_count: int = 10,
+ html_file: Optional[str] = None,
parallelism: Optional[int] = None,
) -> str:
"""Returns a string representation of a report. The representation can
@@ -320,7 +320,7 @@ def report(
First dataframe to check
df2 : ``AnyDataFrame``
Second dataframe to check
- join_columns : list or str, optional
+ join_columns : list or str
Column(s) to join dataframes on. If a string is passed in, that one
column will be used.
abs_tol : float, optional
@@ -406,7 +406,7 @@ def report(
def shape0(col: str) -> int:
return sum(x[col][0] for x in res)
- def shape1(col: str) -> int:
+ def shape1(col: str) -> Any:
return first[col][1]
def _sum(col: str) -> int:
@@ -454,6 +454,8 @@ def _any(col: str) -> int:
"Yes" if _any("_any_dupes") else "No",
)
+ column_stats: List[Dict[str, Any]]
+ match_sample: List[pd.DataFrame]
column_stats, match_sample = _aggregate_stats(res, sample_count=sample_count)
any_mismatch = len(match_sample) > 0
@@ -673,7 +675,10 @@ def _deserialize(
) -> pd.DataFrame:
arr = [pickle.loads(r["data"]) for r in df if r["left"] == left]
if len(arr) > 0:
- return pd.concat(arr).sort_values(schema.names).reset_index(drop=True)
+ return cast(
+ pd.DataFrame,
+ pd.concat(arr).sort_values(schema.names).reset_index(drop=True),
+ )
# The following is how to construct an empty pandas dataframe with
# the correct schema, it avoids pandas schema inference which is wrong.
# This is not needed when upgrading to Fugue >= 0.8.7
@@ -772,7 +777,7 @@ def _get_compare_result(
def _aggregate_stats(
- compares, sample_count
+ compares: List[Any], sample_count: int
) -> Tuple[List[Dict[str, Any]], List[pd.DataFrame]]:
samples = defaultdict(list)
stats = []
@@ -798,9 +803,16 @@ def _aggregate_stats(
)
.reset_index(drop=False)
)
- return df.to_dict(orient="records"), [
- _sample(pd.concat(v), sample_count=sample_count) for v in samples.values()
- ]
+ return cast(
+ Tuple[List[Dict[str, Any]], List[pd.DataFrame]],
+ (
+ df.to_dict(orient="records"),
+ [
+ _sample(pd.concat(v), sample_count=sample_count)
+ for v in samples.values()
+ ],
+ ),
+ )
def _sample(df: pd.DataFrame, sample_count: int) -> pd.DataFrame:
diff --git a/datacompy/polars.py b/datacompy/polars.py
new file mode 100644
index 00000000..814a7cd6
--- /dev/null
+++ b/datacompy/polars.py
@@ -0,0 +1,984 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Compare two Polars DataFrames
+
+Originally this package was meant to provide similar functionality to
+PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
+two dataframes.
+"""
+import logging
+import os
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Union, cast
+
+import numpy as np
+from ordered_set import OrderedSet
+
+from datacompy.base import BaseCompare
+
+try:
+ import polars as pl
+ from polars.exceptions import ComputeError, InvalidOperationError
+except ImportError:
+ pass # Let non-Polars people at least enjoy the loveliness of the pandas datacompy functionality
+
+LOG = logging.getLogger(__name__)
+
+STRING_TYPE = ["String", "Utf8"]
+DATE_TYPE = ["Date", "Datetime"]
+
+
+class PolarsCompare(BaseCompare):
+ """Comparison class to be used to compare whether two dataframes as equal.
+
+ Both df1 and df2 should be dataframes containing all of the join_columns,
+ with unique column names. Differences between values are compared to
+ abs_tol + rel_tol * abs(df2['value']).
+
+ Parameters
+ ----------
+ df1 : Polars ``DataFrame``
+ First dataframe to check
+ df2 : Polars ``DataFrame``
+ Second dataframe to check
+ join_columns : list or str
+ Column(s) to join dataframes on. If a string is passed in, that one
+ column will be used.
+ abs_tol : float, optional
+ Absolute tolerance between two values.
+ rel_tol : float, optional
+ Relative tolerance between two values.
+ df1_name : str, optional
+ A string name for the first dataframe. This allows the reporting to
+ print out an actual name instead of "df1", and allows human users to
+ more easily track the dataframes.
+ df2_name : str, optional
+ A string name for the second dataframe
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns (including any join
+ columns)
+ ignore_case : bool, optional
+ Flag to ignore the case of string columns
+ cast_column_names_lower: bool, optional
+ Boolean indicator that controls of column names will be cast into lower case
+
+ Attributes
+ ----------
+ df1_unq_rows : Polars ``DataFrame``
+ All records that are only in df1 (based on a join on join_columns)
+ df2_unq_rows : Polars ``DataFrame``
+ All records that are only in df2 (based on a join on join_columns)
+ """
+
+ def __init__(
+ self,
+ df1: "pl.DataFrame",
+ df2: "pl.DataFrame",
+ join_columns: Union[List[str], str],
+ abs_tol: float = 0,
+ rel_tol: float = 0,
+ df1_name: str = "df1",
+ df2_name: str = "df2",
+ ignore_spaces: bool = False,
+ ignore_case: bool = False,
+ cast_column_names_lower: bool = True,
+ ) -> None:
+ self.cast_column_names_lower = cast_column_names_lower
+
+ if isinstance(join_columns, str):
+ self.join_columns = [
+ str(join_columns).lower()
+ if self.cast_column_names_lower
+ else str(join_columns)
+ ]
+ elif isinstance(join_columns, list):
+ self.join_columns = [
+ str(col).lower() if self.cast_column_names_lower else str(col)
+ for col in join_columns
+ ]
+ else:
+ raise TypeError(f"{join_columns} must be a string or list of string(s)")
+
+ self._any_dupes: bool = False
+ self.df1 = df1
+ self.df2 = df2
+ self.df1_name = df1_name
+ self.df2_name = df2_name
+ self.abs_tol = abs_tol
+ self.rel_tol = rel_tol
+ self.ignore_spaces = ignore_spaces
+ self.ignore_case = ignore_case
+ self.df1_unq_rows: "pl.DataFrame"
+ self.df2_unq_rows: "pl.DataFrame"
+ self.intersect_rows: "pl.DataFrame"
+ self.column_stats: List[Dict[str, Any]] = []
+ self._compare(ignore_spaces=ignore_spaces, ignore_case=ignore_case)
+
+ @property
+ def df1(self) -> "pl.DataFrame":
+ return self._df1
+
+ @df1.setter
+ def df1(self, df1: "pl.DataFrame") -> None:
+ """Check that it is a dataframe and has the join columns"""
+ self._df1 = df1
+ self._validate_dataframe(
+ "df1", cast_column_names_lower=self.cast_column_names_lower
+ )
+
+ @property
+ def df2(self) -> "pl.DataFrame":
+ return self._df2
+
+ @df2.setter
+ def df2(self, df2: "pl.DataFrame") -> None:
+ """Check that it is a dataframe and has the join columns"""
+ self._df2 = df2
+ self._validate_dataframe(
+ "df2", cast_column_names_lower=self.cast_column_names_lower
+ )
+
+ def _validate_dataframe(
+ self, index: str, cast_column_names_lower: bool = True
+ ) -> None:
+ """Check that it is a dataframe and has the join columns
+
+ Parameters
+ ----------
+ index : str
+ The "index" of the dataframe - df1 or df2.
+ cast_column_names_lower: bool, optional
+ Boolean indicator that controls of column names will be cast into lower case
+ """
+ dataframe = getattr(self, index)
+ if not isinstance(dataframe, pl.DataFrame):
+ raise TypeError(f"{index} must be a Polars DataFrame")
+
+ if cast_column_names_lower:
+ dataframe.columns = [str(col).lower() for col in dataframe.columns]
+
+ # Check if join_columns are present in the dataframe
+ if not set(self.join_columns).issubset(set(dataframe.columns)):
+ raise ValueError(f"{index} must have all columns from join_columns")
+
+ if len(set(dataframe.columns)) < len(dataframe.columns):
+ raise ValueError(f"{index} must have unique column names")
+
+ if len(dataframe.unique(subset=self.join_columns)) < len(dataframe):
+ self._any_dupes = True
+
+ def _compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
+ """Actually run the comparison. This tries to run df1.equals(df2)
+ first so that if they're truly equal we can tell.
+
+ This method will log out information about what is different between
+ the two dataframes, and will also return a boolean.
+ """
+ LOG.debug("Checking equality")
+ if self.df1.equals(self.df2):
+ LOG.info("df1 Polars.DataFrame.equals df2")
+ else:
+ LOG.info("df1 does not Polars.DataFrame.equals df2")
+ LOG.info(f"Number of columns in common: {len(self.intersect_columns())}")
+ LOG.debug("Checking column overlap")
+ for col in self.df1_unq_columns():
+ LOG.info(f"Column in df1 and not in df2: {col}")
+ LOG.info(
+ f"Number of columns in df1 and not in df2: {len(self.df1_unq_columns())}"
+ )
+ for col in self.df2_unq_columns():
+ LOG.info(f"Column in df2 and not in df1: {col}")
+ LOG.info(
+ f"Number of columns in df2 and not in df1: {len(self.df2_unq_columns())}"
+ )
+ LOG.debug("Merging dataframes")
+ self._dataframe_merge(ignore_spaces)
+ self._intersect_compare(ignore_spaces, ignore_case)
+ if self.matches():
+ LOG.info("df1 matches df2")
+ else:
+ LOG.info("df1 does not match df2")
+
+ def df1_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df1"""
+ return cast(
+ OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
+ )
+
+ def df2_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df2"""
+ return cast(
+ OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
+ )
+
+ def intersect_columns(self) -> OrderedSet[str]:
+ """Get columns that are shared between the two dataframes"""
+ return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns)
+
+ def _dataframe_merge(self, ignore_spaces: bool) -> None:
+ """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1
+ and df1 & df2
+ """
+ params: Dict[str, Any]
+ LOG.debug("Outer joining")
+
+ df1 = self.df1.clone()
+ df2 = self.df2.clone()
+ temp_join_columns = deepcopy(self.join_columns)
+
+ if self._any_dupes:
+ LOG.debug("Duplicate rows found, deduping by order of remaining fields")
+ # Create order column for uniqueness of match
+ order_column = temp_column_name(df1, df2)
+ df1 = df1.with_columns(
+ generate_id_within_group(df1, temp_join_columns).alias(order_column)
+ )
+ df2 = df2.with_columns(
+ generate_id_within_group(df2, temp_join_columns).alias(order_column)
+ )
+ temp_join_columns.append(order_column)
+
+ params = {"on": temp_join_columns}
+
+ if ignore_spaces:
+ for column in self.join_columns:
+ if str(df1[column].dtype) in STRING_TYPE:
+ df1 = df1.with_columns(pl.col(column).str.strip_chars())
+ if str(df2[column].dtype) in STRING_TYPE:
+ df2 = df2.with_columns(pl.col(column).str.strip_chars())
+
+ df1_non_join_columns = OrderedSet(df1.columns) - OrderedSet(temp_join_columns)
+ df2_non_join_columns = OrderedSet(df2.columns) - OrderedSet(temp_join_columns)
+
+ for c in df1_non_join_columns:
+ df1 = df1.rename({c: c + "_df1"})
+ for c in df2_non_join_columns:
+ df2 = df2.rename({c: c + "_df2"})
+
+ # generate merge indicator
+ df1 = df1.with_columns(_merge_left=pl.lit(True))
+ df2 = df2.with_columns(_merge_right=pl.lit(True))
+
+ outer_join = df1.join(df2, how="outer_coalesce", join_nulls=True, **params)
+
+ # process merge indicator
+ outer_join = outer_join.with_columns(
+ pl.when((pl.col("_merge_left") == True) & (pl.col("_merge_right") == True))
+ .then(pl.lit("both"))
+ .when((pl.col("_merge_left") == True) & (pl.col("_merge_right").is_null()))
+ .then(pl.lit("left_only"))
+ .when((pl.col("_merge_left").is_null()) & (pl.col("_merge_right") == True))
+ .then(pl.lit("right_only"))
+ .alias("_merge")
+ )
+
+ # Clean up temp columns for duplicate row matching
+ if self._any_dupes:
+ outer_join = outer_join.drop(order_column)
+
+ df1_cols = get_merged_columns(self.df1, outer_join, "_df1")
+ df2_cols = get_merged_columns(self.df2, outer_join, "_df2")
+
+ LOG.debug("Selecting df1 unique rows")
+ self.df1_unq_rows = outer_join.filter(
+ outer_join["_merge"] == "left_only"
+ ).select(df1_cols)
+ self.df1_unq_rows.columns = self.df1.columns
+
+ LOG.debug("Selecting df2 unique rows")
+ self.df2_unq_rows = outer_join.filter(
+ outer_join["_merge"] == "right_only"
+ ).select(df2_cols)
+ self.df2_unq_rows.columns = self.df2.columns
+
+ LOG.info(f"Number of rows in df1 and not in df2: {len(self.df1_unq_rows)}")
+ LOG.info(f"Number of rows in df2 and not in df1: {len(self.df2_unq_rows)}")
+
+ LOG.debug("Selecting intersecting rows")
+ self.intersect_rows = outer_join.filter(outer_join["_merge"] == "both")
+ LOG.info(
+ f"Number of rows in df1 and df2 (not necessarily equal): {len(self.intersect_rows)}"
+ )
+
+ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
+ """Run the comparison on the intersect dataframe
+
+ This loops through all columns that are shared between df1 and df2, and
+ creates a column column_match which is True for matches, False
+ otherwise.
+ """
+ match_cnt: Union[int, float]
+ null_diff: Union[int, float]
+
+ LOG.debug("Comparing intersection")
+ row_cnt = len(self.intersect_rows)
+ for column in self.intersect_columns():
+ if column in self.join_columns:
+ match_cnt = row_cnt
+ col_match = ""
+ max_diff = 0.0
+ null_diff = 0
+ else:
+ col_1 = column + "_df1"
+ col_2 = column + "_df2"
+ col_match = column + "_match"
+ self.intersect_rows = self.intersect_rows.with_columns(
+ columns_equal(
+ self.intersect_rows[col_1],
+ self.intersect_rows[col_2],
+ self.rel_tol,
+ self.abs_tol,
+ ignore_spaces,
+ ignore_case,
+ ).alias(col_match)
+ )
+ match_cnt = self.intersect_rows[col_match].sum()
+ max_diff = calculate_max_diff(
+ self.intersect_rows[col_1], self.intersect_rows[col_2]
+ )
+ null_diff = (
+ (self.intersect_rows[col_1].is_null())
+ ^ (self.intersect_rows[col_2].is_null())
+ ).sum()
+ if row_cnt > 0:
+ match_rate = float(match_cnt) / row_cnt
+ else:
+ match_rate = 0
+ LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match")
+
+ self.column_stats.append(
+ {
+ "column": column,
+ "match_column": col_match,
+ "match_cnt": match_cnt,
+ "unequal_cnt": row_cnt - match_cnt,
+ "dtype1": str(self.df1[column].dtype),
+ "dtype2": str(self.df2[column].dtype),
+ "all_match": all(
+ (
+ self.df1[column].dtype == self.df2[column].dtype,
+ row_cnt == match_cnt,
+ )
+ ),
+ "max_diff": max_diff,
+ "null_diff": null_diff,
+ }
+ )
+
+ def all_columns_match(self) -> bool:
+ """Whether the columns all match in the dataframes"""
+ return self.df1_unq_columns() == self.df2_unq_columns() == set()
+
+ def all_rows_overlap(self) -> bool:
+ """Whether the rows are all present in both dataframes
+
+ Returns
+ -------
+ bool
+ True if all rows in df1 are in df2 and vice versa (based on
+ existence for join option)
+ """
+ return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0
+
+ def count_matching_rows(self) -> int:
+ """Count the number of rows match (on overlapping fields)
+
+ Returns
+ -------
+ int
+ Number of matching rows
+ """
+ match_columns = []
+ for column in self.intersect_columns():
+ if column not in self.join_columns:
+ match_columns.append(column + "_match")
+
+ if len(match_columns) > 0:
+ return int(
+ self.intersect_rows[match_columns]
+ .select(pl.all_horizontal(match_columns).alias("__sum"))
+ .sum()
+ .item()
+ )
+ else:
+ # corner case where it is just the join columns that make the dataframes
+ if len(self.intersect_rows) > 0:
+ return len(self.intersect_rows)
+ else:
+ return 0
+
+ def intersect_rows_match(self) -> bool:
+ """Check whether the intersect rows all match"""
+ actual_length = self.intersect_rows.shape[0]
+ return self.count_matching_rows() == actual_length
+
+ def matches(self, ignore_extra_columns: bool = False) -> bool:
+ """Return True or False if the dataframes match.
+
+ Parameters
+ ----------
+ ignore_extra_columns : bool
+ Ignores any columns in one dataframe and not in the other.
+
+ Returns
+ -------
+ bool
+ True or False if the dataframes match.
+ """
+ if not ignore_extra_columns and not self.all_columns_match():
+ return False
+ elif not self.all_rows_overlap():
+ return False
+ elif not self.intersect_rows_match():
+ return False
+ else:
+ return True
+
+ def subset(self) -> bool:
+ """Return True if dataframe 2 is a subset of dataframe 1.
+
+ Dataframe 2 is considered a subset if all of its columns are in
+ dataframe 1, and all of its rows match rows in dataframe 1 for the
+ shared columns.
+
+ Returns
+ -------
+ bool
+ True if dataframe 2 is a subset of dataframe 1.
+ """
+ if not self.df2_unq_columns() == set():
+ return False
+ elif not len(self.df2_unq_rows) == 0:
+ return False
+ elif not self.intersect_rows_match():
+ return False
+ else:
+ return True
+
+ def sample_mismatch(
+ self, column: str, sample_count: int = 10, for_display: bool = False
+ ) -> "pl.DataFrame":
+ """Returns a sample sub-dataframe which contains the identifying
+ columns, and df1 and df2 versions of the column.
+
+ Parameters
+ ----------
+ column : str
+ The raw column name (i.e. without ``_df1`` appended)
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+ for_display : bool, optional
+ Whether this is just going to be used for display (overwrite the
+ column names)
+
+ Returns
+ -------
+ Polars.DataFrame
+ A sample of the intersection dataframe, containing only the
+ "pertinent" columns, for rows that don't match on the provided
+ column.
+ """
+ row_cnt = self.intersect_rows.shape[0]
+ col_match = self.intersect_rows[column + "_match"]
+ match_cnt = col_match.sum()
+ sample_count = min(sample_count, row_cnt - match_cnt) # type: ignore
+ sample = self.intersect_rows.filter(pl.col(column + "_match") != True).sample(
+ sample_count
+ )
+ return_cols = self.join_columns + [column + "_df1", column + "_df2"]
+ to_return = sample[return_cols]
+ if for_display:
+ to_return.columns = self.join_columns + [
+ column + " (" + self.df1_name + ")",
+ column + " (" + self.df2_name + ")",
+ ]
+ return to_return
+
+ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
+ """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join
+ columns.
+
+ Parameters
+ ----------
+ ignore_matching_cols : bool, optional
+ Whether showing the matching columns in the output or not. The default is False.
+
+ Returns
+ -------
+ Polars.DataFrame
+ All rows of the intersection dataframe, containing any columns, that don't match.
+ """
+ match_list = []
+ return_list = []
+ for col in self.intersect_rows.columns:
+ if col.endswith("_match"):
+ orig_col_name = col[:-6]
+
+ col_comparison = columns_equal(
+ self.intersect_rows[orig_col_name + "_df1"],
+ self.intersect_rows[orig_col_name + "_df2"],
+ self.rel_tol,
+ self.abs_tol,
+ self.ignore_spaces,
+ self.ignore_case,
+ )
+
+ if not ignore_matching_cols or (
+ ignore_matching_cols and not col_comparison.all()
+ ):
+ LOG.debug(f"Adding column {orig_col_name} to the result.")
+ match_list.append(col)
+ return_list.extend([orig_col_name + "_df1", orig_col_name + "_df2"])
+ elif ignore_matching_cols:
+ LOG.debug(
+ f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
+ )
+ return (
+ self.intersect_rows.with_columns(__all=pl.all_horizontal(match_list))
+ .filter(pl.col("__all") != True)
+ .select(self.join_columns + return_list)
+ )
+
+ def report(
+ self,
+ sample_count: int = 10,
+ column_count: int = 10,
+ html_file: Optional[str] = None,
+ ) -> str:
+ """Returns a string representation of a report. The representation can
+ then be printed or saved to a file.
+
+ Parameters
+ ----------
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+
+ column_count : int, optional
+ The number of columns to display in the sample records output. Defaults to 10.
+
+ html_file : str, optional
+ HTML file name to save report output to. If ``None`` the file creation will be skipped.
+
+ Returns
+ -------
+ str
+ The report, formatted kinda nicely.
+ """
+
+ def df_to_str(pdf: "pl.DataFrame") -> str:
+ return pdf.to_pandas().to_string()
+
+ # Header
+ report = render("header.txt")
+ df_header = pl.DataFrame(
+ {
+ "DataFrame": [self.df1_name, self.df2_name],
+ "Columns": [self.df1.shape[1], self.df2.shape[1]],
+ "Rows": [self.df1.shape[0], self.df2.shape[0]],
+ }
+ )
+ report += df_to_str(df_header[["DataFrame", "Columns", "Rows"]])
+ report += "\n\n"
+
+ # Column Summary
+ report += render(
+ "column_summary.txt",
+ len(self.intersect_columns()),
+ len(self.df1_unq_columns()),
+ len(self.df2_unq_columns()),
+ self.df1_name,
+ self.df2_name,
+ )
+
+ # Row Summary
+ match_on = ", ".join(self.join_columns)
+ report += render(
+ "row_summary.txt",
+ match_on,
+ self.abs_tol,
+ self.rel_tol,
+ self.intersect_rows.shape[0],
+ self.df1_unq_rows.shape[0],
+ self.df2_unq_rows.shape[0],
+ self.intersect_rows.shape[0] - self.count_matching_rows(),
+ self.count_matching_rows(),
+ self.df1_name,
+ self.df2_name,
+ "Yes" if self._any_dupes else "No",
+ )
+
+ # Column Matching
+ cnt_intersect = self.intersect_rows.shape[0]
+ report += render(
+ "column_comparison.txt",
+ len([col for col in self.column_stats if col["unequal_cnt"] > 0]),
+ len([col for col in self.column_stats if col["unequal_cnt"] == 0]),
+ sum([col["unequal_cnt"] for col in self.column_stats]),
+ )
+
+ match_stats = []
+ match_sample = []
+ any_mismatch = False
+ for column in self.column_stats:
+ if not column["all_match"]:
+ any_mismatch = True
+ match_stats.append(
+ {
+ "Column": column["column"],
+ f"{self.df1_name} dtype": column["dtype1"],
+ f"{self.df2_name} dtype": column["dtype2"],
+ "# Unequal": column["unequal_cnt"],
+ "Max Diff": column["max_diff"],
+ "# Null Diff": column["null_diff"],
+ }
+ )
+ if column["unequal_cnt"] > 0:
+ match_sample.append(
+ self.sample_mismatch(
+ column["column"], sample_count, for_display=True
+ )
+ )
+
+ if any_mismatch:
+ report += "Columns with Unequal Values or Types\n"
+ report += "------------------------------------\n"
+ report += "\n"
+ df_match_stats = pl.DataFrame(match_stats)
+ df_match_stats = df_match_stats.sort("Column")
+ # Have to specify again for sorting
+ report += (
+ df_match_stats[
+ [
+ "Column",
+ f"{self.df1_name} dtype",
+ f"{self.df2_name} dtype",
+ "# Unequal",
+ "Max Diff",
+ "# Null Diff",
+ ]
+ ]
+ .to_pandas()
+ .to_string()
+ )
+ report += "\n\n"
+
+ if sample_count > 0:
+ report += "Sample Rows with Unequal Values\n"
+ report += "-------------------------------\n"
+ report += "\n"
+ for sample in match_sample:
+ report += df_to_str(sample)
+ report += "\n\n"
+
+ if min(sample_count, self.df1_unq_rows.shape[0]) > 0:
+ report += (
+ f"Sample Rows Only in {self.df1_name} (First {column_count} Columns)\n"
+ )
+ report += (
+ f"---------------------------------------{'-' * len(self.df1_name)}\n"
+ )
+ report += "\n"
+ columns = self.df1_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df1_unq_rows.shape[0])
+ report += df_to_str(self.df1_unq_rows.sample(unq_count)[columns])
+ report += "\n\n"
+
+ if min(sample_count, self.df2_unq_rows.shape[0]) > 0:
+ report += (
+ f"Sample Rows Only in {self.df2_name} (First {column_count} Columns)\n"
+ )
+ report += (
+ f"---------------------------------------{'-' * len(self.df2_name)}\n"
+ )
+ report += "\n"
+ columns = self.df2_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df2_unq_rows.shape[0])
+ report += df_to_str(self.df2_unq_rows.sample(unq_count)[columns])
+ report += "\n\n"
+
+ if html_file:
+ html_report = report.replace("\n", "
").replace(" ", " ")
+ html_report = f"
{html_report}" + with open(html_file, "w") as f: + f.write(html_report) + + return report + + +def render(filename: str, *fields: Union[int, float, str]) -> str: + """Renders out an individual template. This basically just reads in a + template file, and applies ``.format()`` on the fields. + + Parameters + ---------- + filename : str + The file that contains the template. Will automagically prepend the + templates directory before opening + fields : list + Fields to be rendered out in the template + + Returns + ------- + str + The fully rendered out file. + """ + this_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(this_dir, "templates", filename)) as file_open: + return file_open.read().format(*fields) + + +def columns_equal( + col_1: "pl.Series", + col_2: "pl.Series", + rel_tol: float = 0, + abs_tol: float = 0, + ignore_spaces: bool = False, + ignore_case: bool = False, +) -> "pl.Series": + """Compares two columns from a dataframe, returning a True/False series, + with the same index as column 1. + + - Two nulls (np.nan) will evaluate to True. + - A null and a non-null value will evaluate to False. + - Numeric values will use the relative and absolute tolerances. + - Decimal values (decimal.Decimal) will attempt to be converted to floats + before comparing + - Non-numeric values (i.e. where np.isclose can't be used) will just + trigger True on two nulls or exact matches. + + Parameters + ---------- + col_1 : Polars.Series + The first column to look at + col_2 : Polars.Series + The second column + rel_tol : float, optional + Relative tolerance + abs_tol : float, optional + Absolute tolerance + ignore_spaces : bool, optional + Flag to strip whitespace (including newlines) from string columns + ignore_case : bool, optional + Flag to ignore the case of string columns + + Returns + ------- + Polars.Series + A series of Boolean values. True == the values match, False == the + values don't match. + """ + compare: pl.Series + try: + compare = pl.Series( + np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True) + ) + except TypeError: + try: + if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE: + raise TypeError("Found date, moving to alternative logic") + + compare = pl.Series( + np.isclose( + col_1.cast(pl.Float64, strict=True), + col_2.cast(pl.Float64, strict=True), + rtol=rel_tol, + atol=abs_tol, + equal_nan=True, + ) + ) + except (ValueError, TypeError, InvalidOperationError, ComputeError): + try: + if ignore_spaces: + if str(col_1.dtype) in STRING_TYPE: + col_1 = col_1.str.strip_chars() + if str(col_2.dtype) in STRING_TYPE: + col_2 = col_2.str.strip_chars() + + if ignore_case: + if str(col_1.dtype) in STRING_TYPE: + col_1 = col_1.str.to_uppercase() + if str(col_2.dtype) in STRING_TYPE: + col_2 = col_2.str.to_uppercase() + + if ( + str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE + ) or ( + str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE + ): + compare = compare_string_and_date_columns(col_1, col_2) + else: + compare = pl.Series( + (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null()) + ) + except: + # Blanket exception should just return all False + compare = pl.Series(False * col_1.shape[0]) + return compare + + +def compare_string_and_date_columns( + col_1: "pl.Series", col_2: "pl.Series" +) -> "pl.Series": + """Compare a string column and date column, value-wise. This tries to + convert a string column to a date column and compare that way. + + Parameters + ---------- + col_1 : Polars.Series + The first column to look at + col_2 : Polars.Series + The second column + + Returns + ------- + Polars.Series + A series of Boolean values. True == the values match, False == the + values don't match. + """ + if str(col_1.dtype) in STRING_TYPE: + str_column = col_1 + date_column = col_2 + else: + str_column = col_2 + date_column = col_1 + + try: # datetime is inferred + return pl.Series( + (str_column.str.to_datetime().eq_missing(date_column)) + | (str_column.is_null() & date_column.is_null()) + ) + except: + return pl.Series([False] * col_1.shape[0]) + + +def get_merged_columns( + original_df: "pl.DataFrame", merged_df: "pl.DataFrame", suffix: str +) -> List[str]: + """Gets the columns from an original dataframe, in the new merged dataframe + + Parameters + ---------- + original_df : Polars.DataFrame + The original, pre-merge dataframe + merged_df : Polars.DataFrame + Post-merge with another dataframe, with suffixes added in. + suffix : str + What suffix was used to distinguish when the original dataframe was + overlapping with the other merged dataframe. + """ + columns = [] + for col in original_df.columns: + if col in merged_df.columns: + columns.append(col) + elif col + suffix in merged_df.columns: + columns.append(col + suffix) + else: + raise ValueError("Column not found: %s", col) + return columns + + +def temp_column_name(*dataframes: "pl.DataFrame") -> str: + """Gets a temp column name that isn't included in columns of any dataframes + + Parameters + ---------- + dataframes : list of Polars.DataFrame + The DataFrames to create a temporary column name for + + Returns + ------- + str + String column name that looks like '_temp_x' for some integer x + """ + i = 0 + while True: + temp_column = f"_temp_{i}" + unique = True + for dataframe in dataframes: + if temp_column in dataframe.columns: + i += 1 + unique = False + if unique: + return temp_column + + +def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float: + """Get a maximum difference between two columns + + Parameters + ---------- + col_1 : Polars.Series + The first column + col_2 : Polars.Series + The second column + + Returns + ------- + Numeric + Numeric field, or zero. + """ + try: + return cast( + float, (col_1.cast(pl.Float64) - col_2.cast(pl.Float64)).abs().max() + ) + except: + return 0.0 + + +def generate_id_within_group( + dataframe: "pl.DataFrame", join_columns: List[str] +) -> "pl.Series": + """Generate an ID column that can be used to deduplicate identical rows. The series generated + is the order within a unique group, and it handles nulls. + + Parameters + ---------- + dataframe : Polars.DataFrame + The dataframe to operate on + join_columns : list + List of strings which are the join columns + + Returns + ------- + Polars.Series + The ID column that's unique in each group. + """ + default_value = "DATACOMPY_NULL" + if ( + dataframe.select(pl.any_horizontal(pl.col(join_columns).is_null())) + .to_series() + .any() + ): + if ( + dataframe.select( + pl.any_horizontal(pl.col(join_columns).cast(pl.String) == default_value) + ) + .to_series() + .any() + ): + raise ValueError(f"{default_value} was found in your join columns") + return ( + dataframe[join_columns] + .cast(pl.String) + .fill_null(default_value) + .select(rn=pl.col(dataframe.columns[0]).cum_count().over(join_columns)) + .to_series() + ) + else: + return dataframe.select( + rn=pl.col(dataframe.columns[0]).cum_count().over(join_columns) + ).to_series() diff --git a/datacompy/py.typed b/datacompy/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/datacompy/spark.py b/datacompy/spark.py index a285036b..45fe3419 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -1,5 +1,5 @@ # -# Copyright 2020 Capital One Services, LLC +# Copyright 2024 Capital One Services, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - import sys from enum import Enum from itertools import chain +from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union try: + import pyspark from pyspark.sql import functions as F except ImportError: pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality @@ -29,9 +30,9 @@ class MatchType(Enum): # Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal". -def decimal_comparator(): +def decimal_comparator() -> str: class DecimalComparator(str): - def __eq__(self, other): + def __eq__(self, other: str) -> bool: # type: ignore[override] return len(other) >= 7 and other[0:7] == "decimal" return DecimalComparator("decimal") @@ -48,7 +49,7 @@ def __eq__(self, other): ] -def _is_comparable(type1, type2): +def _is_comparable(type1: str, type2: str) -> bool: """Checks if two Spark data types can be safely compared. Two data types are considered comparable if any of the following apply: 1. Both data types are the same @@ -141,17 +142,17 @@ class SparkCompare: def __init__( self, - spark_session, - base_df, - compare_df, - join_columns, - column_mapping=None, - cache_intermediates=False, - known_differences=None, - rel_tol=0, - abs_tol=0, - show_all_columns=False, - match_rates=False, + spark_session: "pyspark.sql.SparkSession", + base_df: "pyspark.sql.DataFrame", + compare_df: "pyspark.sql.DataFrame", + join_columns: List[Union[str, Tuple[str, str]]], + column_mapping: Optional[List[Tuple[str, str]]] = None, + cache_intermediates: bool = False, + known_differences: Optional[List[Dict[str, Any]]] = None, + rel_tol: float = 0, + abs_tol: float = 0, + show_all_columns: bool = False, + match_rates: bool = False, ): self.rel_tol = rel_tol self.abs_tol = abs_tol @@ -164,7 +165,7 @@ def __init__( self._original_compare_df = compare_df self.cache_intermediates = cache_intermediates - self.join_columns = self._tuplizer(join_columns) + self.join_columns = self._tuplizer(input_list=join_columns) self._join_column_names = [name[0] for name in self.join_columns] self._known_differences = known_differences @@ -182,13 +183,15 @@ def __init__( self.spark = spark_session self.base_unq_rows = self.compare_unq_rows = None - self._base_row_count = self._compare_row_count = self._common_row_count = None - self._joined_dataframe = None - self._rows_only_base = None - self._rows_only_compare = None - self._all_matched_rows = None - self._all_rows_mismatched = None - self.columns_match_dict = {} + self._base_row_count: Optional[int] = None + self._compare_row_count: Optional[int] = None + self._common_row_count: Optional[int] = None + self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None + self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None + self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None + self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None + self.columns_match_dict: Dict[str, Any] = {} # drop the duplicates before actual comparison made. self.base_df = base_df.dropDuplicates(self._join_column_names) @@ -200,8 +203,10 @@ def __init__( self.compare_df.cache() self._compare_row_count = self.compare_df.count() - def _tuplizer(self, input_list): - join_columns = [] + def _tuplizer( + self, input_list: List[Union[str, Tuple[str, str]]] + ) -> List[Tuple[str, str]]: + join_columns: List[Tuple[str, str]] = [] for val in input_list: if isinstance(val, str): join_columns.append((val, val)) @@ -211,12 +216,12 @@ def _tuplizer(self, input_list): return join_columns @property - def columns_in_both(self): + def columns_in_both(self) -> Set[str]: """set[str]: Get columns in both dataframes""" return set(self.base_df.columns) & set(self.compare_df.columns) @property - def columns_compared(self): + def columns_compared(self) -> List[str]: """list[str]: Get columns to be compared in both dataframes (all columns in both excluding the join key(s)""" return [ @@ -226,17 +231,17 @@ def columns_compared(self): ] @property - def columns_only_base(self): + def columns_only_base(self) -> Set[str]: """set[str]: Get columns that are unique to the base dataframe""" return set(self.base_df.columns) - set(self.compare_df.columns) @property - def columns_only_compare(self): + def columns_only_compare(self) -> Set[str]: """set[str]: Get columns that are unique to the compare dataframe""" return set(self.compare_df.columns) - set(self.base_df.columns) @property - def base_row_count(self): + def base_row_count(self) -> int: """int: Get the count of rows in the de-duped base dataframe""" if self._base_row_count is None: self._base_row_count = self.base_df.count() @@ -244,7 +249,7 @@ def base_row_count(self): return self._base_row_count @property - def compare_row_count(self): + def compare_row_count(self) -> int: """int: Get the count of rows in the de-duped compare dataframe""" if self._compare_row_count is None: self._compare_row_count = self.compare_df.count() @@ -252,7 +257,7 @@ def compare_row_count(self): return self._compare_row_count @property - def common_row_count(self): + def common_row_count(self) -> int: """int: Get the count of rows in common between base and compare dataframes""" if self._common_row_count is None: common_rows = self._get_or_create_joined_dataframe() @@ -260,19 +265,19 @@ def common_row_count(self): return self._common_row_count - def _get_unq_base_rows(self): + def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame": """Get the rows only from base data frame""" return self.base_df.select(self._join_column_names).subtract( self.compare_df.select(self._join_column_names) ) - def _get_compare_rows(self): + def _get_compare_rows(self) -> "pyspark.sql.DataFrame": """Get the rows only from compare data frame""" return self.compare_df.select(self._join_column_names).subtract( self.base_df.select(self._join_column_names) ) - def _print_columns_summary(self, myfile): + def _print_columns_summary(self, myfile: TextIO) -> None: """Prints the column summary details""" print("\n****** Column Summary ******", file=myfile) print( @@ -292,7 +297,7 @@ def _print_columns_summary(self, myfile): file=myfile, ) - def _print_only_columns(self, base_or_compare, myfile): + def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None: """Prints the columns and data types only in either the base or compare datasets""" if base_or_compare.upper() == "BASE": @@ -321,7 +326,7 @@ def _print_only_columns(self, base_or_compare, myfile): col_type = df.select(column).dtypes[0][1] print((format_pattern + " {:13s}").format(column, col_type), file=myfile) - def _columns_with_matching_schema(self): + def _columns_with_matching_schema(self) -> Dict[str, str]: """This function will identify the columns which has matching schema""" col_schema_match = {} base_columns_dict = dict(self.base_df.dtypes) @@ -329,12 +334,13 @@ def _columns_with_matching_schema(self): for base_row, base_type in base_columns_dict.items(): if base_row in compare_columns_dict: - if base_type in compare_columns_dict.get(base_row): - col_schema_match[base_row] = compare_columns_dict.get(base_row) + compare_column_type = compare_columns_dict.get(base_row) + if compare_column_type is not None and base_type in compare_column_type: + col_schema_match[base_row] = compare_column_type return col_schema_match - def _columns_with_schemadiff(self): + def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]: """This function will identify the columns which has different schema""" col_schema_diff = {} base_columns_dict = dict(self.base_df.dtypes) @@ -342,15 +348,19 @@ def _columns_with_schemadiff(self): for base_row, base_type in base_columns_dict.items(): if base_row in compare_columns_dict: - if base_type not in compare_columns_dict.get(base_row): + compare_column_type = compare_columns_dict.get(base_row) + if ( + compare_column_type is not None + and base_type not in compare_column_type + ): col_schema_diff[base_row] = dict( base_type=base_type, - compare_type=compare_columns_dict.get(base_row), + compare_type=compare_column_type, ) return col_schema_diff @property - def rows_both_mismatch(self): + def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches""" if self._all_rows_mismatched is None: self._merge_dataframes() @@ -358,7 +368,7 @@ def rows_both_mismatch(self): return self._all_rows_mismatched @property - def rows_both_all(self): + def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns all rows in both dataframes""" if self._all_matched_rows is None: self._merge_dataframes() @@ -366,7 +376,7 @@ def rows_both_all(self): return self._all_matched_rows @property - def rows_only_base(self): + def rows_only_base(self) -> "pyspark.sql.DataFrame": """pyspark.sql.DataFrame: Returns rows only in the base dataframe""" if not self._rows_only_base: base_rows = self._get_unq_base_rows() @@ -386,7 +396,7 @@ def rows_only_base(self): return self._rows_only_base @property - def rows_only_compare(self): + def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]: """pyspark.sql.DataFrame: Returns rows only in the compare dataframe""" if not self._rows_only_compare: compare_rows = self._get_compare_rows() @@ -407,7 +417,7 @@ def rows_only_compare(self): return self._rows_only_compare - def _generate_select_statement(self, match_data=True): + def _generate_select_statement(self, match_data: bool = True) -> str: """This function is to generate the select statement to be used later in the query.""" base_only = list(set(self.base_df.columns) - set(self.compare_df.columns)) compare_only = list(set(self.compare_df.columns) - set(self.base_df.columns)) @@ -440,7 +450,7 @@ def _generate_select_statement(self, match_data=True): return select_statement - def _merge_dataframes(self): + def _merge_dataframes(self) -> None: """Merges the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched.""" full_joined_dataframe = self._get_or_create_joined_dataframe() full_joined_dataframe.createOrReplaceTempView("full_matched_table") @@ -449,9 +459,8 @@ def _merge_dataframes(self): select_query = """SELECT {} FROM full_matched_table A""".format( select_statement ) - self._all_matched_rows = self.spark.sql(select_query).orderBy( - self._join_column_names + self._join_column_names # type: ignore[arg-type] ) self._all_matched_rows.createOrReplaceTempView("matched_table") @@ -460,10 +469,10 @@ def _merge_dataframes(self): ) mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond) self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy( - self._join_column_names + self._join_column_names # type: ignore[arg-type] ) - def _get_or_create_joined_dataframe(self): + def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame": if self._joined_dataframe is None: join_condition = " AND ".join( ["A." + name + "<=>B." + name for name in self._join_column_names] @@ -488,7 +497,7 @@ def _get_or_create_joined_dataframe(self): return self._joined_dataframe - def _print_num_of_rows_with_column_equality(self, myfile): + def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None: # match_dataframe contains columns from both dataframes with flag to indicate if columns matched match_dataframe = self._get_or_create_joined_dataframe().select( *self.columns_compared @@ -507,7 +516,10 @@ def _print_num_of_rows_with_column_equality(self, myfile): ) ) all_rows_matched = self.spark.sql(match_query) - matched_rows = all_rows_matched.head()[0] + all_rows_matched_head = all_rows_matched.head() + matched_rows = ( + all_rows_matched_head[0] if all_rows_matched_head is not None else 0 + ) print("\n****** Row Comparison ******", file=myfile) print( @@ -516,7 +528,7 @@ def _print_num_of_rows_with_column_equality(self, myfile): ) print(f"Number of rows with all columns equal: {matched_rows}", file=myfile) - def _populate_columns_match_dict(self): + def _populate_columns_match_dict(self) -> None: """ side effects: columns_match_dict assigned to { column -> match_type_counts } @@ -531,7 +543,7 @@ def _populate_columns_match_dict(self): *self.columns_compared ) - def helper(c): + def helper(c: str) -> "pyspark.sql.Column": # Create a predicate for each match type, comparing column values to the match type value predicates = [F.col(c) == k.value for k in MatchType] # Create a tuple(number of match types found for each match type in this column) @@ -541,15 +553,15 @@ def helper(c): # For each column, create a single tuple. This tuple's values correspond to the number of times # each match type appears in that column - match_data = match_dataframe.agg( + match_data_agg = match_dataframe.agg( *[helper(col) for col in self.columns_compared] ).collect() - match_data = match_data[0] + match_data = match_data_agg[0] for c in self.columns_compared: self.columns_match_dict[c] = match_data[c] - def _create_select_statement(self, name): + def _create_select_statement(self, name: str) -> str: if self._known_differences: match_type_comparison = "" for k in MatchType: @@ -568,7 +580,7 @@ def _create_select_statement(self, name): name=name, match_failure=MatchType.MISMATCH.value ) - def _create_case_statement(self, name): + def _create_case_statement(self, name: str) -> str: equal_comparisons = ["(A.{name} IS NULL AND B.{name} IS NULL)"] known_diff_comparisons = ["(FALSE)"] @@ -622,7 +634,7 @@ def _create_case_statement(self, name): match_failure=MatchType.MISMATCH.value, ) - def _print_row_summary(self, myfile): + def _print_row_summary(self, myfile: TextIO) -> None: base_df_cnt = self.base_df.count() compare_df_cnt = self.compare_df.count() base_df_with_dup_cnt = self._original_base_df.count() @@ -647,7 +659,7 @@ def _print_row_summary(self, myfile): file=myfile, ) - def _print_schema_diff_details(self, myfile): + def _print_schema_diff_details(self, myfile: TextIO) -> None: schema_diff_dict = self._columns_with_schemadiff() if not schema_diff_dict: # If there are no differences, don't print the section @@ -691,7 +703,7 @@ def _print_schema_diff_details(self, myfile): file=myfile, ) - def _base_to_compare_name(self, base_name): + def _base_to_compare_name(self, base_name: str) -> str: """Translates a column name in the base dataframe to its counterpart in the compare dataframe, if they are different.""" @@ -703,7 +715,7 @@ def _base_to_compare_name(self, base_name): return name[1] return base_name - def _print_row_matches_by_column(self, myfile): + def _print_row_matches_by_column(self, myfile: TextIO) -> None: self._populate_columns_match_dict() columns_with_mismatches = { key: self.columns_match_dict[key] @@ -852,7 +864,7 @@ def _print_row_matches_by_column(self, myfile): print(format_pattern.format(*output_row), file=myfile) # noinspection PyUnresolvedReferences - def report(self, file=sys.stdout): + def report(self, file: TextIO = sys.stdout) -> None: """Creates a comparison report and prints it to the file specified (stdout by default). diff --git a/docs/source/conf.py b/docs/source/conf.py index 00407433..33c9c03b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,7 +20,7 @@ # -- Project information ----------------------------------------------------- project = "datacompy" -copyright = "2023, Capital One" +copyright = "2024, Capital One" author = "Ian Robertson, Dan Coates, Faisal Dosani" # The full version, including alpha/beta/rc tags diff --git a/docs/source/index.rst b/docs/source/index.rst index bcb00f66..0ac03d6d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,6 +11,7 @@ Contents Installation