Skip to content

Commit

Permalink
Merge pull request #33 from DataDome/allow-min_sup-parameter-between-…
Browse files Browse the repository at this point in the history
…0-and-1

allow parameter `min_sup` between 0 and 1
  • Loading branch information
jmarecaille committed Mar 23, 2023
2 parents ddf2ac7 + f18caee commit fa486d7
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
16 changes: 12 additions & 4 deletions sliceline/slicefinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
The slicefinder module implements the Slicefinder class.
"""
import logging
from typing import Tuple
from typing import Tuple, Union

import numpy as np
from scipy import sparse as sp
Expand Down Expand Up @@ -52,9 +52,10 @@ class Slicefinder(BaseEstimator, TransformerMixin):
Maximum lattice level.
In other words: the maximum number of predicate to define a slice.
min_sup: int, default=10
min_sup: int or float, default=10
Minimum support threshold.
Inspired by frequent itemset mining, it ensures statistical significance.
If `min_sup` is a float (0 < `min_sup` < 1), it represents the faction of the input dataset (`X`)
verbose: bool, default=True
Controls the verbosity.
Expand All @@ -79,7 +80,7 @@ def __init__(
alpha: float = 0.6,
k: int = 1,
max_l: int = 4,
min_sup: int = 10,
min_sup: Union[int, float] = 10,
verbose: bool = True,
):
self.alpha = alpha
Expand Down Expand Up @@ -107,7 +108,10 @@ def _check_params(self):
if self.max_l <= 0:
raise ValueError(f"Invalid 'max_l' parameter: {self.max_l}")

if self.min_sup < 0:
if (
self.min_sup < 0 or
(isinstance(self.min_sup, float) and self.min_sup >= 1)
):
raise ValueError(f"Invalid 'min_sup' parameter: {self.min_sup}")

def _check_top_slices(self):
Expand Down Expand Up @@ -138,6 +142,10 @@ def fit(self, X, errors):
"""
self._check_params()

# Update min_sup for a fraction of the input dataset size
if 0 < self.min_sup < 1:
self.min_sup = int(self.min_sup * len(X))

# Check that X and e have correct shape
X_array, errors = check_X_e(X, errors)

Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ def experiments():
X_16, errors_16, expected_top_slices_16, alpha=0.01, max_l=3, min_sup=7
)

# Experiment 17: Experiment 4 w/ min_sup=0.1
expected_top_slices_17 = np.array(
[
[1.0, 1.0, None, None, None, None],
[1.0, None, None, None, None, None],
[None, 1.0, None, None, None, None]
]
)
experiment_17 = Experiment(X_4, errors_4, expected_top_slices_17, min_sup=0.5)

return {
"experiment_1": experiment_1,
"experiment_2": experiment_2,
Expand All @@ -338,4 +348,5 @@ def experiments():
"experiment_14": experiment_14,
"experiment_15": experiment_15,
"experiment_16": experiment_16,
"experiment_17": experiment_17
}
3 changes: 2 additions & 1 deletion tests/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class Experiment:
Maximum lattice level.
In other words: the maximum number of predicate to define a slice.
min_sup: int, default=10
min_sup: int or float, default=10
Minimum support threshold.
Inspired by frequent itemset mining, it ensures statistical significance.
If `min_sup` is a float (0 < `min_sup` < 1), it represents the faction of the input dataset (`X`)
verbose: bool, default=True
Controls the verbosity.
Expand Down
1 change: 1 addition & 0 deletions tests/test_slicefinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def test_search_slices(benchmark, basic_test_data):
"experiment_14",
"experiment_15",
"experiment_16",
"experiment_17",
],
)
def test_experiments(benchmark, experiments, experiment_name):
Expand Down

0 comments on commit fa486d7

Please sign in to comment.