Skip to content

Commit

Permalink
refactor: simplify early stopping test case
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Oct 10, 2023
1 parent f5cebb6 commit fc731d2
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tests/test_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,20 +495,22 @@ def setUpClass(cls) -> None:
"learning_rate": 0.3,
"num_class": num_classes,
"early_stopping_rounds": 5,
"verbose_eval": False,
}

cls.xgb_clf = xgb.XGBClassifier(**params)
cls.xgb_clf.fit(X_train, y_train, eval_set=[(cls.X_test, cls.y_test)])
cls.xgb_clf.fit(
X_train, y_train, eval_set=[(cls.X_test, cls.y_test)], verbose=False
)
cls.daal_model = d4p.mb.convert_model(cls.xgb_clf.get_booster())

def test_early_stopping(self):
xgb_prediction = self.xgb_clf.predict(self.X_test)
xgb_proba = self.xgb_clf.predict_proba(self.X_test)
xgb_errors_count = np.count_nonzero(xgb_prediction - np.ravel(self.y_test))

booster = self.xgb_clf.get_booster()
daal_model = d4p.mb.convert_model(booster)
daal_prediction = daal_model.predict(self.X_test)
daal_proba = daal_model.predict_proba(self.X_test)
daal_prediction = self.daal_model.predict(self.X_test)
daal_proba = self.daal_model.predict_proba(self.X_test)
daal_errors_count = np.count_nonzero(daal_prediction - np.ravel(self.y_test))

self.assertTrue(np.absolute(xgb_errors_count - daal_errors_count) == 0)
Expand Down

0 comments on commit fc731d2

Please sign in to comment.