Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test polars support #826

Merged
merged 26 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
20081ae
Rename polars_missing_msg in maj
TheooJ Nov 10, 2023
0f8aa2b
Test polars inputs in test_joiner
TheooJ Nov 10, 2023
cc30981
Test polars inputs in test_deduplicate
TheooJ Nov 10, 2023
7359a19
Test polars inputs in test_fuzzy_join
TheooJ Nov 10, 2023
a2f413c
Test polars inputs in test_minhash_encoder
TheooJ Nov 10, 2023
54367db
Rename list of tested modules to MODULES
TheooJ Nov 10, 2023
207b90c
Test polars inputs in test_gap_encoder. Add dict of possible NULL opt…
TheooJ Nov 10, 2023
84ed228
Test polars inputs in test_datetime_encore. Lots of tests don't pass
TheooJ Nov 10, 2023
69eb2d1
Store comparison utils in list of tuples instead of dictionaries
TheooJ Nov 14, 2023
ba6d613
Merge branch 'main' into test_polars
TheooJ Nov 14, 2023
e623a72
Adapt test_interpolation_join for polars. All tests xfail because df.…
TheooJ Nov 14, 2023
8a7429c
Remove NULL dict in test_gap_encoder
TheooJ Nov 17, 2023
ae88ba3
Xfail set_output to polars in test_similarity_encoder
TheooJ Nov 17, 2023
c07047a
Create dfs with pandas, then convert them in px.df in test_interpolat…
TheooJ Nov 17, 2023
02075de
Format
TheooJ Nov 19, 2023
aab0874
Remove pl testing in test_deduplicate as it isn't dependent on it
TheooJ Nov 19, 2023
06a87d0
Merge branch 'main' into test_polars
TheooJ Nov 19, 2023
5b80fd4
Format
TheooJ Nov 19, 2023
d9163d4
Create pd.DataFrames first in test_datetime_encoder
TheooJ Nov 19, 2023
e4046ee
Fix error when polars isn't available
TheooJ Nov 20, 2023
b579028
Create function to test if the polars module is available. Use it to …
TheooJ Nov 21, 2023
45b3af9
Move functions to test modules to _utils.py
TheooJ Nov 21, 2023
e0812de
Move functions to test modules into a new _test_utils.py in _dataframe
TheooJ Nov 21, 2023
569d7cc
Rename is_namespace into is_module
TheooJ Nov 21, 2023
74c07f4
Merge branch 'main' into test_polars
TheooJ Nov 21, 2023
16adcaf
Merge main
TheooJ Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions skrub/_dataframe/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
}
)
else:
polars_missing_msg = "Polars is not available"
pytest.skip(reason=polars_missing_msg, allow_module_level=True)
POLARS_MISSING_MSG = "Polars is not available"
pytest.skip(reason=POLARS_MISSING_MSG, allow_module_level=True)

Check warning on line 26 in skrub/_dataframe/tests/test_polars.py

View check run for this annotation

Codecov / codecov/patch

skrub/_dataframe/tests/test_polars.py#L25-L26

Added lines #L25 - L26 were not covered by tests


def test_join():
Expand Down
6 changes: 3 additions & 3 deletions skrub/tests/test_agg_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
)


assert_tuples = [(main, pd, assert_frame_equal)]
ASSERT_TUPLES = [(main, pd, assert_frame_equal)]
if POLARS_SETUP:
assert_tuples.append((pl.DataFrame(main), pl, assert_frame_equal_pl))
ASSERT_TUPLES.append((pl.DataFrame(main), pl, assert_frame_equal_pl))


@pytest.mark.parametrize("use_X_placeholder", [False, True])
@pytest.mark.parametrize(
"X, px, assert_frame_equal_",
assert_tuples,
ASSERT_TUPLES,
)
def test_simple_fit_transform(use_X_placeholder, X, px, assert_frame_equal_):
aux = X if not use_X_placeholder else "X"
Expand Down
69 changes: 56 additions & 13 deletions skrub/tests/test_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,24 @@
from numpy.testing import assert_allclose, assert_array_equal
from pandas.testing import assert_frame_equal

