Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SHAP calculation to GBT regression #1399

Merged
merged 64 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
e991dae
rename gbt_convertors.pyx -> *.py
ahuber21 Jul 27, 2023
128aff4
use dataclasses for Node
ahuber21 Jul 27, 2023
db445ed
isort/black
ahuber21 Jul 27, 2023
5025f60
refactor get_gbt_model_from_xgboost() with improved Node classes
ahuber21 Jul 27, 2023
a5695ae
refactor: put new NodeList and related classes in module namespace
ahuber21 Jul 27, 2023
9b72553
add cover to gbt regression nodes
ahuber21 Jul 28, 2023
e6e1984
simplify xgboost tree parser
ahuber21 Aug 1, 2023
87dc4ca
Refactor gbt model parser for speed and add tests
ahuber21 Aug 9, 2023
0aaa508
feat: provide pred_contribs/pred_interactions kwargs in GBT _predict_…
ahuber21 Aug 11, 2023
caab075
re-enable mb tests
ahuber21 Aug 11, 2023
b457e11
Return pred_interactions in correct shape
ahuber21 Sep 12, 2023
b2edef4
clean up inference APIs and versioning
ahuber21 Sep 27, 2023
be077cb
Fix SHAP interaction output shape
ahuber21 Sep 27, 2023
bf28b08
align tree clf/reg APIs
ahuber21 Sep 27, 2023
fdec263
update copyright
ahuber21 Sep 28, 2023
26333a4
fix: remove loading xgb only for a type hint
ahuber21 Sep 29, 2023
c3f463a
Update LightGBM Model Builder for TreeView
ahuber21 Sep 29, 2023
b4206d9
chore: rename model builders test file and remove ancient version check
ahuber21 Oct 2, 2023
f86f72d
Start cleaning up model builder tests, fix some failing tests
ahuber21 Oct 2, 2023
4b4620f
Add exhaustive model builder testing
ahuber21 Oct 4, 2023
d5e0c0c
chore: merge test_xgboost_mb.py and test_model_builders.py
ahuber21 Oct 4, 2023
bdde096
fix: support XGBoost models trained with early stopping
ahuber21 Oct 4, 2023
89fde70
refactor: simplify early stopping test case
ahuber21 Oct 4, 2023
fb7b209
fix: add SHAP to requirements-test
ahuber21 Oct 10, 2023
7a2e892
chore: update oneDAL version for _gbt_inference_api_versision 2
ahuber21 Oct 10, 2023
697622e
Add GBT model builder API version descriptions
ahuber21 Oct 10, 2023
89b37b3
Fix typo in pred_interactions test
ahuber21 Oct 10, 2023
e4ab316
fix: remove local backup file
ahuber21 Oct 10, 2023
0859330
fix: remove local backup file
ahuber21 Oct 10, 2023
bdac1f5
Start work on fixing LightGBM model builder test cases
ahuber21 Oct 10, 2023
cb0d895
Properly use XGBoost's base_score parameter
ahuber21 Oct 12, 2023
4ef7712
fix: parse enums declared with bit shifting
ahuber21 Oct 19, 2023
bba8beb
refactor: SHAP prediction replace boolean parameters with DAAL_UINT64…
ahuber21 Oct 19, 2023
8298633
chore: fix typos and add another classification test
ahuber21 Oct 19, 2023
5765523
feat: add more tests for LightGBM models
ahuber21 Oct 19, 2023
30fbafa
fix LightGBM model conversion
ahuber21 Oct 23, 2023
65fbe16
feat: provide XGBoost SHAP example
ahuber21 Oct 23, 2023
8e89e0a
clean imports
ahuber21 Oct 23, 2023
4a14e86
Include SHAP description
ahuber21 Oct 23, 2023
8c3e0e7
typos
ahuber21 Oct 23, 2023
38559ed
chore: move model builder examples to dedicated directory
ahuber21 Oct 23, 2023
bd78635
rename model_builders -> mb
ahuber21 Oct 24, 2023
53afa0f
Apply suggestions from code review
ahuber21 Oct 24, 2023
11b492e
add reg/clf leaf node wrappers for backwards compatibility
ahuber21 Oct 24, 2023
43b4d2a
fix: model retrieve API
ahuber21 Oct 25, 2023
da24ef9
chore: remove requirements-test-optional.txt
ahuber21 Oct 25, 2023
f311e27
Update CODEOWNERS after removing requirements-test-optional.txt
ahuber21 Oct 25, 2023
523d160
fix: add new mb path to test_examples sys.path
ahuber21 Oct 25, 2023
3c58610
feat: add xgboost_shap example to testing for 2024.0.1
ahuber21 Oct 25, 2023
a7170af
fix: add shap to test requirements
ahuber21 Oct 25, 2023
6a01c63
Skip SHAP checks for older versions
ahuber21 Oct 25, 2023
f7031a1
fixup: skip shap tests if *not* daal_check_version(...)
ahuber21 Oct 25, 2023
e029fc8
Let main() accept args and kwargs
ahuber21 Oct 25, 2023
02aaf33
fix: only request resultsToCompute with compatible versions
ahuber21 Oct 25, 2023
0143067
fixup: better error reporting
ahuber21 Oct 25, 2023
41cda26
use pytest for main()
ahuber21 Oct 25, 2023
dcba3af
fix: use unittest.skipIf
ahuber21 Oct 25, 2023
c8b2a69
fix: typo 2023 -> 2024
ahuber21 Oct 25, 2023
8bc6d7c
Drop 3.12 requirement
ahuber21 Oct 26, 2023
cfe0607
cleanup after rebase
ahuber21 Oct 26, 2023
701a2ff
Skip SHAP install & tests on 3.12
ahuber21 Oct 26, 2023
2a94bed
Install catboost on all python versions
ahuber21 Oct 26, 2023
1142630
Skip catboost install & tests on 3.12
ahuber21 Oct 26, 2023
4f490ea
chore: add fixmes for catboost and shap support on 3.12
ahuber21 Oct 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/pipeline/build-and-test-lnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ steps:
. /usr/share/miniconda/etc/profile.d/conda.sh
conda activate CB
bash .ci/scripts/setup_sklearn.sh $(SKLEARN_VERSION)
pip install --upgrade -r requirements-test.txt -r requirements-test-optional.txt
pip install --upgrade -r requirements-test.txt
pip install $(python .ci/scripts/get_compatible_scipy_version.py)
if [ $(echo $(PYTHON_VERSION) | grep '3.8\|3.9\|3.10') ]; then conda install -q -y -c intel dpnp; fi
pip list
Expand Down
2 changes: 1 addition & 1 deletion .ci/pipeline/build-and-test-mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ steps:
- script: |
source activate CB
bash .ci/scripts/setup_sklearn.sh $(SKLEARN_VERSION)
pip install --upgrade -r requirements-test.txt -r requirements-test-optional.txt
pip install --upgrade -r requirements-test.txt
pip install $(python .ci/scripts/get_compatible_scipy_version.py)
pip list
displayName: 'Install testing requirements'
Expand Down
2 changes: 1 addition & 1 deletion .ci/pipeline/build-and-test-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ steps:
set PATH=C:\msys64\usr\bin;%PATH%
call activate CB
bash .ci/scripts/setup_sklearn.sh $(SKLEARN_VERSION)
pip install --upgrade -r requirements-test.txt -r requirements-test-optional.txt
pip install --upgrade -r requirements-test.txt
cd ..
for /f "delims=" %%c in ('python s\.ci\scripts\get_compatible_scipy_version.py') do set SCIPY_VERSION=%%c
pip install %SCIPY_VERSION%
Expand Down
2 changes: 1 addition & 1 deletion .ci/pipeline/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
conda activate CB
pip install -r dependencies-dev
pip install -r requirements-doc.txt
pip install -r requirements-test.txt -r requirements-test-optional.txt
pip install -r requirements-test.txt
pip install jupyter matplotlib requests
displayName: 'Install requirements'
- script: |
Expand Down
5 changes: 2 additions & 3 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ requirements-doc.txt @maria-Petrova @napetrov @aepanchi @Alexsandruss
onedal/ @Alexsandruss @samir-nasibli @KulikovNikita
sklearnex/ @Alexsandruss @samir-nasibli @KulikovNikita

