Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Spark Connect DataFrames #15

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ RUN apk --update add openjdk8-jre gcc musl-dev bash
ENV JAVA_HOME /usr/

# Hadoop
ENV HADOOP_VERSION 2.7.2
ENV HADOOP_VERSION 3.3.3
ENV HADOOP_HOME /usr/hadoop-$HADOOP_VERSION
ENV HADOOP_CONF_DIR=$HADOOP_HOME/etc/hadoop
ENV PATH $PATH:$HADOOP_HOME/bin
Expand All @@ -14,7 +14,7 @@ RUN wget "http://archive.apache.org/dist/hadoop/common/hadoop-$HADOOP_VERSION/ha
&& rm "hadoop-$HADOOP_VERSION.tar.gz"

# Spark
ENV SPARK_VERSION 2.4.8
ENV SPARK_VERSION 3.3.3
ENV SPARK_PACKAGE spark-$SPARK_VERSION
ENV SPARK_HOME /usr/$SPARK_PACKAGE-bin-without-hadoop
ENV PYSPARK_PYTHON python
Expand Down
4 changes: 2 additions & 2 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
black==20.8b1
pyspark==2.4.7
pyspark[connect]==3.4.0
pytest-testdox==2.0.1
pytest==6.1.1
pytest==7.3.1
63 changes: 47 additions & 16 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,50 +1,81 @@
#
# This file is autogenerated by pip-compile
# To update, run:
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile requirements.in
#
appdirs==1.4.4
# via black
attrs==20.2.0
# via pytest
black==20.8b1
# via -r requirements.in
click==7.1.2
# via black
exceptiongroup==1.2.0
# via pytest
googleapis-common-protos==1.62.0
# via
# grpcio-status
# pyspark
grpcio==1.60.0
# via
# grpcio-status
# pyspark
grpcio-status==1.60.0
# via pyspark
iniconfig==1.0.1
# via pytest
mypy-extensions==0.4.3
# via black
numpy==1.26.2
# via
# pandas
# pyarrow
# pyspark
packaging==20.4
# via pytest
pandas==2.1.4
# via pyspark
pathspec==0.8.0
# via black
pluggy==0.13.1
# via pytest
py4j==0.10.7
protobuf==4.25.1
# via
# googleapis-common-protos
# grpcio-status
py4j==0.10.9.7
# via pyspark
pyarrow==14.0.1
# via pyspark
py==1.10.0
# via pytest
pyparsing==2.4.7
# via packaging
pyspark==2.4.7
# via -r requirements.in
pytest-testdox==2.0.1
# via -r requirements.in
pytest==6.1.1
pyspark[connect]==3.4.0
# via
# -r requirements.in
# pyspark
pytest==7.3.1
# via
# -r requirements.in
# pytest-testdox
pytest-testdox==2.0.1
# via -r requirements.in
python-dateutil==2.8.2
# via pandas
pytz==2023.3.post1
# via pandas
regex==2020.10.28
# via black
six==1.15.0
# via packaging
toml==0.10.1
# via
# black
# pytest
# packaging
# python-dateutil
toml==0.10.1
# via black
tomli==2.0.1
# via pytest
typed-ast==1.4.1
# via black
typing-extensions==3.7.4.3
# via black
tzdata==2023.3
# via pandas
41 changes: 32 additions & 9 deletions src/pyspark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,36 @@

import pyspark

try:
from pyspark.sql.connect.dataframe import DataFrame as CDF

def _check_isinstance(left: Any, right: Any, cls):
assert isinstance(
left, cls
), f"Left expected type {cls}, found {type(left)} instead"
assert isinstance(
right, cls
), f"Right expected type {cls}, found {type(right)} instead"
has_connect_deps = True
except ImportError:
has_connect_deps = False


def _check_isinstance_df(left: Any, right: Any):
types_to_test = [pyspark.sql.DataFrame]
msg_string = ""
# If Spark Connect dependencies are not available, the input is not going to be a Spark Connect
# DataFrame so we can safely skip the validation.
if has_connect_deps:
types_to_test.append(CDF)
msg_string = " or {CDF}"

left_good = any(map(lambda x: isinstance(left, x), types_to_test))
right_good = any(map(lambda x: isinstance(right, x), types_to_test))
assert (
left_good
), f"Left expected type {pyspark.sql.DataFrame}{msg_string}, found {type(left)} instead"
assert (
right_good
), f"Right expected type {pyspark.sql.DataFrame}{msg_string}, found {type(right)} instead"

# Check that both sides are of the same DataFrame type.
assert type(left) == type(
right
), f"Left and right DataFrames are not of the same type: {type(left)} != {type(right)}"


def _check_columns(
Expand Down Expand Up @@ -39,7 +61,8 @@ def _check_schema(


def _check_df_content(
left_df: pyspark.sql.DataFrame, right_df: pyspark.sql.DataFrame,
left_df: pyspark.sql.DataFrame,
right_df: pyspark.sql.DataFrame,
):
left_df_list = left_df.collect()
right_df_list = right_df.collect()
Expand Down Expand Up @@ -88,7 +111,7 @@ def assert_pyspark_df_equal(
"""

# Check if
_check_isinstance(left_df, right_df, pyspark.sql.DataFrame)
_check_isinstance_df(left_df, right_df)

# Check Column Names
if check_column_names:
Expand Down
18 changes: 17 additions & 1 deletion tests/unit_test/test_assert_pyspark_df_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

from src.pyspark_test import assert_pyspark_df_equal
from src.pyspark_test import _check_isinstance_df


class TestAssertPysparkDfEqual:
Expand Down Expand Up @@ -68,7 +69,7 @@ def test_assert_pyspark_df_equal_one_is_not_pyspark_df(
right_df = "Demo"
with pytest.raises(
AssertionError,
match="Right expected type <class 'pyspark.sql.dataframe.DataFrame'>, found <class 'str'> instead",
match="Right expected type <class 'pyspark.sql.dataframe.DataFrame'> or .*?, found <class 'str'> instead",
):
assert_pyspark_df_equal(left_df, right_df)

Expand Down Expand Up @@ -324,3 +325,18 @@ def test_assert_pyspark_df_equal_different_row_count(
match="Number of rows are not same.\n \n Actual Rows: 2\n Expected Rows: 3",
):
assert_pyspark_df_equal(left_df, right_df)

def test_instance_checks_for_spark_connect(
self, spark_session: pyspark.sql.SparkSession
):
from pyspark.sql.connect.dataframe import DataFrame as CDF
left_df = spark_session.range(1)
right_df = spark_session.range(1)
_check_isinstance_df(left_df, right_df)

left_df = CDF.withPlan(None, None)
right_df = CDF.withPlan(None, None)
_check_isinstance_df(left_df, right_df)

with pytest.raises(AssertionError):
_check_isinstance_df(spark_session.range(1), right_df)