Skip to content

Commit

Permalink
chore: improve binary classification check in encrypted training
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed May 6, 2024
1 parent 8991ce2 commit 4f5be4f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def _fit_encrypted(

assert isinstance(self.classes_, numpy.ndarray)

if len(self.classes_) != 2:
# Allow the training set to only provide a single class. This can happen, for example,
# when running 'partial_fit' on a small batch of values. Even with a single class, the
# model remains binary
if len(self.classes_) not in [1, 2]:
raise NotImplementedError(
f"Only binary classification is currently supported when FHE training is "
f"enabled. Got {len(self.classes_)} labels: {self.classes_}."
Expand Down
43 changes: 36 additions & 7 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,15 @@
from concrete.ml.sklearn import SGDClassifier


def get_blob_data(binary_targets=True, scale_input=False, parameters_range=None):
def get_blob_data(n_classes=2, scale_input=False, parameters_range=None):
"""Get the training data."""

n_samples = 1000
n_features = 8

# Determine the number of target classes to generate
centers = 2 if binary_targets else 3

# Generate the input and target values
# pylint: disable-next=unbalanced-tuple-unpacking
x, y = make_blobs(n_samples=n_samples, centers=centers, n_features=n_features)
x, y = make_blobs(n_samples=n_samples, centers=n_classes, n_features=n_features)

# Scale the input values if needed
if scale_input:
Expand Down Expand Up @@ -107,7 +104,7 @@ def test_fit_error_if_non_binary_targets(n_bits, max_iter, parameter_min_max):
parameters_range = (-parameter_min_max, parameter_min_max)

# Generate a data-set with three target classes
x, y = get_blob_data(binary_targets=False)
x, y = get_blob_data(n_classes=3)

with warnings.catch_warnings():

Expand Down Expand Up @@ -136,6 +133,37 @@ def test_fit_error_if_non_binary_targets(n_bits, max_iter, parameter_min_max):
model.partial_fit(x, y, fhe="disable")


@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 30, 1.0)])
def test_fit_single_target_class(n_bits, max_iter, parameter_min_max):
"""Test that training in FHE on a data-set with a single target class works properly."""

# Model parameters
random_state = numpy.random.randint(0, 2**15)
parameters_range = (-parameter_min_max, parameter_min_max)

# Generate a data-set with a single target class
x, y = get_blob_data(n_classes=1)

with warnings.catch_warnings():

# FHE training is an experimental feature and a warning is raised each time `fit_encrypted`
# is set to True
warnings.filterwarnings("ignore", message="FHE training is an experimental feature.*")

model = SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
random_state=random_state,
parameters_range=parameters_range,
max_iter=max_iter,
)

with pytest.warns(UserWarning, match="ONNX Preprocess - Removing mutation from node .*"):
model.fit(x, y, fhe="disable")

model.partial_fit(x, y, fhe="disable")


def test_clear_fit_error_raises():
"""Test that training in clear using wrong parameters raises proper errors."""

Expand Down Expand Up @@ -285,9 +313,10 @@ def test_clear_fit(
# Model parameters
random_state = numpy.random.randint(0, 2**15)
parameters_range = (-parameter_min_max, parameter_min_max)
n_classes = 2 if binary else 3

# Generate a data-set
x, y = get_blob_data(binary_targets=binary, scale_input=True, parameters_range=parameters_range)
x, y = get_blob_data(n_classes=n_classes, scale_input=True, parameters_range=parameters_range)

random_state = numpy.random.randint(0, 2**15)

Expand Down

0 comments on commit 4f5be4f

Please sign in to comment.