Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: SGDClassifier post-processing with multi-class and improve linear models' predict method #585

Merged
merged 5 commits into from
Apr 10, 2024

Conversation

RomanBredehoft
Copy link
Collaborator

@RomanBredehoft RomanBredehoft commented Apr 3, 2024

first step towards a green weekly CI

So there are a few things to note in this PR :

  • the weekly CI was failing because SGD classifier's post_processing method was not properly integrated : instead, everything was done in predict_proba, which interferes with the way we build and test our API (notably the test_separated_inference)
  • I found a more general "issue" related to linear models' predict method: while trees / qnns usually do predict_proba + argmax, for linear models sklearn does decision_function + argmax. In theory this should not change anything as the predict_proba basically does a sigmoid/softmax or normalization, but in practice it made the argmax behave differently because of slight floating points errors (basically the same as in https://github.com/zama-ai/concrete-ml-internal/issues/3369)
  • while debugging this, I also encountered https://github.com/zama-ai/concrete-ml-internal/issues/4029 and decided to fix it once and for all, else I more or less had to make linear classifiers ouputs' shape not coherent with sklearn which seemed going backward

Why are we discovering this only recently ? A few reasons :

  • for linear classifiers: it's been quite a long time now that we almost never test our classifier's predict method because of the float issues mentioned above : we consider that the most important fact is making sure that quantized == quantized. The only paces where this seem to happen is the hyper parameter test, which is where @jfrery detected the issue
  • for sgd's post processing: the way we were doing post-processing was good, the issue was only when running the inference in separated ways like in test_separated_inference. Still, our custom post_processing function was good for 1D array (tested in regular CIs) with "log_loss" loss (default). Only the weekly tests 2D arrays, which is where we found the error

I also took the liberty to clean a bit some parts/tests (no breaking changes)

refs https://github.com/zama-ai/concrete-ml-internal/issues/4030
closes https://github.com/zama-ai/concrete-ml-internal/issues/4344
closes https://github.com/zama-ai/concrete-ml-internal/issues/4252
closes https://github.com/zama-ai/concrete-ml-internal/issues/4029

@RomanBredehoft RomanBredehoft requested a review from a team as a code owner April 3, 2024 16:31
@cla-bot cla-bot bot added the cla-signed label Apr 3, 2024
@@ -1605,31 +1605,11 @@ def test_predict_correctness(


@pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS)
@pytest.mark.parametrize(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no point of simulating here (we are testing encrypt + run + decrypt)

# make sure we only compile below that bit-width.
# Additionally, prevent computations in FHE with too many bits
@pytest.mark.parametrize(
"n_bits",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there no real point of testing multiple n_bits values, as we are only testing the api + comparing fhe vs simulation

@RomanBredehoft RomanBredehoft force-pushed the fix/sgd_classifier_post_processing_4252 branch from 5ffaf9a to 8baa5e9 Compare April 4, 2024 08:39
@RomanBredehoft RomanBredehoft changed the title fix: SGDClassifier post-processing with multi-class fix: SGDClassifier post-processing with multi-class and improve linear models' predict method Apr 4, 2024
@@ -1735,38 +1735,6 @@ class SklearnLinearRegressorMixin(SklearnLinearModelMixin, sklearn.base.Regresso
"""


class SklearnSGDRegressorMixin(SklearnLinearRegressorMixin):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just moved it for better readability

@@ -1815,6 +1783,48 @@ def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) ->
y_proba = self.post_processing(y_logits)
return y_proba

# In scikit-learn, the argmax is done on the scores directly, not the probabilities
def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fixes https://github.com/zama-ai/concrete-ml-internal/issues/4344 (reason is explained in main comment, but basically this is how sklearn does)

@@ -253,27 +252,12 @@ def __init__(
"Setting 'parameter_range' is mandatory if FHE training is enabled "
f"({fit_encrypted=}). Got {parameters_range=}"
)

def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this post_processing was wrong + I moved it below

@@ -835,61 +819,24 @@ def partial_fit(
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4184
raise NotImplementedError("Partial fit is not currently supported for clear training.")

# This method is taken directly from scikit-learn
def _predict_proba_lr(self, X: Data, fhe: Union[FheMode, str]) -> numpy.ndarray:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part should be included in our post_processing method


def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
"""Probability estimates.
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically we don't need to define the predict_proba method as it is in sklearn, we only need to re-define post_processing with sklearn's implem

@@ -133,27 +133,23 @@ def test_fit_error_if_non_binary_targets(n_bits, max_iter, parameter_min_max):
model.partial_fit(x, y, fhe="disable")


@pytest.mark.parametrize("loss", ["log_loss", "modified_huber"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we are testing error raises for specific arguments, we don't need all these inputs

@@ -651,7 +648,11 @@ def check_separated_inference(model, fhe_circuit, x, check_float_array_equal):
is_classifier_or_partial_classifier(model)
and get_model_name(model) != "KNeighborsClassifier"
):
y_pred = numpy.argmax(y_pred, axis=-1)
# For linear classifiers, the argmax is done on the scores directly, not the probabilities
if is_model_class_in_a_list(model, _get_sklearn_linear_models()):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as mentioned in the main comment, let's do it like sklearn

jfrery
jfrery previously approved these changes Apr 4, 2024
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks a lot for the fixes + cleaning. Great explanations as well.

@RomanBredehoft RomanBredehoft force-pushed the fix/sgd_classifier_post_processing_4252 branch from 9e8b111 to a7b606c Compare April 9, 2024 09:57
@@ -420,7 +421,7 @@ def check_accuracy():
"""Fixture to check the accuracy."""

def check_accuracy_impl(expected, actual, threshold=0.9):
accuracy = numpy.mean(expected == actual)
accuracy = accuracy_score(expected, actual)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better imo

y_preds = self.output_quantizers[0].dequant(q_y_preds)

# If the preds have shape (n, 1), squeeze it to shape (n,) like in scikit-learn
if y_preds.ndim == 2 and y_preds.shape[1] == 1:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like in sklearn (same for the followings one)

# "Method 'decision_function' outputs different shapes between scikit-learn and "
# f"Concrete ML in FHE (fhe={fhe})"
# )
assert y_scores_sklearn.shape == y_scores_fhe.shape, (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put back the assert here and below

@@ -1912,7 +1892,7 @@ def test_rounding_consistency_for_regular_models(
else:
# Check `predict` for regressors
predict_method = model.predict
metric = check_accuracy
metric = check_r2_score
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regressors should be tested with r2 score

@RomanBredehoft RomanBredehoft force-pushed the fix/sgd_classifier_post_processing_4252 branch from 93aaec4 to a7dc084 Compare April 9, 2024 12:23
Copy link

github-actions bot commented Apr 9, 2024

Coverage passed ✅

Coverage details

---------- coverage: platform linux, python 3.8.18-final-0 -----------
Name    Stmts   Miss  Cover   Missing
-------------------------------------
TOTAL    7548      0   100%

59 files skipped due to complete coverage.

@RomanBredehoft
Copy link
Collaborator Author

(for info, the last commits are about fixing https://github.com/zama-ai/concrete-ml-internal/issues/4029)

@RomanBredehoft RomanBredehoft requested a review from jfrery April 9, 2024 14:53
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@RomanBredehoft RomanBredehoft merged commit b097022 into main Apr 10, 2024
11 checks passed
@RomanBredehoft RomanBredehoft deleted the fix/sgd_classifier_post_processing_4252 branch April 10, 2024 12:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants