diff --git a/tests/sklearn/test_fhe_training.py b/tests/sklearn/test_fhe_training.py index b8bd98baa..7b1a89baf 100644 --- a/tests/sklearn/test_fhe_training.py +++ b/tests/sklearn/test_fhe_training.py @@ -134,7 +134,8 @@ def test_fit_error_if_non_binary_targets(n_bits, max_iter, parameter_min_max): @pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 30, 1.0)]) -def test_fit_single_target_class(n_bits, max_iter, parameter_min_max): +@pytest.mark.parametrize("use_partial", [True, False]) +def test_fit_single_target_class(n_bits, max_iter, parameter_min_max, use_partial): """Test that training in FHE on a data-set with a single target class works properly.""" # Model parameters @@ -159,9 +160,14 @@ def test_fit_single_target_class(n_bits, max_iter, parameter_min_max): ) with pytest.warns(UserWarning, match="ONNX Preprocess - Removing mutation from node .*"): - model.fit(x, y, fhe="disable") + if use_partial: + model.partial_fit(x, y, fhe="disable") + else: + model.fit(x, y, fhe="disable") - model.partial_fit(x, y, fhe="disable") + model.compile(x) + + model.predict(x) def test_clear_fit_error_raises():