From dcba3afda5161494b477d415edc78d3644c1719e Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Wed, 25 Oct 2023 07:28:28 -0700 Subject: [PATCH] fix: use unittest.skipIf --- tests/test_model_builders.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/test_model_builders.py b/tests/test_model_builders.py index 7dca304887..905299f58a 100644 --- a/tests/test_model_builders.py +++ b/tests/test_model_builders.py @@ -19,7 +19,6 @@ import catboost as cb import lightgbm as lgbm import numpy as np -import pytest import shap import xgboost as xgb from sklearn.datasets import ( @@ -111,7 +110,7 @@ def test_breast_cancer_without_intercept(self): self.assertTrue(np.allclose(pred_daal, pred_sklearn)) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostRegressionModelBuilder(unittest.TestCase): @classmethod def setUpClass(cls, base_score=0.5): @@ -189,7 +188,7 @@ def test_model_predict_shap_contribs_missing_values(self): # duplicate all tests for bae_score=0.0 -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostRegressionModelBuilder_base_score0(XGBoostRegressionModelBuilder): @classmethod def setUpClass(cls): @@ -197,14 +196,14 @@ def setUpClass(cls): # duplicate all tests for bae_score=100 -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostRegressionModelBuilder_base_score100(XGBoostRegressionModelBuilder): @classmethod def setUpClass(cls): XGBoostRegressionModelBuilder.setUpClass(100) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostClassificationModelBuilder(unittest.TestCase): @classmethod def setUpClass(cls, base_score=0.5, n_classes=2, objective="binary:logistic"): @@ -272,7 +271,7 @@ def test_model_predict_shap_interactions(self): # duplicate all tests for bae_score=0.3 -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostClassificationModelBuilder_base_score03(XGBoostClassificationModelBuilder): @classmethod def setUpClass(cls): @@ -280,21 +279,21 @@ def setUpClass(cls): # duplicate all tests for bae_score=0.7 -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostClassificationModelBuilder_base_score07(XGBoostClassificationModelBuilder): @classmethod def setUpClass(cls): XGBoostClassificationModelBuilder.setUpClass(base_score=0.7) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostClassificationModelBuilder_n_classes5(XGBoostClassificationModelBuilder): @classmethod def setUpClass(cls): XGBoostClassificationModelBuilder.setUpClass(n_classes=5) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostClassificationModelBuilder_n_classes5_base_score03( XGBoostClassificationModelBuilder ): @@ -303,7 +302,7 @@ def setUpClass(cls): XGBoostClassificationModelBuilder.setUpClass(n_classes=5, base_score=0.3) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostClassificationModelBuilder_objective_logitraw( XGBoostClassificationModelBuilder ): @@ -332,7 +331,7 @@ def test_model_predict_proba(self): np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=1e-5) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class LightGBMRegressionModelBuilder(unittest.TestCase): @classmethod def setUpClass(cls): @@ -410,7 +409,7 @@ def test_model_predict_shap_contribs_missing_values(self): np.testing.assert_allclose(d4p_pred, lgbm_pred, rtol=1e-6) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class LightGBMClassificationModelBuilder(unittest.TestCase): @classmethod def setUpClass(cls): @@ -472,7 +471,7 @@ def test_model_predict_shap_contribs_missing_values(self): m.predict(self.X_nan, pred_contribs=True) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class LightGBMClassificationModelBuilder_binaryClassification(unittest.TestCase): @classmethod def setUpClass(cls): @@ -542,7 +541,7 @@ def test_model_predict_shap_contribs_missing_values(self): m.predict(self.X_nan, pred_contribs=True) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class CatBoostRegressionModelBuilder(unittest.TestCase): @classmethod def setUpClass(cls): @@ -591,7 +590,7 @@ def test_model_predict_shap_contribs(self): d4p.mb.convert_model(self.cb_model) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class CatBoostClassificationModelBuilder(unittest.TestCase): @classmethod def setUpClass(cls): @@ -642,7 +641,7 @@ def test_model_predict_shap_contribs(self): d4p.mb.convert_model(self.cb_model) -@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str) +@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str) class XGBoostEarlyStopping(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -759,4 +758,4 @@ def get_dump(self, *_, **kwargs): if __name__ == "__main__": - pytest.main([__file__]) + unittest.main()