Skip to content

Commit

Permalink
chore: clean and speed up fhe training tests (#724)
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft authored Jun 17, 2024
1 parent e44df6e commit b379903
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,6 @@ def deserialize_decrypt_dequantize(

# In training mode, note that this step does not make much sense for now. Still, nothing
# breaks since QuantizedModule don't do anything in post-processing
result = self.model.post_processing(*result)
result_post_processed = self.model.post_processing(*result)

return result
return result_post_processed
2 changes: 0 additions & 2 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,8 +887,6 @@ def compile(
global_p_error=global_p_error,
verbose=verbose,
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
compress_input_ciphertexts=enable_input_compression,
compress_evaluation_keys=enable_key_compression,
)
Expand Down
2 changes: 0 additions & 2 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,6 @@ def compile(
global_p_error=global_p_error,
verbose=verbose,
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
compress_input_ciphertexts=enable_input_compression,
compress_evaluation_keys=enable_key_compression,
)
Expand Down
8 changes: 7 additions & 1 deletion src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __init__(
self.learning_rate_value = 1.0
self.batch_size = 8
self.training_p_error = 0.01
self.training_fhe_configuration = None

self.fit_encrypted = fit_encrypted
self.parameters_range = parameters_range
Expand Down Expand Up @@ -344,10 +345,15 @@ def _get_training_quantized_module(
fit_bias=self.fit_intercept,
)

if self.training_fhe_configuration is None:
configuration = Configuration()
else:
configuration = self.training_fhe_configuration

# Enable the underlying FHE circuit to be composed with itself
# This feature is used in order to be able to iterate in the clear n times without having
# to encrypt/decrypt the weight/bias values between each loop
configuration = Configuration(composable=True, compress_evaluation_keys=True)
configuration.composable = True

composition_mapping = {0: 2, 1: 3}

Expand Down
2 changes: 1 addition & 1 deletion tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def dev_send_clientspecs_and_modelspecs_to_client(self):
@pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS)
@pytest.mark.parametrize("n_bits", [2])
def test_client_server_sklearn_inference(
default_configuration,
model_class,
parameters,
n_bits,
load_data,
default_configuration,
check_is_good_execution_for_cml_vs_circuit,
check_array_equal,
check_float_array_equal,
Expand Down
22 changes: 20 additions & 2 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def check_encrypted_fit(
parameters_range,
max_iter,
fit_intercept,
configuration,
check_accuracy=None,
fhe=None,
partial_fit=False,
Expand Down Expand Up @@ -356,6 +357,8 @@ def check_encrypted_fit(
# We need to lower the p-error to make sure that the test passes
model.training_p_error = 1e-15

model.training_fhe_configuration = configuration

if partial_fit:
# Check that we can swap between disable and simulation modes without any impact on the
# final training performance
Expand Down Expand Up @@ -418,7 +421,13 @@ def check_encrypted_fit(
@pytest.mark.parametrize("label_offset", [0, 1])
@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 30, 1.0)])
def test_encrypted_fit_coherence(
fit_intercept, label_offset, n_bits, max_iter, parameter_min_max, check_accuracy
fit_intercept,
label_offset,
n_bits,
max_iter,
parameter_min_max,
check_accuracy,
simulation_configuration,
):
"""Test that encrypted fitting works properly."""

Expand All @@ -439,6 +448,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="disable",
)
Expand All @@ -453,6 +463,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="simulate",
)
Expand All @@ -474,6 +485,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
partial_fit=True,
)
Expand All @@ -496,6 +508,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
warm_fit=True,
init_kwargs=warm_fit_init_kwargs,
Expand All @@ -519,6 +532,7 @@ def test_encrypted_fit_coherence(
parameters_range,
first_iterations,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="simulate",
)
Expand All @@ -542,6 +556,7 @@ def test_encrypted_fit_coherence(
parameters_range,
last_iterations,
fit_intercept,
simulation_configuration,
check_accuracy=check_accuracy,
fhe="simulate",
random_number_generator=rng_coef_init,
Expand Down Expand Up @@ -569,14 +584,15 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
simulation_configuration,
check_accuracy=None,
fhe="simulate",
init_kwargs=early_break_kwargs,
)


@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 2, 1.0)])
def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max):
def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max, default_configuration):
"""Test that encrypted fitting works properly when executed in FHE."""

# Model parameters
Expand All @@ -600,6 +616,7 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max):
parameters_range,
max_iter,
fit_intercept,
default_configuration,
fhe="disable",
)
)
Expand All @@ -613,6 +630,7 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max):
parameters_range,
max_iter,
fit_intercept,
default_configuration,
fhe="execute",
)

Expand Down

0 comments on commit b379903

Please sign in to comment.