from skrub._dataframe._polars import POLARS_SETUP
from skrub._datetime_encoder import (
TIME_LEVELS,
DatetimeEncoder,
_is_pandas_format_mixed_available,
to_datetime,
)

MODULES = [pd]
ASSERT_TUPLES = [(pd, assert_frame_equal)]

if POLARS_SETUP:
import polars as pl
from polars.testing import assert_frame_equal as assert_frame_equal_pl

MODULES.append(pl)
ASSERT_TUPLES.append((pl, assert_frame_equal_pl))

NANOSECONDS_FORMAT = (
"%Y-%m-%d %H:%M:%S.%f" if _is_pandas_format_mixed_available() else None
)
Expand Down Expand Up @@ -119,6 +130,7 @@ def get_mixed_datetime_format(as_array=False):
return df


@pytest.mark.parametrize("px", MODULES)
@pytest.mark.parametrize("as_array", [True, False])
@pytest.mark.parametrize(
"get_data_func, features, format",
Expand All @@ -135,6 +147,7 @@ def get_mixed_datetime_format(as_array=False):
)
@pytest.mark.parametrize("resolution", TIME_LEVELS)
def test_fit(
px,
as_array,
get_data_func,
features,
Expand Down Expand Up @@ -175,8 +188,10 @@ def test_fit(
assert enc.get_feature_names_out() == expected_feature_names


def test_format_nan():
@pytest.mark.parametrize("px", MODULES)
def test_format_nan(px):
X = get_nan_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder().fit(X)
expected_index_to_format = {
0: "%Y-%m-%d %H:%M:%S",
Expand All @@ -186,14 +201,18 @@ def test_format_nan():
assert enc.index_to_format_ == expected_index_to_format


def test_format_nz():
@pytest.mark.parametrize("px", MODULES)
def test_format_nz(px):
X = get_tz_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder().fit(X)
assert enc.index_to_format_ == {0: "%Y-%m-%d %H:%M:%S%z"}


def test_resolution_none():
@pytest.mark.parametrize("px", MODULES)
def test_resolution_none(px):
X = get_datetime()
px.DataFrame(X)
enc = DatetimeEncoder(
resolution=None,
add_total_seconds=False,
Expand All @@ -205,8 +224,10 @@ def test_resolution_none():
assert enc.get_feature_names_out() == []


def test_transform_date():
@pytest.mark.parametrize("px", MODULES)
def test_transform_date(px):
X = get_date()
X = px.DataFrame(X)
enc = DatetimeEncoder(
add_total_seconds=False,
)
Expand All @@ -224,8 +245,10 @@ def test_transform_date():
assert_array_equal(X_trans, expected_result)


def test_transform_datetime():
@pytest.mark.parametrize("px", MODULES)
def test_transform_datetime(px):
X = get_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder(
resolution="second",
add_total_seconds=False,
Expand All @@ -242,8 +265,10 @@ def test_transform_datetime():
assert_array_equal(X_trans, expected_X_trans)


def test_transform_tz():
@pytest.mark.parametrize("px", MODULES)
def test_transform_tz(px):
X = get_tz_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder(
add_total_seconds=True,
)
Expand All @@ -259,8 +284,10 @@ def test_transform_tz():
assert_allclose(X_trans, expected_X_trans)


def test_transform_nan():
@pytest.mark.parametrize("px", MODULES)
def test_transform_nan(px):
X = get_nan_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder(
add_total_seconds=True,
)
Expand Down Expand Up @@ -323,8 +350,17 @@ def test_transform_nan():
assert_allclose(X_trans, expected_X_trans)


def test_mixed_type_dataframe():
@pytest.mark.parametrize("px", MODULES)
def test_mixed_type_dataframe(px):
if px is pl:
pytest.xfail(
reason=(
"to_datetime(X) raises polars.exceptions.ComputeError: cannot cast"
" 'Object' type"
)
)
X = get_mixed_type_dataframe()
X = px.DataFrame(X)
enc = DatetimeEncoder().fit(X)
assert enc.index_to_format_ == {0: "%Y-%m-%d", 4: "%d/%m/%Y"}

Expand All @@ -343,19 +379,23 @@ def test_mixed_type_dataframe():
assert X_dt.dtype == np.object_


def test_indempotency():
@pytest.mark.parametrize("px, assert_frame_equal_", ASSERT_TUPLES)
def test_indempotency(px, assert_frame_equal_):
df = get_mixed_datetime_format()
df = px.DataFrame(df)
df_dt = to_datetime(df)
df_dt_2 = to_datetime(df_dt)
assert_frame_equal(df_dt, df_dt_2)
assert_frame_equal_(df_dt, df_dt_2)

X_trans = DatetimeEncoder().fit_transform(df)
X_trans_2 = DatetimeEncoder().fit_transform(df_dt)
assert_array_equal(X_trans, X_trans_2)


def test_datetime_encoder_invalid_params():
@pytest.mark.parametrize("px", MODULES)
def test_datetime_encoder_invalid_params(px):
X = get_datetime()
X = px.DataFrame(X)

with pytest.raises(ValueError, match=r"(?=.*'resolution' options)"):
DatetimeEncoder(resolution="hello").fit(X)
Expand Down Expand Up @@ -419,8 +459,10 @@ def test_to_datetime_format_param():
assert_array_equal(out, expected_out)


def test_mixed_datetime_format():
@pytest.mark.parametrize("px, assert_frame_equal_", ASSERT_TUPLES)
def test_mixed_datetime_format(px, assert_frame_equal_):
df = get_mixed_datetime_format()
df = px.DataFrame(df)

df_dt = to_datetime(df)
expected_df_dt = pd.DataFrame(
Expand All @@ -433,7 +475,8 @@ def test_mixed_datetime_format():
]
)
)
assert_frame_equal(df_dt, expected_df_dt)
expected_df_dt = px.DataFrame(expected_df_dt)
assert_frame_equal_(df_dt, expected_df_dt)

series_dt = to_datetime(df["a"])
expected_series_dt = expected_df_dt["a"]
Expand Down
14 changes: 7 additions & 7 deletions skrub/tests/test_deduplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@


@pytest.mark.parametrize(
["entries_per_category", "prob_mistake_per_letter"],
[[[500, 100, 1500], 0.05], [[100, 100], 0.02], [[200, 50, 30, 200, 800], 0.01]],
"entries_per_category, prob_mistake_per_letter",
[([500, 100, 1500], 0.05), ([100, 100], 0.02), ([200, 50, 30, 200, 800], 0.01)],
)
def test_deduplicate(
entries_per_category: list[int],
prob_mistake_per_letter: float,
seed: int = 123,
) -> None:
):
rng = np.random.RandomState(seed)

# hard coded to fix ground truth string similarities
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_deduplicate(
assert np.isin(unique_other_analyzer, recovered_categories).all()


def test_compute_ngram_distance() -> None:
def test_compute_ngram_distance():
words = np.array(["aac", "aaa", "aaab", "aaa", "aaab", "aaa", "aaab", "aaa"])
distance = compute_ngram_distance(words)
distance = squareform(distance)
Expand All @@ -70,15 +70,15 @@ def test_compute_ngram_distance() -> None:
assert np.allclose(distance[words == un_word][:, words == un_word], 0)


def test__guess_clusters() -> None:
def test__guess_clusters():
words = np.array(["aac", "aaa", "aaab", "aaa", "aaab", "aaa", "aaab", "aaa"])
distance = compute_ngram_distance(words)
Z = linkage(distance, method="average")
n_clusters = _guess_clusters(Z, distance)
assert n_clusters == len(np.unique(words))


def test__create_spelling_correction(seed: int = 123) -> None:
def test__create_spelling_correction(seed: int = 123):
rng = np.random.RandomState(seed)
n_clusters = 3
samples_per_cluster = 10
Expand Down Expand Up @@ -116,7 +116,7 @@ def default_deduplicate(n: int = 500, random_state=0):
return X, y


def test_parallelism() -> None:
def test_parallelism():
"""Tests that parallelism works with different backends and n_jobs."""

X, y = default_deduplicate(n=200)
Expand Down
Loading
Loading