diff --git a/skrub/_joiner.py b/skrub/_joiner.py index 5cf06b569..2b6eb9ea8 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -334,13 +334,8 @@ def transform(self, X, y=None): _join_utils.check_column_name_duplicates( X, self._aux_table, self.suffix, main_table_name="X" ) - main = self.vectorizer_.transform( - _compat_df( - sbd.set_column_names( - s.select(X, s.cols(*self._main_key)), self._aux_key - ) - ) - ) + main = sbd.set_column_names(s.select(X, s.cols(*self._main_key)), self._aux_key) + main = self.vectorizer_.transform(_compat_df(main)) match_result = self._matching.match(main, self.max_dist_) matching_col = match_result["index"].copy() matching_col[~match_result["match_accepted"]] = -1 diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index 7b14c583d..368ffe252 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -131,14 +131,15 @@ def test_max_dist(main_table, aux_table, max_dist): def test_missing_values(df_module): df = df_module.make_dataframe({"A": [None, "hollywood", "beverly"]}) joiner = Joiner(df, key="A", suffix="_") - with pytest.xfail(): - joiner.fit_transform(df) + out = joiner.fit_transform(df) + assert ns.shape(out) == (3, 5) def test_fit_transform_numeric(df_module): df = df_module.make_dataframe({"A": [4.5, 0.5, 1, -1.5]}) joiner = Joiner(df, key="A", suffix="_") - joiner.fit_transform(df) + out = joiner.fit_transform(df) + assert ns.shape(out) == (4, 5) def test_fit_transform_datetimes(df_module):