-
Notifications
You must be signed in to change notification settings - Fork 173
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SHAP calculation to GBT regression (#1399)
* rename gbt_convertors.pyx -> *.py * use dataclasses for Node * isort/black * refactor get_gbt_model_from_xgboost() with improved Node classes * refactor: put new NodeList and related classes in module namespace * add cover to gbt regression nodes * simplify xgboost tree parser * Refactor gbt model parser for speed and add tests * feat: provide pred_contribs/pred_interactions kwargs in GBT _predict_regression * re-enable mb tests * Return pred_interactions in correct shape * clean up inference APIs and versioning * Fix SHAP interaction output shape * align tree clf/reg APIs * update copyright * fix: remove loading xgb only for a type hint * Update LightGBM Model Builder for TreeView * chore: rename model builders test file and remove ancient version check * Start cleaning up model builder tests, fix some failing tests * Add exhaustive model builder testing * chore: merge test_xgboost_mb.py and test_model_builders.py * fix: support XGBoost models trained with early stopping * refactor: simplify early stopping test case * fix: add SHAP to requirements-test * chore: update oneDAL version for _gbt_inference_api_versision 2 * Add GBT model builder API version descriptions * Fix typo in pred_interactions test * fix: remove local backup file * fix: remove local backup file * Start work on fixing LightGBM model builder test cases * Properly use XGBoost's base_score parameter * fix: parse enums declared with bit shifting * refactor: SHAP prediction replace boolean parameters with DAAL_UINT64 flag * chore: fix typos and add another classification test * feat: add more tests for LightGBM models * fix LightGBM model conversion * feat: provide XGBoost SHAP example * clean imports * Include SHAP description * typos * chore: move model builder examples to dedicated directory * rename model_builders -> mb * Apply suggestions from code review Co-authored-by: Alexandra <alexandra.epanchinzeva@intel.com> * add reg/clf leaf node wrappers for backwards compatibility * fix: model retrieve API * chore: remove requirements-test-optional.txt * Update CODEOWNERS after removing requirements-test-optional.txt * fix: add new mb path to test_examples sys.path * feat: add xgboost_shap example to testing for 2024.0.1 * fix: add shap to test requirements * Skip SHAP checks for older versions * fixup: skip shap tests if *not* daal_check_version(...) * Let main() accept args and kwargs * fix: only request resultsToCompute with compatible versions * fixup: better error reporting * use pytest for main() * fix: use unittest.skipIf * fix: typo 2023 -> 2024 * Drop 3.12 requirement Co-authored-by: Nikolay Petrov <nikolay.a.petrov@intel.com> * cleanup after rebase * Skip SHAP install & tests on 3.12 * Install catboost on all python versions * Skip catboost install & tests on 3.12 * chore: add fixmes for catboost and shap support on 3.12 --------- Co-authored-by: Alexandra <alexandra.epanchinzeva@intel.com> Co-authored-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
- Loading branch information
1 parent
1fe0df1
commit 6d95372
Showing
22 changed files
with
1,776 additions
and
815 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# ============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
# daal4py Gradient Boosting Classification model creation and SHAP value | ||
# prediction example | ||
|
||
import numpy as np | ||
import xgboost as xgb | ||
from sklearn.datasets import make_regression | ||
from sklearn.model_selection import train_test_split | ||
|
||
import daal4py as d4p | ||
|
||
|
||
def main(*ars, **kwargs): | ||
# create data | ||
X, y = make_regression(n_samples=10000, n_features=10, random_state=42) | ||
X_train, X_test, y_train, _ = train_test_split(X, y, random_state=42) | ||
|
||
# train the model | ||
xgb_model = xgb.XGBRegressor( | ||
max_depth=6, n_estimators=100, random_state=42, base_score=0.5 | ||
) | ||
xgb_model.fit(X_train, y_train) | ||
|
||
# Conversion to daal4py | ||
daal_model = d4p.mb.convert_model(xgb_model.get_booster()) | ||
|
||
# SHAP contributions | ||
daal_contribs = daal_model.predict(X_test, pred_contribs=True) | ||
|
||
# SHAP interactions | ||
daal_interactions = daal_model.predict(X_test, pred_interactions=True) | ||
|
||
# XGBoost reference values | ||
xgb_contribs = xgb_model.get_booster().predict( | ||
xgb.DMatrix(X_test), pred_contribs=True, validate_features=False | ||
) | ||
xgb_interactions = xgb_model.get_booster().predict( | ||
xgb.DMatrix(X_test), pred_interactions=True, validate_features=False | ||
) | ||
|
||
return ( | ||
daal_contribs, | ||
daal_interactions, | ||
xgb_contribs, | ||
xgb_interactions, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
daal_contribs, daal_interactions, xgb_contribs, xgb_interactions = main() | ||
print(f"XGBoost SHAP contributions shape: {xgb_contribs.shape}") | ||
print(f"daal4py SHAP contributions shape: {daal_contribs.shape}") | ||
|
||
print(f"XGBoost SHAP interactions shape: {xgb_interactions.shape}") | ||
print(f"daal4py SHAP interactions shape: {daal_interactions.shape}") | ||
|
||
contribution_rmse = np.sqrt( | ||
np.mean((daal_contribs.reshape(-1, 1) - xgb_contribs.reshape(-1, 1)) ** 2) | ||
) | ||
print(f"SHAP contributions RMSE: {contribution_rmse:.2e}") | ||
|
||
interaction_rmse = np.sqrt( | ||
np.mean((daal_interactions.reshape(-1, 1) - xgb_interactions.reshape(-1, 1)) ** 2) | ||
) | ||
print(f"SHAP interactions RMSE: {interaction_rmse:.2e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.