Skip to content

Commit

Permalink
chore: fix pcc
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Apr 9, 2024
1 parent 833eb09 commit a7b606c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
3 changes: 2 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from concrete.fhe.compilation import Circuit, Configuration
from concrete.fhe.mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from sklearn.datasets import make_classification, make_regression
from sklearn.metrics import accuracy_score

from concrete.ml.common.utils import (
SUPPORTED_FLOAT_TYPES,
Expand Down Expand Up @@ -420,7 +421,7 @@ def check_accuracy():
"""Fixture to check the accuracy."""

def check_accuracy_impl(expected, actual, threshold=0.9):
accuracy = numpy.mean(expected == actual)
accuracy = accuracy_score(expected, actual)
assert accuracy >= threshold, f"Accuracy of {accuracy} is not high enough ({threshold})."

return check_accuracy_impl
Expand Down
10 changes: 7 additions & 3 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,13 +1783,17 @@ def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) ->
y_proba = self.post_processing(y_logits)
return y_proba

# In scikit-learn, the argmax is done on the scores directly, not the probabilities
# In scikit-learn, the argmax is done on the logits directly, not the probabilities
def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
# Compute the predicted scores
y_proba = self.decision_function(X, fhe=fhe)
y_logits = self.decision_function(X, fhe=fhe)

# Retrieve the class with the highest score
y_preds = numpy.argmax(y_proba, axis=1)
# If there is a single dimension, only compare the scores to 0
if y_logits.ndim == 1 or y_logits.shape[1] == 1:
y_preds = (y_logits > 0).astype(int)
else:
y_preds = numpy.argmax(y_logits, axis=1)

return self.classes_[y_preds]

Expand Down

0 comments on commit a7b606c

Please sign in to comment.