Skip to content

Commit

Permalink
chore: improve test for single class in FHE training
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed May 7, 2024
1 parent d140407 commit b839c6d
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def test_fit_error_if_non_binary_targets(n_bits, max_iter, parameter_min_max):


@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 30, 1.0)])
def test_fit_single_target_class(n_bits, max_iter, parameter_min_max):
@pytest.mark.parametrize("use_partial", [True, False])
def test_fit_single_target_class(n_bits, max_iter, parameter_min_max, use_partial):
"""Test that training in FHE on a data-set with a single target class works properly."""

# Model parameters
Expand All @@ -159,9 +160,14 @@ def test_fit_single_target_class(n_bits, max_iter, parameter_min_max):
)

with pytest.warns(UserWarning, match="ONNX Preprocess - Removing mutation from node .*"):
model.fit(x, y, fhe="disable")
if use_partial:
model.partial_fit(x, y, fhe="disable")
else:
model.fit(x, y, fhe="disable")

model.partial_fit(x, y, fhe="disable")
model.compile(x)

model.predict(x)


def test_clear_fit_error_raises():
Expand Down

0 comments on commit b839c6d

Please sign in to comment.