# Examples
# Examples
examples/ @maria-Petrova @Alexsandruss @samir-nasibli @napetrov

# Dependencies
setup.py @napetrov @Alexsandruss @samir-nasibli
requirements* @napetrov @Alexsandruss @samir-nasibli @homksei @ahuber21 @ethanglaser
conda-recipe/ @napetrov @Alexsandruss
conda-recipe/ @napetrov @Alexsandruss

# Model builders
*model_builders* @razdoburdin @ahuber21 @avolkov-intel
requirements-test-optional.txt @razdoburdin @ahuber21 @avolkov-intel

# Forests
*ensemble* @ahuber21 @icfaust
Expand Down
54 changes: 49 additions & 5 deletions daal4py/mb/model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def _predict_classification(self, X, fptype, resultsToEvaluate):
else:
return predict_result.probabilities

def _predict_regression(self, X, fptype):
def _predict_regression(
self, X, fptype, pred_contribs=False, pred_interactions=False
):
if X.shape[1] != self.n_features_in_:
raise ValueError("Shape of input is different from what was seen in `fit`")

Expand All @@ -212,22 +214,64 @@ def _predict_regression(self, X, fptype):
).format(type(self).__name__)
)

# Prediction
try:
return self._predict_regression_with_results_to_compute(
X, fptype, pred_contribs, pred_interactions
)
except TypeError as e:
if "unexpected keyword argument 'resultsToCompute'" in str(e):
if pred_contribs or pred_interactions:
# SHAP values requested, but not supported by this version
raise TypeError(
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} not supported by this version of daalp4y"
) from e
else:
# unknown type error
raise

