Skip to content

Commit

Permalink
fix: use unittest.skipIf
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Oct 26, 2023
1 parent 41cda26 commit dcba3af
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions tests/test_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import catboost as cb
import lightgbm as lgbm
import numpy as np
import pytest
import shap
import xgboost as xgb
from sklearn.datasets import (
Expand Down Expand Up @@ -111,7 +110,7 @@ def test_breast_cancer_without_intercept(self):
self.assertTrue(np.allclose(pred_daal, pred_sklearn))


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostRegressionModelBuilder(unittest.TestCase):
@classmethod
def setUpClass(cls, base_score=0.5):
Expand Down Expand Up @@ -189,22 +188,22 @@ def test_model_predict_shap_contribs_missing_values(self):


# duplicate all tests for bae_score=0.0
@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostRegressionModelBuilder_base_score0(XGBoostRegressionModelBuilder):
@classmethod
def setUpClass(cls):
XGBoostRegressionModelBuilder.setUpClass(0)


# duplicate all tests for bae_score=100
@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostRegressionModelBuilder_base_score100(XGBoostRegressionModelBuilder):
@classmethod
def setUpClass(cls):
XGBoostRegressionModelBuilder.setUpClass(100)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder(unittest.TestCase):
@classmethod
def setUpClass(cls, base_score=0.5, n_classes=2, objective="binary:logistic"):
Expand Down Expand Up @@ -272,29 +271,29 @@ def test_model_predict_shap_interactions(self):


# duplicate all tests for bae_score=0.3
@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder_base_score03(XGBoostClassificationModelBuilder):
@classmethod
def setUpClass(cls):
XGBoostClassificationModelBuilder.setUpClass(base_score=0.3)


# duplicate all tests for bae_score=0.7
@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder_base_score07(XGBoostClassificationModelBuilder):
@classmethod
def setUpClass(cls):
XGBoostClassificationModelBuilder.setUpClass(base_score=0.7)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder_n_classes5(XGBoostClassificationModelBuilder):
@classmethod
def setUpClass(cls):
XGBoostClassificationModelBuilder.setUpClass(n_classes=5)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder_n_classes5_base_score03(
XGBoostClassificationModelBuilder
):
Expand All @@ -303,7 +302,7 @@ def setUpClass(cls):
XGBoostClassificationModelBuilder.setUpClass(n_classes=5, base_score=0.3)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostClassificationModelBuilder_objective_logitraw(
XGBoostClassificationModelBuilder
):
Expand Down Expand Up @@ -332,7 +331,7 @@ def test_model_predict_proba(self):
np.testing.assert_allclose(d4p_pred, xgboost_pred, rtol=1e-5)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class LightGBMRegressionModelBuilder(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -410,7 +409,7 @@ def test_model_predict_shap_contribs_missing_values(self):
np.testing.assert_allclose(d4p_pred, lgbm_pred, rtol=1e-6)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class LightGBMClassificationModelBuilder(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -472,7 +471,7 @@ def test_model_predict_shap_contribs_missing_values(self):
m.predict(self.X_nan, pred_contribs=True)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class LightGBMClassificationModelBuilder_binaryClassification(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -542,7 +541,7 @@ def test_model_predict_shap_contribs_missing_values(self):
m.predict(self.X_nan, pred_contribs=True)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class CatBoostRegressionModelBuilder(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -591,7 +590,7 @@ def test_model_predict_shap_contribs(self):
d4p.mb.convert_model(self.cb_model)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class CatBoostClassificationModelBuilder(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -642,7 +641,7 @@ def test_model_predict_shap_contribs(self):
d4p.mb.convert_model(self.cb_model)


@pytest.mark.skipif(shap_not_supported, reason=shap_not_supported_str)
@unittest.skipIf(shap_not_supported, reason=shap_not_supported_str)
class XGBoostEarlyStopping(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -759,4 +758,4 @@ def get_dump(self, *_, **kwargs):


if __name__ == "__main__":
pytest.main([__file__])
unittest.main()

0 comments on commit dcba3af

Please sign in to comment.