diff --git a/CHANGES.rst b/CHANGES.rst index 06c7a5c59..1bd681d15 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,8 @@ It is currently undergoing fast development and backward compatibility is not en Major changes ------------- +* The :class:`Joiner` has been adapted to support polars dataframes. :pr:`945` by :user:`Théo Jolivet `. + * The :class:`TableVectorizer` now consistently applies the same transformation across different calls to `transform`. There also have been some breaking changes to its functionality: (i) all transformations are now applied diff --git a/skrub/_agg_joiner.py b/skrub/_agg_joiner.py index 7044aa571..88ad42330 100644 --- a/skrub/_agg_joiner.py +++ b/skrub/_agg_joiner.py @@ -289,9 +289,8 @@ def transform(self, X): X, _ = self._check_dataframes(X, self.aux_table_) _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") - skrub_px, _ = get_df_namespace(self.aux_table_) - X = skrub_px.join( - left=X, + X = _join_utils.left_join( + X, right=self.aux_table_, left_on=self._main_key, right_on=self._aux_key, @@ -439,10 +438,9 @@ def transform(self, X): The augmented input. """ check_is_fitted(self, "y_") - skrub_px, _ = get_df_namespace(X) - return skrub_px.join( - left=X, + return _join_utils.left_join( + X, right=self.y_, left_on=self.main_key_, right_on=self.main_key_, diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 53c9113de..a66b08ad6 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -85,6 +85,7 @@ "sample", "head", "replace", + "with_columns", ] # @@ -1007,3 +1008,9 @@ def _replace_pandas(col, old, new): @replace.specialize("polars", argument_type="Column") def _replace_polars(col, old, new): return col.replace(old, new) + + +def with_columns(df, **new_cols): + cols = {col_name: col(df, col_name) for col_name in column_names(df)} + cols.update({n: make_column_like(df, c, n) for n, c in new_cols.items()}) + return make_dataframe_like(df, cols) diff --git a/skrub/_dataframe/_pandas.py b/skrub/_dataframe/_pandas.py index 5e9f369bf..db21a8c28 100644 --- a/skrub/_dataframe/_pandas.py +++ b/skrub/_dataframe/_pandas.py @@ -104,49 +104,7 @@ def aggregate( ] sorted_cols = sorted(base_group.columns) - return base_group[sorted_cols] - - -def join( - left, - right, - left_on, - right_on, -): - """Left join two :obj:`pandas.DataFrame`. - - This function uses the ``dataframe.merge`` method from Pandas. - - Parameters - ---------- - left : pd.DataFrame - The left dataframe to left-join. - - right : pd.DataFrame - The right dataframe to left-join. - - left_on : str or Iterable[str] - Left keys to merge on. - - right_on : str or Iterable[str] - Right keys to merge on. - - Returns - ------- - merged : pd.DataFrame, - The merged output. - """ - if not (isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame)): - raise TypeError( - "'left' and 'right' must be pandas dataframes, " - f"got {type(left)!r} and {type(right)!r}." - ) - return left.merge( - right, - how="left", - left_on=left_on, - right_on=right_on, - ) + return base_group[sorted_cols].reset_index(drop=False) def get_named_agg(table, cols, operations): diff --git a/skrub/_dataframe/_polars.py b/skrub/_dataframe/_polars.py index 8f1db8656..04af24156 100644 --- a/skrub/_dataframe/_polars.py +++ b/skrub/_dataframe/_polars.py @@ -1,8 +1,6 @@ """ Polars specialization of the aggregate and join operations. """ -import inspect - try: import polars as pl import polars.selectors as cs @@ -91,50 +89,6 @@ def aggregate( return table.select(sorted_cols) -def join(left, right, left_on, right_on): - """Left join two :obj:`polars.DataFrame` or :obj:`polars.LazyFrame`. - - This function uses the ``dataframe.join`` method from Polars. - - Note that the input dataframes type must agree: either both - Polars dataframes or both Polars lazyframes. - - Mixing polars dataframe with lazyframe will raise an error. - - Parameters - ---------- - left : pl.DataFrame or pl.LazyFrame - The left dataframe of the left-join. - - right : pl.DataFrame or pl.LazyFrame - The right dataframe of the left-join. - - left_on : str or Iterable[str] - Left keys to merge on. - - right_on : str or Iterable[str] - Right keys to merge on. - - Returns - ------- - merged : pl.DataFrame or pl.LazyFrame - The merged output. - """ - is_dataframe = isinstance(left, pl.DataFrame) and isinstance(right, pl.DataFrame) - is_lazyframe = isinstance(left, pl.LazyFrame) and isinstance(right, pl.LazyFrame) - if is_dataframe or is_lazyframe: - if "coalesce" in inspect.signature(left.join).parameters: - kw = {"coalesce": True} - else: - kw = {} - return left.join(right, how="left", left_on=left_on, right_on=right_on, **kw) - else: - raise TypeError( - "'left' and 'right' must be polars dataframes or lazyframes, " - f"got {type(left)!r} and {type(right)!r}." - ) - - def get_aggfuncs(cols, operations): """List Polars aggregation functions. diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index a065371b1..53d02fe5b 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -31,6 +31,7 @@ def test_not_implemented(): "reset_index", "copy_index", "index", + "with_columns", } for func_name in sorted(set(ns.__all__) - has_default_impl): func = getattr(ns, func_name) @@ -147,6 +148,10 @@ def test_make_column_like(df_module, example_data_dict): ) assert ns.dataframe_module_name(col) == df_module.name + col = df_module.make_column("old_name", [1, 2, 3]) + expected = df_module.make_column("new_name", [1, 2, 3]) + df_module.assert_column_equal(ns.make_column_like(col, col, "new_name"), expected) + def test_null_value_for(df_module): assert ns.null_value_for(df_module.example_dataframe) is None @@ -645,3 +650,38 @@ def same(c1, c2): same(ns.drop_nulls(s), col([1.1, 2.2, float("inf")])) same(ns.fill_nulls(s, -1.0), col([1.1, -1.0, 2.2, -1.0, float("inf")])) + + +def test_with_columns(df_module): + df = df_module.make_dataframe({"a": [1, 2], "b": [3, 4]}) + + # Add one new col + out = ns.with_columns(df, **{"c": [5, 6]}) + if df_module.description == "pandas-nullable-dtypes": + # for pandas, make_column_like will return an old-style / numpy dtypes Series + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + df_module.assert_frame_equal(out, expected) + + # Add multiple new cols + out = ns.with_columns(df, **{"c": [5, 6], "d": [7, 8]}) + if df_module.description == "pandas-nullable-dtypes": + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe( + {"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]} + ) + df_module.assert_frame_equal(out, expected) + + # Pass a col instead of an array + out = ns.with_columns(df, **{"c": df_module.make_column("c", [5, 6])}) + if df_module.description == "pandas-nullable-dtypes": + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + df_module.assert_frame_equal(out, expected) + + # Replace col + out = ns.with_columns(df, **{"a": [5, 6]}) + if df_module.description == "pandas-nullable-dtypes": + out = ns.pandas_convert_dtypes(out) + expected = df_module.make_dataframe({"a": [5, 6], "b": [3, 4]}) + df_module.assert_frame_equal(out, expected) diff --git a/skrub/_dataframe/tests/test_pandas.py b/skrub/_dataframe/tests/test_pandas.py index 7f4d6141b..25ed23223 100644 --- a/skrub/_dataframe/tests/test_pandas.py +++ b/skrub/_dataframe/tests/test_pandas.py @@ -4,7 +4,6 @@ from skrub._dataframe._pandas import ( aggregate, - join, rename_columns, ) @@ -18,12 +17,6 @@ ) -def test_join(): - joined = join(left=main, right=main, left_on="movieId", right_on="movieId") - expected = main.merge(main, on="movieId", how="left") - assert_frame_equal(joined, expected) - - def test_simple_agg(): aggregated = aggregate( table=main, @@ -36,7 +29,7 @@ def test_simple_agg(): "genre_mode": ("genre", pd.Series.mode), "rating_mean": ("rating", "mean"), } - expected = main.groupby("movieId").agg(**aggfunc) + expected = main.groupby("movieId").agg(**aggfunc).reset_index() assert_frame_equal(aggregated, expected) @@ -56,7 +49,7 @@ def test_value_counts_agg(): "rating_4.0_user": [3.0, 1.0], "userId": [1, 2], } - ) + ).reset_index(drop=False) assert_frame_equal(aggregated, expected) aggregated = aggregate( @@ -73,14 +66,11 @@ def test_value_counts_agg(): "rating_(3.0, 4.0]_user": [3, 1], "userId": [1, 2], } - ) + ).reset_index(drop=False) assert_frame_equal(aggregated, expected) def test_incorrect_dataframe_inputs(): - with pytest.raises(TypeError, match=r"(?=.*pandas dataframes)(?=.*array)"): - join(left=main.values, right=main, left_on="movieId", right_on="movieId") - with pytest.raises(TypeError, match=r"(?=.*pandas dataframe)(?=.*array)"): aggregate( table=main.values, diff --git a/skrub/_dataframe/tests/test_polars.py b/skrub/_dataframe/tests/test_polars.py index 8875e9c22..2a2a2fa16 100644 --- a/skrub/_dataframe/tests/test_polars.py +++ b/skrub/_dataframe/tests/test_polars.py @@ -1,11 +1,8 @@ -import inspect - import pandas as pd import pytest from skrub._dataframe._polars import ( aggregate, - join, rename_columns, ) from skrub.conftest import _POLARS_INSTALLED @@ -27,16 +24,6 @@ pytest.skip(reason=POLARS_MISSING_MSG, allow_module_level=True) -def test_join(): - joined = join(left=main, right=main, left_on="movieId", right_on="movieId") - if "coalesce" in inspect.signature(main.join).parameters: - kw = {"coalesce": True} - else: - kw = {} - expected = main.join(main, on="movieId", how="left", **kw) - assert_frame_equal(joined, expected) - - def test_simple_agg(): aggregated = aggregate( table=main, @@ -68,9 +55,6 @@ def test_mode_agg(): def test_incorrect_dataframe_inputs(): - with pytest.raises(TypeError, match=r"(?=.*polars dataframes)(?=.*pandas)"): - join(left=pd.DataFrame(main), right=main, left_on="movieId", right_on="movieId") - with pytest.raises(TypeError, match=r"(?=.*polars dataframe)(?=.*pandas)"): aggregate( table=pd.DataFrame(main), diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index 86e5749ad..9b3ece456 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -3,8 +3,10 @@ """ import numpy as np -from skrub import _join_utils -from skrub._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner +from . import _dataframe as sbd +from . import _join_utils +from . import _selectors as s +from ._joiner import DEFAULT_REF_DIST, DEFAULT_STRING_ENCODER, Joiner def fuzzy_join( @@ -210,7 +212,7 @@ def fuzzy_join( add_match_info=True, ).fit_transform(left) if drop_unmatched: - join = join[join["skrub_Joiner_match_accepted"]] + join = sbd.filter(join, sbd.col(join, "skrub_Joiner_match_accepted")) if not add_match_info: - join = join.drop(Joiner.match_info_columns, axis=1) + join = s.select(join, ~s.cols(*Joiner.match_info_columns)) return join diff --git a/skrub/_join_utils.py b/skrub/_join_utils.py index bd1e62607..1333699e0 100644 --- a/skrub/_join_utils.py +++ b/skrub/_join_utils.py @@ -1,9 +1,13 @@ """Utilities specific to the JOIN operations.""" +import inspect import re +from skrub import _dataframe as sbd +from skrub import _selectors as s from skrub import _utils from skrub._dataframe._namespace import get_df_namespace +from skrub._dispatch import dispatch def check_key( @@ -218,3 +222,95 @@ def _get_new_name(suggested_name, forbidden_names): return suggested_name token = _utils.random_string() return f"{untagged_name}__skrub_{token}__" + + +def left_join(left, right, left_on, right_on, rename_right_cols="{}"): + """Left join two dataframes of the same type. + + The input dataframes type must agree: both `left` and `right` need to be + pandas or polars dataframes. Mixing types will raise an error. + + `rename_right_cols` can be used to format the right dataframe columns, e.g. use + "right_.{}" to rename all right cols with a leading "right_.". + + If duplicate column names are found between renamed right cols and left cols, + a __skrub___ is added at the end of columns that would otherwise + be duplicates. + + Parameters + ---------- + left : dataframe + The left dataframe of the left-join. + right : dataframe + The right dataframe of the left-join. + left_on : str or list of str + Left keys to merge on. + right_on : str or list of str + Right keys to merge on. + rename_right_cols : str or callable, default="{}" + Formatting used to rename right cols. If it is a callable, it should + accept strings as an argument. By default, no formatting is applied. + + Returns + ------- + dataframe + The joined output. + + Raises + ------ + TypeError + If either of `left` and `right` is not a dataframe, or if both types + are not equal. + """ + if not sbd.is_dataframe(left): + raise TypeError( + f"`left` must be a pandas or polars dataframe, got {type(left)}." + ) + if not sbd.is_dataframe(right): + raise TypeError( + f"`right` must be a pandas or polars dataframe, got {type(right)}." + ) + if not sbd.dataframe_module_name(left) == sbd.dataframe_module_name(right): + raise TypeError( + "`left` and `right` must be of the same dataframe type, got" + f"{type(left)} and {type(right)}." + ) + + left_cols = sbd.column_names(left) + original_right_cols = sbd.column_names(right) + right_cols = map(_utils.renaming_func(rename_right_cols), original_right_cols) + right_cols = pick_column_names(right_cols, forbidden_names=left_cols) + renaming = dict(zip(original_right_cols, right_cols)) + right = sbd.set_column_names(right, right_cols) + if isinstance(right_on, str): + right_on = renaming[right_on] + right_on_selector = s.cols(right_on) + else: + right_on = tuple(renaming[c] for c in right_on) + right_on_selector = s.cols(*right_on) + joined = _do_left_join(left, right, left_on, right_on) + joined = s.select(joined, ~right_on_selector) + return joined + + +@dispatch +def _do_left_join(left, right, left_on, right_on): + raise NotImplementedError() + + +@_do_left_join.specialize("pandas", argument_type="DataFrame") +def _do_left_join_pandas(left, right, left_on, right_on): + return left.merge( + right, left_on=left_on, right_on=right_on, how="left", suffixes=("", "") + ) + + +@_do_left_join.specialize("polars", argument_type="DataFrame") +def _do_left_join_polars(left, right, left_on, right_on): + if "coalesce" in inspect.signature(left.join).parameters: + kw = {"coalesce": True} + else: + kw = {} + return left.join( + right, left_on=left_on, right_on=right_on, how="left", suffix="", **kw + ) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 54d7ae956..7438e1f93 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -2,29 +2,29 @@ The Joiner provides fuzzy joining as a scikit-learn transformer. """ +from functools import partial + import numpy as np -import pandas as pd +import sklearn from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.compose import make_column_transformer from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer from sklearn.pipeline import make_pipeline from sklearn.preprocessing import FunctionTransformer, StandardScaler +from sklearn.utils.fixes import parse_version from sklearn.utils.validation import check_is_fitted -from skrub import _join_utils, _matching, _utils -from skrub._dataframe._namespace import is_pandas, is_polars -from skrub._datetime_encoder import DatetimeEncoder - +from . import _dataframe as sbd +from . import _join_utils, _matching, _utils from . import _selectors as s +from ._check_input import CheckInputDataFrame +from ._datetime_encoder import DatetimeEncoder +from ._to_str import ToStr from ._wrap_transformer import wrap_transformer - -def _as_str(column): - return column.fillna("").astype(str) - - DEFAULT_STRING_ENCODER = make_pipeline( - FunctionTransformer(_as_str), + FunctionTransformer(partial(sbd.fill_nulls, value="")), + ToStr(), HashingVectorizer(analyzer="char_wb", ngram_range=(2, 4)), TfidfTransformer(), ) @@ -40,6 +40,17 @@ def _as_str(column): DEFAULT_REF_DIST = "random_pairs" +def _compat_df(df): + # In scikit-learn versions older than 1.4, the ColumnTransformer fails on + # polars dataframes. Here it is only applied as an internal step on the + # joining columns, and we get the output as a numpy array or sparse matrix. + # Therefore on old scikit-learn versions we convert the joining columns to + # pandas before vectorizing them. + if parse_version(sklearn.__version__) < parse_version("1.4"): + return sbd.to_pandas(df) + return df + + def _make_vectorizer(table, string_encoder, rescale): """Construct the transformer used to vectorize joining columns. @@ -48,18 +59,18 @@ def _make_vectorizer(table, string_encoder, rescale): In addition if `rescale` is `True`, a StandardScaler is applied to numeric and datetime columns. """ - # TODO remove use of ColumnTransformer, select_dtypes & pandas-specific code + # TODO: add Skrubber before ColumnTransformer + # TODO: remove use of ColumnTransformer transformers = [ - (clone(string_encoder), c) - for c in table.select_dtypes(include=["string", "category", "object"]).columns + (clone(string_encoder), c) for c in (s.string() | s.categorical()).expand(table) ] - num_columns = table.select_dtypes(include="number").columns - if not num_columns.empty: + num_columns = s.numeric().expand(table) + if num_columns: transformers.append( (StandardScaler() if rescale else "passthrough", num_columns) ) - dt_columns = table.select_dtypes(["datetime", "datetimetz"]).columns - if not dt_columns.empty: + dt_columns = s.any_date().expand(table) + if dt_columns: transformers.append( ( make_pipeline( @@ -120,24 +131,24 @@ class Joiner(TransformerMixin, BaseEstimator): Parameters ---------- - aux_table : :obj:`~pandas.DataFrame` + aux_table : dataframe The auxiliary table, which will be fuzzy-joined to the main table when calling `transform`. - key : str or iterable of str, default=None + key : str or list of str, default=None The column names to use for both `main_key` and `aux_key` when they are the same. Provide either `key` or both `main_key` and `aux_key`. - main_key : str or iterable of str, default=None + main_key : str or list of str, default=None The column names in the main table on which the join will be performed. Can be a string if joining on a single column. If `None`, `aux_key` must also be `None` and `key` must be provided. - aux_key : str or iterable of str, default=None + aux_key : str or list of str, default=None The column names in the auxiliary table on which the join will be performed. Can be a string if joining on a single column. If `None`, `main_key` must also be `None` and `key` must be provided. suffix : str, default="" Suffix to append to the `aux_table`'s column names. You can use it to avoid duplicate column names in the join. - max_dist : float, default=np.inf + max_dist : int, float, `None` or `np.inf`, default=`np.inf` Maximum acceptable (rescaled) distance between a row in the `main_table` and its nearest neighbor in the `aux_table`. Rows that are farther apart are not considered to match. By default, the distance @@ -145,7 +156,7 @@ class Joiner(TransformerMixin, BaseEstimator): although rescaled distances can be greater than 1 for some choices of `ref_dist`. `None`, `"inf"`, `float("inf")` or `numpy.inf` mean that no matches are rejected. - ref_dist : reference distance for rescaling, default = 'random_pairs' + ref_dist : reference distance for rescaling, default='random_pairs' Options are {"random_pairs", "second_neighbor", "self_join_neighbor", "no_rescaling"}. See above for a description of each option. To facilitate the choice of `max_dist`, distances between rows in @@ -178,7 +189,7 @@ class Joiner(TransformerMixin, BaseEstimator): See Also -------- AggJoiner : - Aggregate auxiliary dataframes before joining them on a base dataframe. + Aggregate an auxiliary dataframe before joining it on a base dataframe. fuzzy_join : Join two tables (dataframes) based on approximate column matching. This @@ -247,17 +258,6 @@ def __init__( ) self.add_match_info = add_match_info - def _check_dataframe(self, dataframe): - # TODO: add support for polars, ATM we just convert to pandas - if is_polars(dataframe): - return dataframe.to_pandas() - if is_pandas(dataframe): - return dataframe - raise TypeError( - f"{self.__class__.__qualname__} only operates on Pandas or Polars" - " dataframes." - ) - def _check_max_dist(self): if ( self.max_dist is None @@ -271,7 +271,7 @@ def _check_max_dist(self): def _check_ref_dist(self): if self.ref_dist not in _MATCHERS: raise ValueError( - f"ref_dist should be one of {list(_MATCHERS.keys())}, got" + f"'ref_dist' should be one of {list(_MATCHERS.keys())}. Got" f" {self.ref_dist!r}" ) self._matching = _MATCHERS[self.ref_dist]() @@ -281,7 +281,7 @@ def fit(self, X, y=None): Parameters ---------- - X : :obj:`~pandas.DataFrame`, shape [n_samples, n_features] + X : dataframe The main table, to be joined to the auxiliary ones. y : None Unused, only here for compatibility. @@ -292,8 +292,9 @@ def fit(self, X, y=None): Fitted Joiner instance (self). """ del y - X = self._check_dataframe(X) - self._aux_table = self._check_dataframe(self.aux_table) + self._aux_table = CheckInputDataFrame().fit_transform(self.aux_table) + self._main_check_input = CheckInputDataFrame() + X = self._main_check_input.fit_transform(X) self._check_ref_dist() self._check_max_dist() self._main_key, self._aux_key = _join_utils.check_key( @@ -304,12 +305,15 @@ def fit(self, X, y=None): _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) + self._right_cols_renaming = f"{{}}{self.suffix}".format self.vectorizer_ = _make_vectorizer( - self._aux_table[self._aux_key], + s.select(self._aux_table, s.cols(*self._aux_key)), self.string_encoder, rescale=self.ref_dist != "no_rescaling", ) - aux = self.vectorizer_.fit_transform(self._aux_table[self._aux_key]) + aux = self.vectorizer_.fit_transform( + _compat_df(s.select(self._aux_table, s.cols(*self._aux_key))) + ) self._matching.fit(aux) return self @@ -318,52 +322,47 @@ def transform(self, X, y=None): Parameters ---------- - X : :obj:`~pandas.DataFrame`, shape [n_samples, n_features] + X : dataframe The main table, to be joined to the auxiliary ones. y : None Unused, only here for compatibility. Returns ------- - :obj:`~pandas.DataFrame` + dataframe The final joined table. """ del y check_is_fitted(self, "vectorizer_") - input_is_polars = is_polars(X) - X = self._check_dataframe(X) + X = self._main_check_input.transform(X) _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) - main = self.vectorizer_.transform( - X[self._main_key].set_axis(self._aux_key, axis="columns") - ) + main = sbd.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) + main = self.vectorizer_.transform(_compat_df(main)) match_result = self._matching.match(main, self.max_dist_) - aux_table = _join_utils.add_column_name_suffix( - self._aux_table, self.suffix - ).reset_index(drop=True) matching_col = match_result["index"].copy() matching_col[~match_result["match_accepted"]] = -1 token = _utils.random_string() left_key_name = f"skrub_left_key_{token}" right_key_name = f"skrub_right_key_{token}" - left = X.assign(**{left_key_name: matching_col}) - right = aux_table.assign(**{right_key_name: np.arange(aux_table.shape[0])}) - join = pd.merge( + left = sbd.with_columns(X, **{left_key_name: matching_col}) + right = sbd.with_columns( + self._aux_table, + **{right_key_name: np.arange(sbd.shape(self._aux_table)[0], dtype="int64")}, + ) + join = _join_utils.left_join( left, right, left_on=left_key_name, right_on=right_key_name, - suffixes=("", ""), - how="left", + rename_right_cols=self._right_cols_renaming, ) - join = join.drop([left_key_name, right_key_name], axis=1) + join = s.select(join, ~s.cols(left_key_name)) if self.add_match_info: + match_info_dict = {} for info_key, info_col_name in self._match_info_key_renaming.items(): - join[info_col_name] = match_result[info_key] - if input_is_polars: - import polars as pl - - join = pl.from_pandas(join) + match_info_dict[info_col_name] = match_result[info_key] + join = sbd.with_columns(join, **match_info_dict) return join diff --git a/skrub/tests/test_agg_joiner.py b/skrub/tests/test_agg_joiner.py index 402a7009b..53340ae38 100644 --- a/skrub/tests/test_agg_joiner.py +++ b/skrub/tests/test_agg_joiner.py @@ -401,6 +401,7 @@ def test_agg_target(main_table, y, col_name): "movieId": [1, 3, 6, 318, 6, 1704], "rating": [4.0, 4.0, 4.0, 3.0, 2.0, 4.0], "genre": ["drama", "drama", "comedy", "sf", "comedy", "sf"], + "index": [0, 0, 0, 1, 1, 1], f"{col_name}_(1.999, 3.0]_user": [0, 0, 0, 2, 2, 2], f"{col_name}_(3.0, 4.0]_user": [3, 3, 3, 1, 1, 1], f"{col_name}_2.0_user": [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], diff --git a/skrub/tests/test_fuzzy_join.py b/skrub/tests/test_fuzzy_join.py index 528516ae6..37f5673e6 100644 --- a/skrub/tests/test_fuzzy_join.py +++ b/skrub/tests/test_fuzzy_join.py @@ -1,14 +1,13 @@ import warnings import numpy as np -import pandas as pd import pytest from numpy.testing import assert_array_equal from sklearn.feature_extraction.text import HashingVectorizer -from skrub import fuzzy_join +from skrub import ToDatetime, _join_utils, fuzzy_join +from skrub import _selectors as s from skrub._dataframe import _common as ns -from skrub._dataframe._testing_utils import assert_frame_equal @pytest.mark.parametrize( @@ -19,8 +18,6 @@ def test_fuzzy_join(df_module, analyzer): """ Testing if ``fuzzy_join`` results are as expected. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") df1 = df_module.make_dataframe({"a1": ["ana", "lala", "nana et sana", np.nan]}) df2 = df_module.make_dataframe( {"a2": ["anna", "lala et nana", "lana", "sana", np.nan]} @@ -50,16 +47,13 @@ def test_fuzzy_join(df_module, analyzer): # Joining is always done on the left table and thus takes it shape: assert ns.shape(df_joined2) == (len(df2), n_cols) - # TODO: dispatch ``with_columns`` - df1["a2"] = 1 + df1 = ns.with_columns(df1, **{"a2": [1] * ns.shape(df1)[0]}) df_on = fuzzy_join(df_joined, df1, on="a1", suffix="2") assert "a12" in ns.column_names(df_on) def test_max_dist(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") left = df_module.make_dataframe({"A": ["aa", "bb"]}) right = df_module.make_dataframe({"A": ["aa", "ba"], "B": [1, 2]}) join = fuzzy_join(left, right, on="A", suffix="r") @@ -69,8 +63,6 @@ def test_max_dist(df_module): def test_perfect_matches(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") # non-regression test for https://github.com/skrub-data/skrub/issues/764 # fuzzy_join when all rows had a perfect match used to trigger a division by 0 df = df_module.make_dataframe({"A": [0, 1]}) @@ -87,19 +79,15 @@ def test_fuzzy_join_dtypes(df_module): """ Test that the dtypes of dataframes are maintained after join. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaa", "bbb"], "col2": [1, 2]}) b = df_module.make_dataframe({"col1": ["aaa_", "bbb_"], "col3": [1, 2]}) c = fuzzy_join(a, b, on="col1", suffix="r") - assert ns.dtype(ns.col(a, "col2")).kind == "i" + assert ns.is_integer(ns.col(a, "col2")) assert ns.dtype(ns.col(c, "col2")) == ns.dtype(ns.col(a, "col2")) assert ns.dtype(ns.col(c, "col3r")) == ns.dtype(ns.col(b, "col3")) def test_missing_keys(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaa", "bbb"], "col2": [1, 2]}) b = df_module.make_dataframe({"col1": ["aaa_", "bbb_"], "col3": [1, 2]}) with pytest.raises( @@ -116,8 +104,6 @@ def test_missing_keys(df_module): def test_drop_unmatched(df_module): - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaaa", "bbb", "ddd dd"], "col2": [1, 2, 3]}) b = df_module.make_dataframe( {"col1": ["aaa_", "bbb_", "cc ccc"], "col3": [1, 2, 3]} @@ -137,20 +123,20 @@ def test_drop_unmatched(df_module): assert sum(ns.is_null(ns.col(c2, "col3r"))) > 0 -def test_fuzzy_join_pandas_comparison(): +def test_fuzzy_join_exact_matches(df_module): """ - Tests if fuzzy_join's output is as similar as - possible with `pandas.merge`. + Tests if fuzzy_join's output is the same as a normal left-join when there + are exact matches for all rows. """ - left = pd.DataFrame( + left = df_module.make_dataframe( { - "key": ["K0", "K1", "K2", "K3"], + "key": ["K2", "K2", "K3", "K1"], "A": ["A0", "A1", "A2", "A3"], "B": ["B0", "B1", "B2", "B3"], } ) - right = pd.DataFrame( + right = df_module.make_dataframe( { "key_": ["K0", "K1", "K2", "K3"], "C": ["C0", "C1", "C2", "C3"], @@ -158,20 +144,28 @@ def test_fuzzy_join_pandas_comparison(): } ) - result = pd.merge(left, right, left_on="key", right_on="key_") + result = _join_utils.left_join(left, right, left_on="key", right_on="key_") result_fj = fuzzy_join( left, right, left_on="key", right_on="key_", add_match_info=False ) + df_module.assert_column_equal( + ns.col(result_fj, "key_"), ns.rename(ns.col(result_fj, "key"), "key_") + ) + # `_left_join` does a (non-fuzzy, regular) equijoin so it only keeps one of the + # join columns (keeping both would be redundant as they are identical due to exact + # matching) -- same as the default behavior of polars (coalesce=True) and pandas. + # `fuzzy_join` keeps both columns because they are not identical, only up to + # fuzziness, so keeping both is informative. So here we drop `key_` to compare the + # 2 resulting dataframes. + result_fj = s.select(result_fj, ~s.cols("key_")) - assert_frame_equal(result, result_fj) + df_module.assert_frame_equal(result, result_fj) def test_correct_encoder(df_module): """ Test that the encoder error checking is working as intended. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") class TestVectorizer(HashingVectorizer): """ @@ -213,8 +207,6 @@ def test_numerical_column(df_module): """ Testing that ``fuzzy_join`` works with numerical columns. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") left = df_module.make_dataframe({"str1": ["aa", "a", "bb"], "int": [10, 2, 5]}) right = df_module.make_dataframe( { @@ -249,20 +241,20 @@ def test_datetime_column(df_module): """ Testing that ``fuzzy_join`` works with datetime columns. """ - if df_module.name == "polars": - pytest.xfail(reason="Module 'polars' has no attribute 'to_datetime'.") + + def to_dt(lst): + return ToDatetime().fit_transform(df_module.make_column("", lst)) + left = df_module.make_dataframe( { "str1": ["aa", "a", "bb"], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), } ) right = df_module.make_dataframe( { "str2": ["aa", "bb", "a", "cc", "dd"], - "date": df_module.module.to_datetime( + "date": to_dt( ["09/10/2022", "12/24/2021", "09/25/2010", "11/05/2025", "02/21/2000"] ), } @@ -273,16 +265,12 @@ def test_datetime_column(df_module): fj_time_expected = df_module.make_dataframe( { "str1": ["aa", "a", "bb"], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), "str2r": ["aa", "bb", "a"], - "dater": df_module.module.to_datetime( - ["09/10/2022", "12/24/2021", "09/25/2010"] - ), + "dater": to_dt(["09/10/2022", "12/24/2021", "09/25/2010"]), } ) - assert_frame_equal(fj_time, fj_time_expected) + df_module.assert_frame_equal(fj_time, fj_time_expected) n_cols = ns.shape(left)[1] + ns.shape(right)[1] n_samples = len(left) @@ -308,17 +296,17 @@ def test_mixed_joins(df_module): """ Test fuzzy joining on mixed and multiple column types. """ - if df_module.name == "polars": - pytest.xfail(reason="Module 'polars' has no attribute 'to_datetime'") + + def to_dt(lst): + return ToDatetime().fit_transform(df_module.make_column("", lst)) + left = df_module.make_dataframe( { "str1": ["Paris", "Paris", "Paris"], "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), } ) right = df_module.make_dataframe( @@ -327,7 +315,7 @@ def test_mixed_joins(df_module): "str_2": ["TX", "FR", "GR Mytho", "cc", "dd"], "int1": [55, 6, 2, 15, 6], "int2": [554, 146, 32, 215, 612], - "date": df_module.module.to_datetime( + "date": to_dt( ["09/10/2022", "12/24/2021", "09/25/2010", "11/05/2025", "02/21/2000"] ), } @@ -344,19 +332,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["10/10/2022", "12/11/2021", "09/25/2011"] - ), + "date": to_dt(["10/10/2022", "12/11/2021", "09/25/2011"]), "str_1r": ["Paris", "Paris", "dd"], "str_2r": ["FR", "FR", "dd"], "int1r": [6, 6, 6], "int2r": [146, 146, 612], - "dater": df_module.module.to_datetime( - ["12/24/2021", "12/24/2021", "02/21/2000"] - ), + "dater": to_dt(["12/24/2021", "12/24/2021", "02/21/2000"]), } ) - assert_frame_equal(fj_num, expected_fj_num) + df_module.assert_frame_equal(fj_num, expected_fj_num) assert ns.shape(fj_num) == (3, 10) # On multiple string keys @@ -375,19 +359,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["TX", "FR", "GR Mytho"], "int1r": [55, 6, 2], "int2r": [554, 146, 32], - "dater": df_module.module.to_datetime( - ["2022-09-10", "2021-12-24", "2010-09-25"] - ), + "dater": to_dt(["2022-09-10", "2021-12-24", "2010-09-25"]), } ) - assert_frame_equal(fj_str, expected_fj_str) + df_module.assert_frame_equal(fj_str, expected_fj_str) assert ns.shape(fj_str) == (3, 10) # On mixed, numeric and string keys @@ -405,19 +385,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["FR", "FR", "TX"], "int1r": [6, 6, 55], "int2r": [146, 146, 554], - "dater": df_module.module.to_datetime( - ["2021-12-24", "2021-12-24", "2022-09-10"] - ), + "dater": to_dt(["2021-12-24", "2021-12-24", "2022-09-10"]), } ) - assert_frame_equal(fj_mixed, expected_fj_mixed) + df_module.assert_frame_equal(fj_mixed, expected_fj_mixed) assert ns.shape(fj_mixed) == (3, 10) # On mixed time and string keys @@ -435,19 +411,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["TX", "FR", "GR Mytho"], "int1r": [55, 6, 2], "int2r": [554, 146, 32], - "dater": df_module.module.to_datetime( - ["2022-09-10", "2021-12-24", "2010-09-25"] - ), + "dater": to_dt(["2022-09-10", "2021-12-24", "2010-09-25"]), } ) - assert_frame_equal(fj_mixed2, expected_fj_mixed2) + df_module.assert_frame_equal(fj_mixed2, expected_fj_mixed2) assert ns.shape(fj_mixed2) == (3, 10) # On mixed time and numbers keys @@ -465,19 +437,15 @@ def test_mixed_joins(df_module): "str2": ["Texas", "France", "Greek God"], "int1": [10, 2, 5], "int2": [103, 250, 532], - "date": df_module.module.to_datetime( - ["2022-10-10", "2021-12-11", "2011-09-25"] - ), + "date": to_dt(["2022-10-10", "2021-12-11", "2011-09-25"]), "str_1r": ["Paris", "Paris", "Paris"], "str_2r": ["FR", "FR", "GR Mytho"], "int1r": [6, 6, 2], "int2r": [146, 146, 32], - "dater": df_module.module.to_datetime( - ["2021-12-24", "2021-12-24", "2010-09-25"] - ), + "dater": to_dt(["2021-12-24", "2021-12-24", "2010-09-25"]), } ) - assert_frame_equal(fj_mixed3, expected_fj_mixed3) + df_module.assert_frame_equal(fj_mixed3, expected_fj_mixed3) assert ns.shape(fj_mixed3) == (3, 10) @@ -485,8 +453,6 @@ def test_iterable_input(df_module): """ Test if iterable inputs (list, set, dictionary or tuple) work. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") df1 = df_module.make_dataframe( {"a": ["ana", "lala", "nana"], "str2": ["Texas", "France", "Greek God"]} ) @@ -529,20 +495,15 @@ def test_iterable_input(df_module): ) == (3, 4) -@pytest.mark.xfail def test_missing_values(df_module): """ Test fuzzy joining on missing values. """ - if df_module.name == "polars": - pytest.xfail(reason="Polars DataFrame object has no attribute 'reset_index'.") a = df_module.make_dataframe({"col1": ["aaaa", "bbb", "ddd dd"], "col2": [1, 2, 3]}) b = df_module.make_dataframe({"col3": [np.nan, "bbb", "ddd dd"], "col4": [1, 2, 3]}) - with pytest.warns(UserWarning, match=r"merging on missing values"): - c = fuzzy_join(a, b, left_on="col1", right_on="col3", add_match_info=False) + c = fuzzy_join(a, b, left_on="col1", right_on="col3", add_match_info=False) assert ns.shape(c)[0] == len(b) - with pytest.warns(UserWarning, match=r"merging on missing values"): - c = fuzzy_join(b, a, left_on="col3", right_on="col1", add_match_info=True) + c = fuzzy_join(b, a, left_on="col3", right_on="col1", add_match_info=True) assert ns.shape(c)[0] == len(b) diff --git a/skrub/tests/test_join_utils.py b/skrub/tests/test_join_utils.py index 678754b1a..31f7976f7 100644 --- a/skrub/tests/test_join_utils.py +++ b/skrub/tests/test_join_utils.py @@ -1,6 +1,10 @@ +import re + +import numpy as np import pandas as pd import pytest +from skrub import _dataframe as sbd from skrub import _join_utils @@ -44,31 +48,161 @@ def test_check_key_length_mismatch(): ) -def test_check_column_name_duplicates(): - left = pd.DataFrame(columns=["A", "B"]) - right = pd.DataFrame(columns=["C"]) +def test_check_no_column_name_duplicates_with_no_suffix(df_module): + left = df_module.make_dataframe({"A": [], "B": []}) + right = df_module.make_dataframe({"C": []}) _join_utils.check_column_name_duplicates(left, right, "") - left = pd.DataFrame(columns=["A", "B"]) - right = pd.DataFrame(columns=["B"]) + +def test_check_no_column_name_duplicates_after_adding_a_suffix(df_module): + left = df_module.make_dataframe({"A": [], "B": []}) + right = df_module.make_dataframe({"B": []}) _join_utils.check_column_name_duplicates(left, right, "_right") - left = pd.DataFrame(columns=["A", "B_right"]) - right = pd.DataFrame(columns=["B"]) + +def test_check_column_name_duplicates_after_adding_a_suffix(df_module): + left = df_module.make_dataframe({"A": [], "B_right": []}) + right = df_module.make_dataframe({"B": []}) with pytest.raises(ValueError, match=".*suffix '_right'.*['B_right']"): _join_utils.check_column_name_duplicates(left, right, "_right") - left = pd.DataFrame(columns=["A", "A"]) - right = pd.DataFrame(columns=["B"]) - with pytest.raises(ValueError, match="Table 'left' has duplicate"): - _join_utils.check_column_name_duplicates( - left, right, "", main_table_name="left" - ) - -def test_add_column_name_suffix(): - df = pd.DataFrame(columns=["one", "two three", "x"]) +def test_add_column_name_suffix(df_module): + df = df_module.make_dataframe({"one": [], "two three": [], "x": []}) df = _join_utils.add_column_name_suffix(df, "") assert list(df.columns) == ["one", "two three", "x"] df = _join_utils.add_column_name_suffix(df, "_y") assert list(df.columns) == ["one_y", "two three_y", "x_y"] + + +@pytest.fixture +def left(df_module): + return df_module.make_dataframe({"left_key": [1, 2, 2], "left_col": [10, 20, 30]}) + + +def test_left_join_all_keys_in_right_dataframe(df_module, left): + right = df_module.make_dataframe({"right_key": [2, 1], "right_col": ["b", "a"]}) + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key" + ) + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right_col": ["a", "b", "b"], + } + ) + df_module.assert_frame_equal(joined, expected) + + +def test_left_join_some_keys_not_in_right_dataframe(df_module, left): + right = df_module.make_dataframe({"right_key": [2, 3], "right_col": ["a", "c"]}) + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key" + ) + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right_col": [np.nan, "a", "a"], + } + ) + df_module.assert_frame_equal(joined, expected) + + +def test_left_join_same_key_name(df_module, left): + right = df_module.make_dataframe({"left_key": [2, 1], "right_col": ["b", "a"]}) + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="left_key" + ) + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right_col": ["a", "b", "b"], + } + ) + df_module.assert_frame_equal(joined, expected) + + +def test_left_join_same_col_name(df_module, left): + right = df_module.make_dataframe({"right_key": [2, 1], "left_col": ["b", "a"]}) + joined = _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key" + ) + + cols = sbd.column_names(joined) + assert cols[:2] == ["left_key", "left_col"] + assert re.match("left_col__skrub_.*__", cols[2]) is not None + + expected = df_module.make_dataframe( + { + "a": [1, 2, 2], + "b": [10, 20, 30], + "c": ["a", "b", "b"], + } + ) + # Renaming is necessary because a random tag has been added + expected = sbd.set_column_names(expected, cols) + df_module.assert_frame_equal(joined, expected) + + +def test_left_join_renaming_right_cols(df_module, left): + right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) + joined = _join_utils.left_join( + left, + right=right, + left_on="left_key", + right_on="right_key", + rename_right_cols="right.{}", + ) + expected = df_module.make_dataframe( + { + "left_key": [1, 2, 2], + "left_col": [10, 20, 30], + "right.right_col": ["a", "b", "b"], + } + ) + df_module.assert_frame_equal(joined, expected) + + +def test_left_join_wrong_left_type(df_module): + right = df_module.make_dataframe({"right_key": [1, 2], "right_col": ["a", "b"]}) + with pytest.raises( + TypeError, + match=( + "`left` must be a pandas or polars dataframe, got ." + ), + ): + _join_utils.left_join( + np.array([1, 2]), right=right, left_on="left_key", right_on="right_key" + ) + + +def test_left_join_wrong_right_type(df_module, left): + with pytest.raises( + TypeError, + match=( + "`right` must be a pandas or polars dataframe, got ." + ), + ): + _join_utils.left_join( + left, right=np.array([1, 2]), left_on="left_key", right_on="right_key" + ) + + +def test_left_join_types_not_equal(df_module, left): + try: + import polars as pl + except ImportError: + pytest.skip(reason="Polars not available.") + + other_px = pd if df_module.module is pl else pl + right = other_px.DataFrame(left) + + with pytest.raises( + TypeError, match=r"`left` and `right` must be of the same dataframe type" + ): + _join_utils.left_join( + left, right=right, left_on="left_key", right_on="right_key" + ) diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index d28715486..c7fdc778d 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -1,3 +1,5 @@ +import datetime + import numpy as np import pandas as pd import pytest @@ -5,7 +7,6 @@ from skrub import Joiner from skrub._dataframe import _common as ns -from skrub._dataframe._testing_utils import assert_frame_equal @pytest.fixture @@ -49,22 +50,17 @@ def test_fit_transform(main_table, aux_table): def test_wrong_main_key(main_table, aux_table): - false_joiner = Joiner(aux_table=aux_table, main_key="Countryy", aux_key="country") - - with pytest.raises( - ValueError, - match="do not exist in 'X'", - ): - false_joiner.fit(main_table) + wrong_joiner = Joiner(aux_table=aux_table, main_key="wrong_key", aux_key="country") + with pytest.raises(ValueError, match="do not exist in 'X'"): + wrong_joiner.fit(main_table) def test_wrong_aux_key(main_table, aux_table): - false_joiner2 = Joiner(aux_table=aux_table, main_key="Country", aux_key="bad") - with pytest.raises( - ValueError, - match="do not exist in 'aux_table'", - ): - false_joiner2.fit(main_table) + wrong_joiner_2 = Joiner( + aux_table=aux_table, main_key="Country", aux_key="wrong_key" + ) + with pytest.raises(ValueError, match="do not exist in 'aux_table'"): + wrong_joiner_2.fit(main_table) def test_multiple_keys(df_module): @@ -76,20 +72,20 @@ def test_multiple_keys(df_module): ) joiner_list = Joiner( aux_table=df2, - aux_key=["CO", "CA"], main_key=["Co", "Ca"], + aux_key=["CO", "CA"], add_match_info=False, ) result = joiner_list.fit_transform(df) expected = ns.concat_horizontal(df, df2) - assert_frame_equal(result, expected) + df_module.assert_frame_equal(result, expected) joiner_list = Joiner( - aux_table=df2, aux_key="CA", main_key="Ca", add_match_info=False + aux_table=df2, main_key="Ca", aux_key="CA", add_match_info=False ) result = joiner_list.fit_transform(df) - assert_frame_equal(result, expected) + df_module.assert_frame_equal(result, expected) def test_pandas_aux_table_index(): @@ -108,22 +104,54 @@ def test_pandas_aux_table_index(): suffix="_capitals", ) join = joiner.fit_transform(main_table) - assert ns.to_list(ns.col(join, "Country_capitals")) == [ + assert join["Country_capitals"].to_list() == [ "France", "Italy", "Germany", ] -def test_bad_ref_dist(): - table = pd.DataFrame({"A": [1, 2]}) - joiner = Joiner(table, key="A", ref_dist="bad") - with pytest.raises(ValueError, match="got 'bad'"): - joiner.fit(table) +def test_wrong_ref_dist(main_table, aux_table): + joiner = Joiner( + aux_table, main_key="Country", aux_key="country", ref_dist="wrong_ref_dist" + ) + with pytest.raises( + ValueError, match=r"'ref_dist' should be one of.* Got 'wrong_ref_dist'" + ): + joiner.fit(main_table) @pytest.mark.parametrize("max_dist", [np.inf, float("inf"), "inf", None]) -def test_max_dist(max_dist): - table = pd.DataFrame({"A": [1, 2]}) - joiner = Joiner(table, key="A", max_dist=max_dist, suffix="_").fit(table) +def test_max_dist(main_table, aux_table, max_dist): + joiner = Joiner( + aux_table, main_key="Country", aux_key="country", max_dist=max_dist + ).fit(main_table) assert joiner.max_dist_ == np.inf + + +def test_missing_values(df_module): + df = df_module.make_dataframe({"A": [None, "hollywood", "beverly"]}) + joiner = Joiner(df, key="A", suffix="_") + out = joiner.fit_transform(df) + assert ns.shape(out) == (3, 5) + + +def test_fit_transform_numeric(df_module): + df = df_module.make_dataframe({"A": [4.5, 0.5, 1, -1.5]}) + joiner = Joiner(df, key="A", suffix="_") + out = joiner.fit_transform(df) + assert ns.shape(out) == (4, 5) + + +def test_fit_transform_datetimes(df_module): + values = [ + datetime.datetime.fromisoformat(dt) + for dt in [ + "2020-02-03T12:30:05", + "2021-03-15T00:37:15", + "2022-02-13T17:03:25", + ] + ] + df = df_module.make_dataframe({"A": values}) + joiner = Joiner(df, key="A", suffix="_") + joiner.fit_transform(df)