From 31799d729234d2d5b67c6043f8aef471357c63dd Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Wed, 19 Jun 2024 15:16:28 +0200 Subject: [PATCH] add _dataframe.filter function --- skrub/_dataframe/_common.py | 16 ++++++++++++++++ skrub/_dataframe/tests/test_common.py | 10 ++++++++++ 2 files changed, 26 insertions(+) diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 5a5ee94a5..53c9113de 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -80,6 +80,7 @@ "fill_nulls", "n_unique", "unique", + "filter", "where", "sample", "head", @@ -933,6 +934,21 @@ def _unique_polars(col): return col.unique().drop_nulls() +@dispatch +def filter(obj, predicate): + raise NotImplementedError() + + +@filter.specialize("pandas") +def _filter_pandas(obj, predicate): + return obj[predicate] + + +@filter.specialize("polars") +def _filter_polars(obj, predicate): + return obj.filter(predicate) + + @dispatch def where(col, mask, other): raise NotImplementedError() diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index 99c60ee1d..a065371b1 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -585,6 +585,16 @@ def test_unique(df_module): ) +def test_filter(df_module): + df = df_module.example_dataframe + pred = ns.col(df, "int-not-null-col") > 1 + filtered_df = ns.filter(df, pred) + assert ns.shape(filtered_df) == (2, 8) + assert ns.to_list(ns.col(filtered_df, "int-not-null-col")) == [4, 10] + filtered_col = ns.filter(df_module.example_column, pred) + assert ns.to_list(filtered_col) == [4.5, -1.5] + + def test_where(df_module): s = ns.pandas_convert_dtypes(df_module.make_column("", [0, 1, 2])) out = ns.where(