Skip to content

Commit

Permalink
fix: SGDClassifier post-processing with multi-class
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Apr 3, 2024
1 parent df3b5b6 commit 5ffaf9a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 43 deletions.
21 changes: 0 additions & 21 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,27 +254,6 @@ def __init__(
f"({fit_encrypted=}). Got {parameters_range=}"
)

def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
# If the prediction array is 1D, which happens with some models such as XGBCLassifier or
# LogisticRegression models, we have a binary classification problem
n_classes = y_preds.shape[1] if y_preds.ndim > 1 and y_preds.shape[1] > 1 else 2

# For binary classification problem, apply the sigmoid operator
if n_classes == 2:
y_preds = numpy_sigmoid(y_preds)[0]

# If the prediction array is 1D, transform the output into a 2D array [1-p, p],
# with p the initial output probabilities
if y_preds.ndim == 1 or y_preds.shape[1] == 1:
y_preds = numpy.concatenate((1 - y_preds, y_preds), axis=1)

# Else, apply the softmax operator
else:
y_preds = numpy_sigmoid(y_preds)[0]
y_preds = y_preds / y_preds.sum(axis=1)

return y_preds

def get_sklearn_params(self, deep: bool = True) -> dict:
# Here, the `get_params` method is the `BaseEstimator.get_params` method from scikit-learn
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373
Expand Down
26 changes: 4 additions & 22 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,31 +1605,11 @@ def test_predict_correctness(


@pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS)
@pytest.mark.parametrize(
"simulate",
[
pytest.param(False, id="fhe"),
],
)
# N_BITS_LINEAR_MODEL_CRYPTO_PARAMETERS bits is currently the
# limit to find crypto parameters for linear models
# make sure we only compile below that bit-width.
# Additionally, prevent computations in FHE with too many bits
@pytest.mark.parametrize(
"n_bits",
[
n_bits
for n_bits in N_BITS_WEEKLY_ONLY_BUILDS + N_BITS_REGULAR_BUILDS
if n_bits
< min(N_BITS_LINEAR_MODEL_CRYPTO_PARAMETERS, N_BITS_THRESHOLD_TO_FORCE_EXECUTION_NOT_IN_FHE)
],
)

# pylint: disable=too-many-branches
def test_separated_inference(
model_class,
parameters,
simulate,
n_bits,
load_data,
default_configuration,
is_weekly_option,
Expand All @@ -1638,6 +1618,8 @@ def test_separated_inference(
):
"""Test prediction correctness between clear quantized and FHE simulation or execution."""

n_bits = min(N_BITS_REGULAR_BUILDS)

# KNN can only be compiled with small quantization bit numbers for now
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979
if n_bits > 5 and get_model_name(model_class) == "KNeighborsClassifier":
Expand All @@ -1646,7 +1628,7 @@ def test_separated_inference(
model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option)

# Run the test with more samples during weekly CIs or when using FHE simulation
if is_weekly_option or simulate:
if is_weekly_option:
fhe_samples = 5
else:
fhe_samples = 1
Expand Down

0 comments on commit 5ffaf9a

Please sign in to comment.