From 52f27a54109d22629c9f4834c056275f5902673e Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 18:35:52 -0700 Subject: [PATCH] [CI, testing] add SVM C and nu parameter check prevents seg fault sklearn < 1.2 (#1930) (#1947) * Update svc.py * Update svr.py * Update nusvc.py * Update nusvc.py * Update nusvr.py * Update svc.py * match sklearn error messages * add comments * formatting * Update sklearnex/svm/svc.py Co-authored-by: ethanglaser <42726565+ethanglaser@users.noreply.github.com> * Update sklearnex/svm/nusvr.py Co-authored-by: ethanglaser <42726565+ethanglaser@users.noreply.github.com> --------- Co-authored-by: ethanglaser <42726565+ethanglaser@users.noreply.github.com> (cherry picked from commit 04dbc5e1a412ac76f3578f76508df87e46487910) Co-authored-by: Ian Faust --- sklearnex/svm/nusvc.py | 11 +++++++++++ sklearnex/svm/nusvr.py | 11 +++++++++++ sklearnex/svm/svc.py | 11 +++++++++++ sklearnex/svm/svr.py | 11 +++++++++++ 4 files changed, 44 insertions(+) diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index d892dd598e..421546a203 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -83,6 +83,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.nu <= 0 or self.nu > 1: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("nu <= 0 or nu > 1") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch( diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py index 5c2c1a1dee..2945175398 100644 --- a/sklearnex/svm/nusvr.py +++ b/sklearnex/svm/nusvr.py @@ -65,6 +65,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.nu <= 0 or self.nu > 1: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("nu <= 0 or nu > 1") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch( diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index b0e44a5bb1..337f44ba4b 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -85,6 +85,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.C <= 0: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("C <= 0") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch( diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py index ed6c5baa23..1b16a5aa7e 100644 --- a/sklearnex/svm/svr.py +++ b/sklearnex/svm/svr.py @@ -65,6 +65,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.C <= 0: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("C <= 0") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch(