From f18caee0e3a6de43741da19f71859e622e86ab81 Mon Sep 17 00:00:00 2001 From: Jules Marecaille Date: Fri, 17 Mar 2023 16:54:58 +0100 Subject: [PATCH] allow parameter `min_sup` between 0 and 1 --- sliceline/slicefinder.py | 16 ++++++++++++---- tests/conftest.py | 11 +++++++++++ tests/experiment.py | 3 ++- tests/test_slicefinder.py | 1 + 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/sliceline/slicefinder.py b/sliceline/slicefinder.py index 640a7dd..2024e04 100644 --- a/sliceline/slicefinder.py +++ b/sliceline/slicefinder.py @@ -2,7 +2,7 @@ The slicefinder module implements the Slicefinder class. """ import logging -from typing import Tuple +from typing import Tuple, Union import numpy as np from scipy import sparse as sp @@ -52,9 +52,10 @@ class Slicefinder(BaseEstimator, TransformerMixin): Maximum lattice level. In other words: the maximum number of predicate to define a slice. - min_sup: int, default=10 + min_sup: int or float, default=10 Minimum support threshold. Inspired by frequent itemset mining, it ensures statistical significance. + If `min_sup` is a float (0 < `min_sup` < 1), it represents the faction of the input dataset (`X`) verbose: bool, default=True Controls the verbosity. @@ -79,7 +80,7 @@ def __init__( alpha: float = 0.6, k: int = 1, max_l: int = 4, - min_sup: int = 10, + min_sup: Union[int, float] = 10, verbose: bool = True, ): self.alpha = alpha @@ -107,7 +108,10 @@ def _check_params(self): if self.max_l <= 0: raise ValueError(f"Invalid 'max_l' parameter: {self.max_l}") - if self.min_sup < 0: + if ( + self.min_sup < 0 or + (isinstance(self.min_sup, float) and self.min_sup >= 1) + ): raise ValueError(f"Invalid 'min_sup' parameter: {self.min_sup}") def _check_top_slices(self): @@ -138,6 +142,10 @@ def fit(self, X, errors): """ self._check_params() + # Update min_sup for a fraction of the input dataset size + if 0 < self.min_sup < 1: + self.min_sup = int(self.min_sup * len(X)) + # Check that X and e have correct shape X_array, errors = check_X_e(X, errors) diff --git a/tests/conftest.py b/tests/conftest.py index 4cefdb0..66212b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -321,6 +321,16 @@ def experiments(): X_16, errors_16, expected_top_slices_16, alpha=0.01, max_l=3, min_sup=7 ) + # Experiment 17: Experiment 4 w/ min_sup=0.1 + expected_top_slices_17 = np.array( + [ + [1.0, 1.0, None, None, None, None], + [1.0, None, None, None, None, None], + [None, 1.0, None, None, None, None] + ] + ) + experiment_17 = Experiment(X_4, errors_4, expected_top_slices_17, min_sup=0.5) + return { "experiment_1": experiment_1, "experiment_2": experiment_2, @@ -338,4 +348,5 @@ def experiments(): "experiment_14": experiment_14, "experiment_15": experiment_15, "experiment_16": experiment_16, + "experiment_17": experiment_17 } diff --git a/tests/experiment.py b/tests/experiment.py index ea087f9..53df664 100644 --- a/tests/experiment.py +++ b/tests/experiment.py @@ -36,9 +36,10 @@ class Experiment: Maximum lattice level. In other words: the maximum number of predicate to define a slice. - min_sup: int, default=10 + min_sup: int or float, default=10 Minimum support threshold. Inspired by frequent itemset mining, it ensures statistical significance. + If `min_sup` is a float (0 < `min_sup` < 1), it represents the faction of the input dataset (`X`) verbose: bool, default=True Controls the verbosity. diff --git a/tests/test_slicefinder.py b/tests/test_slicefinder.py index 9ae0627..74ec5c7 100644 --- a/tests/test_slicefinder.py +++ b/tests/test_slicefinder.py @@ -303,6 +303,7 @@ def test_search_slices(benchmark, basic_test_data): "experiment_14", "experiment_15", "experiment_16", + "experiment_17", ], ) def test_experiments(benchmark, experiments, experiment_name):