From 849a70156becc5657920fd25f7de37eea095f455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 12 Jun 2024 20:35:30 +0200 Subject: [PATCH 01/53] Plan out future changes, dispatch with_columns --- skrub/_dataframe/_common.py | 14 ++++++++ skrub/_joiner.py | 64 +++++++++++++++++++++++++------------ 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index e79bd562e..61bec4ef3 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -82,6 +82,7 @@ "sample", "head", "replace", + "with_columns", ] # @@ -955,3 +956,16 @@ def _replace_pandas(column, old, new): @replace.specialize("polars") def _replace_polars(column, old, new): return column.replace(old, new) + + +@dispatch +def with_columns(df, **new_columns): + raise NotImplementedError() + + +@with_columns.specialize("pandas", argument_type="DataFrame") +@with_columns.specialize("polars", argument_type="DataFrame") +def _with_columns_dataframe(df, **new_columns): + columns = {col_name: col(df, col_name) for col_name in column_names(df)} + columns.update({n: make_column_like(df, c, n) for n, c in new_columns.items()}) + return make_dataframe_like(df, columns) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 54d7ae956..0f67c0060 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -12,6 +12,9 @@ from sklearn.utils.validation import check_is_fitted from skrub import _join_utils, _matching, _utils +from skrub._dataframe import _common as ns + +# TODO: rm from skrub._dataframe._namespace import is_pandas, is_polars from skrub._datetime_encoder import DatetimeEncoder @@ -26,6 +29,7 @@ def _as_str(column): DEFAULT_STRING_ENCODER = make_pipeline( FunctionTransformer(_as_str), HashingVectorizer(analyzer="char_wb", ngram_range=(2, 4)), + # TODO: Remove sparse output from Tfidf to work with TableVectorizer TfidfTransformer(), ) _DATETIME_ENCODER = DatetimeEncoder(resolution=None, add_total_seconds=True) @@ -51,7 +55,9 @@ def _make_vectorizer(table, string_encoder, rescale): # TODO remove use of ColumnTransformer, select_dtypes & pandas-specific code transformers = [ (clone(string_encoder), c) - for c in table.select_dtypes(include=["string", "category", "object"]).columns + for c in table.select_dtypes( + include=["string", "category", "object"] + ).columns # TODO: Use selector (s.expand) ] num_columns = table.select_dtypes(include="number").columns if not num_columns.empty: @@ -120,7 +126,7 @@ class Joiner(TransformerMixin, BaseEstimator): Parameters ---------- - aux_table : :obj:`~pandas.DataFrame` + aux_table : :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` The auxiliary table, which will be fuzzy-joined to the main table when calling `transform`. key : str or iterable of str, default=None @@ -145,7 +151,7 @@ class Joiner(TransformerMixin, BaseEstimator): although rescaled distances can be greater than 1 for some choices of `ref_dist`. `None`, `"inf"`, `float("inf")` or `numpy.inf` mean that no matches are rejected. - ref_dist : reference distance for rescaling, default = 'random_pairs' + ref_dist : reference distance for rescaling, default='random_pairs' Options are {"random_pairs", "second_neighbor", "self_join_neighbor", "no_rescaling"}. See above for a description of each option. To facilitate the choice of `max_dist`, distances between rows in @@ -178,7 +184,7 @@ class Joiner(TransformerMixin, BaseEstimator): See Also -------- AggJoiner : - Aggregate auxiliary dataframes before joining them on a base dataframe. + Aggregate an auxiliary dataframe before joining it on a base dataframe. fuzzy_join : Join two tables (dataframes) based on approximate column matching. This @@ -247,6 +253,7 @@ def __init__( ) self.add_match_info = add_match_info + # TODO: rm def _check_dataframe(self, dataframe): # TODO: add support for polars, ATM we just convert to pandas if is_polars(dataframe): @@ -271,7 +278,7 @@ def _check_max_dist(self): def _check_ref_dist(self): if self.ref_dist not in _MATCHERS: raise ValueError( - f"ref_dist should be one of {list(_MATCHERS.keys())}, got" + f"`ref_dist` should be one of {list(_MATCHERS.keys())}, got" f" {self.ref_dist!r}" ) self._matching = _MATCHERS[self.ref_dist]() @@ -281,7 +288,7 @@ def fit(self, X, y=None): Parameters ---------- - X : :obj:`~pandas.DataFrame`, shape [n_samples, n_features] + X : :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` The main table, to be joined to the auxiliary ones. y : None Unused, only here for compatibility. @@ -305,11 +312,11 @@ def fit(self, X, y=None): X, self._aux_table, self.suffix, main_table_name="X" ) self.vectorizer_ = _make_vectorizer( - self._aux_table[self._aux_key], + ns.col(self._aux_table, self._aux_key), self.string_encoder, rescale=self.ref_dist != "no_rescaling", ) - aux = self.vectorizer_.fit_transform(self._aux_table[self._aux_key]) + aux = self.vectorizer_.fit_transform(ns.col(self._aux_table, self._aux_key)) self._matching.fit(aux) return self @@ -318,50 +325,65 @@ def transform(self, X, y=None): Parameters ---------- - X : :obj:`~pandas.DataFrame`, shape [n_samples, n_features] + X : :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` The main table, to be joined to the auxiliary ones. y : None Unused, only here for compatibility. Returns ------- - :obj:`~pandas.DataFrame` + :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` The final joined table. """ del y check_is_fitted(self, "vectorizer_") + # rm input_is_polars = is_polars(X) X = self._check_dataframe(X) _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) + # TODO: dispatch select, set_axis in df._common main = self.vectorizer_.transform( - X[self._main_key].set_axis(self._aux_key, axis="columns") + ns.col(X, self._main_key).set_axis(self._aux_key, axis="columns") + ) + _match_result = self._matching.match(main, self.max_dist_) + match_result = ns.make_dataframe_like(X, _match_result) + aux_table = ns.reset_index( + _join_utils.add_column_name_suffix(self._aux_table, self.suffix) ) - match_result = self._matching.match(main, self.max_dist_) - aux_table = _join_utils.add_column_name_suffix( - self._aux_table, self.suffix - ).reset_index(drop=True) - matching_col = match_result["index"].copy() - matching_col[~match_result["match_accepted"]] = -1 + # dispatch copy ? + matching_col = ns.col(match_result, "index").copy() + matching_col[~ns.col(match_result, "match_accepted")] = -1 token = _utils.random_string() left_key_name = f"skrub_left_key_{token}" right_key_name = f"skrub_right_key_{token}" - left = X.assign(**{left_key_name: matching_col}) - right = aux_table.assign(**{right_key_name: np.arange(aux_table.shape[0])}) + left = ns.with_columns(X, **{left_key_name: matching_col}) + right = ns.with_columns( + aux_table, **{right_key_name: np.arange(ns.shape(aux_table)[0])} + ) + # TODO: dispatch in ``_join_utils`` + # TODO: check pd vs pl behavior and how can we use the duplicates -> in PR WIP join = pd.merge( left, right, left_on=left_key_name, right_on=right_key_name, - suffixes=("", ""), + suffixes=( + "", + "", + ), how="left", ) + # TODO: dispatch ``drop`` join = join.drop([left_key_name, right_key_name], axis=1) if self.add_match_info: for info_key, info_col_name in self._match_info_key_renaming.items(): - join[info_col_name] = match_result[info_key] + join = ns.with_columns( + join, **{info_col_name: ns.col(match_result, info_key)} + ) + # TODO: remove this part if input_is_polars: import polars as pl From 66042e75d0e35df2a1e05987494806abe385384e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 11:26:33 +0200 Subject: [PATCH 02/53] Dispatch left_join in _join_utils --- skrub/_dataframe/_pandas.py | 42 ---------- skrub/_dataframe/_polars.py | 46 ----------- skrub/_dataframe/tests/test_pandas.py | 12 +-- skrub/_dataframe/tests/test_polars.py | 12 +-- skrub/_join_utils.py | 111 ++++++++++++++++++++++++++ skrub/_joiner.py | 4 +- skrub/tests/test_join_utils.py | 43 ++++++++++ 7 files changed, 161 insertions(+), 109 deletions(-) diff --git a/skrub/_dataframe/_pandas.py b/skrub/_dataframe/_pandas.py index 5e9f369bf..25f244491 100644 --- a/skrub/_dataframe/_pandas.py +++ b/skrub/_dataframe/_pandas.py @@ -107,48 +107,6 @@ def aggregate( return base_group[sorted_cols] -def join( - left, - right, - left_on, - right_on, -): - """Left join two :obj:`pandas.DataFrame`. - - This function uses the ``dataframe.merge`` method from Pandas. - - Parameters - ---------- - left : pd.DataFrame - The left dataframe to left-join. - - right : pd.DataFrame - The right dataframe to left-join. - - left_on : str or Iterable[str] - Left keys to merge on. - - right_on : str or Iterable[str] - Right keys to merge on. - - Returns - ------- - merged : pd.DataFrame, - The merged output. - """ - if not (isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame)): - raise TypeError( - "'left' and 'right' must be pandas dataframes, " - f"got {type(left)!r} and {type(right)!r}." - ) - return left.merge( - right, - how="left", - left_on=left_on, - right_on=right_on, - ) - - def get_named_agg(table, cols, operations): """Map aggregation tuples to their output key. diff --git a/skrub/_dataframe/_polars.py b/skrub/_dataframe/_polars.py index 8f1db8656..04af24156 100644 --- a/skrub/_dataframe/_polars.py +++ b/skrub/_dataframe/_polars.py @@ -1,8 +1,6 @@ """ Polars specialization of the aggregate and join operations. """ -import inspect - try: import polars as pl import polars.selectors as cs @@ -91,50 +89,6 @@ def aggregate( return table.select(sorted_cols) -def join(left, right, left_on, right_on): - """Left join two :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`. - - This function uses the ``dataframe.join`` method from Polars. - - Note that the input dataframes type must agree: either both - Polars dataframes or both Polars lazyframes. - - Mixing polars dataframe with lazyframe will raise an error. - - Parameters - ---------- - left : pl.DataFrame or pl.LazyFrame - The left dataframe of the left-join. - - right : pl.DataFrame or pl.LazyFrame - The right dataframe of the left-join. - - left_on : str or Iterable[str] - Left keys to merge on. - - right_on : str or Iterable[str] - Right keys to merge on. - - Returns - ------- - merged : pl.DataFrame or pl.LazyFrame - The merged output. - """ - is_dataframe = isinstance(left, pl.DataFrame) and isinstance(right, pl.DataFrame) - is_lazyframe = isinstance(left, pl.LazyFrame) and isinstance(right, pl.LazyFrame) - if is_dataframe or is_lazyframe: - if "coalesce" in inspect.signature(left.join).parameters: - kw = {"coalesce": True} - else: - kw = {} - return left.join(right, how="left", left_on=left_on, right_on=right_on, **kw) - else: - raise TypeError( - "'left' and 'right' must be polars dataframes or lazyframes, " - f"got {type(left)!r} and {type(right)!r}." - ) - - def get_aggfuncs(cols, operations): """List Polars aggregation functions. diff --git a/skrub/_dataframe/tests/test_pandas.py b/skrub/_dataframe/tests/test_pandas.py index 7f4d6141b..62e328792 100644 --- a/skrub/_dataframe/tests/test_pandas.py +++ b/skrub/_dataframe/tests/test_pandas.py @@ -4,7 +4,6 @@ from skrub._dataframe._pandas import ( aggregate, - join, rename_columns, ) @@ -18,12 +17,6 @@ ) -def test_join(): - joined = join(left=main, right=main, left_on="movieId", right_on="movieId") - expected = main.merge(main, on="movieId", how="left") - assert_frame_equal(joined, expected) - - def test_simple_agg(): aggregated = aggregate( table=main, @@ -78,8 +71,9 @@ def test_value_counts_agg(): def test_incorrect_dataframe_inputs(): - with pytest.raises(TypeError, match=r"(?=.*pandas dataframes)(?=.*array)"): - join(left=main.values, right=main, left_on="movieId", right_on="movieId") + # TODO: deal with this + # with pytest.raises(TypeError, match=r"(?=.*pandas dataframes)(?=.*array)"): + # join(left=main.values, right=main, left_on="movieId", right_on="movieId") with pytest.raises(TypeError, match=r"(?=.*pandas dataframe)(?=.*array)"): aggregate( diff --git a/skrub/_dataframe/tests/test_polars.py b/skrub/_dataframe/tests/test_polars.py index fdc682209..310a5ef6b 100644 --- a/skrub/_dataframe/tests/test_polars.py +++ b/skrub/_dataframe/tests/test_polars.py @@ -3,7 +3,6 @@ from skrub._dataframe._polars import ( aggregate, - join, rename_columns, ) from skrub.conftest import _POLARS_INSTALLED @@ -25,12 +24,6 @@ pytest.skip(reason=POLARS_MISSING_MSG, allow_module_level=True) -def test_join(): - joined = join(left=main, right=main, left_on="movieId", right_on="movieId") - expected = main.join(main, on="movieId", how="left", coalesce=True) - assert_frame_equal(joined, expected) - - def test_simple_agg(): aggregated = aggregate( table=main, @@ -62,8 +55,9 @@ def test_mode_agg(): def test_incorrect_dataframe_inputs(): - with pytest.raises(TypeError, match=r"(?=.*polars dataframes)(?=.*pandas)"): - join(left=pd.DataFrame(main), right=main, left_on="movieId", right_on="movieId") + # TODO: deal with this + # with pytest.raises(TypeError, match=r"(?=.*polars dataframes)(?=.*pandas)"): + # join(left=pd.DataFrame(main), right=main, left_on="movieId", right_on="movieId") # noqa: E501 with pytest.raises(TypeError, match=r"(?=.*polars dataframe)(?=.*pandas)"): aggregate( diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index bd1e62607..7a88f6555 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -1,10 +1,20 @@ """Utilities specific to the JOIN operations.""" +import inspect import re +import pandas as pd + +try: + import polars as pl +except ImportError: + pass + from skrub import _utils from skrub._dataframe._namespace import get_df_namespace +from ._dispatch import dispatch + def check_key( main_key, @@ -218,3 +228,104 @@ def _get_new_name(suggested_name, forbidden_names): return suggested_name token = _utils.random_string() return f"{untagged_name}__skrub_{token}__" + + +@dispatch +def left_join( + left, + right, + left_on, + right_on, + suffixes, + how, +): + raise NotImplementedError() + + +@left_join.specialize("pandas", argument_type="DataFrame") +def _left_join_pandas( + left, + right, + left_on, + right_on, + suffixes, +): + """Left join two :obj:`pandas.DataFrame`. + + This function uses the ``dataframe.merge`` method from Pandas. + + Parameters + ---------- + left : pd.DataFrame + The left dataframe to left-join. + + right : pd.DataFrame + The right dataframe to left-join. + + left_on : str or Iterable[str] + Left keys to merge on. + + right_on : str or Iterable[str] + Right keys to merge on. + + Returns + ------- + merged : pd.DataFrame, + The merged output. + """ + if not (isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame)): + raise TypeError( + "'left' and 'right' must be pandas dataframes, " + f"got {type(left)!r} and {type(right)!r}." + ) + return left.merge( + right, + how="left", + left_on=left_on, + right_on=right_on, + ).drop(columns=right_on) + + +@left_join.specialize("polars", argument_type="DataFrame") +def _left_join_polars(left, right, left_on, right_on, suffixes): + """Left join two :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`. + + This function uses the ``dataframe.join`` method from Polars. + + Note that the input dataframes type must agree: either both + Polars dataframes or both Polars lazyframes. + + Mixing polars dataframe with lazyframe will raise an error. + + Parameters + ---------- + left : pl.DataFrame or pl.LazyFrame + The left dataframe of the left-join. + + right : pl.DataFrame or pl.LazyFrame + The right dataframe of the left-join. + + left_on : str or Iterable[str] + Left keys to merge on. + + right_on : str or Iterable[str] + Right keys to merge on. + + Returns + ------- + merged : pl.DataFrame or pl.LazyFrame + The merged output. + """ + is_dataframe = isinstance(left, pl.DataFrame) and isinstance(right, pl.DataFrame) + is_lazyframe = isinstance(left, pl.LazyFrame) and isinstance(right, pl.LazyFrame) + if is_dataframe or is_lazyframe: + if "coalesce" in inspect.signature(left.join).parameters: + kw = {"coalesce": True} + else: + kw = {} + return left.join(right, how="left", left_on=left_on, right_on=right_on, **kw) + else: + raise TypeError( + "'left' and 'right' must be polars dataframes or lazyframes, " + f"got {type(left)!r} and {type(right)!r}." + ) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 0f67c0060..1cd9e86c4 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -3,7 +3,6 @@ """ import numpy as np -import pandas as pd from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.compose import make_column_transformer from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer @@ -365,7 +364,7 @@ def transform(self, X, y=None): ) # TODO: dispatch in ``_join_utils`` # TODO: check pd vs pl behavior and how can we use the duplicates -> in PR WIP - join = pd.merge( + join = _join_utils.left_join( left, right, left_on=left_key_name, @@ -374,7 +373,6 @@ def transform(self, X, y=None): "", "", ), - how="left", ) # TODO: dispatch ``drop`` join = join.drop([left_key_name, right_key_name], axis=1) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index 678754b1a..b97ac4ccb 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -2,6 +2,7 @@ import pytest from skrub import _join_utils +from skrub._dataframe._testing_utils import assert_frame_equal @pytest.mark.parametrize( @@ -72,3 +73,45 @@ def test_add_column_name_suffix(): assert list(df.columns) == ["one", "two three", "x"] df = _join_utils.add_column_name_suffix(df, "_y") assert list(df.columns) == ["one_y", "two three_y", "x_y"] + + +def test_left_join(df_module): + # Test all left keys in right dataframe + left = df_module.make_dataframe({"left_key": [1, 2, 2], "left_col": [10, 20, 30]}) + right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) + + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key", suffixes=None + ) + + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right_col": ["a", "b", "b"], + } + ) + assert_frame_equal(joined, expected) + + # Test all left keys not it right dataframe + # TODO: modify default behavior of `left_join` + # Will break for polars or pandas by default + # Polars keep right_key by default when it doesn't find a match + # left = df_module.make_dataframe( + # {"left_key": [1, 2, 2, 3], "left_col": [10, 20, 30, 40]} + # ) + # right = df_module.make_dataframe( + # {"right_key": [2, 9, 3, 5, 6], "right_col": ["a", "b", "c", "d", "e"]} + # ) + # joined = _join_utils.left_join( + # left, right=right, left_on="left_key", right_on="right_key", suffixes=None + # ) + # expected = df_module.make_dataframe( + # { + # "left_key": [1, 2, 2, 3], + # "left_col": [10, 20, 30, 40], + # "right_key": [None, 2, 2, 3], + # "right_col": [None, "a", "a", "c"], + # } + # ) + assert_frame_equal(joined, expected) From 6ed32034d457edf1c300751f4f6c81ffe00e2975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 11:27:41 +0200 Subject: [PATCH 03/53] Iter tests with_columns --- skrub/_dataframe/tests/test_common.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index 060739f33..15c17a5f2 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -1,3 +1,4 @@ +skrub/_dataframe/tests/test_common.py """ Note: most tests in this file use the ``df_module`` fixture, which is defined in ``skrub.conftest``. See the corresponding docstrings for details. @@ -589,3 +590,11 @@ def same(c1, c2): same(ns.drop_nulls(s), col([1.1, 2.2, float("inf")])) same(ns.fill_nulls(s, -1.0), col([1.1, -1.0, 2.2, -1.0, float("inf")])) + + +def test_with_columns(df_module): + # TODO: test one new col + # TODO: test multiple new cols + # TODO: test array (+ 1 test in ns.make_column_like(s) = s) + # TODO: test replace col + pass From 72cad9beb61fa700913f4b70decdd75973ff4ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 11:51:39 +0200 Subject: [PATCH 04/53] Iter left_join --- skrub/_join_utils.py | 5 +---- skrub/_joiner.py | 8 ++------ skrub/tests/test_join_utils.py | 4 ++-- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index 7a88f6555..486b992d4 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -236,8 +236,6 @@ def left_join( right, left_on, right_on, - suffixes, - how, ): raise NotImplementedError() @@ -248,7 +246,6 @@ def _left_join_pandas( right, left_on, right_on, - suffixes, ): """Left join two :obj:`pandas.DataFrame`. @@ -287,7 +284,7 @@ def _left_join_pandas( @left_join.specialize("polars", argument_type="DataFrame") -def _left_join_polars(left, right, left_on, right_on, suffixes): +def _left_join_polars(left, right, left_on, right_on): """Left join two :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`. This function uses the ``dataframe.join`` method from Polars. diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 1cd9e86c4..935a2217f 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -369,13 +369,9 @@ def transform(self, X, y=None): right, left_on=left_key_name, right_on=right_key_name, - suffixes=( - "", - "", - ), - ) + ) # TODO: dispatch ``drop`` - join = join.drop([left_key_name, right_key_name], axis=1) + join = join.drop([left_key_name], axis=1) if self.add_match_info: for info_key, info_col_name in self._match_info_key_renaming.items(): join = ns.with_columns( diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index b97ac4ccb..a8f5cac44 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -81,7 +81,7 @@ def test_left_join(df_module): right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) joined = _join_utils.left_join( - left, right=right, left_on="left_key", right_on="right_key", suffixes=None + left, right=right, left_on="left_key", right_on="right_key" ) expected = df_module.make_dataframe( @@ -104,7 +104,7 @@ def test_left_join(df_module): # {"right_key": [2, 9, 3, 5, 6], "right_col": ["a", "b", "c", "d", "e"]} # ) # joined = _join_utils.left_join( - # left, right=right, left_on="left_key", right_on="right_key", suffixes=None + # left, right=right, left_on="left_key", right_on="right_key" # ) # expected = df_module.make_dataframe( # { From aab6390f733e77676aaf1c9c2621c08fced4c237 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 11:52:16 +0200 Subject: [PATCH 05/53] Use left_join in AggJoiner & AggTarget --- skrub/_agg_joiner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/skrub/_agg_joiner.py b/skrub/_agg_joiner.py index 159374a46..3b6899ba2 100644 --- a/skrub/_agg_joiner.py +++ b/skrub/_agg_joiner.py @@ -6,6 +6,7 @@ table with the main table. """ from typing import Iterable +from skrub import _join_utils import numpy as np from sklearn.base import BaseEstimator, TransformerMixin @@ -290,8 +291,8 @@ def transform(self, X): X, _ = self._check_dataframes(X, self.aux_table_) _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") - skrub_px, _ = get_df_namespace(self.aux_table_) - X = skrub_px.join( + # skrub_px, _ = get_df_namespace(self.aux_table_) + X = _join_utils.left_join( left=X, right=self.aux_table_, left_on=self._main_key, @@ -442,7 +443,7 @@ def transform(self, X): check_is_fitted(self, "y_") skrub_px, _ = get_df_namespace(X) - return skrub_px.join( + return _join_utils.left_join( left=X, right=self.y_, left_on=self.main_key_, From d75e57230fb39f58ab25b31029e57417fa990e31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 12:29:09 +0200 Subject: [PATCH 06/53] . --- skrub/_agg_joiner.py | 6 ++---- skrub/_dataframe/tests/test_common.py | 1 - skrub/_joiner.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/skrub/_agg_joiner.py b/skrub/_agg_joiner.py index 3b6899ba2..8751d2b3b 100644 --- a/skrub/_agg_joiner.py +++ b/skrub/_agg_joiner.py @@ -6,7 +6,6 @@ table with the main table. """ from typing import Iterable -from skrub import _join_utils import numpy as np from sklearn.base import BaseEstimator, TransformerMixin @@ -293,7 +292,7 @@ def transform(self, X): # skrub_px, _ = get_df_namespace(self.aux_table_) X = _join_utils.left_join( - left=X, + X, right=self.aux_table_, left_on=self._main_key, right_on=self._aux_key, @@ -441,10 +440,9 @@ def transform(self, X): The augmented input. """ check_is_fitted(self, "y_") - skrub_px, _ = get_df_namespace(X) return _join_utils.left_join( - left=X, + X, right=self.y_, left_on=self.main_key_, right_on=self.main_key_, diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index 15c17a5f2..181fcc0fd 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -1,4 +1,3 @@ -skrub/_dataframe/tests/test_common.py """ Note: most tests in this file use the ``df_module`` fixture, which is defined in ``skrub.conftest``. See the corresponding docstrings for details. diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 935a2217f..b7ed89444 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -369,7 +369,7 @@ def transform(self, X, y=None): right, left_on=left_key_name, right_on=right_key_name, - ) + ) # TODO: dispatch ``drop`` join = join.drop([left_key_name], axis=1) if self.add_match_info: From e047b7e251f630b94bed803fe5752ac9ff613883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 12:35:28 +0200 Subject: [PATCH 07/53] Only drop right_on col when it is not equal to left_on --- skrub/_join_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index 486b992d4..c1aaa638f 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -275,12 +275,18 @@ def _left_join_pandas( "'left' and 'right' must be pandas dataframes, " f"got {type(left)!r} and {type(right)!r}." ) - return left.merge( + + joined = left.merge( right, how="left", left_on=left_on, right_on=right_on, - ).drop(columns=right_on) + ) + + if left_on == right_on: + return joined + else: + return joined.drop(columns=right_on) @left_join.specialize("polars", argument_type="DataFrame") From 6c4c17b07c587ef9d4c614b276044b1f5e2c5a56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 14:56:50 +0200 Subject: [PATCH 08/53] Switch to default implem for with_columns --- skrub/_dataframe/_common.py | 7 ------- skrub/_dataframe/tests/test_common.py | 1 + 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 61bec4ef3..0804676be 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -958,14 +958,7 @@ def _replace_polars(column, old, new): return column.replace(old, new) -@dispatch def with_columns(df, **new_columns): - raise NotImplementedError() - - -@with_columns.specialize("pandas", argument_type="DataFrame") -@with_columns.specialize("polars", argument_type="DataFrame") -def _with_columns_dataframe(df, **new_columns): columns = {col_name: col(df, col_name) for col_name in column_names(df)} columns.update({n: make_column_like(df, c, n) for n, c in new_columns.items()}) return make_dataframe_like(df, columns) diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index 181fcc0fd..b6fbdbfd1 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -29,6 +29,7 @@ def test_not_implemented(): "to_column_list", "reset_index", "index", + "with_columns", } for func_name in sorted(set(ns.__all__) - has_default_impl): func = getattr(ns, func_name) From 595aab446adb6667469ec91ac5caebd2e1151b9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 16:33:19 +0200 Subject: [PATCH 09/53] Format --- skrub/_dataframe/_common.py | 4 ++-- skrub/_joiner.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 4d643f297..501cec2ae 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -958,7 +958,7 @@ def _replace_polars(col, old, new): return col.replace(old, new) -def with_columns(df, **new_columns): +def with_columns(df, **new_cols): columns = {col_name: col(df, col_name) for col_name in column_names(df)} - columns.update({n: make_column_like(df, c, n) for n, c in new_columns.items()}) + columns.update({n: make_column_like(df, c, n) for n, c in new_cols.items()}) return make_dataframe_like(df, columns) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index b7ed89444..1af50c797 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -352,7 +352,7 @@ def transform(self, X, y=None): aux_table = ns.reset_index( _join_utils.add_column_name_suffix(self._aux_table, self.suffix) ) - # dispatch copy ? + # Remove ns.col matching_col = ns.col(match_result, "index").copy() matching_col[~ns.col(match_result, "match_accepted")] = -1 token = _utils.random_string() @@ -362,8 +362,6 @@ def transform(self, X, y=None): right = ns.with_columns( aux_table, **{right_key_name: np.arange(ns.shape(aux_table)[0])} ) - # TODO: dispatch in ``_join_utils`` - # TODO: check pd vs pl behavior and how can we use the duplicates -> in PR WIP join = _join_utils.left_join( left, right, From 2caa2273e4dd37f3a5cd540b94335aef42d7b5e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 17:08:01 +0200 Subject: [PATCH 10/53] Iter dispatch Joiner --- skrub/_joiner.py | 67 +++++++++++++++--------------------------------- 1 file changed, 20 insertions(+), 47 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 1af50c797..e3acac369 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -11,18 +11,17 @@ from sklearn.utils.validation import check_is_fitted from skrub import _join_utils, _matching, _utils +from skrub import _selectors as s +from skrub._check_input import CheckInputDataFrame from skrub._dataframe import _common as ns - -# TODO: rm -from skrub._dataframe._namespace import is_pandas, is_polars from skrub._datetime_encoder import DatetimeEncoder +from skrub._to_str import ToStr -from . import _selectors as s from ._wrap_transformer import wrap_transformer -def _as_str(column): - return column.fillna("").astype(str) +def _as_str(col): + return ToStr().fit_transform(col) DEFAULT_STRING_ENCODER = make_pipeline( @@ -53,18 +52,15 @@ def _make_vectorizer(table, string_encoder, rescale): """ # TODO remove use of ColumnTransformer, select_dtypes & pandas-specific code transformers = [ - (clone(string_encoder), c) - for c in table.select_dtypes( - include=["string", "category", "object"] - ).columns # TODO: Use selector (s.expand) + (clone(string_encoder), c) for c in (s.string() | s.categorical()).expand(table) ] - num_columns = table.select_dtypes(include="number").columns - if not num_columns.empty: + num_columns = s.numeric().expand(table) + if num_columns: transformers.append( (StandardScaler() if rescale else "passthrough", num_columns) ) - dt_columns = table.select_dtypes(["datetime", "datetimetz"]).columns - if not dt_columns.empty: + dt_columns = s.any_date().expand(table) + if dt_columns: transformers.append( ( make_pipeline( @@ -252,18 +248,6 @@ def __init__( ) self.add_match_info = add_match_info - # TODO: rm - def _check_dataframe(self, dataframe): - # TODO: add support for polars, ATM we just convert to pandas - if is_polars(dataframe): - return dataframe.to_pandas() - if is_pandas(dataframe): - return dataframe - raise TypeError( - f"{self.__class__.__qualname__} only operates on Pandas or Polars" - " dataframes." - ) - def _check_max_dist(self): if ( self.max_dist is None @@ -298,8 +282,9 @@ def fit(self, X, y=None): Fitted Joiner instance (self). """ del y - X = self._check_dataframe(X) - self._aux_table = self._check_dataframe(self.aux_table) + self._aux_table = CheckInputDataFrame().fit_transform(self.aux_table) + self._main_check_input = CheckInputDataFrame() + X = self._main_check_input.fit_transform(X) self._check_ref_dist() self._check_max_dist() self._main_key, self._aux_key = _join_utils.check_key( @@ -336,25 +321,21 @@ def transform(self, X, y=None): """ del y check_is_fitted(self, "vectorizer_") - # rm - input_is_polars = is_polars(X) - X = self._check_dataframe(X) + X = self._main_check_input.transform(X) _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) - # TODO: dispatch select, set_axis in df._common + # TODO: dispatch set_axis in df._common main = self.vectorizer_.transform( ns.col(X, self._main_key).set_axis(self._aux_key, axis="columns") ) - _match_result = self._matching.match(main, self.max_dist_) - match_result = ns.make_dataframe_like(X, _match_result) aux_table = ns.reset_index( _join_utils.add_column_name_suffix(self._aux_table, self.suffix) ) - # Remove ns.col - matching_col = ns.col(match_result, "index").copy() - matching_col[~ns.col(match_result, "match_accepted")] = -1 + _match_result = self._matching.match(main, self.max_dist_) + matching_col = _match_result["index"].copy() + matching_col[~_match_result["match_accepted"]] = -1 token = _utils.random_string() left_key_name = f"skrub_left_key_{token}" right_key_name = f"skrub_right_key_{token}" @@ -368,16 +349,8 @@ def transform(self, X, y=None): left_on=left_key_name, right_on=right_key_name, ) - # TODO: dispatch ``drop`` - join = join.drop([left_key_name], axis=1) + join = s.select(join, ~s.cols(left_key_name)) if self.add_match_info: for info_key, info_col_name in self._match_info_key_renaming.items(): - join = ns.with_columns( - join, **{info_col_name: ns.col(match_result, info_key)} - ) - # TODO: remove this part - if input_is_polars: - import polars as pl - - join = pl.from_pandas(join) + join = ns.with_columns(join, **{info_col_name: _match_result[info_key]}) return join From edac2d59513895e842e8098a98ba487c16840e27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 17:53:27 +0200 Subject: [PATCH 11/53] Remove old test in pandas and polars --- skrub/_dataframe/tests/test_pandas.py | 4 ---- skrub/_dataframe/tests/test_polars.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/skrub/_dataframe/tests/test_pandas.py b/skrub/_dataframe/tests/test_pandas.py index 62e328792..e64cf48e9 100644 --- a/skrub/_dataframe/tests/test_pandas.py +++ b/skrub/_dataframe/tests/test_pandas.py @@ -71,10 +71,6 @@ def test_value_counts_agg(): def test_incorrect_dataframe_inputs(): - # TODO: deal with this - # with pytest.raises(TypeError, match=r"(?=.*pandas dataframes)(?=.*array)"): - # join(left=main.values, right=main, left_on="movieId", right_on="movieId") - with pytest.raises(TypeError, match=r"(?=.*pandas dataframe)(?=.*array)"): aggregate( table=main.values, diff --git a/skrub/_dataframe/tests/test_polars.py b/skrub/_dataframe/tests/test_polars.py index 310a5ef6b..2a2a2fa16 100644 --- a/skrub/_dataframe/tests/test_polars.py +++ b/skrub/_dataframe/tests/test_polars.py @@ -55,10 +55,6 @@ def test_mode_agg(): def test_incorrect_dataframe_inputs(): - # TODO: deal with this - # with pytest.raises(TypeError, match=r"(?=.*polars dataframes)(?=.*pandas)"): - # join(left=pd.DataFrame(main), right=main, left_on="movieId", right_on="movieId") # noqa: E501 - with pytest.raises(TypeError, match=r"(?=.*polars dataframe)(?=.*pandas)"): aggregate( table=pd.DataFrame(main), From ec13d35dee2c6dd4cfa4288312d466975b08f143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 18:19:38 +0200 Subject: [PATCH 12/53] Test with_columns --- skrub/_dataframe/_common.py | 6 ++--- skrub/_dataframe/tests/test_common.py | 36 +++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 501cec2ae..6b850164b 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -959,6 +959,6 @@ def _replace_polars(col, old, new): def with_columns(df, **new_cols): - columns = {col_name: col(df, col_name) for col_name in column_names(df)} - columns.update({n: make_column_like(df, c, n) for n, c in new_cols.items()}) - return make_dataframe_like(df, columns) + cols = {col_name: col(df, col_name) for col_name in column_names(df)} + cols.update({n: make_column_like(df, c, n) for n, c in new_cols.items()}) + return make_dataframe_like(df, cols) diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index b6fbdbfd1..a86f5ce35 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -593,8 +593,34 @@ def same(c1, c2): def test_with_columns(df_module): - # TODO: test one new col - # TODO: test multiple new cols - # TODO: test array (+ 1 test in ns.make_column_like(s) = s) - # TODO: test replace col - pass + df = df_module.make_dataframe({"a": [1, 2], "b": [3, 4]}) + + # Add one new col + out = ns.with_columns(df, **{"c": [5, 6]}) + if df_module.description == "pandas-nullable-dtypes": + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + df_module.assert_frame_equal(out, expected) + + # Add multiple new cols + out = ns.with_columns(df, **{"c": [5, 6], "d": [7, 8]}) + if df_module.description == "pandas-nullable-dtypes": + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe( + {"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]} + ) + df_module.assert_frame_equal(out, expected) + + # Pass a col instead of an array + out = ns.with_columns(df, **{"c": df_module.make_column("c", [5, 6])}) + if df_module.description == "pandas-nullable-dtypes": + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + df_module.assert_frame_equal(out, expected) + + # Replace col + out = ns.with_columns(df, **{"a": [5, 6]}) + if df_module.description == "pandas-nullable-dtypes": + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe({"a": [5, 6], "b": [3, 4]}) + df_module.assert_frame_equal(out, expected) From c3e290e62bdf193f7c9f53f3291265c5a4596dbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 18:22:11 +0200 Subject: [PATCH 13/53] More left_join tests --- skrub/tests/test_join_utils.py | 37 ++++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index a8f5cac44..3d06d89de 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -1,3 +1,4 @@ +import numpy as np import pandas as pd import pytest @@ -94,24 +95,20 @@ def test_left_join(df_module): assert_frame_equal(joined, expected) # Test all left keys not it right dataframe - # TODO: modify default behavior of `left_join` - # Will break for polars or pandas by default - # Polars keep right_key by default when it doesn't find a match - # left = df_module.make_dataframe( - # {"left_key": [1, 2, 2, 3], "left_col": [10, 20, 30, 40]} - # ) - # right = df_module.make_dataframe( - # {"right_key": [2, 9, 3, 5, 6], "right_col": ["a", "b", "c", "d", "e"]} - # ) - # joined = _join_utils.left_join( - # left, right=right, left_on="left_key", right_on="right_key" - # ) - # expected = df_module.make_dataframe( - # { - # "left_key": [1, 2, 2, 3], - # "left_col": [10, 20, 30, 40], - # "right_key": [None, 2, 2, 3], - # "right_col": [None, "a", "a", "c"], - # } - # ) + left = df_module.make_dataframe( + {"left_key": [1, 2, 2, 3], "left_col": [10, 20, 30, 40]} + ) + right = df_module.make_dataframe( + {"right_key": [2, 9, 3, 5, 6], "right_col": ["a", "b", "c", "d", "e"]} + ) + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key" + ) + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2, 3], + "left_col": [10, 20, 30, 40], + "right_col": [np.nan, "a", "a", "c"], + } + ) assert_frame_equal(joined, expected) From d09359e5925e4e7376deec517cee1bfb0a1ba69e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 18:24:02 +0200 Subject: [PATCH 14/53] Simplify test --- skrub/tests/test_join_utils.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index 3d06d89de..321d443b9 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -77,14 +77,13 @@ def test_add_column_name_suffix(): def test_left_join(df_module): - # Test all left keys in right dataframe left = df_module.make_dataframe({"left_key": [1, 2, 2], "left_col": [10, 20, 30]}) - right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) + # Test all left keys in right dataframe + right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) joined = _join_utils.left_join( left, right=right, left_on="left_key", right_on="right_key" ) - expected = df_module.make_dataframe( { "left_key": [1, 2, 2], @@ -94,21 +93,16 @@ def test_left_join(df_module): ) assert_frame_equal(joined, expected) - # Test all left keys not it right dataframe - left = df_module.make_dataframe( - {"left_key": [1, 2, 2, 3], "left_col": [10, 20, 30, 40]} - ) - right = df_module.make_dataframe( - {"right_key": [2, 9, 3, 5, 6], "right_col": ["a", "b", "c", "d", "e"]} - ) + # Some left keys not it right dataframe + right = df_module.make_dataframe({"right_key": [2, 3], "right_col": ["a", "c"]}) joined = _join_utils.left_join( left, right=right, left_on="left_key", right_on="right_key" ) expected = df_module.make_dataframe( { - "left_key": [1, 2, 2, 3], - "left_col": [10, 20, 30, 40], - "right_col": [np.nan, "a", "a", "c"], + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right_col": [np.nan, "a", "a"], } ) assert_frame_equal(joined, expected) From 2a04e99acb05dbe60efe4ea5de418cafb3b1e6e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 18:33:04 +0200 Subject: [PATCH 15/53] Test make_column_like on col is col --- skrub/_dataframe/tests/test_common.py | 4 ++++ skrub/tests/test_join_utils.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index a86f5ce35..939138365 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -146,6 +146,9 @@ def test_make_column_like(df_module, example_data_dict): ) assert ns.dataframe_module_name(col) == df_module.name + col = df_module.make_column("name", [1, 2, 3]) + df_module.assert_column_equal(col, ns.make_column_like(col, col, "name")) + def test_null_value_for(df_module): assert ns.null_value_for(df_module.example_dataframe) is None @@ -598,6 +601,7 @@ def test_with_columns(df_module): # Add one new col out = ns.with_columns(df, **{"c": [5, 6]}) if df_module.description == "pandas-nullable-dtypes": + # for pandas, make_column_like will return an old-style / numpy dtypes Series out = ns.pandas_convert_dtypes(out) expected = df_module.make_dataframe({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) df_module.assert_frame_equal(out, expected) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index 321d443b9..617c3b8d3 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -79,7 +79,7 @@ def test_add_column_name_suffix(): def test_left_join(df_module): left = df_module.make_dataframe({"left_key": [1, 2, 2], "left_col": [10, 20, 30]}) - # Test all left keys in right dataframe + # All left keys in right dataframe right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) joined = _join_utils.left_join( left, right=right, left_on="left_key", right_on="right_key" @@ -106,3 +106,6 @@ def test_left_join(df_module): } ) assert_frame_equal(joined, expected) + + # TODO: test joining on different types doesn't work + # TODO: check adding suffixes From be7ec265abd30890a6af7791fdea8c280808210b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 18:50:14 +0200 Subject: [PATCH 16/53] TODO --- skrub/_fuzzy_join.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index 86e5749ad..67666e0a6 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -210,7 +210,9 @@ def fuzzy_join( add_match_info=True, ).fit_transform(left) if drop_unmatched: + # TODO: dispatch join = join[join["skrub_Joiner_match_accepted"]] if not add_match_info: + # TODO: use selectors join = join.drop(Joiner.match_info_columns, axis=1) return join From 686de71ebe80a57d4a71ab1a965d278406c3c2e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 13 Jun 2024 18:56:38 +0200 Subject: [PATCH 17/53] Make Joiner work for Polars --- skrub/_joiner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index e3acac369..a294eacb2 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -326,9 +326,8 @@ def transform(self, X, y=None): _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) - # TODO: dispatch set_axis in df._common main = self.vectorizer_.transform( - ns.col(X, self._main_key).set_axis(self._aux_key, axis="columns") + ns.set_column_names(ns.col(X, self._main_key), self._aux_key) ) aux_table = ns.reset_index( _join_utils.add_column_name_suffix(self._aux_table, self.suffix) From 8d6e62422af06c34f32b69bdd6d0f4f8da94ed69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= <57430673+TheooJ@users.noreply.github.com> Date: Sun, 16 Jun 2024 23:45:56 +0200 Subject: [PATCH 18/53] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérôme Dockès --- skrub/_join_utils.py | 129 ++++++++++++------------------------------- 1 file changed, 35 insertions(+), 94 deletions(-) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index c1aaa638f..a4358d42c 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -230,105 +230,46 @@ def _get_new_name(suggested_name, forbidden_names): return f"{untagged_name}__skrub_{token}__" -@dispatch -def left_join( - left, - right, - left_on, - right_on, -): - raise NotImplementedError() - - -@left_join.specialize("pandas", argument_type="DataFrame") -def _left_join_pandas( - left, - right, - left_on, - right_on, -): - """Left join two :obj:`pandas.DataFrame`. - - This function uses the ``dataframe.merge`` method from Pandas. - - Parameters - ---------- - left : pd.DataFrame - The left dataframe to left-join. - - right : pd.DataFrame - The right dataframe to left-join. - - left_on : str or Iterable[str] - Left keys to merge on. - - right_on : str or Iterable[str] - Right keys to merge on. - - Returns - ------- - merged : pd.DataFrame, - The merged output. - """ - if not (isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame)): - raise TypeError( - "'left' and 'right' must be pandas dataframes, " - f"got {type(left)!r} and {type(right)!r}." - ) - - joined = left.merge( - right, - how="left", - left_on=left_on, - right_on=right_on, - ) - - if left_on == right_on: - return joined +def left_join(left, right, left_on, right_on, rename_right_cols="{}"): + # TODO -- replace AssertionErrors with exceptions with appropriate type & + # message + assert sbd.is_dataframe(left) + assert sbd.is_dataframe(right) + assert sbd.dataframe_module_name(left) == sbd.dataframe_module_name(right) + + left_cols = sbd.column_names(left) + original_right_cols = sbd.column_names(right) + right_cols = map(_utils.renaming_func(rename_right_cols), original_right_cols) + right_cols = pick_column_names(right_cols , forbidden_names=left_cols) + renaming = dict(zip(original_right_cols, right_cols)) + right = sbd.set_column_names(right, right_cols) + if isinstance(right_on, str): + right_on = renaming[right_on] + right_on_selector = s.cols(right_on) else: - return joined.drop(columns=right_on) - + right_on = tuple(renaming[c] for c in right_on) + right_on_selector = s.cols(*right_on) + joined = _do_left_join(left, right, left_on, right_on) + joined = s.select(joined, ~right_on_selector) + return joined -@left_join.specialize("polars", argument_type="DataFrame") -def _left_join_polars(left, right, left_on, right_on): - """Left join two :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`. - This function uses the ``dataframe.join`` method from Polars. - - Note that the input dataframes type must agree: either both - Polars dataframes or both Polars lazyframes. - - Mixing polars dataframe with lazyframe will raise an error. - - Parameters - ---------- - left : pl.DataFrame or pl.LazyFrame - The left dataframe of the left-join. +@dispatch +def _do_left_join(left, right, left_on, right_on): + raise NotImplementedError() - right : pl.DataFrame or pl.LazyFrame - The right dataframe of the left-join. - left_on : str or Iterable[str] - Left keys to merge on. +@_do_left_join.specialize("pandas", argument_type="DataFrame") +def _do_left_join_pandas(left, right, left_on, right_on): + return left.merge(right, left_on=left_on, right_on=right_on, how="left", suffixes=("", "")) - right_on : str or Iterable[str] - Right keys to merge on. - Returns - ------- - merged : pl.DataFrame or pl.LazyFrame - The merged output. - """ - is_dataframe = isinstance(left, pl.DataFrame) and isinstance(right, pl.DataFrame) - is_lazyframe = isinstance(left, pl.LazyFrame) and isinstance(right, pl.LazyFrame) - if is_dataframe or is_lazyframe: - if "coalesce" in inspect.signature(left.join).parameters: - kw = {"coalesce": True} - else: - kw = {} - return left.join(right, how="left", left_on=left_on, right_on=right_on, **kw) +@_do_left_join.specialize("polars", argument_type="DataFrame") +def _do_left_join_polars(left, right, left_on, right_on): + if "coalesce" in inspect.signature(left.join).parameters: + kw = {"coalesce": True} else: - raise TypeError( - "'left' and 'right' must be polars dataframes or lazyframes, " - f"got {type(left)!r} and {type(right)!r}." - ) + kw = {} + return left.join( + right, left_on=left_on, right_on=right_on, how="left", suffix="", **kw + ) From ccc67bafadac4a902642545aac24d3654e97e612 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Mon, 17 Jun 2024 00:58:12 +0200 Subject: [PATCH 19/53] Test & comment left_join --- skrub/_join_utils.py | 73 ++++++++++++++++++++++++++-------- skrub/tests/test_join_utils.py | 62 ++++++++++++++++++++++++++--- 2 files changed, 114 insertions(+), 21 deletions(-) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index a4358d42c..adda9b0ec 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -3,17 +3,11 @@ import inspect import re -import pandas as pd - -try: - import polars as pl -except ImportError: - pass - +from skrub import _dataframe as sbd +from skrub import _selectors as s from skrub import _utils from skrub._dataframe._namespace import get_df_namespace - -from ._dispatch import dispatch +from skrub._dispatch import dispatch def check_key( @@ -231,16 +225,61 @@ def _get_new_name(suggested_name, forbidden_names): def left_join(left, right, left_on, right_on, rename_right_cols="{}"): - # TODO -- replace AssertionErrors with exceptions with appropriate type & - # message - assert sbd.is_dataframe(left) - assert sbd.is_dataframe(right) - assert sbd.dataframe_module_name(left) == sbd.dataframe_module_name(right) + """Left join two dataframes of the same type. + + The input dataframes type must agree: either both `left` and `right` need to be + :obj:`pandas.DataFrame`, :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`. + Mixing types will raise an error. + + `rename_right_cols` can be used to format the right dataframe columns, e.g. use + "right_.{}" to rename all right cols with a leading "right_". + + If duplicate column names are found between renamed right cols and left cols, + a __skrub___ is added at the end of columns that would otherwise + be duplicates. + + Parameters + ---------- + left : pd.DataFrame or pl.DataFrame or pl.LazyFrame + The left dataframe of the left-join. + right : pd.DataFrame or pl.DataFrame or pl.LazyFrame + The right dataframe of the left-join. + left_on : str or Iterable[str] + Left keys to merge on. + right_on : str or Iterable[str] + Right keys to merge on. + rename_right_cols : str, optional + Formatting used to rename right cols, by default "{}". + + Returns + ------- + pd.DataFrame or pl.DataFrame or pl.LazyFrame + The joined output. + + Raises + ------ + TypeError + If either of `left` and `right` is not a dataframe, or if both types + are not equal. + """ + if not sbd.is_dataframe(left): + raise TypeError( + f"`left` must be a pandas or polars dataframe, got {type(left)}." + ) + if not sbd.is_dataframe(right): + raise TypeError( + f"`right` must be a pandas or polars dataframe, got {type(right)}." + ) + if not sbd.dataframe_module_name(left) == sbd.dataframe_module_name(right): + raise TypeError( + "`left` and `right` must be of the same dataframe type, got" + f"{type(left)} and {type(right)}." + ) left_cols = sbd.column_names(left) original_right_cols = sbd.column_names(right) right_cols = map(_utils.renaming_func(rename_right_cols), original_right_cols) - right_cols = pick_column_names(right_cols , forbidden_names=left_cols) + right_cols = pick_column_names(right_cols, forbidden_names=left_cols) renaming = dict(zip(original_right_cols, right_cols)) right = sbd.set_column_names(right, right_cols) if isinstance(right_on, str): @@ -261,7 +300,9 @@ def _do_left_join(left, right, left_on, right_on): @_do_left_join.specialize("pandas", argument_type="DataFrame") def _do_left_join_pandas(left, right, left_on, right_on): - return left.merge(right, left_on=left_on, right_on=right_on, how="left", suffixes=("", "")) + return left.merge( + right, left_on=left_on, right_on=right_on, how="left", suffixes=("", "") + ) @_do_left_join.specialize("polars", argument_type="DataFrame") diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index 617c3b8d3..f42eeefa8 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -3,7 +3,6 @@ import pytest from skrub import _join_utils -from skrub._dataframe._testing_utils import assert_frame_equal @pytest.mark.parametrize( @@ -91,7 +90,7 @@ def test_left_join(df_module): "right_col": ["a", "b", "b"], } ) - assert_frame_equal(joined, expected) + df_module.assert_frame_equal(joined, expected) # Some left keys not it right dataframe right = df_module.make_dataframe({"right_key": [2, 3], "right_col": ["a", "c"]}) @@ -105,7 +104,60 @@ def test_left_join(df_module): "right_col": [np.nan, "a", "a"], } ) - assert_frame_equal(joined, expected) + df_module.assert_frame_equal(joined, expected) - # TODO: test joining on different types doesn't work - # TODO: check adding suffixes + # Renaming right col + right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) + joined = _join_utils.left_join( + left, + right=right, + left_on="left_key", + right_on="right_key", + rename_right_cols="right.{}", + ) + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right.right_col": ["a", "b", "b"], + } + ) + df_module.assert_frame_equal(joined, expected) + + # Left not a df raises TypeError + with pytest.raises( + TypeError, + match=( + "`left` must be a pandas or polars dataframe, got ." + ), + ): + joined = _join_utils.left_join( + np.array([1, 2]), right=right, left_on="left_key", right_on="right_key" + ) + + # Right not a df raises TypeError + with pytest.raises( + TypeError, + match=( + "`right` must be a pandas or polars dataframe, got ." + ), + ): + joined = _join_utils.left_join( + left, right=np.array([1, 2]), left_on="left_key", right_on="right_key" + ) + + # Joining on different types raises TypeError + try: + import polars as pl + except ImportError: + pytest.skip(reason="Polars not available.") + + other_px = pd if df_module.module is pl else pl + right = other_px.DataFrame(left) + + with pytest.raises( + TypeError, match=r"`left` and `right` must be of the same dataframe type" + ): + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key" + ) From d5031d954463c5d7a0d4f1e6ded3484465cf4ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= <57430673+TheooJ@users.noreply.github.com> Date: Tue, 18 Jun 2024 01:04:35 +0200 Subject: [PATCH 20/53] Apply suggestion from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérôme Dockès --- skrub/_join_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index adda9b0ec..6da51ca5a 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -232,7 +232,7 @@ def left_join(left, right, left_on, right_on, rename_right_cols="{}"): Mixing types will raise an error. `rename_right_cols` can be used to format the right dataframe columns, e.g. use - "right_.{}" to rename all right cols with a leading "right_". + "right_.{}" to rename all right cols with a leading "right_.". If duplicate column names are found between renamed right cols and left cols, a __skrub___ is added at the end of columns that would otherwise From d0796ec488e6218d9d652e699fbab4a8777a0cc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 01:16:33 +0200 Subject: [PATCH 21/53] Address more review comments --- skrub/_join_utils.py | 19 +++++++++---------- skrub/_joiner.py | 10 +++++----- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index 6da51ca5a..107973007 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -227,9 +227,8 @@ def _get_new_name(suggested_name, forbidden_names): def left_join(left, right, left_on, right_on, rename_right_cols="{}"): """Left join two dataframes of the same type. - The input dataframes type must agree: either both `left` and `right` need to be - :obj:`pandas.DataFrame`, :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`. - Mixing types will raise an error. + The input dataframes type must agree: both `left` and `right` need to be + pandas or polars dataframes. Mixing types will raise an error. `rename_right_cols` can be used to format the right dataframe columns, e.g. use "right_.{}" to rename all right cols with a leading "right_.". @@ -240,20 +239,20 @@ def left_join(left, right, left_on, right_on, rename_right_cols="{}"): Parameters ---------- - left : pd.DataFrame or pl.DataFrame or pl.LazyFrame + left : dataframe The left dataframe of the left-join. - right : pd.DataFrame or pl.DataFrame or pl.LazyFrame + right : dataframe The right dataframe of the left-join. - left_on : str or Iterable[str] + left_on : str or iterable of str Left keys to merge on. - right_on : str or Iterable[str] + right_on : str or iterable of str Right keys to merge on. - rename_right_cols : str, optional - Formatting used to rename right cols, by default "{}". + rename_right_cols : str or callable, default="{}" + Formatting used to rename right cols. Returns ------- - pd.DataFrame or pl.DataFrame or pl.LazyFrame + dataframe The joined output. Raises diff --git a/skrub/_joiner.py b/skrub/_joiner.py index a294eacb2..bb77c988b 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -121,7 +121,7 @@ class Joiner(TransformerMixin, BaseEstimator): Parameters ---------- - aux_table : :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` + aux_table : dataframe The auxiliary table, which will be fuzzy-joined to the main table when calling `transform`. key : str or iterable of str, default=None @@ -261,7 +261,7 @@ def _check_max_dist(self): def _check_ref_dist(self): if self.ref_dist not in _MATCHERS: raise ValueError( - f"`ref_dist` should be one of {list(_MATCHERS.keys())}, got" + f"'ref_dist' should be one of {list(_MATCHERS.keys())}, got" f" {self.ref_dist!r}" ) self._matching = _MATCHERS[self.ref_dist]() @@ -271,7 +271,7 @@ def fit(self, X, y=None): Parameters ---------- - X : :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` + X : dataframe The main table, to be joined to the auxiliary ones. y : None Unused, only here for compatibility. @@ -309,14 +309,14 @@ def transform(self, X, y=None): Parameters ---------- - X : :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` + X : dataframe The main table, to be joined to the auxiliary ones. y : None Unused, only here for compatibility. Returns ------- - :obj:`~pandas.DataFrame` or :obj:`~polars.DataFrame` + dataframe The final joined table. """ del y From e21fb9a15b2f6d9a15c11af3bb3d2df466a6055b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 01:21:40 +0200 Subject: [PATCH 22/53] Check that make_column_like name is the requested name even if the input column has a different name --- skrub/_dataframe/tests/test_common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index d1fbff4ca..301ae185e 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -147,8 +147,9 @@ def test_make_column_like(df_module, example_data_dict): ) assert ns.dataframe_module_name(col) == df_module.name - col = df_module.make_column("name", [1, 2, 3]) - df_module.assert_column_equal(col, ns.make_column_like(col, col, "name")) + col = df_module.make_column("old_name", [1, 2, 3]) + expected = df_module.make_column("new_name", [1, 2, 3]) + df_module.assert_column_equal(ns.make_column_like(col, col, "new_name"), expected) def test_null_value_for(df_module): From b9688c823d9157bc58262a9cba304e383fd9e149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 15:25:26 +0200 Subject: [PATCH 23/53] CHANGES.rst --- CHANGES.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 9d83fa758..b60421485 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,8 @@ It is currently undergoing fast development and backward compatibility is not en Major changes ------------- +* The :class:`Joiner` has been adapted to support polars dataframes. :pr:`945` by :user:`Théo Jolivet `. + * The :class:`TableVectorizer` now consistently applies the same transformation across different calls to `transform`. There also have been some breaking changes to its functionality: (i) all transformations are now applied From 41e385c63b54c3948ce7113c8ee51b9b23e903e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 15:27:08 +0200 Subject: [PATCH 24/53] Change iterable of str to list of str --- skrub/_fuzzy_join.py | 1 + skrub/_join_utils.py | 4 ++-- skrub/_joiner.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index 67666e0a6..7a878e1a4 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -20,6 +20,7 @@ def fuzzy_join( add_match_info=False, drop_unmatched=False, ): + # TODO: change docstring """Fuzzy (approximate) join. Rows in the left table are joined to their closest match from the right diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index 107973007..4ba25e6fe 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -243,9 +243,9 @@ def left_join(left, right, left_on, right_on, rename_right_cols="{}"): The left dataframe of the left-join. right : dataframe The right dataframe of the left-join. - left_on : str or iterable of str + left_on : str or list of str Left keys to merge on. - right_on : str or iterable of str + right_on : str or list of str Right keys to merge on. rename_right_cols : str or callable, default="{}" Formatting used to rename right cols. diff --git a/skrub/_joiner.py b/skrub/_joiner.py index bb77c988b..d3d25c91a 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -124,14 +124,14 @@ class Joiner(TransformerMixin, BaseEstimator): aux_table : dataframe The auxiliary table, which will be fuzzy-joined to the main table when calling `transform`. - key : str or iterable of str, default=None + key : str or list of str, default=None The column names to use for both `main_key` and `aux_key` when they are the same. Provide either `key` or both `main_key` and `aux_key`. - main_key : str or iterable of str, default=None + main_key : str or list of str, default=None The column names in the main table on which the join will be performed. Can be a string if joining on a single column. If `None`, `aux_key` must also be `None` and `key` must be provided. - aux_key : str or iterable of str, default=None + aux_key : str or list of str, default=None The column names in the auxiliary table on which the join will be performed. Can be a string if joining on a single column. If `None`, `main_key` must also be `None` and `key` must be provided. From 283ddb0d32e5dd5079bede1c140da5fc927b8548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 16:59:40 +0200 Subject: [PATCH 25/53] Use df_module.assert_frame_equal --- skrub/tests/test_joiner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index d28715486..5c48205ff 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -5,7 +5,6 @@ from skrub import Joiner from skrub._dataframe import _common as ns -from skrub._dataframe._testing_utils import assert_frame_equal @pytest.fixture @@ -83,13 +82,13 @@ def test_multiple_keys(df_module): result = joiner_list.fit_transform(df) expected = ns.concat_horizontal(df, df2) - assert_frame_equal(result, expected) + df_module.assert_frame_equal(result, expected) joiner_list = Joiner( aux_table=df2, aux_key="CA", main_key="Ca", add_match_info=False ) result = joiner_list.fit_transform(df) - assert_frame_equal(result, expected) + df_module.assert_frame_equal(result, expected) def test_pandas_aux_table_index(): From d59c916296b0288b72d351947c46bffb314d902c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 17:17:49 +0200 Subject: [PATCH 26/53] Docstring --- skrub/_join_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index 4ba25e6fe..1333699e0 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -248,7 +248,8 @@ def left_join(left, right, left_on, right_on, rename_right_cols="{}"): right_on : str or list of str Right keys to merge on. rename_right_cols : str or callable, default="{}" - Formatting used to rename right cols. + Formatting used to rename right cols. If it is a callable, it should + accept strings as an argument. By default, no formatting is applied. Returns ------- From 6b7e379b37a25c302f898e6f36ae36644f9c2c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 17:35:21 +0200 Subject: [PATCH 27/53] Iter test_joiner --- skrub/tests/test_joiner.py | 47 ++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index 5c48205ff..a0e8267b5 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -1,5 +1,4 @@ import numpy as np -import pandas as pd import pytest from numpy.testing import assert_array_equal @@ -48,22 +47,17 @@ def test_fit_transform(main_table, aux_table): def test_wrong_main_key(main_table, aux_table): - false_joiner = Joiner(aux_table=aux_table, main_key="Countryy", aux_key="country") - - with pytest.raises( - ValueError, - match="do not exist in 'X'", - ): - false_joiner.fit(main_table) + wrong_joiner = Joiner(aux_table=aux_table, main_key="wrong_key", aux_key="country") + with pytest.raises(ValueError, match="do not exist in 'X'"): + wrong_joiner.fit(main_table) def test_wrong_aux_key(main_table, aux_table): - false_joiner2 = Joiner(aux_table=aux_table, main_key="Country", aux_key="bad") - with pytest.raises( - ValueError, - match="do not exist in 'aux_table'", - ): - false_joiner2.fit(main_table) + wrong_joiner_2 = Joiner( + aux_table=aux_table, main_key="Country", aux_key="wrong_key" + ) + with pytest.raises(ValueError, match="do not exist in 'aux_table'"): + wrong_joiner_2.fit(main_table) def test_multiple_keys(df_module): @@ -75,8 +69,8 @@ def test_multiple_keys(df_module): ) joiner_list = Joiner( aux_table=df2, - aux_key=["CO", "CA"], main_key=["Co", "Ca"], + aux_key=["CO", "CA"], add_match_info=False, ) result = joiner_list.fit_transform(df) @@ -91,9 +85,9 @@ def test_multiple_keys(df_module): df_module.assert_frame_equal(result, expected) -def test_pandas_aux_table_index(): - main_table = pd.DataFrame({"Country": ["France", "Italia", "Georgia"]}) - aux_table = pd.DataFrame( +def test_pandas_aux_table_index(df_module): + main_table = df_module.make_dataframe({"Country": ["France", "Italia", "Georgia"]}) + aux_table = df_module.make_dataframe( { "Country": ["Germany", "France", "Italy"], "Capital": ["Berlin", "Paris", "Rome"], @@ -114,15 +108,18 @@ def test_pandas_aux_table_index(): ] -def test_bad_ref_dist(): - table = pd.DataFrame({"A": [1, 2]}) - joiner = Joiner(table, key="A", ref_dist="bad") - with pytest.raises(ValueError, match="got 'bad'"): +def test_wrong_ref_dist(df_module): + table = df_module.make_dataframe({"A": [1, 2]}) + joiner = Joiner(table, key="A", ref_dist="wrong_ref_dist") + with pytest.raises( + ValueError, match=r"('ref_dist' should be one of)*(got 'wrong_ref_dist')" + ): joiner.fit(table) @pytest.mark.parametrize("max_dist", [np.inf, float("inf"), "inf", None]) -def test_max_dist(max_dist): - table = pd.DataFrame({"A": [1, 2]}) - joiner = Joiner(table, key="A", max_dist=max_dist, suffix="_").fit(table) +def test_max_dist(main_table, aux_table, max_dist): + joiner = Joiner( + aux_table, main_key="Country", aux_key="country", max_dist=max_dist + ).fit(main_table) assert joiner.max_dist_ == np.inf From e467994d607155408fedea131c85f3f5394e65b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 19:56:55 +0200 Subject: [PATCH 28/53] Add useful error msg for max_dist --- skrub/_joiner.py | 13 +++++++++---- skrub/tests/test_joiner.py | 27 +++++++++++++++++++++------ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index d3d25c91a..138a5c76a 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -1,4 +1,4 @@ -""" +skrub/_joiner.py""" The Joiner provides fuzzy joining as a scikit-learn transformer. """ @@ -138,7 +138,7 @@ class Joiner(TransformerMixin, BaseEstimator): suffix : str, default="" Suffix to append to the `aux_table`'s column names. You can use it to avoid duplicate column names in the join. - max_dist : float, default=np.inf + max_dist : int, float, `None` or `np.inf`, default=`np.inf` Maximum acceptable (rescaled) distance between a row in the `main_table` and its nearest neighbor in the `aux_table`. Rows that are farther apart are not considered to match. By default, the distance @@ -255,13 +255,18 @@ def _check_max_dist(self): and self.max_dist == "inf" ): self.max_dist_ = np.inf - else: + elif isinstance(self.max_dist, int) or isinstance(self.max_dist, float): self.max_dist_ = self.max_dist + else: + raise ValueError( + "'max_dist' should be an int, a float, `None` or `np.inf`. Got" + f" {self.max_dist!r}" + ) def _check_ref_dist(self): if self.ref_dist not in _MATCHERS: raise ValueError( - f"'ref_dist' should be one of {list(_MATCHERS.keys())}, got" + f"'ref_dist' should be one of {list(_MATCHERS.keys())}. Got" f" {self.ref_dist!r}" ) self._matching = _MATCHERS[self.ref_dist]() diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index a0e8267b5..3b651f461 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -79,7 +79,7 @@ def test_multiple_keys(df_module): df_module.assert_frame_equal(result, expected) joiner_list = Joiner( - aux_table=df2, aux_key="CA", main_key="Ca", add_match_info=False + aux_table=df2, main_key="Ca", aux_key="CA", add_match_info=False ) result = joiner_list.fit_transform(df) df_module.assert_frame_equal(result, expected) @@ -108,13 +108,14 @@ def test_pandas_aux_table_index(df_module): ] -def test_wrong_ref_dist(df_module): - table = df_module.make_dataframe({"A": [1, 2]}) - joiner = Joiner(table, key="A", ref_dist="wrong_ref_dist") +def test_wrong_ref_dist(main_table, aux_table): + joiner = Joiner( + aux_table, main_key="Country", aux_key="country", ref_dist="wrong_ref_dist" + ) with pytest.raises( - ValueError, match=r"('ref_dist' should be one of)*(got 'wrong_ref_dist')" + ValueError, match=r"('ref_dist' should be one of)*(Got 'wrong_ref_dist')" ): - joiner.fit(table) + joiner.fit(main_table) @pytest.mark.parametrize("max_dist", [np.inf, float("inf"), "inf", None]) @@ -123,3 +124,17 @@ def test_max_dist(main_table, aux_table, max_dist): aux_table, main_key="Country", aux_key="country", max_dist=max_dist ).fit(main_table) assert joiner.max_dist_ == np.inf + + +def test_wrong_max_dist(main_table, aux_table): + joiner = Joiner( + aux_table, main_key="Country", aux_key="country", max_dist="wrong_max_dist" + ) + with pytest.raises( + ValueError, + match=( + "'max_dist' should be an int, a float, `None` or `np.inf`. Got" + " 'wrong_max_dist'" + ), + ): + joiner.fit(main_table) From 4be4bdc531870ba41ec99074e9c7f69827c18ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 19:56:55 +0200 Subject: [PATCH 29/53] Add useful error msg for max_dist --- skrub/_joiner.py | 11 ++++++++--- skrub/tests/test_joiner.py | 27 +++++++++++++++++++++------ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index d3d25c91a..923c447f5 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -138,7 +138,7 @@ class Joiner(TransformerMixin, BaseEstimator): suffix : str, default="" Suffix to append to the `aux_table`'s column names. You can use it to avoid duplicate column names in the join. - max_dist : float, default=np.inf + max_dist : int, float, `None` or `np.inf`, default=`np.inf` Maximum acceptable (rescaled) distance between a row in the `main_table` and its nearest neighbor in the `aux_table`. Rows that are farther apart are not considered to match. By default, the distance @@ -255,13 +255,18 @@ def _check_max_dist(self): and self.max_dist == "inf" ): self.max_dist_ = np.inf - else: + elif isinstance(self.max_dist, int) or isinstance(self.max_dist, float): self.max_dist_ = self.max_dist + else: + raise ValueError( + "'max_dist' should be an int, a float, `None` or `np.inf`. Got" + f" {self.max_dist!r}" + ) def _check_ref_dist(self): if self.ref_dist not in _MATCHERS: raise ValueError( - f"'ref_dist' should be one of {list(_MATCHERS.keys())}, got" + f"'ref_dist' should be one of {list(_MATCHERS.keys())}. Got" f" {self.ref_dist!r}" ) self._matching = _MATCHERS[self.ref_dist]() diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index a0e8267b5..3b651f461 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -79,7 +79,7 @@ def test_multiple_keys(df_module): df_module.assert_frame_equal(result, expected) joiner_list = Joiner( - aux_table=df2, aux_key="CA", main_key="Ca", add_match_info=False + aux_table=df2, main_key="Ca", aux_key="CA", add_match_info=False ) result = joiner_list.fit_transform(df) df_module.assert_frame_equal(result, expected) @@ -108,13 +108,14 @@ def test_pandas_aux_table_index(df_module): ] -def test_wrong_ref_dist(df_module): - table = df_module.make_dataframe({"A": [1, 2]}) - joiner = Joiner(table, key="A", ref_dist="wrong_ref_dist") +def test_wrong_ref_dist(main_table, aux_table): + joiner = Joiner( + aux_table, main_key="Country", aux_key="country", ref_dist="wrong_ref_dist" + ) with pytest.raises( - ValueError, match=r"('ref_dist' should be one of)*(got 'wrong_ref_dist')" + ValueError, match=r"('ref_dist' should be one of)*(Got 'wrong_ref_dist')" ): - joiner.fit(table) + joiner.fit(main_table) @pytest.mark.parametrize("max_dist", [np.inf, float("inf"), "inf", None]) @@ -123,3 +124,17 @@ def test_max_dist(main_table, aux_table, max_dist): aux_table, main_key="Country", aux_key="country", max_dist=max_dist ).fit(main_table) assert joiner.max_dist_ == np.inf + + +def test_wrong_max_dist(main_table, aux_table): + joiner = Joiner( + aux_table, main_key="Country", aux_key="country", max_dist="wrong_max_dist" + ) + with pytest.raises( + ValueError, + match=( + "'max_dist' should be an int, a float, `None` or `np.inf`. Got" + " 'wrong_max_dist'" + ), + ): + joiner.fit(main_table) From 1ade2b02c5c19cd3c6d9242314fd3ac3d043037b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 20:02:44 +0200 Subject: [PATCH 30/53] . --- skrub/_joiner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 138a5c76a..923c447f5 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -1,4 +1,4 @@ -skrub/_joiner.py""" +""" The Joiner provides fuzzy joining as a scikit-learn transformer. """ From 14a3d973f8e9dfbd55f38e80384a1b4577fbbee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 20:07:34 +0200 Subject: [PATCH 31/53] Handle case where main_key and aux_key are lists --- skrub/_joiner.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 923c447f5..0b4b88500 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -301,11 +301,13 @@ def fit(self, X, y=None): X, self._aux_table, self.suffix, main_table_name="X" ) self.vectorizer_ = _make_vectorizer( - ns.col(self._aux_table, self._aux_key), + s.select(self._aux_table, s.cols(*self._aux_key)), self.string_encoder, rescale=self.ref_dist != "no_rescaling", ) - aux = self.vectorizer_.fit_transform(ns.col(self._aux_table, self._aux_key)) + aux = self.vectorizer_.fit_transform( + s.select(self._aux_table, s.cols(*self._aux_key)) + ) self._matching.fit(aux) return self @@ -332,7 +334,7 @@ def transform(self, X, y=None): X, self._aux_table, self.suffix, main_table_name="X" ) main = self.vectorizer_.transform( - ns.set_column_names(ns.col(X, self._main_key), self._aux_key) + ns.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) ) aux_table = ns.reset_index( _join_utils.add_column_name_suffix(self._aux_table, self.suffix) From 5ab616a2206129bd5a498b777a299798eeb8024d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 20:36:25 +0200 Subject: [PATCH 32/53] Test missing values, numeric, datetimes --- skrub/_fuzzy_join.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index 7a878e1a4..c1c8db0fd 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -215,5 +215,6 @@ def fuzzy_join( join = join[join["skrub_Joiner_match_accepted"]] if not add_match_info: # TODO: use selectors + # join = s.select(join, ~s.cols(Joiner.match_info_columns)) join = join.drop(Joiner.match_info_columns, axis=1) return join From 92f78b03bc1869594defdfcc1d628eb38b08afbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 18 Jun 2024 20:38:57 +0200 Subject: [PATCH 33/53] Test missing values, numeric, datetimes --- skrub/tests/test_joiner.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index 3b651f461..d472ce685 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -1,3 +1,5 @@ +import datetime + import numpy as np import pytest from numpy.testing import assert_array_equal @@ -138,3 +140,30 @@ def test_wrong_max_dist(main_table, aux_table): ), ): joiner.fit(main_table) + + +def test_missing_values(df_module): + df = df_module.make_dataframe({"A": [None, "hollywood", "beverly"]}) + joiner = Joiner(df, key="A", suffix="_") + with pytest.xfail(): + joiner.fit_transform(df) + + +def test_fit_transform_numeric(df_module): + df = df_module.make_dataframe({"A": [4.5, 0.5, 1, -1.5]}) + joiner = Joiner(df, key="A", suffix="_") + joiner.fit_transform(df) + + +def test_fit_transform_datetimes(df_module): + values = [ + datetime.datetime.fromisoformat(dt) + for dt in [ + "2020-02-03T12:30:05", + "2021-03-15T00:37:15", + "2022-02-13T17:03:25", + ] + ] + df = df_module.make_dataframe({"A": values}) + joiner = Joiner(df, key="A", suffix="_") + joiner.fit_transform(df) From 101d7886cbb0f8ed777b95df4cbda5a4d738f688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= <57430673+TheooJ@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:07:26 +0200 Subject: [PATCH 34/53] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérôme Dockès --- skrub/_agg_joiner.py | 1 - skrub/_joiner.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/skrub/_agg_joiner.py b/skrub/_agg_joiner.py index 356907693..88ad42330 100644 --- a/skrub/_agg_joiner.py +++ b/skrub/_agg_joiner.py @@ -289,7 +289,6 @@ def transform(self, X): X, _ = self._check_dataframes(X, self.aux_table_) _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") - # skrub_px, _ = get_df_namespace(self.aux_table_) X = _join_utils.left_join( X, right=self.aux_table_, diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 0b4b88500..0542b57ac 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -13,7 +13,7 @@ from skrub import _join_utils, _matching, _utils from skrub import _selectors as s from skrub._check_input import CheckInputDataFrame -from skrub._dataframe import _common as ns +from . import _dataframe as ns from skrub._datetime_encoder import DatetimeEncoder from skrub._to_str import ToStr @@ -25,7 +25,7 @@ def _as_str(col): DEFAULT_STRING_ENCODER = make_pipeline( - FunctionTransformer(_as_str), + ToStr(), HashingVectorizer(analyzer="char_wb", ngram_range=(2, 4)), # TODO: Remove sparse output from Tfidf to work with TableVectorizer TfidfTransformer(), From 8acbd92b962e9a48c2590b751079288913be0e6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 11:37:02 +0200 Subject: [PATCH 35/53] Apply more suggestions from code review --- skrub/_joiner.py | 41 +++++++++++++++----------------------- skrub/tests/test_joiner.py | 14 ------------- 2 files changed, 16 insertions(+), 39 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 0542b57ac..8e572a3e7 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -7,23 +7,17 @@ from sklearn.compose import make_column_transformer from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import FunctionTransformer, StandardScaler +from sklearn.preprocessing import StandardScaler from sklearn.utils.validation import check_is_fitted -from skrub import _join_utils, _matching, _utils -from skrub import _selectors as s -from skrub._check_input import CheckInputDataFrame -from . import _dataframe as ns -from skrub._datetime_encoder import DatetimeEncoder -from skrub._to_str import ToStr - +from . import _dataframe as sbd +from . import _join_utils, _matching, _utils +from . import _selectors as s +from ._check_input import CheckInputDataFrame +from ._datetime_encoder import DatetimeEncoder +from ._to_str import ToStr from ._wrap_transformer import wrap_transformer - -def _as_str(col): - return ToStr().fit_transform(col) - - DEFAULT_STRING_ENCODER = make_pipeline( ToStr(), HashingVectorizer(analyzer="char_wb", ngram_range=(2, 4)), @@ -255,13 +249,8 @@ def _check_max_dist(self): and self.max_dist == "inf" ): self.max_dist_ = np.inf - elif isinstance(self.max_dist, int) or isinstance(self.max_dist, float): - self.max_dist_ = self.max_dist else: - raise ValueError( - "'max_dist' should be an int, a float, `None` or `np.inf`. Got" - f" {self.max_dist!r}" - ) + self.max_dist_ = self.max_dist def _check_ref_dist(self): if self.ref_dist not in _MATCHERS: @@ -334,9 +323,9 @@ def transform(self, X, y=None): X, self._aux_table, self.suffix, main_table_name="X" ) main = self.vectorizer_.transform( - ns.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) + sbd.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) ) - aux_table = ns.reset_index( + aux_table = sbd.reset_index( _join_utils.add_column_name_suffix(self._aux_table, self.suffix) ) _match_result = self._matching.match(main, self.max_dist_) @@ -345,9 +334,9 @@ def transform(self, X, y=None): token = _utils.random_string() left_key_name = f"skrub_left_key_{token}" right_key_name = f"skrub_right_key_{token}" - left = ns.with_columns(X, **{left_key_name: matching_col}) - right = ns.with_columns( - aux_table, **{right_key_name: np.arange(ns.shape(aux_table)[0])} + left = sbd.with_columns(X, **{left_key_name: matching_col}) + right = sbd.with_columns( + aux_table, **{right_key_name: np.arange(sbd.shape(aux_table)[0])} ) join = _join_utils.left_join( left, @@ -358,5 +347,7 @@ def transform(self, X, y=None): join = s.select(join, ~s.cols(left_key_name)) if self.add_match_info: for info_key, info_col_name in self._match_info_key_renaming.items(): - join = ns.with_columns(join, **{info_col_name: _match_result[info_key]}) + join = sbd.with_columns( + join, **{info_col_name: _match_result[info_key]} + ) return join diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index d472ce685..7b14c583d 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -128,20 +128,6 @@ def test_max_dist(main_table, aux_table, max_dist): assert joiner.max_dist_ == np.inf -def test_wrong_max_dist(main_table, aux_table): - joiner = Joiner( - aux_table, main_key="Country", aux_key="country", max_dist="wrong_max_dist" - ) - with pytest.raises( - ValueError, - match=( - "'max_dist' should be an int, a float, `None` or `np.inf`. Got" - " 'wrong_max_dist'" - ), - ): - joiner.fit(main_table) - - def test_missing_values(df_module): df = df_module.make_dataframe({"A": [None, "hollywood", "beverly"]}) joiner = Joiner(df, key="A", suffix="_") From 4365d9ca725b69e9489d6b93255f3068aeb70cb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 11:41:32 +0200 Subject: [PATCH 36/53] More suggestions from code review --- skrub/_joiner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 8e572a3e7..6cdf3d2da 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -346,8 +346,8 @@ def transform(self, X, y=None): ) join = s.select(join, ~s.cols(left_key_name)) if self.add_match_info: + match_info_dict = {} for info_key, info_col_name in self._match_info_key_renaming.items(): - join = sbd.with_columns( - join, **{info_col_name: _match_result[info_key]} - ) + match_info_dict[info_col_name] = _match_result[info_key] + join = sbd.with_columns(join, **match_info_dict) return join From aa7015db8093fd4055ce19f5f957315848c20314 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 11:47:51 +0200 Subject: [PATCH 37/53] More suggestions from code review --- skrub/_joiner.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 6cdf3d2da..2a9f74666 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -2,12 +2,14 @@ The Joiner provides fuzzy joining as a scikit-learn transformer. """ +from functools import partial + import numpy as np from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.compose import make_column_transformer from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import FunctionTransformer, StandardScaler from sklearn.utils.validation import check_is_fitted from . import _dataframe as sbd @@ -19,6 +21,7 @@ from ._wrap_transformer import wrap_transformer DEFAULT_STRING_ENCODER = make_pipeline( + FunctionTransformer(partial(sbd.fill_nulls, value="")), ToStr(), HashingVectorizer(analyzer="char_wb", ngram_range=(2, 4)), # TODO: Remove sparse output from Tfidf to work with TableVectorizer @@ -328,9 +331,9 @@ def transform(self, X, y=None): aux_table = sbd.reset_index( _join_utils.add_column_name_suffix(self._aux_table, self.suffix) ) - _match_result = self._matching.match(main, self.max_dist_) - matching_col = _match_result["index"].copy() - matching_col[~_match_result["match_accepted"]] = -1 + match_result = self._matching.match(main, self.max_dist_) + matching_col = match_result["index"].copy() + matching_col[~match_result["match_accepted"]] = -1 token = _utils.random_string() left_key_name = f"skrub_left_key_{token}" right_key_name = f"skrub_right_key_{token}" @@ -348,6 +351,6 @@ def transform(self, X, y=None): if self.add_match_info: match_info_dict = {} for info_key, info_col_name in self._match_info_key_renaming.items(): - match_info_dict[info_col_name] = _match_result[info_key] + match_info_dict[info_col_name] = match_result[info_key] join = sbd.with_columns(join, **match_info_dict) return join From da74a0c03c4078fcdf9cb997230707527a9390c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 12:31:48 +0200 Subject: [PATCH 38/53] More suggestions from code review --- skrub/_joiner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 2a9f74666..0c325a099 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -292,6 +292,7 @@ def fit(self, X, y=None): _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) + self._suffix = f"{{}}{self.suffix}".format self.vectorizer_ = _make_vectorizer( s.select(self._aux_table, s.cols(*self._aux_key)), self.string_encoder, @@ -328,9 +329,6 @@ def transform(self, X, y=None): main = self.vectorizer_.transform( sbd.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) ) - aux_table = sbd.reset_index( - _join_utils.add_column_name_suffix(self._aux_table, self.suffix) - ) match_result = self._matching.match(main, self.max_dist_) matching_col = match_result["index"].copy() matching_col[~match_result["match_accepted"]] = -1 @@ -339,13 +337,15 @@ def transform(self, X, y=None): right_key_name = f"skrub_right_key_{token}" left = sbd.with_columns(X, **{left_key_name: matching_col}) right = sbd.with_columns( - aux_table, **{right_key_name: np.arange(sbd.shape(aux_table)[0])} + self._aux_table, + **{right_key_name: np.arange(sbd.shape(self._aux_table)[0])}, ) join = _join_utils.left_join( left, right, left_on=left_key_name, right_on=right_key_name, + rename_right_cols=self._suffix, ) join = s.select(join, ~s.cols(left_key_name)) if self.add_match_info: From c3930881c3e641e8053362a3b4be3c4f5c23f4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 12:42:25 +0200 Subject: [PATCH 39/53] Add reset_index to _pandas.aggregate --- skrub/_dataframe/_pandas.py | 2 +- skrub/tests/test_agg_joiner.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/skrub/_dataframe/_pandas.py b/skrub/_dataframe/_pandas.py index 25f244491..db21a8c28 100644 --- a/skrub/_dataframe/_pandas.py +++ b/skrub/_dataframe/_pandas.py @@ -104,7 +104,7 @@ def aggregate( ] sorted_cols = sorted(base_group.columns) - return base_group[sorted_cols] + return base_group[sorted_cols].reset_index(drop=False) def get_named_agg(table, cols, operations): diff --git a/skrub/tests/test_agg_joiner.py b/skrub/tests/test_agg_joiner.py index 402a7009b..53340ae38 100644 --- a/skrub/tests/test_agg_joiner.py +++ b/skrub/tests/test_agg_joiner.py @@ -401,6 +401,7 @@ def test_agg_target(main_table, y, col_name): "movieId": [1, 3, 6, 318, 6, 1704], "rating": [4.0, 4.0, 4.0, 3.0, 2.0, 4.0], "genre": ["drama", "drama", "comedy", "sf", "comedy", "sf"], + "index": [0, 0, 0, 1, 1, 1], f"{col_name}_(1.999, 3.0]_user": [0, 0, 0, 2, 2, 2], f"{col_name}_(3.0, 4.0]_user": [3, 3, 3, 1, 1, 1], f"{col_name}_2.0_user": [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], From a7f1ab7c1d963eb4c253e0235ec415e3e671386f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 14:55:28 +0200 Subject: [PATCH 40/53] Dispatch fuzzy_join --- skrub/_fuzzy_join.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index c1c8db0fd..687405ba3 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -3,8 +3,10 @@ """ import numpy as np -from skrub import _join_utils -from skrub._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner +from . import _join_utils +from . import _dataframe as sbd +from . import _selectors as s +from ._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner def fuzzy_join( @@ -20,7 +22,6 @@ def fuzzy_join( add_match_info=False, drop_unmatched=False, ): - # TODO: change docstring """Fuzzy (approximate) join. Rows in the left table are joined to their closest match from the right @@ -212,9 +213,8 @@ def fuzzy_join( ).fit_transform(left) if drop_unmatched: # TODO: dispatch - join = join[join["skrub_Joiner_match_accepted"]] + # join = join[join["skrub_Joiner_match_accepted"]] + join = sbd.where(join, if not add_match_info: - # TODO: use selectors - # join = s.select(join, ~s.cols(Joiner.match_info_columns)) - join = join.drop(Joiner.match_info_columns, axis=1) + join = s.select(join, ~s.cols(*Joiner.match_info_columns)) return join From af684c34b54dbc971afd583285accfc63f2f1dfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 14:56:53 +0200 Subject: [PATCH 41/53] Format fuzzy_join --- skrub/_fuzzy_join.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index 687405ba3..f63a07b24 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -4,7 +4,6 @@ import numpy as np from . import _join_utils -from . import _dataframe as sbd from . import _selectors as s from ._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner @@ -213,8 +212,7 @@ def fuzzy_join( ).fit_transform(left) if drop_unmatched: # TODO: dispatch - # join = join[join["skrub_Joiner_match_accepted"]] - join = sbd.where(join, + join = join[join["skrub_Joiner_match_accepted"]] if not add_match_info: join = s.select(join, ~s.cols(*Joiner.match_info_columns)) return join From 5f2fd606fba920d93a0e8f0d3902dcca717def18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 16:20:43 +0200 Subject: [PATCH 42/53] Dispatch filter in fuzzy_join --- skrub/_fuzzy_join.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index f63a07b24..9b3ece456 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -3,6 +3,7 @@ """ import numpy as np +from . import _dataframe as sbd from . import _join_utils from . import _selectors as s from ._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner @@ -211,8 +212,7 @@ def fuzzy_join( add_match_info=True, ).fit_transform(left) if drop_unmatched: - # TODO: dispatch - join = join[join["skrub_Joiner_match_accepted"]] + join = sbd.filter(join, sbd.col(join, "skrub_Joiner_match_accepted")) if not add_match_info: join = s.select(join, ~s.cols(*Joiner.match_info_columns)) return join From 5accc202ab293b6aa3ec4af8b4816754986ff2f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Wed, 19 Jun 2024 16:41:29 +0200 Subject: [PATCH 43/53] Fix pandas aggregate testing --- skrub/_dataframe/tests/test_pandas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skrub/_dataframe/tests/test_pandas.py b/skrub/_dataframe/tests/test_pandas.py index e64cf48e9..25ed23223 100644 --- a/skrub/_dataframe/tests/test_pandas.py +++ b/skrub/_dataframe/tests/test_pandas.py @@ -29,7 +29,7 @@ def test_simple_agg(): "genre_mode": ("genre", pd.Series.mode), "rating_mean": ("rating", "mean"), } - expected = main.groupby("movieId").agg(**aggfunc) + expected = main.groupby("movieId").agg(**aggfunc).reset_index() assert_frame_equal(aggregated, expected) @@ -49,7 +49,7 @@ def test_value_counts_agg(): "rating_4.0_user": [3.0, 1.0], "userId": [1, 2], } - ) + ).reset_index(drop=False) assert_frame_equal(aggregated, expected) aggregated = aggregate( @@ -66,7 +66,7 @@ def test_value_counts_agg(): "rating_(3.0, 4.0]_user": [3, 1], "userId": [1, 2], } - ) + ).reset_index(drop=False) assert_frame_equal(aggregated, expected) From 1f4d49d5239a424401a8829a469d2a39f402fb4e Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Wed, 19 Jun 2024 21:48:21 +0200 Subject: [PATCH 44/53] fix joiner for old sklearn --- skrub/_joiner.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 0c325a099..795f71c2a 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -5,11 +5,13 @@ from functools import partial import numpy as np +import sklearn from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.compose import make_column_transformer from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer from sklearn.pipeline import make_pipeline from sklearn.preprocessing import FunctionTransformer, StandardScaler +from sklearn.utils.fixes import parse_version from sklearn.utils.validation import check_is_fitted from . import _dataframe as sbd @@ -39,6 +41,12 @@ DEFAULT_REF_DIST = "random_pairs" +def _compat_df(df): + if parse_version(sklearn.__version__) < parse_version("1.4"): + return sbd.to_pandas(df) + return df + + def _make_vectorizer(table, string_encoder, rescale): """Construct the transformer used to vectorize joining columns. @@ -299,7 +307,7 @@ def fit(self, X, y=None): rescale=self.ref_dist != "no_rescaling", ) aux = self.vectorizer_.fit_transform( - s.select(self._aux_table, s.cols(*self._aux_key)) + _compat_df(s.select(self._aux_table, s.cols(*self._aux_key))) ) self._matching.fit(aux) return self @@ -327,7 +335,11 @@ def transform(self, X, y=None): X, self._aux_table, self.suffix, main_table_name="X" ) main = self.vectorizer_.transform( - sbd.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) + _compat_df( + sbd.set_column_names( + s.select(X, s.cols(*self._main_key)), self._aux_key + ) + ) ) match_result = self._matching.match(main, self.max_dist_) matching_col = match_result["index"].copy() From 86efa017faf495df22776a1a3230887af51f4e31 Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Thu, 20 Jun 2024 07:23:21 +0200 Subject: [PATCH 45/53] specify right key dtype --- skrub/_joiner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 795f71c2a..3616752b7 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -350,7 +350,7 @@ def transform(self, X, y=None): left = sbd.with_columns(X, **{left_key_name: matching_col}) right = sbd.with_columns( self._aux_table, - **{right_key_name: np.arange(sbd.shape(self._aux_table)[0])}, + **{right_key_name: np.arange(sbd.shape(self._aux_table)[0], dtype="int64")}, ) join = _join_utils.left_join( left, From 206fe5436c6f84eb64a00dcbad466652913283f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 20 Jun 2024 17:22:04 +0200 Subject: [PATCH 46/53] Dispatch fuzzy_join tests --- skrub/tests/test_fuzzy_join.py | 133 +++++++++++---------------------- 1 file changed, 43 insertions(+), 90 deletions(-) diff --git a/skrub/tests/test_fuzzy_join.py b/skrub/tests/test_fuzzy_join.py index 528516ae6..70114e4a3 100644 --- a/skrub/tests/test_fuzzy_join.py +++ b/skrub/tests/test_fuzzy_join.py @@ -1,14 +1,13 @@ import warnings import numpy as np -import pandas as pd import pytest from numpy.testing import assert_array_equal from sklearn.feature_extraction.text import HashingVectorizer -from skrub import fuzzy_join +from skrub import ToDatetime, _join_utils, fuzzy_join +from skrub import _selectors as s from skrub._dataframe import _common as ns -from skrub._dataframe._testing_utils import assert_frame_equal @pytest.mark.parametrize( @@ -19,8 +18,6 @@ def test_fuzzy_join(df_module, analyzer): """ Testing if ``fuzzy_join`` results are as expected. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") df1 = df_module.make_dataframe({"a1": ["ana", "lala", "nana et sana", np.nan]}) df2 = df_module.make_dataframe( {"a2": ["anna", "lala et nana", "lana", "sana", np.nan]} @@ -50,16 +47,13 @@ def test_fuzzy_join(df_module, analyzer): # Joining is always done on the left table and thus takes it shape: assert ns.shape(df_joined2) == (len(df2), n_cols) - # TODO: dispatch ``with_columns`` - df1["a2"] = 1 + df1 = ns.with_columns(df1, **{"a2": [1] * ns.shape(df1)[0]}) df_on = fuzzy_join(df_joined, df1, on="a1", suffix="2") assert "a12" in ns.column_names(df_on) def test_max_dist(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") left = df_module.make_dataframe({"A": ["aa", "bb"]}) right = df_module.make_dataframe({"A": ["aa", "ba"], "B": [1, 2]}) join = fuzzy_join(left, right, on="A", suffix="r") @@ -69,8 +63,6 @@ def test_max_dist(df_module): def test_perfect_matches(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") # non-regression test for https://github.com/skrub-data/skrub/issues/764 # fuzzy_join when all rows had a perfect match used to trigger a division by 0 df = df_module.make_dataframe({"A": [0, 1]}) @@ -87,19 +79,15 @@ def test_fuzzy_join_dtypes(df_module): """ Test that the dtypes of dataframes are maintained after join. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaa", "bbb"], "col2": [1, 2]}) b = df_module.make_dataframe({"col1": ["aaa_", "bbb_"], "col3": [1, 2]}) c = fuzzy_join(a, b, on="col1", suffix="r") - assert ns.dtype(ns.col(a, "col2")).kind == "i" + ns.is_integer(ns.col(a, "col2")) assert ns.dtype(ns.col(c, "col2")) == ns.dtype(ns.col(a, "col2")) assert ns.dtype(ns.col(c, "col3r")) == ns.dtype(ns.col(b, "col3")) def test_missing_keys(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaa", "bbb"], "col2": [1, 2]}) b = df_module.make_dataframe({"col1": ["aaa_", "bbb_"], "col3": [1, 2]}) with pytest.raises( @@ -116,8 +104,6 @@ def test_missing_keys(df_module): def test_drop_unmatched(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaaa", "bbb", "ddd dd"], "col2": [1, 2, 3]}) b = df_module.make_dataframe( {"col1": ["aaa_", "bbb_", "cc ccc"], "col3": [1, 2, 3]} @@ -137,12 +123,12 @@ def test_drop_unmatched(df_module): assert sum(ns.is_null(ns.col(c2, "col3r"))) > 0 -def test_fuzzy_join_pandas_comparison(): +def test_fuzzy_join_pandas_comparison(df_module): """ Tests if fuzzy_join's output is as similar as possible with `pandas.merge`. """ - left = pd.DataFrame( + left = df_module.make_dataframe( { "key": ["K0", "K1", "K2", "K3"], "A": ["A0", "A1", "A2", "A3"], @@ -150,7 +136,7 @@ def test_fuzzy_join_pandas_comparison(): } ) - right = pd.DataFrame( + right = df_module.make_dataframe( { "key_": ["K0", "K1", "K2", "K3"], "C": ["C0", "C1", "C2", "C3"], @@ -158,20 +144,20 @@ def test_fuzzy_join_pandas_comparison(): } ) - result = pd.merge(left, right, left_on="key", right_on="key_") + result = _join_utils.left_join(left, right, left_on="key", right_on="key_") result_fj = fuzzy_join( left, right, left_on="key", right_on="key_", add_match_info=False ) + # `fuzzy_join`` keeps the vectorized col, so we must drop it + result_fj = s.select(result_fj, ~s.cols("key_")) - assert_frame_equal(result, result_fj) + df_module.assert_frame_equal(result, result_fj) def test_correct_encoder(df_module): """ Test that the encoder error checking is working as intended. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") class TestVectorizer(HashingVectorizer): """ @@ -213,8 +199,6 @@ def test_numerical_column(df_module): """ Testing that ``fuzzy_join`` works with numerical columns. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") left = df_module.make_dataframe({"str1": ["aa", "a", "bb"], "int": [10, 2, 5]}) right = df_module.make_dataframe( { @@ -249,20 +233,20 @@ def test_datetime_column(df_module): """ Testing that ``fuzzy_join`` works with datetime columns. """ - if df_module.name == "polars": - pytest.xfail(reason="Module 'polars' has no attribute 'to_datetime'.") + + def to_dt(lst): + return ToDatetime().fit_transform(df_module.make_column("", lst)) + left = df_module.make_dataframe( { "str1": ["aa", "a", "bb"], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), } ) right = df_module.make_dataframe( { "str2": ["aa", "bb", "a", "cc", "dd"], - "date": df_module.module.to_datetime( + "date": to_dt( ["09/10/2022", "12/24/2021", "09/25/2010", "11/05/2025", "02/21/2000"] ), } @@ -273,16 +257,12 @@ def test_datetime_column(df_module): fj_time_expected = df_module.make_dataframe( { "str1": ["aa", "a", "bb"], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), "str2r": ["aa", "bb", "a"], - "dater": df_module.module.to_datetime( - ["09/10/2022", "12/24/2021", "09/25/2010"] - ), + "dater": to_dt(["09/10/2022", "12/24/2021", "09/25/2010"]), } ) - assert_frame_equal(fj_time, fj_time_expected) + df_module.assert_frame_equal(fj_time, fj_time_expected) n_cols = ns.shape(left)[1] + ns.shape(right)[1] n_samples = len(left) @@ -308,17 +288,17 @@ def test_mixed_joins(df_module): """ Test fuzzy joining on mixed and multiple column types. """ - if df_module.name == "polars": - pytest.xfail(reason="Module 'polars' has no attribute 'to_datetime'") + + def to_dt(lst): + return ToDatetime().fit_transform(df_module.make_column("", lst)) + left = df_module.make_dataframe( { "str1": ["Paris", "Paris", "Paris"], "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), } ) right = df_module.make_dataframe( @@ -327,7 +307,7 @@ def test_mixed_joins(df_module): "str_2": ["TX", "FR", "GR Mytho", "cc", "dd"], "int1": [55, 6, 2, 15, 6], "int2": [554, 146, 32, 215, 612], - "date": df_module.module.to_datetime( + "date": to_dt( ["09/10/2022", "12/24/2021", "09/25/2010", "11/05/2025", "02/21/2000"] ), } @@ -344,19 +324,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), "str_1r": ["Paris", "Paris", "dd"], "str_2r": ["FR", "FR", "dd"], "int1r": [6, 6, 6], "int2r": [146, 146, 612], - "dater": df_module.module.to_datetime( - ["12/24/2021", "12/24/2021", "02/21/2000"] - ), + "dater": to_dt(["12/24/2021", "12/24/2021", "02/21/2000"]), } ) - assert_frame_equal(fj_num, expected_fj_num) + df_module.assert_frame_equal(fj_num, expected_fj_num) assert ns.shape(fj_num) == (3, 10) # On multiple string keys @@ -375,19 +351,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["TX", "FR", "GR Mytho"], "int1r": [55, 6, 2], "int2r": [554, 146, 32], - "dater": df_module.module.to_datetime( - ["2022-09-10", "2021-12-24", "2010-09-25"] - ), + "dater": to_dt(["2022-09-10", "2021-12-24", "2010-09-25"]), } ) - assert_frame_equal(fj_str, expected_fj_str) + df_module.assert_frame_equal(fj_str, expected_fj_str) assert ns.shape(fj_str) == (3, 10) # On mixed, numeric and string keys @@ -405,19 +377,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["FR", "FR", "TX"], "int1r": [6, 6, 55], "int2r": [146, 146, 554], - "dater": df_module.module.to_datetime( - ["2021-12-24", "2021-12-24", "2022-09-10"] - ), + "dater": to_dt(["2021-12-24", "2021-12-24", "2022-09-10"]), } ) - assert_frame_equal(fj_mixed, expected_fj_mixed) + df_module.assert_frame_equal(fj_mixed, expected_fj_mixed) assert ns.shape(fj_mixed) == (3, 10) # On mixed time and string keys @@ -435,19 +403,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["TX", "FR", "GR Mytho"], "int1r": [55, 6, 2], "int2r": [554, 146, 32], - "dater": df_module.module.to_datetime( - ["2022-09-10", "2021-12-24", "2010-09-25"] - ), + "dater": to_dt(["2022-09-10", "2021-12-24", "2010-09-25"]), } ) - assert_frame_equal(fj_mixed2, expected_fj_mixed2) + df_module.assert_frame_equal(fj_mixed2, expected_fj_mixed2) assert ns.shape(fj_mixed2) == (3, 10) # On mixed time and numbers keys @@ -465,19 +429,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["FR", "FR", "GR Mytho"], "int1r": [6, 6, 2], "int2r": [146, 146, 32], - "dater": df_module.module.to_datetime( - ["2021-12-24", "2021-12-24", "2010-09-25"] - ), + "dater": to_dt(["2021-12-24", "2021-12-24", "2010-09-25"]), } ) - assert_frame_equal(fj_mixed3, expected_fj_mixed3) + df_module.assert_frame_equal(fj_mixed3, expected_fj_mixed3) assert ns.shape(fj_mixed3) == (3, 10) @@ -485,8 +445,6 @@ def test_iterable_input(df_module): """ Test if iterable inputs (list, set, dictionary or tuple) work. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") df1 = df_module.make_dataframe( {"a": ["ana", "lala", "nana"], "str2": ["Texas", "France", "Greek God"]} ) @@ -529,20 +487,15 @@ def test_iterable_input(df_module): ) == (3, 4) -@pytest.mark.xfail def test_missing_values(df_module): """ Test fuzzy joining on missing values. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaaa", "bbb", "ddd dd"], "col2": [1, 2, 3]}) b = df_module.make_dataframe({"col3": [np.nan, "bbb", "ddd dd"], "col4": [1, 2, 3]}) - with pytest.warns(UserWarning, match=r"merging on missing values"): - c = fuzzy_join(a, b, left_on="col1", right_on="col3", add_match_info=False) + c = fuzzy_join(a, b, left_on="col1", right_on="col3", add_match_info=False) assert ns.shape(c)[0] == len(b) - with pytest.warns(UserWarning, match=r"merging on missing values"): - c = fuzzy_join(b, a, left_on="col3", right_on="col1", add_match_info=True) + c = fuzzy_join(b, a, left_on="col3", right_on="col1", add_match_info=True) assert ns.shape(c)[0] == len(b) From 3e4a597a433409a14eed874010d1d4ea4816d540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 20 Jun 2024 17:33:38 +0200 Subject: [PATCH 47/53] Next steps --- skrub/_joiner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 3616752b7..5cf06b569 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -26,7 +26,6 @@ FunctionTransformer(partial(sbd.fill_nulls, value="")), ToStr(), HashingVectorizer(analyzer="char_wb", ngram_range=(2, 4)), - # TODO: Remove sparse output from Tfidf to work with TableVectorizer TfidfTransformer(), ) _DATETIME_ENCODER = DatetimeEncoder(resolution=None, add_total_seconds=True) @@ -55,7 +54,8 @@ def _make_vectorizer(table, string_encoder, rescale): In addition if `rescale` is `True`, a StandardScaler is applied to numeric and datetime columns. """ - # TODO remove use of ColumnTransformer, select_dtypes & pandas-specific code + # TODO: add Skrubber before ColumnTransformer + # TODO: remove use of ColumnTransformer transformers = [ (clone(string_encoder), c) for c in (s.string() | s.categorical()).expand(table) ] From d139fd8879ca6cd1b1abf85db29e1f6e97c06760 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Thu, 20 Jun 2024 17:42:24 +0200 Subject: [PATCH 48/53] Format tests --- skrub/_joiner.py | 9 ++------- skrub/tests/test_joiner.py | 7 ++++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 5cf06b569..2b6eb9ea8 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -334,13 +334,8 @@ def transform(self, X, y=None): _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) - main = self.vectorizer_.transform( - _compat_df( - sbd.set_column_names( - s.select(X, s.cols(*self._main_key)), self._aux_key - ) - ) - ) + main = sbd.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) + main = self.vectorizer_.transform(_compat_df(main)) match_result = self._matching.match(main, self.max_dist_) matching_col = match_result["index"].copy() matching_col[~match_result["match_accepted"]] = -1 diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index 7b14c583d..368ffe252 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -131,14 +131,15 @@ def test_max_dist(main_table, aux_table, max_dist): def test_missing_values(df_module): df = df_module.make_dataframe({"A": [None, "hollywood", "beverly"]}) joiner = Joiner(df, key="A", suffix="_") - with pytest.xfail(): - joiner.fit_transform(df) + out = joiner.fit_transform(df) + assert ns.shape(out) == (3, 5) def test_fit_transform_numeric(df_module): df = df_module.make_dataframe({"A": [4.5, 0.5, 1, -1.5]}) joiner = Joiner(df, key="A", suffix="_") - joiner.fit_transform(df) + out = joiner.fit_transform(df) + assert ns.shape(out) == (4, 5) def test_fit_transform_datetimes(df_module): From 3f4d4bbc1c79ccc7d9a084dec42c1a9b9f06e96e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= <57430673+TheooJ@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:21:33 +0200 Subject: [PATCH 49/53] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérôme Dockès --- skrub/_joiner.py | 5 +++++ skrub/tests/test_fuzzy_join.py | 2 +- skrub/tests/test_joiner.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 2b6eb9ea8..4bd230983 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -41,6 +41,11 @@ def _compat_df(df): + # In scikit-learn versions older than 1.4, the ColumnTransformer fails on + # polars dataframes. Here it is only applied as an internal step on the + # joining columns, and we get the output as a numpy array or sparse matrix. + # Therefore on old scikit-learn versions we convert the joining columns to + # pandas before vectorizing them. if parse_version(sklearn.__version__) < parse_version("1.4"): return sbd.to_pandas(df) return df diff --git a/skrub/tests/test_fuzzy_join.py b/skrub/tests/test_fuzzy_join.py index 70114e4a3..9486a1cb5 100644 --- a/skrub/tests/test_fuzzy_join.py +++ b/skrub/tests/test_fuzzy_join.py @@ -82,7 +82,7 @@ def test_fuzzy_join_dtypes(df_module): a = df_module.make_dataframe({"col1": ["aaa", "bbb"], "col2": [1, 2]}) b = df_module.make_dataframe({"col1": ["aaa_", "bbb_"], "col3": [1, 2]}) c = fuzzy_join(a, b, on="col1", suffix="r") - ns.is_integer(ns.col(a, "col2")) + assert ns.is_integer(ns.col(a, "col2")) assert ns.dtype(ns.col(c, "col2")) == ns.dtype(ns.col(a, "col2")) assert ns.dtype(ns.col(c, "col3r")) == ns.dtype(ns.col(b, "col3")) diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index 368ffe252..7dd6bf86e 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -115,7 +115,7 @@ def test_wrong_ref_dist(main_table, aux_table): aux_table, main_key="Country", aux_key="country", ref_dist="wrong_ref_dist" ) with pytest.raises( - ValueError, match=r"('ref_dist' should be one of)*(Got 'wrong_ref_dist')" + ValueError, match=r"'ref_dist' should be one of.* Got 'wrong_ref_dist'" ): joiner.fit(main_table) From 79f11927b360e89ef3248c37097f390d0592879b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Fri, 21 Jun 2024 16:40:50 +0200 Subject: [PATCH 50/53] More suggestions from code review --- skrub/tests/test_fuzzy_join.py | 18 +++++++++++++----- skrub/tests/test_join_utils.py | 31 ++++++++++++++++++++----------- skrub/tests/test_joiner.py | 9 +++++---- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/skrub/tests/test_fuzzy_join.py b/skrub/tests/test_fuzzy_join.py index 9486a1cb5..37f5673e6 100644 --- a/skrub/tests/test_fuzzy_join.py +++ b/skrub/tests/test_fuzzy_join.py @@ -123,14 +123,14 @@ def test_drop_unmatched(df_module): assert sum(ns.is_null(ns.col(c2, "col3r"))) > 0 -def test_fuzzy_join_pandas_comparison(df_module): +def test_fuzzy_join_exact_matches(df_module): """ - Tests if fuzzy_join's output is as similar as - possible with `pandas.merge`. + Tests if fuzzy_join's output is the same as a normal left-join when there + are exact matches for all rows. """ left = df_module.make_dataframe( { - "key": ["K0", "K1", "K2", "K3"], + "key": ["K2", "K2", "K3", "K1"], "A": ["A0", "A1", "A2", "A3"], "B": ["B0", "B1", "B2", "B3"], } @@ -148,7 +148,15 @@ def test_fuzzy_join_pandas_comparison(df_module): result_fj = fuzzy_join( left, right, left_on="key", right_on="key_", add_match_info=False ) - # `fuzzy_join`` keeps the vectorized col, so we must drop it + df_module.assert_column_equal( + ns.col(result_fj, "key_"), ns.rename(ns.col(result_fj, "key"), "key_") + ) + # `_left_join` does a (non-fuzzy, regular) equijoin so it only keeps one of the + # join columns (keeping both would be redundant as they are identical due to exact + # matching) -- same as the default behavior of polars (coalesce=True) and pandas. + # `fuzzy_join` keeps both columns because they are not identical, only up to + # fuzziness, so keeping both is informative. So here we drop `key_` to compare the + # 2 resulting dataframes. result_fj = s.select(result_fj, ~s.cols("key_")) df_module.assert_frame_equal(result, result_fj) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index f42eeefa8..18788265e 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -75,11 +75,14 @@ def test_add_column_name_suffix(): assert list(df.columns) == ["one_y", "two three_y", "x_y"] -def test_left_join(df_module): - left = df_module.make_dataframe({"left_key": [1, 2, 2], "left_col": [10, 20, 30]}) +@pytest.fixture +def left(df_module): + return df_module.make_dataframe({"left_key": [1, 2, 2], "left_col": [10, 20, 30]}) + +def test_left_join_all_keys_in_right_dataframe(df_module, left): # All left keys in right dataframe - right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) + right = df_module.make_dataframe({"right_key": [2, 1], "right_col": ["b", "a"]}) joined = _join_utils.left_join( left, right=right, left_on="left_key", right_on="right_key" ) @@ -92,7 +95,8 @@ def test_left_join(df_module): ) df_module.assert_frame_equal(joined, expected) - # Some left keys not it right dataframe + +def test_left_join_some_keys_not_in_right_dataframe(df_module, left): right = df_module.make_dataframe({"right_key": [2, 3], "right_col": ["a", "c"]}) joined = _join_utils.left_join( left, right=right, left_on="left_key", right_on="right_key" @@ -106,7 +110,8 @@ def test_left_join(df_module): ) df_module.assert_frame_equal(joined, expected) - # Renaming right col + +def test_left_join_renaming_right_cols(df_module, left): right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) joined = _join_utils.left_join( left, @@ -124,29 +129,33 @@ def test_left_join(df_module): ) df_module.assert_frame_equal(joined, expected) - # Left not a df raises TypeError + +def test_left_join_wrong_left_type(df_module): + right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) with pytest.raises( TypeError, match=( "`left` must be a pandas or polars dataframe, got ." ), ): - joined = _join_utils.left_join( + _join_utils.left_join( np.array([1, 2]), right=right, left_on="left_key", right_on="right_key" ) - # Right not a df raises TypeError + +def test_left_join_wrong_right_type(df_module, left): with pytest.raises( TypeError, match=( "`right` must be a pandas or polars dataframe, got ." ), ): - joined = _join_utils.left_join( + _join_utils.left_join( left, right=np.array([1, 2]), left_on="left_key", right_on="right_key" ) - # Joining on different types raises TypeError + +def test_left_join_types_not_equal(df_module, left): try: import polars as pl except ImportError: @@ -158,6 +167,6 @@ def test_left_join(df_module): with pytest.raises( TypeError, match=r"`left` and `right` must be of the same dataframe type" ): - joined = _join_utils.left_join( + _join_utils.left_join( left, right=right, left_on="left_key", right_on="right_key" ) diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index 7dd6bf86e..c7fdc778d 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -1,6 +1,7 @@ import datetime import numpy as np +import pandas as pd import pytest from numpy.testing import assert_array_equal @@ -87,9 +88,9 @@ def test_multiple_keys(df_module): df_module.assert_frame_equal(result, expected) -def test_pandas_aux_table_index(df_module): - main_table = df_module.make_dataframe({"Country": ["France", "Italia", "Georgia"]}) - aux_table = df_module.make_dataframe( +def test_pandas_aux_table_index(): + main_table = pd.DataFrame({"Country": ["France", "Italia", "Georgia"]}) + aux_table = pd.DataFrame( { "Country": ["Germany", "France", "Italy"], "Capital": ["Berlin", "Paris", "Rome"], @@ -103,7 +104,7 @@ def test_pandas_aux_table_index(df_module): suffix="_capitals", ) join = joiner.fit_transform(main_table) - assert ns.to_list(ns.col(join, "Country_capitals")) == [ + assert join["Country_capitals"].to_list() == [ "France", "Italy", "Germany", From 36d81ff3fe11e7cb191786d4c5c0e30f7306c5fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Fri, 21 Jun 2024 17:09:39 +0200 Subject: [PATCH 51/53] Dispatch test_join_utils --- skrub/tests/test_join_utils.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index 18788265e..a44f12de8 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -45,30 +45,27 @@ def test_check_key_length_mismatch(): ) -def test_check_column_name_duplicates(): - left = pd.DataFrame(columns=["A", "B"]) - right = pd.DataFrame(columns=["C"]) +def test_check_no_column_name_duplicates_with_no_suffix(df_module): + left = df_module.make_dataframe({"A": [], "B": []}) + right = df_module.make_dataframe({"C": []}) _join_utils.check_column_name_duplicates(left, right, "") - left = pd.DataFrame(columns=["A", "B"]) - right = pd.DataFrame(columns=["B"]) + +def test_check_no_column_name_duplicates_after_adding_a_suffix(df_module): + left = df_module.make_dataframe({"A": [], "B": []}) + right = df_module.make_dataframe({"B": []}) _join_utils.check_column_name_duplicates(left, right, "_right") - left = pd.DataFrame(columns=["A", "B_right"]) - right = pd.DataFrame(columns=["B"]) + +def test_check_column_name_duplicates_after_adding_a_suffix(df_module): + left = df_module.make_dataframe({"A": [], "B_right": []}) + right = df_module.make_dataframe({"B": []}) with pytest.raises(ValueError, match=".*suffix '_right'.*['B_right']"): _join_utils.check_column_name_duplicates(left, right, "_right") - left = pd.DataFrame(columns=["A", "A"]) - right = pd.DataFrame(columns=["B"]) - with pytest.raises(ValueError, match="Table 'left' has duplicate"): - _join_utils.check_column_name_duplicates( - left, right, "", main_table_name="left" - ) - -def test_add_column_name_suffix(): - df = pd.DataFrame(columns=["one", "two three", "x"]) +def test_add_column_name_suffix(df_module): + df = df_module.make_dataframe({"one": [], "two three": [], "x": []}) df = _join_utils.add_column_name_suffix(df, "") assert list(df.columns) == ["one", "two three", "x"] df = _join_utils.add_column_name_suffix(df, "_y") @@ -81,7 +78,6 @@ def left(df_module): def test_left_join_all_keys_in_right_dataframe(df_module, left): - # All left keys in right dataframe right = df_module.make_dataframe({"right_key": [2, 1], "right_col": ["b", "a"]}) joined = _join_utils.left_join( left, right=right, left_on="left_key", right_on="right_key" From f2717c2b133f6201ee33d065c89bd343ab4ef173 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Fri, 21 Jun 2024 17:15:28 +0200 Subject: [PATCH 52/53] More suggestions from code review --- skrub/_joiner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 4bd230983..7438e1f93 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -305,7 +305,7 @@ def fit(self, X, y=None): _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) - self._suffix = f"{{}}{self.suffix}".format + self._right_cols_renaming = f"{{}}{self.suffix}".format self.vectorizer_ = _make_vectorizer( s.select(self._aux_table, s.cols(*self._aux_key)), self.string_encoder, @@ -357,7 +357,7 @@ def transform(self, X, y=None): right, left_on=left_key_name, right_on=right_key_name, - rename_right_cols=self._suffix, + rename_right_cols=self._right_cols_renaming, ) join = s.select(join, ~s.cols(left_key_name)) if self.add_match_info: From 557a32679b313a5166a0918947241f935a936c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Fri, 21 Jun 2024 17:58:16 +0200 Subject: [PATCH 53/53] Test duplicated key and col names --- skrub/tests/test_join_utils.py | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index a44f12de8..31f7976f7 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -1,7 +1,10 @@ +import re + import numpy as np import pandas as pd import pytest +from skrub import _dataframe as sbd from skrub import _join_utils @@ -107,6 +110,43 @@ def test_left_join_some_keys_not_in_right_dataframe(df_module, left): df_module.assert_frame_equal(joined, expected) +def test_left_join_same_key_name(df_module, left): + right = df_module.make_dataframe({"left_key": [2, 1], "right_col": ["b", "a"]}) + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="left_key" + ) + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right_col": ["a", "b", "b"], + } + ) + df_module.assert_frame_equal(joined, expected) + + +def test_left_join_same_col_name(df_module, left): + right = df_module.make_dataframe({"right_key": [2, 1], "left_col": ["b", "a"]}) + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key" + ) + + cols = sbd.column_names(joined) + assert cols[:2] == ["left_key", "left_col"] + assert re.match("left_col__skrub_.*__", cols[2]) is not None + + expected = df_module.make_dataframe( + { + "a": [1, 2, 2], + "b": [10, 20, 30], + "c": ["a", "b", "b"], + } + ) + # Renaming is necessary because a random tag has been added + expected = sbd.set_column_names(expected, cols) + df_module.assert_frame_equal(joined, expected) + + def test_left_join_renaming_right_cols(df_module, left): right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) joined = _join_utils.left_join(