# fallback to calculation without `resultsToCompute`
predict_algo = d4p.gbt_regression_prediction(fptype=fptype)
predict_result = predict_algo.compute(X, self.daal_model_)

return predict_result.prediction.ravel()

def _predict_regression_with_results_to_compute(
self, X, fptype, pred_contribs=False, pred_interactions=False
):
"""Assume daal4py supports the resultsToCompute kwarg"""
resultsToCompute = ""
if pred_contribs:
resultsToCompute = "shapContributions"
elif pred_interactions:
resultsToCompute = "shapInteractions"

predict_algo = d4p.gbt_regression_prediction(
fptype=fptype, resultsToCompute=resultsToCompute
)
predict_result = predict_algo.compute(X, self.daal_model_)

if pred_contribs:
return predict_result.prediction.ravel().reshape((-1, X.shape[1] + 1))
elif pred_interactions:
return predict_result.prediction.ravel().reshape(
(-1, X.shape[1] + 1, X.shape[1] + 1)
)
else:
return predict_result.prediction.ravel()


class GBTDAALModel(GBTDAALBaseModel):
def __init__(self):
pass

def predict(self, X):
def predict(self, X, pred_contribs=False, pred_interactions=False):
fptype = getFPType(X)
if self._is_regression:
return self._predict_regression(X, fptype)
return self._predict_regression(X, fptype, pred_contribs, pred_interactions)
else:
if pred_contribs or pred_interactions:
raise NotImplementedError(
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} is not implemented for classification models"
)
return self._predict_classification(X, fptype, "computeClassLabels")

def predict_proba(self, X):
Expand Down
56 changes: 42 additions & 14 deletions doc/daal4py/model-builders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ Model Builders for the Gradient Boosting Frameworks

