diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index d892dd598e..421546a203 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -83,6 +83,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.nu <= 0 or self.nu > 1: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("nu <= 0 or nu > 1") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch( diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py index 5c2c1a1dee..2945175398 100644 --- a/sklearnex/svm/nusvr.py +++ b/sklearnex/svm/nusvr.py @@ -65,6 +65,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.nu <= 0 or self.nu > 1: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("nu <= 0 or nu > 1") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch( diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index b0e44a5bb1..337f44ba4b 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -85,6 +85,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.C <= 0: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("C <= 0") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch( diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py index ed6c5baa23..1b16a5aa7e 100644 --- a/sklearnex/svm/svr.py +++ b/sklearnex/svm/svr.py @@ -65,6 +65,17 @@ def __init__( def fit(self, X, y, sample_weight=None): if sklearn_check_version("1.2"): self._validate_params() + elif self.C <= 0: + # else if added to correct issues with + # sklearn tests: + # svm/tests/test_sparse.py::test_error + # svm/tests/test_svm.py::test_bad_input + # for sklearn versions < 1.2 (i.e. without + # validate_params parameter checking) + # Without this, a segmentation fault with + # Windows fatal exception: access violation + # occurs + raise ValueError("C <= 0") if sklearn_check_version("1.0"): self._check_feature_names(X, reset=True) dispatch(