Skip to content

Commit

Permalink
ENH Polars support in Joiner (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheooJ authored Jun 21, 2024
1 parent a47c0a5 commit 7fe0f27
Show file tree
Hide file tree
Showing 15 changed files with 481 additions and 327 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ It is currently undergoing fast development and backward compatibility is not en

Major changes
-------------
* The :class:`Joiner` has been adapted to support polars dataframes. :pr:`945` by :user:`Théo Jolivet <TheooJ>`.

* The :class:`TableVectorizer` now consistently applies the same transformation
across different calls to `transform`. There also have been some breaking
changes to its functionality: (i) all transformations are now applied
Expand Down
10 changes: 4 additions & 6 deletions skrub/_agg_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,8 @@ def transform(self, X):
X, _ = self._check_dataframes(X, self.aux_table_)
_join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)")

skrub_px, _ = get_df_namespace(self.aux_table_)
X = skrub_px.join(
left=X,
X = _join_utils.left_join(
X,
right=self.aux_table_,
left_on=self._main_key,
right_on=self._aux_key,
Expand Down Expand Up @@ -439,10 +438,9 @@ def transform(self, X):
The augmented input.
"""
check_is_fitted(self, "y_")
skrub_px, _ = get_df_namespace(X)

return skrub_px.join(
left=X,
return _join_utils.left_join(
X,
right=self.y_,
left_on=self.main_key_,
right_on=self.main_key_,
Expand Down
7 changes: 7 additions & 0 deletions skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"sample",
"head",
"replace",
"with_columns",
]

#
Expand Down Expand Up @@ -1007,3 +1008,9 @@ def _replace_pandas(col, old, new):
@replace.specialize("polars", argument_type="Column")
def _replace_polars(col, old, new):
return col.replace(old, new)


def with_columns(df, **new_cols):
cols = {col_name: col(df, col_name) for col_name in column_names(df)}
cols.update({n: make_column_like(df, c, n) for n, c in new_cols.items()})
return make_dataframe_like(df, cols)
44 changes: 1 addition & 43 deletions skrub/_dataframe/_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,49 +104,7 @@ def aggregate(
]
sorted_cols = sorted(base_group.columns)

return base_group[sorted_cols]


def join(
left,
right,
left_on,
right_on,
):
"""Left join two :obj:`pandas.DataFrame`.
This function uses the ``dataframe.merge`` method from Pandas.
Parameters
----------
left : pd.DataFrame
The left dataframe to left-join.
right : pd.DataFrame
The right dataframe to left-join.
left_on : str or Iterable[str]
Left keys to merge on.
right_on : str or Iterable[str]
Right keys to merge on.
Returns
-------
merged : pd.DataFrame,
The merged output.
"""
if not (isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame)):
raise TypeError(
"'left' and 'right' must be pandas dataframes, "
f"got {type(left)!r} and {type(right)!r}."
)
return left.merge(
right,
how="left",
left_on=left_on,
right_on=right_on,
)
return base_group[sorted_cols].reset_index(drop=False)


def get_named_agg(table, cols, operations):
Expand Down
46 changes: 0 additions & 46 deletions skrub/_dataframe/_polars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
Polars specialization of the aggregate and join operations.
"""
import inspect

try:
import polars as pl
import polars.selectors as cs
Expand Down Expand Up @@ -91,50 +89,6 @@ def aggregate(
return table.select(sorted_cols)


def join(left, right, left_on, right_on):
"""Left join two :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`.
This function uses the ``dataframe.join`` method from Polars.
Note that the input dataframes type must agree: either both
Polars dataframes or both Polars lazyframes.
Mixing polars dataframe with lazyframe will raise an error.
Parameters
----------
left : pl.DataFrame or pl.LazyFrame
The left dataframe of the left-join.
right : pl.DataFrame or pl.LazyFrame
The right dataframe of the left-join.
left_on : str or Iterable[str]
Left keys to merge on.
right_on : str or Iterable[str]
Right keys to merge on.
Returns
-------
merged : pl.DataFrame or pl.LazyFrame
The merged output.
"""
is_dataframe = isinstance(left, pl.DataFrame) and isinstance(right, pl.DataFrame)
is_lazyframe = isinstance(left, pl.LazyFrame) and isinstance(right, pl.LazyFrame)
if is_dataframe or is_lazyframe:
if "coalesce" in inspect.signature(left.join).parameters:
kw = {"coalesce": True}
else:
kw = {}
return left.join(right, how="left", left_on=left_on, right_on=right_on, **kw)
else:
raise TypeError(
"'left' and 'right' must be polars dataframes or lazyframes, "
f"got {type(left)!r} and {type(right)!r}."
)


def get_aggfuncs(cols, operations):
"""List Polars aggregation functions.
Expand Down
40 changes: 40 additions & 0 deletions skrub/_dataframe/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_not_implemented():
"reset_index",
"copy_index",
"index",
"with_columns",
}
for func_name in sorted(set(ns.__all__) - has_default_impl):
func = getattr(ns, func_name)
Expand Down Expand Up @@ -147,6 +148,10 @@ def test_make_column_like(df_module, example_data_dict):
)
assert ns.dataframe_module_name(col) == df_module.name

col = df_module.make_column("old_name", [1, 2, 3])
expected = df_module.make_column("new_name", [1, 2, 3])
df_module.assert_column_equal(ns.make_column_like(col, col, "new_name"), expected)


def test_null_value_for(df_module):
assert ns.null_value_for(df_module.example_dataframe) is None
Expand Down Expand Up @@ -645,3 +650,38 @@ def same(c1, c2):

same(ns.drop_nulls(s), col([1.1, 2.2, float("inf")]))
same(ns.fill_nulls(s, -1.0), col([1.1, -1.0, 2.2, -1.0, float("inf")]))


def test_with_columns(df_module):
df = df_module.make_dataframe({"a": [1, 2], "b": [3, 4]})

# Add one new col
out = ns.with_columns(df, **{"c": [5, 6]})
if df_module.description == "pandas-nullable-dtypes":
# for pandas, make_column_like will return an old-style / numpy dtypes Series
out = ns.pandas_convert_dtypes(out)
expected = df_module.make_dataframe({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
df_module.assert_frame_equal(out, expected)

# Add multiple new cols
out = ns.with_columns(df, **{"c": [5, 6], "d": [7, 8]})
if df_module.description == "pandas-nullable-dtypes":
out = ns.pandas_convert_dtypes(out)
expected = df_module.make_dataframe(
{"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]}
)
df_module.assert_frame_equal(out, expected)

# Pass a col instead of an array
out = ns.with_columns(df, **{"c": df_module.make_column("c", [5, 6])})
if df_module.description == "pandas-nullable-dtypes":
out = ns.pandas_convert_dtypes(out)
expected = df_module.make_dataframe({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
df_module.assert_frame_equal(out, expected)

# Replace col
out = ns.with_columns(df, **{"a": [5, 6]})
if df_module.description == "pandas-nullable-dtypes":
out = ns.pandas_convert_dtypes(out)
expected = df_module.make_dataframe({"a": [5, 6], "b": [3, 4]})
df_module.assert_frame_equal(out, expected)
16 changes: 3 additions & 13 deletions skrub/_dataframe/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from skrub._dataframe._pandas import (
aggregate,
join,
rename_columns,
)

Expand All @@ -18,12 +17,6 @@
)


def test_join():
joined = join(left=main, right=main, left_on="movieId", right_on="movieId")
expected = main.merge(main, on="movieId", how="left")
assert_frame_equal(joined, expected)


def test_simple_agg():
aggregated = aggregate(
table=main,
Expand All @@ -36,7 +29,7 @@ def test_simple_agg():
"genre_mode": ("genre", pd.Series.mode),
"rating_mean": ("rating", "mean"),
}
expected = main.groupby("movieId").agg(**aggfunc)
expected = main.groupby("movieId").agg(**aggfunc).reset_index()
assert_frame_equal(aggregated, expected)


Expand All @@ -56,7 +49,7 @@ def test_value_counts_agg():
"rating_4.0_user": [3.0, 1.0],
"userId": [1, 2],
}
)
).reset_index(drop=False)
assert_frame_equal(aggregated, expected)

aggregated = aggregate(
Expand All @@ -73,14 +66,11 @@ def test_value_counts_agg():
"rating_(3.0, 4.0]_user": [3, 1],
"userId": [1, 2],
}
)
).reset_index(drop=False)
assert_frame_equal(aggregated, expected)


def test_incorrect_dataframe_inputs():
with pytest.raises(TypeError, match=r"(?=.*pandas dataframes)(?=.*array)"):
join(left=main.values, right=main, left_on="movieId", right_on="movieId")

with pytest.raises(TypeError, match=r"(?=.*pandas dataframe)(?=.*array)"):
aggregate(
table=main.values,
Expand Down
16 changes: 0 additions & 16 deletions skrub/_dataframe/tests/test_polars.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import inspect

import pandas as pd
import pytest

from skrub._dataframe._polars import (
aggregate,
join,
rename_columns,
)
from skrub.conftest import _POLARS_INSTALLED
Expand All @@ -27,16 +24,6 @@
pytest.skip(reason=POLARS_MISSING_MSG, allow_module_level=True)


def test_join():
joined = join(left=main, right=main, left_on="movieId", right_on="movieId")
if "coalesce" in inspect.signature(main.join).parameters:
kw = {"coalesce": True}
else:
kw = {}
expected = main.join(main, on="movieId", how="left", **kw)
assert_frame_equal(joined, expected)


def test_simple_agg():
aggregated = aggregate(
table=main,
Expand Down Expand Up @@ -68,9 +55,6 @@ def test_mode_agg():


def test_incorrect_dataframe_inputs():
with pytest.raises(TypeError, match=r"(?=.*polars dataframes)(?=.*pandas)"):
join(left=pd.DataFrame(main), right=main, left_on="movieId", right_on="movieId")

with pytest.raises(TypeError, match=r"(?=.*polars dataframe)(?=.*pandas)"):
aggregate(
table=pd.DataFrame(main),
Expand Down
10 changes: 6 additions & 4 deletions skrub/_fuzzy_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""
import numpy as np

from skrub import _join_utils
from skrub._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner
from . import _dataframe as sbd
from . import _join_utils
from . import _selectors as s
from ._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner


def fuzzy_join(
Expand Down Expand Up @@ -210,7 +212,7 @@ def fuzzy_join(
add_match_info=True,
).fit_transform(left)
if drop_unmatched:
join = join[join["skrub_Joiner_match_accepted"]]
join = sbd.filter(join, sbd.col(join, "skrub_Joiner_match_accepted"))
if not add_match_info:
join = join.drop(Joiner.match_info_columns, axis=1)
join = s.select(join, ~s.cols(*Joiner.match_info_columns))
return join
Loading

0 comments on commit 7fe0f27

Please sign in to comment.