Introduction
------------------
Gradient boosting on decision trees is one of the most accurate and efficient
machine learning algorithms for classification and regression.
The most popular implementations of it are:
Gradient boosting on decision trees is one of the most accurate and efficient
machine learning algorithms for classification and regression.
The most popular implementations of it are:

* XGBoost*
* LightGBM*
* CatBoost*

daal4py Model Builders deliver the accelerated
models inference of those frameworks. The inference is performed by the oneDAL GBT implementation tuned
for the best performance on the Intel(R) Architecture.
models inference of those frameworks. The inference is performed by the oneDAL GBT implementation tuned
for the best performance on the Intel(R) Architecture.

Conversion
---------
Expand All @@ -61,22 +61,49 @@ CatBoost::
Classification and Regression Inference
----------------------------------------

The API is the same for classification and regression inference.
Based on the original model passed to the ``convert_model``, ``d4p_prediction`` is either the classification or regression output.
The API is the same for classification and regression inference.
Based on the original model passed to the ``convert_model()``, ``d4p_prediction`` is either the classification or regression output.

::

d4p_prediction = d4p_model.predict(test_data)

Here, the ``predict()`` method of ``d4p_model`` is being used to make predictions on the ``test_data`` dataset.
The ``d4p_prediction`` variable stores the predictions made by the ``predict()`` method.
The ``d4p_prediction`` variable stores the predictions made by the ``predict()`` method.

SHAP Value Calculation for Regression Models
------------------------------------------------------------

SHAP contribution and interaction value calculation are natively supported by models created with daal4py Model Builders.
For these models, the ``predict()`` method takes additional keyword arguments:

::

d4p_model.predict(test_data, pred_contribs=True) # for SHAP contributions
d4p_model.predict(test_data, pred_interactions=True) # for SHAP interactions

The returned prediction has the shape:

* ``(n_rows, n_features + 1)`` for SHAP contributions
* ``(n_rows, n_features + 1, n_features + 1)`` for SHAP interactions
Here, ``n_rows`` is the number of rows (i.e., observations) in
``test_data``, and ``n_features`` is the number of features in the dataset.

The prediction result for SHAP contributions includes a feature attribution value for each feature and a bias term for each observation.

The prediction result for SHAP interactions comprises ``(n_features + 1) x (n_features + 1)`` values for all possible
feature combinations, along with their corresponding bias terms.

.. note:: The shapes of SHAP contributions and interactions are consistent with the XGBoost results.
In contrast, the `SHAP Python package <https://shap.readthedocs.io/en/latest/>`_ drops bias terms, resulting
in SHAP contributions (SHAP interactions) with one fewer column (one fewer column and row) per observation.

Scikit-learn-style Estimators
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You can also use the scikit-learn-style classes ``GBTDAALClassifier`` and ``GBTDAALRegressor`` to convert and infer your models. For example:

::
::

from daal4py.sklearn.ensemble import GBTDAALRegressor
reg = xgb.XGBRegressor()
Expand All @@ -88,16 +115,17 @@ Limitations
------------------
Model Builders support only base inference with prediction and probabilities prediction. The functionality is to be extended.
Therefore, there are the following limitations:
- The categorical features are not supported for conversion and prediction.
- The categorical features are not supported for conversion and prediction.
- The multioutput models are not supported for conversion and prediction.
- The tree SHAP calculations are not supported.
- SHAP values can be calculated for regression models only.


Examples
---------------------------------
Model Builders models conversion

- `XGBoost model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/model_builders_xgboost.py>`_
- `SHAP value prediction from an XGBoost model <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/model_builders_xgboost_shap.py>`_
- `LightGBM model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/model_builders_lightgbm.py>`_
- `CatBoost model conversion <https://github.com/intel/scikit-learn-intelex/blob/master/examples/daal4py/model_builders_catboost.py>`_

