From 66f6c41a44e8707215c593d45da3344305072f4a Mon Sep 17 00:00:00 2001 From: kcelia Date: Fri, 17 Nov 2023 15:02:46 +0100 Subject: [PATCH] chore: restore test_sklearn --- src/concrete/ml/onnx/onnx_impl_utils.py | 3 ++- tests/sklearn/test_sklearn_models.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/concrete/ml/onnx/onnx_impl_utils.py b/src/concrete/ml/onnx/onnx_impl_utils.py index 748700cf6..0f9653704 100644 --- a/src/concrete/ml/onnx/onnx_impl_utils.py +++ b/src/concrete/ml/onnx/onnx_impl_utils.py @@ -244,7 +244,8 @@ def rounded_comparison( """ assert isinstance(_auto_rounder, AutoRounder) - half = 1 << (_auto_rounder.lsbs_to_remove - 1) if _auto_rounder.is_adjusted else 0 + + half = 1 << (_auto_rounder.lsbs_to_remove - 1) if _auto_rounder.lsbs_to_remove > 0 else 0 rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=_auto_rounder) return (operation(rounded_subtraction),) diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index a50376e04..33cfb7b34 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -107,11 +107,11 @@ def get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option): """Prepare the the (x, y) data-set.""" - # if not is_model_class_in_a_list( - # model_class, _get_sklearn_linear_models() + _get_sklearn_neighbors_models() - # ): - # if n_bits in N_BITS_WEEKLY_ONLY_BUILDS and not is_weekly_option: - # pytest.skip("Skipping some tests in non-weekly builds") + if not is_model_class_in_a_list( + model_class, _get_sklearn_linear_models() + _get_sklearn_neighbors_models() + ): + if n_bits in N_BITS_WEEKLY_ONLY_BUILDS and not is_weekly_option: + pytest.skip("Skipping some tests in non-weekly builds") # Get the data-set. The data generation is seeded in load_data. x, y = load_data(model_class, **parameters)