Skip to content

Commit

Permalink
Add SHAP calculation to GBT regression (oneapi-src#2460)
Browse files Browse the repository at this point in the history
* WIP: Add SHAP contributions and interactions

add weights to GbtDecisionTree

Include TreeShap recursion steps

fix buffer overflow in memcpy

Add cover to GbtDecisionTree from model builder

fix some index offsets, correct results for trees up to depth=5

fix: nodeIsDummyLeaf is supposed to check left child

remove some debug statements

chore: apply oneDAL code style

predictContribution wrapper with template dispatching

increase speed by reducing number of cache misses

use thread-local result accessor

backup commit with 13% speedup wrt xgboost

add preShapContributions/predShapInteractions as function parameter

Revert "introduce pred_contribs and pred_interactions SHAP options"

This reverts commit 483aa5b.

remove some debug content

reset env_detect.cpp to origin/master

remove std::vector<float> test by introducing thread-local NumericTable

Move treeshap into separate translation unit - caution: treeShap undefined in libonedal

builds but segfaults

Fix function arguments

respect predShapContributions and predShapInteractions options and check for legal combinations

tmp: work on pred_interactions

* no more segfaults

* fix pred_interactions

* add fast treeshap v1

* Add combinationSum calculation for Fast TreeSHAP v2

* daal_calloc -> daal_malloc

* support shap contribution calculation with Fast TreeSHAP v2

* Consistently add cover to daaal APIs, add output parameters to end of function arguments

* align tree cfl/reg APIs

* restore .gitignore from master

* cleanup for review

* add newline

* remove defaultLeft value that's not needed

* Update model builder examples

* Add backwards-compatible model builder API & deprecate decls

* fix: remove dead code

* fix: simplify number of nodes calculation

* chore: typos and code style

* Fix bazel build

* fix: remove dead member variable in GbtDecisionTree

* feat: add first unit tests for model builders

* revert dal_module back to daal_module

* feat: execute dal unit tests in CI

* reorganize how tests are executed

* add license

* Fix new_ts: nodeIsLeaf/nodeIsDummyLeaf internal usage & classification Parameter

* Update TreeVisitor with node cover value

* remove deprecation version in comment

* remove skipping of XGBoost base_score tree

* feature: proper support for XGBoost's base_score value

* Update code attributions / cite / license

* typo

* chore: remove resIncrement from GBT predict

* Document functions and separate declarations and implementations

* review comments oneapi-src#1

* review comments oneapi-src#2 - fix pImpl idiom

* refactor: replace boolean parameters with DAAL_UINT64 flag

* fix: usage of bias/margin for LightGBM models

* review comments oneapi-src#2

* fixup endless for loop

* use TArray, introduce TreeShapVersion enum

* use TArray where possible

* fix: move data field to implementation class

* Update cpp/daal/include/algorithms/tree_utils/tree_utils.h

Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>

* add typedef to shorten statements

* provide doxygen description of gbt classification funtions

* fix some typos

* consistently use size_t for node indexing; unsigned -> uint32_t

* fix: don't include test in release

* fix multiline comments

---------

Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
ahuber21 and Vika-F authored Oct 27, 2023
1 parent 2850633 commit 25602a5
Show file tree
Hide file tree
Showing 41 changed files with 2,250 additions and 449 deletions.
4 changes: 4 additions & 0 deletions .ci/pipeline/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ jobs:
--test_thread_mode=par
displayName: 'cpp-examples-thread-release-dynamic'
- script: |
bazel test //cpp/daal:tests
displayName: 'daal-tests-algorithms'
- script: |
bazel test //cpp/oneapi/dal:tests \
--config=host \
Expand Down
37 changes: 30 additions & 7 deletions cpp/daal/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
package(default_visibility = ["//visibility:public"])
load("@onedal//dev/bazel:dal.bzl",
"dal_test_suite",
"dal_collect_test_suites",
)
load("@onedal//dev/bazel:daal.bzl",
"daal_module",
"daal_static_lib",
Expand Down Expand Up @@ -28,7 +32,7 @@ daal_module(
deps = select({
"@config//:backend_ref": [ "@openblas//:openblas",
],
"//conditions:default": [ "@micromkl//:mkl_thr",
"//conditions:default": [ "@micromkl//:mkl_thr",
],
}),
)
Expand All @@ -54,7 +58,7 @@ daal_module(
"DAAL_HIDE_DEPRECATED",
],
deps = select({
"@config//:backend_ref": [
"@config//:backend_ref": [
":public_includes",
"@openblas//:headers",
],
Expand Down Expand Up @@ -123,11 +127,11 @@ daal_module(
hdrs = glob(["src/sycl/**/*.h", "src/sycl/**/*.cl"]),
srcs = glob(["src/sycl/**/*.cpp"]),
deps = select({
"@config//:backend_ref": [
"@config//:backend_ref": [
":services",
"@onedal//cpp/daal/src/algorithms/engines:kernel",
],
"//conditions:default": [
"//conditions:default": [
":services",
"@onedal//cpp/daal/src/algorithms/engines:kernel",
"@micromkl_dpc//:headers",
Expand All @@ -146,13 +150,13 @@ daal_module(
"TBB_USE_ASSERT=0",
],
deps = select({
"@config//:backend_ref": [
"@config//:backend_ref": [
":threading_headers",
":mathbackend_thread",
"@tbb//:tbb",
"@tbb//:tbbmalloc",
],
"//conditions:default": [
"//conditions:default": [
":threading_headers",
":mathbackend_thread",
"@tbb//:tbb",
Expand Down Expand Up @@ -269,7 +273,7 @@ daal_dynamic_lib(
],
def_file = select({
"@config//:backend_ref": "src/threading/export_lnx32e.ref.def",
"//conditions:default": "src/threading/export_lnx32e.mkl.def",
"//conditions:default": "src/threading/export_lnx32e.mkl.def",
}),
)

Expand Down Expand Up @@ -316,3 +320,22 @@ filegroup(
":thread_static",
],
)

dal_test_suite(
name = "unit_tests",
framework = "catch2",
srcs = glob([
"test/*.cpp",
]),
)

dal_collect_test_suites(
name = "tests",
root = "@onedal//cpp/daal/src/algorithms",
modules = [
"dtrees/gbt/regression"
],
tests = [
":unit_tests",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -107,49 +107,79 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] classLabel Class label to be predicted
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel)
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, const double cover)
{
NodeId resId;
_status |= addLeafNodeInternal(treeId, parentId, position, classLabel, resId);
_status |= addLeafNodeInternal(treeId, parentId, position, classLabel, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel)
{
return addLeafNode(treeId, parentId, position, classLabel, 0);
}

/**
* Create Leaf node and add it to certain tree
* \param[in] treeId Tree to which new node is added
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] proba Array with probability values for each class
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba)
NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba, const double cover)
{
NodeId resId;
_status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, resId);
_status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba)
{
return addLeafNodeByProba(treeId, parentId, position, proba, 0);
}

/**
* Create Split node and add it to certain tree
* \param[in] treeId Tree to which new node is added
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue)
NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue,
const int defaultLeft, const double cover)
{
NodeId resId;
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId);
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue)
{
return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0);
}

void setNFeatures(size_t nFeatures)
{
if (!_model.get())
Expand Down Expand Up @@ -184,11 +214,12 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(const size_t nClasses, const size_t nTrees);
services::Status createTreeInternal(const size_t nNodes, TreeId & resId);
services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, NodeId & res);
services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel,
const double cover, NodeId & res);
services::Status addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba,
NodeId & res);
const double cover, NodeId & res);
services::Status addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue, NodeId & res);
const double featureValue, const int defaultLeft, const double cover, NodeId & res);

private:
size_t _nClasses;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public classifier::Model
*/
virtual size_t getNumberOfTrees() const = 0;

/**
* \brief Set the Prediction Bias term
*
* \param value global prediction bias
*/
virtual void setPredictionBias(double value) = 0;

/**
* \brief Get the Prediction Bias term
*
* \return double prediction bias
*/
virtual double getPredictionBias() const = 0;

protected:
Model() : classifier::Model() {}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] response Response value for leaf node to be predicted
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response, double cover)
{
NodeId resId;
_status |= addLeafNodeInternal(treeId, parentId, position, response, resId);
_status |= addLeafNodeInternal(treeId, parentId, position, response, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
{
return addLeafNode(treeId, parentId, position, response, 0);
}

/**
* Create Split node and add it to certain tree
* \param[in] treeId Tree to which new node is added
Expand All @@ -127,16 +136,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft = 0)
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
NodeId resId;
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId, defaultLeft);
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}

/**
* \DAAL_DEPRECATED
*/
DAAL_DEPRECATED NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue)
{
return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0);
}

/**
* Get built model
* \return Model pointer
Expand All @@ -159,9 +177,9 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(size_t nFeatures, size_t nIterations, size_t nClasses);
services::Status createTreeInternal(size_t nNodes, size_t classLabel, TreeId & resId);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, NodeId & res,
int defaultLeft);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, const double cover, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft,
const double cover, NodeId & res);
services::Status convertModelInternal();
size_t _nClasses;
size_t _nIterations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ enum Method
defaultDense = 0 /*!< Default method */
};

/**
* <a name="DAAL-ENUM-ALGORITHMS__GBT__CLASSIFICATION__PREDICTION__RESULTTOCOMPUTEID"></a>
* Available identifiers to specify the result to compute - results are mutually exclusive
*/
enum ResultToComputeId
{
predictionResult = (1 << 0), /*!< Compute the regular prediction */
shapContributions = (1 << 1), /*!< Compute SHAP contribution values */
shapInteractions = (1 << 2) /*!< Compute SHAP interaction values */
};

/**
* \brief Contains version 2.0 of the Intel(R) oneAPI Data Analytics Library interface.
*/
Expand All @@ -70,9 +81,12 @@ namespace interface2
/* [Parameter source code] */
struct DAAL_EXPORT Parameter : public daal::algorithms::classifier::Parameter
{
Parameter(size_t nClasses = 2) : daal::algorithms::classifier::Parameter(nClasses), nIterations(0) {}
Parameter(const Parameter & o) : daal::algorithms::classifier::Parameter(o), nIterations(o.nIterations) {}
size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */
typedef daal::algorithms::classifier::Parameter super;

Parameter(size_t nClasses = 2) : super(nClasses), nIterations(0), resultsToCompute(predictionResult) {}
Parameter(const Parameter & o) : super(o), nIterations(o.nIterations), resultsToCompute(o.resultsToCompute) {}
size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */
DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */
};
/* [Parameter source code] */
} // namespace interface2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public algorithms::regression::Model
*/
virtual size_t getNumberOfTrees() const = 0;

/**
* \brief Set the Prediction Bias term
*
* \param value global prediction bias
*/
virtual void setPredictionBias(double value) = 0;

/**
* \brief Get the Prediction Bias term
*
* \return double prediction bias
*/
virtual double getPredictionBias() const = 0;

protected:
Model();
};
Expand Down
Loading

0 comments on commit 25602a5

Please sign in to comment.