Expand Down
80 changes: 80 additions & 0 deletions examples/mb/model_builders_xgboost_shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# ==============================================================================
ahuber21 marked this conversation as resolved.
Show resolved Hide resolved
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# daal4py Gradient Boosting Classification model creation and SHAP value
# prediction example

import numpy as np
import xgboost as xgb
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

import daal4py as d4p


def main(*ars, **kwargs):
# create data
X, y = make_regression(n_samples=10000, n_features=10, random_state=42)
X_train, X_test, y_train, _ = train_test_split(X, y, random_state=42)

# train the model
xgb_model = xgb.XGBRegressor(
max_depth=6, n_estimators=100, random_state=42, base_score=0.5
)
xgb_model.fit(X_train, y_train)

# Conversion to daal4py
daal_model = d4p.mb.convert_model(xgb_model.get_booster())

# SHAP contributions
daal_contribs = daal_model.predict(X_test, pred_contribs=True)

# SHAP interactions
daal_interactions = daal_model.predict(X_test, pred_interactions=True)

# XGBoost reference values
xgb_contribs = xgb_model.get_booster().predict(
xgb.DMatrix(X_test), pred_contribs=True, validate_features=False
)
xgb_interactions = xgb_model.get_booster().predict(
xgb.DMatrix(X_test), pred_interactions=True, validate_features=False
)

return (
daal_contribs,
daal_interactions,
xgb_contribs,
xgb_interactions,
)


if __name__ == "__main__":
daal_contribs, daal_interactions, xgb_contribs, xgb_interactions = main()
print(f"XGBoost SHAP contributions shape: {xgb_contribs.shape}")
print(f"daal4py SHAP contributions shape: {daal_contribs.shape}")

print(f"XGBoost SHAP interactions shape: {xgb_interactions.shape}")
print(f"daal4py SHAP interactions shape: {daal_interactions.shape}")

contribution_rmse = np.sqrt(
np.mean((daal_contribs.reshape(-1, 1) - xgb_contribs.reshape(-1, 1)) ** 2)
)
print(f"SHAP contributions RMSE: {contribution_rmse:.2e}")

interaction_rmse = np.sqrt(
np.mean((daal_interactions.reshape(-1, 1) - xgb_interactions.reshape(-1, 1)) ** 2)
)
print(f"SHAP interactions RMSE: {interaction_rmse:.2e}")
10 changes: 8 additions & 2 deletions generator/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,14 @@ def parse(self, elem, ctxt):
ctxt.enum = False
return True
regex = (
r"^\s*(\w+)(?:\s*=\s*((\(int\))?\w(\w|:|\s|\+)*))?"
+ r"(\s*,)?\s*((/\*|//).*)?$"
# capture group for value name
r"^\s*(\w+)"
# capture group for value (different possible formats, 123, 0x1, (1 << 5), etc.)
+ r"(?:\s*=\s*((\(int\))?(\w|:|\s|\+|\(?\d+\s*<<\s*\d+\)?)*))?"
# comma after the value, plus possible comments
+ r"(\s*,)?\s*((/\*|//).*)?"
# EOL
+ r"$"
)
me = re.match(regex, elem)
if me and not me.group(1).startswith("last"):
Expand Down
4 changes: 0 additions & 4 deletions requirements-test-optional.txt

This file was deleted.

5 changes: 5 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ scikit-learn==1.2.2 ; python_version == '3.8'
scikit-learn==1.3.1 ; python_version >= '3.9'
pandas==2.0.1 ; python_version == '3.8'
pandas==2.1.1 ; python_version >= '3.9'
xgboost==1.7.6; python_version <= '3.9'
ahuber21 marked this conversation as resolved.
Show resolved Hide resolved
xgboost==2.0.0; python_version >= '3.10'
lightgbm==4.1.0
catboost==1.2.2; python_version <= '3.11' # FIXME: Add as soon as 3.12 is supported
shap==0.42.1; python_version <= '3.11' # FIXME: Add as soon as 3.12 is supported
Loading
Loading