diff --git a/.ci/pipeline/ci.yml b/.ci/pipeline/ci.yml index 811a046f10b..5b1412635fc 100755 --- a/.ci/pipeline/ci.yml +++ b/.ci/pipeline/ci.yml @@ -249,6 +249,10 @@ jobs: --test_thread_mode=par displayName: 'cpp-examples-thread-release-dynamic' + - script: | + bazel test //cpp/daal:tests + displayName: 'daal-tests-algorithms' + - script: | bazel test //cpp/oneapi/dal:tests \ --config=host \ diff --git a/cpp/daal/BUILD b/cpp/daal/BUILD index 5e7f640c782..cea370c9b2d 100644 --- a/cpp/daal/BUILD +++ b/cpp/daal/BUILD @@ -1,4 +1,8 @@ package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_test_suite", + "dal_collect_test_suites", +) load("@onedal//dev/bazel:daal.bzl", "daal_module", "daal_static_lib", @@ -28,7 +32,7 @@ daal_module( deps = select({ "@config//:backend_ref": [ "@openblas//:openblas", ], - "//conditions:default": [ "@micromkl//:mkl_thr", + "//conditions:default": [ "@micromkl//:mkl_thr", ], }), ) @@ -54,7 +58,7 @@ daal_module( "DAAL_HIDE_DEPRECATED", ], deps = select({ - "@config//:backend_ref": [ + "@config//:backend_ref": [ ":public_includes", "@openblas//:headers", ], @@ -123,11 +127,11 @@ daal_module( hdrs = glob(["src/sycl/**/*.h", "src/sycl/**/*.cl"]), srcs = glob(["src/sycl/**/*.cpp"]), deps = select({ - "@config//:backend_ref": [ + "@config//:backend_ref": [ ":services", "@onedal//cpp/daal/src/algorithms/engines:kernel", ], - "//conditions:default": [ + "//conditions:default": [ ":services", "@onedal//cpp/daal/src/algorithms/engines:kernel", "@micromkl_dpc//:headers", @@ -146,13 +150,13 @@ daal_module( "TBB_USE_ASSERT=0", ], deps = select({ - "@config//:backend_ref": [ + "@config//:backend_ref": [ ":threading_headers", ":mathbackend_thread", "@tbb//:tbb", "@tbb//:tbbmalloc", ], - "//conditions:default": [ + "//conditions:default": [ ":threading_headers", ":mathbackend_thread", "@tbb//:tbb", @@ -269,7 +273,7 @@ daal_dynamic_lib( ], def_file = select({ "@config//:backend_ref": "src/threading/export_lnx32e.ref.def", - "//conditions:default": "src/threading/export_lnx32e.mkl.def", + "//conditions:default": "src/threading/export_lnx32e.mkl.def", }), ) @@ -316,3 +320,22 @@ filegroup( ":thread_static", ], ) + +dal_test_suite( + name = "unit_tests", + framework = "catch2", + srcs = glob([ + "test/*.cpp", + ]), +) + +dal_collect_test_suites( + name = "tests", + root = "@onedal//cpp/daal/src/algorithms", + modules = [ + "dtrees/gbt/regression" + ], + tests = [ + ":unit_tests", + ], +) diff --git a/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h b/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h index cef25464418..46b5aaa804b 100644 --- a/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h +++ b/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h @@ -107,32 +107,50 @@ class DAAL_EXPORT ModelBuilder * \param[in] parentId Parent node to which new node is added (use noParent for root node) * \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree) * \param[in] classLabel Class label to be predicted + * \param[in] cover Cover (Hessian sum) of the node * \return Node identifier */ - NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel) + NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, const double cover) { NodeId resId; - _status |= addLeafNodeInternal(treeId, parentId, position, classLabel, resId); + _status |= addLeafNodeInternal(treeId, parentId, position, classLabel, cover, resId); services::throwIfPossible(_status); return resId; } + /** + * \DAAL_DEPRECATED + */ + DAAL_DEPRECATED NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel) + { + return addLeafNode(treeId, parentId, position, classLabel, 0); + } + /** * Create Leaf node and add it to certain tree * \param[in] treeId Tree to which new node is added * \param[in] parentId Parent node to which new node is added (use noParent for root node) * \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree) * \param[in] proba Array with probability values for each class + * \param[in] cover Cover (Hessian sum) of the node * \return Node identifier */ - NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba) + NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba, const double cover) { NodeId resId; - _status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, resId); + _status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, cover, resId); services::throwIfPossible(_status); return resId; } + /** + * \DAAL_DEPRECATED + */ + DAAL_DEPRECATED NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba) + { + return addLeafNodeByProba(treeId, parentId, position, proba, 0); + } + /** * Create Split node and add it to certain tree * \param[in] treeId Tree to which new node is added @@ -140,16 +158,28 @@ class DAAL_EXPORT ModelBuilder * \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree) * \param[in] featureIndex Feature index for splitting * \param[in] featureValue Feature value for splitting + * \param[in] defaultLeft Behaviour in case of missing values + * \param[in] cover Cover (Hessian sum) of the node * \return Node identifier */ - NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue) + NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue, + const int defaultLeft, const double cover) { NodeId resId; - _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId); + _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId); services::throwIfPossible(_status); return resId; } + /** + * \DAAL_DEPRECATED + */ + DAAL_DEPRECATED NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, + const double featureValue) + { + return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0); + } + void setNFeatures(size_t nFeatures) { if (!_model.get()) @@ -184,11 +214,12 @@ class DAAL_EXPORT ModelBuilder services::Status _status; services::Status initialize(const size_t nClasses, const size_t nTrees); services::Status createTreeInternal(const size_t nNodes, TreeId & resId); - services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, NodeId & res); + services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, + const double cover, NodeId & res); services::Status addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba, - NodeId & res); + const double cover, NodeId & res); services::Status addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, - const double featureValue, NodeId & res); + const double featureValue, const int defaultLeft, const double cover, NodeId & res); private: size_t _nClasses; diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h index 69f348c2baf..f6e0a7e8188 100644 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h @@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public classifier::Model */ virtual size_t getNumberOfTrees() const = 0; + /** + * \brief Set the Prediction Bias term + * + * \param value global prediction bias + */ + virtual void setPredictionBias(double value) = 0; + + /** + * \brief Get the Prediction Bias term + * + * \return double prediction bias + */ + virtual double getPredictionBias() const = 0; + protected: Model() : classifier::Model() {} }; diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h index d72cff6e9f4..b8d1c5da47d 100644 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h @@ -109,16 +109,25 @@ class DAAL_EXPORT ModelBuilder * \param[in] parentId Parent node to which new node is added (use noParent for root node) * \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree) * \param[in] response Response value for leaf node to be predicted + * \param[in] cover Cover (Hessian sum) of the node * \return Node identifier */ - NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response) + NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response, double cover) { NodeId resId; - _status |= addLeafNodeInternal(treeId, parentId, position, response, resId); + _status |= addLeafNodeInternal(treeId, parentId, position, response, cover, resId); services::throwIfPossible(_status); return resId; } + /** + * \DAAL_DEPRECATED + */ + DAAL_DEPRECATED NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response) + { + return addLeafNode(treeId, parentId, position, response, 0); + } + /** * Create Split node and add it to certain tree * \param[in] treeId Tree to which new node is added @@ -127,16 +136,25 @@ class DAAL_EXPORT ModelBuilder * \param[in] featureIndex Feature index for splitting * \param[in] featureValue Feature value for splitting * \param[in] defaultLeft Behaviour in case of missing values + * \param[in] cover Cover (Hessian sum) of the node * \return Node identifier */ - NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft = 0) + NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover) { NodeId resId; - _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId, defaultLeft); + _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId); services::throwIfPossible(_status); return resId; } + /** + * \DAAL_DEPRECATED + */ + DAAL_DEPRECATED NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue) + { + return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0); + } + /** * Get built model * \return Model pointer @@ -159,9 +177,9 @@ class DAAL_EXPORT ModelBuilder services::Status _status; services::Status initialize(size_t nFeatures, size_t nIterations, size_t nClasses); services::Status createTreeInternal(size_t nNodes, size_t classLabel, TreeId & resId); - services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res); - services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, NodeId & res, - int defaultLeft); + services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, const double cover, NodeId & res); + services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, + const double cover, NodeId & res); services::Status convertModelInternal(); size_t _nClasses; size_t _nIterations; diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h index ec6fe54ca9d..58aeea58a3a 100755 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h @@ -56,6 +56,17 @@ enum Method defaultDense = 0 /*!< Default method */ }; +/** + * + * Available identifiers to specify the result to compute - results are mutually exclusive + */ +enum ResultToComputeId +{ + predictionResult = (1 << 0), /*!< Compute the regular prediction */ + shapContributions = (1 << 1), /*!< Compute SHAP contribution values */ + shapInteractions = (1 << 2) /*!< Compute SHAP interaction values */ +}; + /** * \brief Contains version 2.0 of the Intel(R) oneAPI Data Analytics Library interface. */ @@ -70,9 +81,12 @@ namespace interface2 /* [Parameter source code] */ struct DAAL_EXPORT Parameter : public daal::algorithms::classifier::Parameter { - Parameter(size_t nClasses = 2) : daal::algorithms::classifier::Parameter(nClasses), nIterations(0) {} - Parameter(const Parameter & o) : daal::algorithms::classifier::Parameter(o), nIterations(o.nIterations) {} - size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */ + typedef daal::algorithms::classifier::Parameter super; + + Parameter(size_t nClasses = 2) : super(nClasses), nIterations(0), resultsToCompute(predictionResult) {} + Parameter(const Parameter & o) : super(o), nIterations(o.nIterations), resultsToCompute(o.resultsToCompute) {} + size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */ + DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */ }; /* [Parameter source code] */ } // namespace interface2 diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h index 99607a6d16f..6e5e2b33027 100644 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h @@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public algorithms::regression::Model */ virtual size_t getNumberOfTrees() const = 0; + /** + * \brief Set the Prediction Bias term + * + * \param value global prediction bias + */ + virtual void setPredictionBias(double value) = 0; + + /** + * \brief Get the Prediction Bias term + * + * \return double prediction bias + */ + virtual double getPredictionBias() const = 0; + protected: Model(); }; diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h index 2a2dd4bcfe9..c9343f03d8d 100644 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h @@ -107,16 +107,25 @@ class DAAL_EXPORT ModelBuilder * \param[in] parentId Parent node to which new node is added (use noParent for root node) * \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree) * \param[in] response Response value for leaf node to be predicted + * \param[in] cover Cover of the node (sum_hess) * \return Node identifier */ - NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response) + NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response, double cover) { NodeId resId; - _status |= addLeafNodeInternal(treeId, parentId, position, response, resId); + _status |= addLeafNodeInternal(treeId, parentId, position, response, cover, resId); services::throwIfPossible(_status); return resId; } + /** + * \DAAL_DEPRECATED + */ + DAAL_DEPRECATED NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response) + { + return addLeafNode(treeId, parentId, position, response, 0); + } + /** * Create Split node and add it to certain tree * \param[in] treeId Tree to which new node is added @@ -125,16 +134,25 @@ class DAAL_EXPORT ModelBuilder * \param[in] featureIndex Feature index for splitting * \param[in] featureValue Feature value for splitting * \param[in] defaultLeft Behaviour in case of missing values + * \param[in] cover Cover of the node (sum_hess) * \return Node identifier */ - NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft = 0) + NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover) { NodeId resId; - _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId, defaultLeft); + _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId); services::throwIfPossible(_status); return resId; } + /** + * \DAAL_DEPRECATED + */ + DAAL_DEPRECATED NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue) + { + return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0); + } + /** * Get built model * \return Model pointer @@ -157,9 +175,9 @@ class DAAL_EXPORT ModelBuilder services::Status _status; services::Status initialize(size_t nFeatures, size_t nIterations); services::Status createTreeInternal(size_t nNodes, TreeId & resId); - services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res); - services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, NodeId & res, - int defaultLeft); + services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res); + services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, + double cover, NodeId & res); services::Status convertModelInternal(); }; /** @} */ diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h index 8a9cc5da552..f69591a9a6e 100644 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h @@ -90,6 +90,17 @@ enum ResultId lastResultId = prediction }; +/** + * + * Available identifiers to specify the result to compute - results are mutually exclusive + */ +enum ResultToComputeId +{ + predictionResult = (1 << 0), /*!< Compute the regular prediction */ + shapContributions = (1 << 1), /*!< Compute SHAP contribution values */ + shapInteractions = (1 << 2) /*!< Compute SHAP interaction values */ +}; + /** * \brief Contains version 1.0 of the Intel(R) oneAPI Data Analytics Library interface */ @@ -104,9 +115,12 @@ namespace interface1 /* [Parameter source code] */ struct DAAL_EXPORT Parameter : public daal::algorithms::Parameter { - Parameter() : daal::algorithms::Parameter(), nIterations(0) {} - Parameter(const Parameter & o) : daal::algorithms::Parameter(o), nIterations(o.nIterations) {} - size_t nIterations; /*!< Number of iterations of the trained model to be uses for prediction*/ + typedef daal::algorithms::Parameter super; + + Parameter() : super(), nIterations(0), resultsToCompute(predictionResult) {} + Parameter(const Parameter & o) : super(o), nIterations(o.nIterations), resultsToCompute(o.resultsToCompute) {} + size_t nIterations; /*!< Number of iterations of the trained model to be uses for prediction*/ + DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */ }; /* [Parameter source code] */ diff --git a/cpp/daal/include/algorithms/tree_utils/tree_utils.h b/cpp/daal/include/algorithms/tree_utils/tree_utils.h index c1e98513b5b..d98ef4779cc 100644 --- a/cpp/daal/include/algorithms/tree_utils/tree_utils.h +++ b/cpp/daal/include/algorithms/tree_utils/tree_utils.h @@ -59,6 +59,7 @@ struct DAAL_EXPORT SplitNodeDescriptor : public NodeDescriptor { size_t featureIndex; /*!< Feature used for splitting the node */ double featureValue; /*!< Threshold value at the node */ + double coverValue; /*!< Cover, a sum of the Hessian values of the loss function evaluated at the points flowing through the node */ }; /** diff --git a/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h b/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h index a776a5412d9..f8d13c56a9f 100644 --- a/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h +++ b/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h @@ -50,8 +50,9 @@ namespace interface2 */ struct DAAL_EXPORT LeafNodeDescriptor : public NodeDescriptor { - size_t label; /*!< Label to be predicted when reaching the leaf */ - const double * prob; /*!< Probabilities estimation for the leaf */ + size_t label; /*!< Label to be predicted when reaching the leaf */ + const double * prob; /*!< Probabilities estimation for the leaf */ + const double * cover; /*!< Cover (sum_hess) for the leaf */ }; typedef daal::algorithms::tree_utils::TreeNodeVisitor TreeNodeVisitor; diff --git a/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h b/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h index 9890556ea12..8e587f984ad 100644 --- a/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h +++ b/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h @@ -50,7 +50,8 @@ namespace interface1 */ struct DAAL_EXPORT LeafNodeDescriptor : public NodeDescriptor { - double response; /*!< Value to be predicted when reaching the leaf */ + double response; /*!< Value to be predicted when reaching the leaf */ + double coverValue; /*!< Cover (sum_hess) for the leaf */ }; typedef daal::algorithms::tree_utils::TreeNodeVisitor TreeNodeVisitor; diff --git a/cpp/daal/include/services/error_indexes.h b/cpp/daal/include/services/error_indexes.h index 8d5ca7c79e7..9c9c634cd08 100644 --- a/cpp/daal/include/services/error_indexes.h +++ b/cpp/daal/include/services/error_indexes.h @@ -378,6 +378,7 @@ enum ErrorID // GBT error: -30000..-30099 ErrorGbtIncorrectNumberOfTrees = -30000, /*!< Number of trees in the model is not consistent with the number of classes */ ErrorGbtPredictIncorrectNumberOfIterations = -30001, /*!< Number of iterations value in GBT parameter is not consistent with the model */ + ErrorGbtPredictShapOptions = -30002, /*< For SHAP values, calculate either contributions or interactions, not both */ // Data management errors: -80001.. ErrorUserAllocatedMemory = -80001, /*!< Couldn't free memory allocated by user */ diff --git a/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp b/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp index 37a06dfe111..32a298dc34b 100644 --- a/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp +++ b/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp @@ -180,17 +180,19 @@ services::Status createTreeInternal(data_management::DataCollectionPtr & seriali return s; } -void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel) +void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel, double cover) { node.featureIndex = featureIndex; node.leftIndexOrClass = classLabel; + node.cover = cover; node.featureValueOrResponse = 0; } -void setNode(DecisionTreeNode & node, int featureIndex, double response) +void setNode(DecisionTreeNode & node, int featureIndex, double response, double cover) { node.featureIndex = featureIndex; node.leftIndexOrClass = 0; + node.cover = cover; node.featureValueOrResponse = response; } @@ -222,7 +224,7 @@ void setProbabilities(const size_t treeId, const size_t nodeId, const size_t res } services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position, - size_t featureIndex, double featureValue, size_t & res, int defaultLeft) + size_t featureIndex, double featureValue, int defaultLeft, double cover, size_t & res) { const size_t noParent = static_cast(-1); services::Status s; @@ -243,6 +245,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria aNode[0].defaultLeft = defaultLeft; aNode[0].leftIndexOrClass = 0; aNode[0].featureValueOrResponse = featureValue; + aNode[0].cover = cover; nodeId = 0; } else if (aNode[parentId].featureIndex < 0) @@ -262,6 +265,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria aNode[nodeId].defaultLeft = defaultLeft; aNode[nodeId].leftIndexOrClass = 0; aNode[nodeId].featureValueOrResponse = featureValue; + aNode[nodeId].cover = cover; } } if ((aNode[parentId].leftIndexOrClass > 0) && (position == 0)) @@ -274,6 +278,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria aNode[nodeId].defaultLeft = defaultLeft; aNode[nodeId].leftIndexOrClass = 0; aNode[nodeId].featureValueOrResponse = featureValue; + aNode[nodeId].cover = cover; } } if ((aNode[parentId].leftIndexOrClass == 0) && (position == 0)) @@ -296,6 +301,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria aNode[nodeId].defaultLeft = defaultLeft; aNode[nodeId].leftIndexOrClass = 0; aNode[nodeId].featureValueOrResponse = featureValue; + aNode[nodeId].cover = cover; aNode[parentId].leftIndexOrClass = nodeId; if (((nodeId + 1) < nRows) && (aNode[nodeId + 1].featureIndex == __NODE_FREE_ID)) { @@ -332,6 +338,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria aNode[nodeId].defaultLeft = defaultLeft; aNode[nodeId].leftIndexOrClass = 0; aNode[nodeId].featureValueOrResponse = featureValue; + aNode[nodeId].cover = cover; } else { diff --git a/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h b/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h index 0eee0e757f3..08c04d677a8 100644 --- a/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h +++ b/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h @@ -53,6 +53,7 @@ struct DecisionTreeNode ClassIndexType leftIndexOrClass; //split: left node index, classification leaf: class index ModelFPType featureValueOrResponse; //split: feature value, regression tree leaf: response int defaultLeft; //split: if 1: go to the yes branch for missing value + ModelFPType cover; //split: cover (sum_hess) of the node DAAL_FORCEINLINE bool isSplit() const { return featureIndex != -1; } ModelFPType featureValue() const { return featureValueOrResponse; } }; @@ -60,12 +61,13 @@ struct DecisionTreeNode class DecisionTreeTable : public data_management::AOSNumericTable { public: - DecisionTreeTable(size_t rowCount = 0) : data_management::AOSNumericTable(sizeof(DecisionTreeNode), 4, rowCount) + DecisionTreeTable(size_t rowCount = 0) : data_management::AOSNumericTable(sizeof(DecisionTreeNode), 5, rowCount) { setFeature(0, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, featureIndex)); setFeature(1, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, leftIndexOrClass)); setFeature(2, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, featureValueOrResponse)); setFeature(3, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, defaultLeft)); + setFeature(4, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, cover)); allocateDataMemory(); } }; @@ -339,19 +341,19 @@ ModelImplType & getModelRef(ModelTypePtr & modelPtr) services::Status createTreeInternal(data_management::DataCollectionPtr & serializationData, size_t nNodes, size_t & resId); -void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel); +void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel, double cover); -void setNode(DecisionTreeNode & node, int featureIndex, double response); +void setNode(DecisionTreeNode & node, int featureIndex, double response, double cover); services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position, - size_t featureIndex, double featureValue, size_t & res, int defaultLeft = 0); + size_t featureIndex, double featureValue, int defaultLeft, double cover, size_t & res); void setProbabilities(const size_t treeId, const size_t nodeId, const size_t response, const data_management::DataCollectionPtr probTbl, const double * const prob); template static services::Status addLeafNodeInternal(const data_management::DataCollectionPtr & serializationData, const size_t treeId, const size_t parentId, - const size_t position, ClassOrResponseType response, size_t & res, + const size_t position, ClassOrResponseType response, double cover, size_t & res, const data_management::DataCollectionPtr probTbl = data_management::DataCollectionPtr(), const double * const prob = nullptr, const size_t nClasses = 0) { @@ -373,7 +375,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio size_t nodeId = 0; if (parentId == noParent) { - setNode(aNode[0], -1, response); + setNode(aNode[0], -1, response, cover); setProbabilities(treeId, 0, response, probTbl, prob); nodeId = 0; } @@ -390,7 +392,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio nodeId = reservedId; if (aNode[reservedId].featureIndex == __NODE_RESERVED_ID) { - setNode(aNode[nodeId], -1, response); + setNode(aNode[nodeId], -1, response, cover); setProbabilities(treeId, nodeId, response, probTbl, prob); } } @@ -400,7 +402,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio nodeId = reservedId; if (aNode[reservedId].featureIndex == __NODE_RESERVED_ID) { - setNode(aNode[nodeId], -1, response); + setNode(aNode[nodeId], -1, response, cover); setProbabilities(treeId, nodeId, response, probTbl, prob); } } @@ -420,7 +422,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio { return services::Status(services::ErrorID::ErrorIncorrectParameter); } - setNode(aNode[nodeId], -1, response); + setNode(aNode[nodeId], -1, response, cover); setProbabilities(treeId, nodeId, response, probTbl, prob); aNode[parentId].leftIndexOrClass = nodeId; if (((nodeId + 1) < nRows) && (aNode[nodeId + 1].featureIndex == __NODE_FREE_ID)) @@ -454,7 +456,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio nodeId = leftEmptyId + 1; if (nodeId < nRows) { - setNode(aNode[nodeId], -1, response); + setNode(aNode[nodeId], -1, response, cover); setProbabilities(treeId, nodeId, response, probTbl, prob); } else diff --git a/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp b/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp index c801da2c4cb..ad53f087081 100644 --- a/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp +++ b/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp @@ -72,7 +72,7 @@ services::Status ModelBuilder::createTreeInternal(const size_t nNodes, TreeId & } services::Status ModelBuilder::addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, - NodeId & res) + const double cover, NodeId & res) { decision_forest::classification::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); @@ -81,7 +81,7 @@ services::Status ModelBuilder::addLeafNodeInternal(const TreeId treeId, const No return services::Status(services::ErrorID::ErrorIncorrectParameter); } return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, classLabel, - res, modelImplRef._probTbl); + cover, res, modelImplRef._probTbl); } bool checkProba(const double * const proba, const size_t nClasses) @@ -104,7 +104,7 @@ bool checkProba(const double * const proba, const size_t nClasses) } services::Status ModelBuilder::addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position, - const double * const proba, NodeId & res) + const double * const proba, const double cover, NodeId & res) { decision_forest::classification::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); @@ -112,17 +112,17 @@ services::Status ModelBuilder::addLeafNodeByProbaInternal(const TreeId treeId, c { return services::Status(services::ErrorID::ErrorIncorrectParameter); } - return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, 0, res, + return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, 0, cover, res, modelImplRef._probTbl, proba, _nClasses); } services::Status ModelBuilder::addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, - const double featureValue, NodeId & res) + const double featureValue, const int defaultLeft, const double cover, NodeId & res) { decision_forest::classification::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex, - featureValue, res); + featureValue, defaultLeft, cover, res); } } // namespace interface2 diff --git a/cpp/daal/src/algorithms/dtrees/gbt/BUILD b/cpp/daal/src/algorithms/dtrees/gbt/BUILD index 712d778e8ae..9d717dae310 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/BUILD +++ b/cpp/daal/src/algorithms/dtrees/gbt/BUILD @@ -1,4 +1,7 @@ package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_collect_test_suites", +) load("@onedal//dev/bazel:daal.bzl", "daal_module") daal_module( @@ -11,3 +14,11 @@ daal_module( "@onedal//cpp/daal/src/algorithms/dtrees:kernel", ], ) + +dal_collect_test_suites( + name = "tests", + root = "@onedal//cpp/oneapi/dal/algo", + modules = [ + "regression" + ], +) diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp index 1e4c6b71234..afb01125ef6 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp @@ -84,10 +84,21 @@ void ModelImpl::traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisito ImplType::traverseBFS(iTree, visitor); } +void ModelImpl::setPredictionBias(double value) +{ + _predictionBias = value; +} + +double ModelImpl::getPredictionBias() const +{ + return _predictionBias; +} + services::Status ModelImpl::serializeImpl(data_management::InputDataArchive * arch) { auto s = algorithms::classifier::Model::serialImpl(arch); arch->set(this->_nFeatures); //algorithms::classifier::internal::ModelInternal + arch->set(this->_predictionBias); return s.add(ImplType::serialImpl(arch)); } @@ -95,6 +106,7 @@ services::Status ModelImpl::deserializeImpl(const data_management::OutputDataArc { auto s = algorithms::classifier::Model::serialImpl(arch); arch->set(this->_nFeatures); //algorithms::classifier::internal::ModelInternal + arch->set(this->_predictionBias); return s.add(ImplType::serialImpl( arch, COMPUTE_DAAL_VERSION(arch->getMajorVersion(), arch->getMinorVersion(), arch->getUpdateVersion()))); } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp index 1c0462dd7a1..c4c195e400f 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp @@ -68,7 +68,7 @@ services::Status ModelBuilder::initialize(size_t nFeatures, size_t nIterations, return s; } -services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t clasLabel, TreeId & resId) +services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t classLabel, TreeId & resId) { gbt::classification::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); @@ -83,19 +83,22 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t clasLabe { return Status(ErrorID::ErrorIncorrectParameter); } - if (clasLabel > (_nClasses - 1)) + if (_nClasses <= classLabel) { return Status(ErrorID::ErrorIncorrectParameter); } - TreeId treeId = clasLabel * _nIterations; + TreeId treeId = classLabel * _nIterations; const SerializationIface * isEmptyTreeTable = (*(modelImplRef._serializationData))[treeId].get(); - const size_t nTrees = (clasLabel + 1) * _nIterations; + const size_t nTrees = (classLabel + 1) * _nIterations; while (isEmptyTreeTable && treeId < nTrees) { treeId++; isEmptyTreeTable = (*(modelImplRef._serializationData))[treeId].get(); } - if (treeId == nTrees) return Status(ErrorID::ErrorIncorrectParameter); + if (treeId == nTrees) + { + return Status(ErrorID::ErrorIncorrectParameter); + } services::SharedPtr treeTablePtr( new DecisionTreeTable(nNodes)); //DecisionTreeTable* const treeTable = new DecisionTreeTable(nNodes); @@ -120,22 +123,21 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t clasLabe } } -services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res) +services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res) { gbt::classification::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, response, - res); - ; + cover, res); } services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, - NodeId & res, int defaultLeft) + int defaultLeft, const double cover, NodeId & res) { gbt::classification::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex, - featureValue, res, defaultLeft); + featureValue, defaultLeft, cover, res); } } // namespace interface1 diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h index 8e9fbbc216e..931780a5f1c 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h @@ -61,10 +61,17 @@ class ModelImpl : public daal::algorithms::gbt::classification::Model, virtual void traverseDFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE; virtual void traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE; + virtual void setPredictionBias(double value) DAAL_C11_OVERRIDE; + virtual double getPredictionBias() const DAAL_C11_OVERRIDE; + virtual services::Status serializeImpl(data_management::InputDataArchive * arch) DAAL_C11_OVERRIDE; virtual services::Status deserializeImpl(const data_management::OutputDataArchive * arch) DAAL_C11_OVERRIDE; virtual size_t getNumberOfTrees() const DAAL_C11_OVERRIDE; + +private: + /* global bias applied to predictions*/ + double _predictionBias; }; } // namespace internal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i index 0a79b9d040b..00a9ee5884d 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i @@ -38,6 +38,7 @@ #include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i" #include "src/algorithms/objective_function/cross_entropy_loss/cross_entropy_loss_dense_default_batch_kernel.h" #include "src/services/service_algo_utils.h" +#include using namespace daal::internal; using namespace daal::services::internal; @@ -56,111 +57,49 @@ namespace internal { ////////////////////////////////////////////////////////////////////////////////////////// -// PredictBinaryClassificationTask +// PredictBinaryClassificationTask - declaration ////////////////////////////////////////////////////////////////////////////////////////// template class PredictBinaryClassificationTask : public gbt::regression::prediction::internal::PredictRegressionTask { public: typedef gbt::regression::prediction::internal::PredictRegressionTask super; + +public: + /** + * \brief Construct a new Predict Binary Classification Task object + * + * \param x NumericTable observation data + * \param y NumericTable prediction data + * \param prob NumericTable probability data + */ PredictBinaryClassificationTask(const NumericTable * x, NumericTable * y, NumericTable * prob) : super(x, y), _prob(prob) {} - services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp) - { - DAAL_ASSERT(!nIterations || nIterations <= m->size()); - DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data)); - const auto nTreesTotal = (nIterations ? nIterations : m->size()); - this->_aTree.reset(nTreesTotal); - DAAL_CHECK_MALLOC(this->_aTree.get()); - for (size_t i = 0; i < nTreesTotal; ++i) this->_aTree[i] = m->at(i); - const auto nRows = this->_data->getNumberOfRows(); - services::Status s; - DAAL_OVERFLOW_CHECK_BY_MULTIPLICATION(size_t, nRows, sizeof(algorithmFPType)); - //compute raw boosted values - if (this->_res && _prob) - { - WriteOnlyRows resBD(this->_res, 0, nRows); - DAAL_CHECK_BLOCK_STATUS(resBD); - const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) }; - algorithmFPType * res = resBD.get(); - WriteOnlyRows probBD(_prob, 0, nRows); - DAAL_CHECK_BLOCK_STATUS(probBD); - algorithmFPType * prob_pred = probBD.get(); - TArray expValPtr(nRows); - algorithmFPType * expVal = expValPtr.get(); - DAAL_CHECK_MALLOC(expVal); - s = super::runInternal(pHostApp, this->_res); - if (!s) return s; - - auto nBlocks = daal::threader_get_threads_number(); - const size_t blockSize = nRows / nBlocks; - nBlocks += (nBlocks * blockSize != nRows); - - daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) { - const size_t startRow = iBlock * blockSize; - const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize); - daal::internal::MathInst::vExp(finishRow - startRow, res + startRow, expVal + startRow); - - PRAGMA_IVDEP - PRAGMA_VECTOR_ALWAYS - for (size_t iRow = startRow; iRow < finishRow; ++iRow) - { - res[iRow] = label[services::internal::SignBit::get(res[iRow])]; - prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]); - prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1]; - } - }); - } - else if ((!this->_res) && _prob) - { - WriteOnlyRows probBD(_prob, 0, nRows); - DAAL_CHECK_BLOCK_STATUS(probBD); - algorithmFPType * prob_pred = probBD.get(); - TArray expValPtr(nRows); - algorithmFPType * expVal = expValPtr.get(); - NumericTablePtr expNT = HomogenNumericTableCPU::create(expVal, 1, nRows, &s); - DAAL_CHECK_MALLOC(expVal); - s = super::runInternal(pHostApp, expNT.get()); - if (!s) return s; - - auto nBlocks = daal::threader_get_threads_number(); - const size_t blockSize = nRows / nBlocks; - nBlocks += (nBlocks * blockSize != nRows); - daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) { - const size_t startRow = iBlock * blockSize; - const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize); - daal::internal::MathInst::vExp(finishRow - startRow, expVal + startRow, expVal + startRow); - for (size_t iRow = startRow; iRow < finishRow; ++iRow) - { - prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]); - prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1]; - } - }); - } - else if (this->_res && (!_prob)) - { - WriteOnlyRows resBD(this->_res, 0, nRows); - DAAL_CHECK_BLOCK_STATUS(resBD); - const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) }; - algorithmFPType * res = resBD.get(); - s = super::runInternal(pHostApp, this->_res); - if (!s) return s; - - for (size_t iRow = 0; iRow < nRows; ++iRow) - { - //probablity is a sigmoid(f) hence sign(f) can be checked - res[iRow] = label[services::internal::SignBit::get(res[iRow])]; - } - } - return s; - } + /** + * \brief Run prediction for the given model + * + * \param m The model for which to run prediction + * \param nIterations Number of iterations + * \param pHostApp HostAppInterface + * \return services::Status + */ + services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp); + +protected: + /** + * \brief Convert the model bias to a margin, considering the softmax activation + * + * \param bias Bias in class score units + * \return algorithmFPType Bias in softmax offset units + */ + algorithmFPType getMarginFromModelBias(algorithmFPType bias) const; protected: NumericTable * _prob; }; ////////////////////////////////////////////////////////////////////////////////////////// -// PredictMulticlassTask +// PredictMulticlassTask - declaration ////////////////////////////////////////////////////////////////////////////////////////// template class PredictMulticlassTask @@ -171,200 +110,603 @@ public: typedef daal::tls ClassesRawBoostedTlsBase; typedef daal::TlsMem ClassesRawBoostedTls; + /** + * \brief Construct a new Predict Multiclass Task object + * + * \param x NumericTable observation data + * \param y NumericTable prediction data + * \param prob NumericTable probability data + */ PredictMulticlassTask(const NumericTable * x, NumericTable * y, NumericTable * prob) : _data(x), _res(y), _prob(prob) {} + + /** + * \brief Run prediction for the given model + * + * \param m The model for which to run prediction + * \param nClasses Number of data classes + * \param nIterations Number of iterations + * \param pHostApp HostAppInterface + * \return services::Status + */ services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nClasses, size_t nIterations, services::HostAppIface * pHostApp); protected: - services::Status predictByAllTrees(size_t nTreesTotal, size_t nClasses, const DimType & dim); - + /** Dispatcher type for template dispatching */ template using dispatcher_t = gbt::prediction::internal::PredictDispatcher; + + /** + * \brief Helper boolean constant to populate template dispatcher + * + * \param val A boolean value, known at compile time + */ + template + struct BooleanConstant + { + typedef BooleanConstant type; + }; + + /** + * \brief Run prediction for all trees + * + * \param nTreesTotal Total number of trees in model + * \param nClasses Number of data classes + * \param bias Global prediction bias (e.g. base_score in XGBoost) + * \param dim DimType helper + * \return services::Status + */ + services::Status predictByAllTrees(size_t nTreesTotal, size_t nClasses, double bias, const DimType & dim); + + /** + * \brief Make prediction for a number of trees + * + * \param hasUnorderedFeatures Data has unordered features yes/no + * \param hasAnyMissing Data has missing values yes/no + * \param val Output prediction + * \param iFirstTree Index of first ree + * \param nTrees Number of trees included in prediction + * \param nClasses Number of data classes + * \param x Input observation data + * \param dispatcher Template dispatcher helper + */ template void predictByTrees(algorithmFPType * val, size_t iFirstTree, size_t nTrees, size_t nClasses, const algorithmFPType * x, const dispatcher_t & dispatcher); + + /** + * \brief Make prediction for a number of trees leveraging vector instructions + * + * \param hasUnorderedFeatures Data has unordered features yes/no + * \param hasAnyMissing Data has missing values yes/no + * \param vectorBlockSize Vector instruction block size + * \param val Output prediction + * \param iFirstTree Index of first ree + * \param nTrees Number of trees included in prediction + * \param nClasses Number of data classes + * \param x Input observation data + * \param dispatcher Template dispatcher helper + */ template void predictByTreesVector(algorithmFPType * val, size_t iFirstTree, size_t nTrees, size_t nClasses, const algorithmFPType * x, const dispatcher_t & dispatcher); - template - struct BooleanConstant - { - typedef BooleanConstant type; - }; + /** + * \brief Assign a class index to the result + * + * \param res Pointer to result array + * \param val Value of current prediction + * \param iRow + * \param i + * \param nClasses Number of data classes + * \param dispatcher Template dispatcher helper + */ + inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher); + + /** + * \brief Empty function if results assigning is not required. + * + * \param res Pointer to result array + * \param val Value of current prediction + * \param iRow + * \param i + * \param nClasses Number of data classes + * \param dispatcher Template dispatcher helper + */ + inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher); + + /** + * \brief Prepare buff pointer for the next using. All steps reuse the same memory. + * + * \param buff Pointer to a buffer + * \param buf_shift + * \param buf_size + * \param dispatcher + * \return algorithmFPType* Pointer to the input buffer + */ + inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher); + + /** + * \brief Prepare buff pointer for the next using. Steps have own memory. + * + * \param buff + * \param buf_shift + * \param buf_size + * \param dispatcher + * \return algorithmFPType* + */ + inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher); + + /** + * \brief Get the total number of nodes in all trees for tree number [1, 2, ... nTrees] + * + * \param nTrees Number of trees that contribute to the sum + * \return size_t Number of nodes in all contributing trees + */ + inline size_t getNumberOfNodes(size_t nTrees); + + /** + * \brief Check for missing data + * + * \param x Input observation data + * \param nTrees Number of contributing trees + * \param nRows Number of rows in input observation data to be considered + * \param nColumns Number of columns in input observation data to be considered + * \return true If runtime check for missing is required + * \return false If runtime check for missing is not required + */ + inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns); + + /** + * \brief Traverse a number of trees to get prediction results + * + * \param hasUnorderedFeatures Data has unordered features yes/no + * \param hasAnyMissing Data has missing values yes/no + * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no) + * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no) + * \param vectorBlockSize Vector instruction block size + * \param nTrees Number of trees contributing to prediction + * \param nClasses Number of data classes + * \param nRows Number of rows in observation data for which prediction is run + * \param nColumns Number of columns in observation data + * \param x Input observation data + * \param buff A pre-allocated buffer for computations + * \param[out] res Output prediction result + */ + template + inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, + algorithmFPType * res); + + /** + * \brief Traverse a number of trees to get prediction results + * + * \param hasAnyMissing Data has missing values yes/no + * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no) + * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no) + * \param vectorBlockSize Vector instruction block size + * \param nTrees Number of trees contributing to prediction + * \param nClasses Number of data classes + * \param nRows Number of rows in observation data for which prediction is run + * \param nColumns Number of columns in observation data + * \param x Input observation data + * \param buff A pre-allocated buffer for computations + * \param[out] res Output prediction result + */ + template + inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, + algorithmFPType * res); + + /** + * \brief Traverse a number of trees to get prediction results + * + * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no) + * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no) + * \param vectorBlockSize Vector instruction block size + * \param nTrees Number of trees contributing to prediction + * \param nClasses Number of data classes + * \param nRows Number of rows in observation data for which prediction is run + * \param nColumns Number of columns in observation data + * \param x Input observation data + * \param buff A pre-allocated buffer for computations + * \param[out] res Output prediction result + */ + template + inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, + algorithmFPType * res); + + /** + * \brief Traverse a number of trees to get prediction results + * + * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no) + * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no) + * \param vectorBlockSizeFactor Vector instruction block size - recursively decremented until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor + * \param nTrees Number of trees contributing to prediction + * \param nClasses Number of data classes + * \param nRows Number of rows in observation data for which prediction is run + * \param nColumns Number of columns in observation data + * \param x Input observation data + * \param buff A pre-allocated buffer for computations + * \param[out] res Output prediction result + */ + template + inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, + algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking); + + /** + * \brief Traverse a number of trees to get prediction results + * + * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no) + * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no) + * \param vectorBlockSizeFactor Vector instruction block size - recursively decremented until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor + * \param nTrees Number of trees contributing to prediction + * \param nClasses Number of data classes + * \param nRows Number of rows in observation data for which prediction is run + * \param nColumns Number of columns in observation data + * \param x Input observation data + * \param buff A pre-allocated buffer for computations + * \param[out] res Output prediction result + * \param dim DimType helper + * \param keepLooking + */ + template + inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, + algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking); + + /** + * \brief Traverse a number of trees to get prediction results + * + * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no) + * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no) + * \param nTrees Number of trees contributing to prediction + * \param nClasses Number of data classes + * \param nRows Number of rows in observation data for which prediction is run + * \param nColumns Number of columns in observation data + * \param x Input observation data + * \param buff A pre-allocated buffer for computations + * \param[out] res Output prediction result + * \param dim DimType helper + */ + template + inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, + algorithmFPType * res, const DimType & dim); + + /** + * \brief Traverse a number of trees to get prediction results + * + * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no) + * \param nTrees Number of trees contributing to prediction + * \param nClasses Number of data classes + * \param nRows Number of rows in observation data for which prediction is run + * \param nColumns Number of columns in observation data + * \param x Input observation data + * \param buff A pre-allocated buffer for computations + * \param[out] res Output prediction result + * \param dim DimType helper + */ + template + inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, + algorithmFPType * res, const DimType & dim); - inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher) - { - res[iRow + i] = getMaxClass(val + i * nClasses, nClasses); - } + /** + * \brief Get index of element with maximum value / activation + * + * \param val Pointer to values + * \param nClasses Number of columns per value + * \return size_t The column index with the maximum value + */ + size_t getMaxClass(const algorithmFPType * val, size_t nClasses) const; - inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher) - {} +protected: + const NumericTable * _data; + NumericTable * _res; + NumericTable * _prob; + dtrees::internal::FeatureTypes _featHelper; + TArray _aTree; +}; - inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher) - { - services::internal::service_memset_seq(buff, algorithmFPType(0), buf_size); - return buff; - } +////////////////////////////////////////////////////////////////////////////////////////// +// PredictBinaryClassificationTask - implementation +////////////////////////////////////////////////////////////////////////////////////////// +template +services::Status PredictBinaryClassificationTask::run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, + services::HostAppIface * pHostApp) +{ + DAAL_ASSERT(!nIterations || nIterations <= m->size()); + DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data)); + const auto nTreesTotal = (nIterations ? nIterations : m->size()); + this->_aTree.reset(nTreesTotal); + DAAL_CHECK_MALLOC(this->_aTree.get()); + for (size_t i = 0; i < nTreesTotal; ++i) this->_aTree[i] = m->at(i); - inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher) - { - return buff + buf_shift; - } + const auto nRows = this->_data->getNumberOfRows(); + services::Status s; + DAAL_OVERFLOW_CHECK_BY_MULTIPLICATION(size_t, nRows, sizeof(algorithmFPType)); - inline size_t getNumberOfNodes(size_t nTrees) + // we convert the bias to a margin if it's > 0 + // otherwise the margin is 0 + algorithmFPType margin(0); + if (m->getPredictionBias() > FLT_EPSILON) { - size_t nNodesTotal = 0; - for (size_t iTree = 0; iTree < nTrees; ++iTree) - { - nNodesTotal += this->_aTree[iTree]->getNumberOfNodes(); - } - return nNodesTotal; + margin = getMarginFromModelBias(m->getPredictionBias()); } - inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns) + // compute raw boosted values + if (this->_res && _prob) { - size_t nLvlTotal = 0; - for (size_t iTree = 0; iTree < nTrees; ++iTree) - { - nLvlTotal += this->_aTree[iTree]->getMaxLvl(); - } - if (nLvlTotal <= nColumns) - { - // Checking is compicated. Better to do it during inferense. - return true; - } - else - { - for (size_t idx = 0; idx < nRows * nColumns; ++idx) + WriteOnlyRows resBD(this->_res, 0, nRows); + DAAL_CHECK_BLOCK_STATUS(resBD); + const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) }; + algorithmFPType * res = resBD.get(); + WriteOnlyRows probBD(_prob, 0, nRows); + DAAL_CHECK_BLOCK_STATUS(probBD); + algorithmFPType * prob_pred = probBD.get(); + TArray expValPtr(nRows); + algorithmFPType * expVal = expValPtr.get(); + DAAL_CHECK_MALLOC(expVal); + s = super::runInternal(pHostApp, this->_res, margin, false, false); + if (!s) return s; + + auto nBlocks = daal::threader_get_threads_number(); + const size_t blockSize = nRows / nBlocks; + nBlocks += (nBlocks * blockSize != nRows); + + daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) { + const size_t startRow = iBlock * blockSize; + const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize); + daal::internal::MathInst::vExp(finishRow - startRow, res + startRow, expVal + startRow); + + PRAGMA_IVDEP + PRAGMA_VECTOR_ALWAYS + for (size_t iRow = startRow; iRow < finishRow; ++iRow) { - if (checkFinitenessByComparison(x[idx])) return true; + res[iRow] = label[services::internal::SignBit::get(res[iRow])]; + prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]); + prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1]; } - } - return false; + }); } - template - inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, - algorithmFPType * res) + else if ((!this->_res) && _prob) { - dispatcher_t dispatcher; - size_t iRow = 0; - for (; iRow + vectorBlockSize <= nRows; iRow += vectorBlockSize) - { - algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses * vectorBlockSize, BooleanConstant()); - predictByTreesVector(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher); - for (size_t i = 0; i < vectorBlockSize; ++i) + WriteOnlyRows probBD(_prob, 0, nRows); + DAAL_CHECK_BLOCK_STATUS(probBD); + algorithmFPType * prob_pred = probBD.get(); + TArray expValPtr(nRows); + algorithmFPType * expVal = expValPtr.get(); + NumericTablePtr expNT = HomogenNumericTableCPU::create(expVal, 1, nRows, &s); + DAAL_CHECK_MALLOC(expVal); + s = super::runInternal(pHostApp, expNT.get(), margin, false, false); + if (!s) return s; + + auto nBlocks = daal::threader_get_threads_number(); + const size_t blockSize = nRows / nBlocks; + nBlocks += (nBlocks * blockSize != nRows); + daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) { + const size_t startRow = iBlock * blockSize; + const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize); + daal::internal::MathInst::vExp(finishRow - startRow, expVal + startRow, expVal + startRow); + for (size_t iRow = startRow; iRow < finishRow; ++iRow) { - updateResult(res, val, iRow, i, nClasses, BooleanConstant()); + prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]); + prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1]; } - } - for (; iRow < nRows; ++iRow) + }); + } + else if (this->_res && (!_prob)) + { + WriteOnlyRows resBD(this->_res, 0, nRows); + DAAL_CHECK_BLOCK_STATUS(resBD); + const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) }; + algorithmFPType * res = resBD.get(); + s = super::runInternal(pHostApp, this->_res, margin, false, false); + if (!s) return s; + + typedef services::internal::SignBit SignBit; + + PRAGMA_IVDEP + for (size_t iRow = 0; iRow < nRows; ++iRow) { - algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses, BooleanConstant()); - predictByTrees(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher); - updateResult(res, val, iRow, 0, nClasses, BooleanConstant()); + // probability is a sigmoid(f) hence sign(f) can be checked + const algorithmFPType initial = res[iRow]; + const int sign = SignBit::get(initial); + res[iRow] = label[sign]; } } + return s; +} - template - inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, - algorithmFPType * res) +template +algorithmFPType PredictBinaryClassificationTask::getMarginFromModelBias(algorithmFPType bias) const +{ + DAAL_ASSERT((0.0 < bias) && (bias < 1.0)); + constexpr algorithmFPType one(1); + // convert bias to margin + return -one * daal::internal::MathInst::sLog(one / bias - one); +} + +////////////////////////////////////////////////////////////////////////////////////////// +// PredictMulticlassTask - implementation +////////////////////////////////////////////////////////////////////////////////////////// + +template +void PredictMulticlassTask::updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, + BooleanConstant dispatcher) +{ + res[iRow + i] = getMaxClass(val + i * nClasses, nClasses); +} + +template +void PredictMulticlassTask::updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, + BooleanConstant dispatcher) +{} + +template +algorithmFPType * PredictMulticlassTask::updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, + BooleanConstant dispatcher) +{ + services::internal::service_memset_seq(buff, algorithmFPType(0), buf_size); + return buff; +} + +template +algorithmFPType * PredictMulticlassTask::updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, + BooleanConstant dispatcher) +{ + return buff + buf_shift; +} + +template +size_t PredictMulticlassTask::getNumberOfNodes(size_t nTrees) +{ + size_t nNodesTotal = 0; + for (size_t iTree = 0; iTree < nTrees; ++iTree) { - if (this->_featHelper.hasUnorderedFeatures()) - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res); - } - else - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res); - } + nNodesTotal += this->_aTree[iTree]->getNumberOfNodes(); } + return nNodesTotal; +} - template - inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, - algorithmFPType * res) +template +bool PredictMulticlassTask::checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns) +{ + size_t nLvlTotal = 0; + for (size_t iTree = 0; iTree < nTrees; ++iTree) { - const bool hasAnyMissing = checkForMissing(x, nTrees, nRows, nColumns); - if (hasAnyMissing) - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res); - } - else + nLvlTotal += this->_aTree[iTree]->getMaxLvl(); + } + if (nLvlTotal <= nColumns) + { + // Checking is complicated. Better to do it during inference. + return true; + } + else + { + for (size_t idx = 0; idx < nRows * nColumns; ++idx) { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res); + if (checkFinitenessByComparison(x[idx])) return true; } } + return false; +} - // Recursivelly checking template parameter until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor. - template - inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, - algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking) +template +template +void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, + algorithmFPType * buff, algorithmFPType * res) +{ + dispatcher_t dispatcher; + size_t iRow = 0; + for (; iRow + vectorBlockSize <= nRows; iRow += vectorBlockSize) { - constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep; - if (dim.vectorBlockSizeFactor == vectorBlockSizeFactor) - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res); - } - else + algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses * vectorBlockSize, BooleanConstant()); + predictByTreesVector(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher); + for (size_t i = 0; i < vectorBlockSize; ++i) { - predict( - nTrees, nClasses, nRows, nColumns, x, buff, res, dim, BooleanConstant()); + updateResult(res, val, iRow, i, nClasses, BooleanConstant()); } } + for (; iRow < nRows; ++iRow) + { + algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses, BooleanConstant()); + predictByTrees(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher); + updateResult(res, val, iRow, 0, nClasses, BooleanConstant()); + } +} - template - inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, - algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking) +template +template +void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, + algorithmFPType * buff, algorithmFPType * res) +{ + if (this->_featHelper.hasUnorderedFeatures()) { - constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep; - predict(nTrees, nClasses, nRows, nColumns, x, buff, res); + predict(nTrees, nClasses, nRows, nColumns, x, buff, res); + } + else + { + predict(nTrees, nClasses, nRows, nColumns, x, buff, res); } +} - template - inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, - algorithmFPType * res, const DimType & dim) +template +template +void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, + algorithmFPType * buff, algorithmFPType * res) +{ + const bool hasAnyMissing = checkForMissing(x, nTrees, nRows, nColumns); + if (hasAnyMissing) { - constexpr size_t maxVectorBlockSizeFactor = DimType::maxVectorBlockSizeFactor; - if (maxVectorBlockSizeFactor > 1) - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim, - BooleanConstant()); - } - else - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim, - BooleanConstant()); - } + predict(nTrees, nClasses, nRows, nColumns, x, buff, res); + } + else + { + predict(nTrees, nClasses, nRows, nColumns, x, buff, res); } +} - template - inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff, - algorithmFPType * res, const DimType & dim) +template +template +void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, + algorithmFPType * buff, algorithmFPType * res, const DimType & dim, + BooleanConstant keepLooking) +{ + constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep; + if (dim.vectorBlockSizeFactor == vectorBlockSizeFactor) { - if (res) - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim); - } - else - { - predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim); - } + predict(nTrees, nClasses, nRows, nColumns, x, buff, res); + } + else + { + predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim, + BooleanConstant()); } +} + +template +template +void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, + algorithmFPType * buff, algorithmFPType * res, const DimType & dim, + BooleanConstant keepLooking) +{ + constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep; + predict(nTrees, nClasses, nRows, nColumns, x, buff, res); +} - void softmax(algorithmFPType * Input, algorithmFPType * Output, size_t nRows, size_t nCols); +template +template +void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, + algorithmFPType * buff, algorithmFPType * res, const DimType & dim) +{ + constexpr size_t maxVectorBlockSizeFactor = DimType::maxVectorBlockSizeFactor; + if (maxVectorBlockSizeFactor > 1) + { + predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim, BooleanConstant()); + } + else + { + predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim, BooleanConstant()); + } +} - size_t getMaxClass(const algorithmFPType * val, size_t nClasses) const +template +template +void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, + algorithmFPType * buff, algorithmFPType * res, const DimType & dim) +{ + if (res) + { + predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim); + } + else { - return services::internal::getMaxElementIndex(val, nClasses); + predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim); } +} -protected: - const NumericTable * _data; - NumericTable * _res; - NumericTable * _prob; - dtrees::internal::FeatureTypes _featHelper; - TArray _aTree; -}; +template +size_t PredictMulticlassTask::getMaxClass(const algorithmFPType * val, size_t nClasses) const +{ + return services::internal::getMaxElementIndex(val, nClasses); +} ////////////////////////////////////////////////////////////////////////////////////////// // PredictKernel @@ -398,7 +740,7 @@ services::Status PredictMulticlassTask::run(const gbt::cla DimType dim(*_data, nTreesTotal, getNumberOfNodes(nTreesTotal)); - return predictByAllTrees(nTreesTotal, nClasses, dim); + return predictByAllTrees(nTreesTotal, nClasses, m->getPredictionBias(), dim); } template @@ -433,7 +775,7 @@ void PredictMulticlassTask::predictByTreesVector(algorithm } template -services::Status PredictMulticlassTask::predictByAllTrees(size_t nTreesTotal, size_t nClasses, const DimType & dim) +services::Status PredictMulticlassTask::predictByAllTrees(size_t nTreesTotal, size_t nClasses, double bias, const DimType & dim) { WriteOnlyRows resBD(_res, 0, dim.nRowsTotal); DAAL_CHECK_BLOCK_STATUS(resBD); @@ -449,7 +791,7 @@ services::Status PredictMulticlassTask::predictByAllTrees( DAAL_OVERFLOW_CHECK_BY_MULTIPLICATION(size_t, nRows * nClasses, sizeof(algorithmFPType)); TArray valPtr(nRows * nClasses); algorithmFPType * valFull = valPtr.get(); - services::internal::service_memset(valFull, algorithmFPType(0), nRows * nClasses); + services::internal::service_memset(valFull, algorithmFPType(bias), nRows * nClasses); daal::threader_for(dim.nDataBlocks, dim.nDataBlocks, [&](size_t iBlock) { const size_t iStartRow = iBlock * dim.nRowsInBlock; diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h index 1a32131f8b8..0efa3707206 100755 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h @@ -51,9 +51,10 @@ class PredictKernel : public daal::algorithms::Kernel /** * \brief Compute gradient boosted trees prediction results. * - * \param a[in] Matrix of input variables X - * \param m[in] Gradient boosted trees model obtained on training stage - * \param r[out] Prediction results + * \param a[in] Matrix of input variables X + * \param m[in] Gradient boosted trees model obtained on training stage + * \param r[out] Prediction results + * \param prob[out] Prediction class probabilities * \param nClasses[in] Number of classes in gradient boosted trees algorithm parameter * \param nIterations[in] Number of iterations to predict in gradient boosted trees algorithm parameter */ diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp index f3ac56aa5f5..58e3da1a52b 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp @@ -96,20 +96,27 @@ services::Status Input::check(const daal::algorithms::Parameter * parameter, int size_t nClasses = 0, nIterations = 0; - const gbt::classification::prediction::interface2::Parameter * pPrm2 = + const gbt::classification::prediction::interface2::Parameter * pPrm = dynamic_cast(parameter); - if (pPrm2) + if (pPrm) { - nClasses = pPrm2->nClasses; - nIterations = pPrm2->nIterations; + nClasses = pPrm->nClasses; + nIterations = pPrm->nIterations; } else + { return services::ErrorNullParameterNotSupported; + } auto maxNIterations = pModel->getNumberOfTrees(); if (nClasses > 2) maxNIterations /= nClasses; DAAL_CHECK((nClasses < 3) || (pModel->getNumberOfTrees() % nClasses == 0), services::ErrorGbtIncorrectNumberOfTrees); DAAL_CHECK((nIterations == 0) || (nIterations <= maxNIterations), services::ErrorGbtPredictIncorrectNumberOfIterations); + + const bool predictContribs = pPrm->resultsToCompute & shapContributions; + const bool predictInteractions = pPrm->resultsToCompute & shapInteractions; + DAAL_CHECK(!(predictContribs || predictInteractions), services::ErrorMethodNotImplemented); + return s; } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp index 05974244d41..391a267f62c 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp @@ -23,7 +23,6 @@ #include "services/daal_defines.h" #include "src/algorithms/dtrees/gbt/gbt_model_impl.h" -#include "src/algorithms/dtrees/dtrees_model_impl_common.h" using namespace daal::data_management; using namespace daal::services; @@ -63,8 +62,8 @@ void ModelImpl::traverseDF(size_t iTree, algorithms::regression::TreeNodeVisitor const GbtDecisionTree & gbtTree = *at(iTree); - const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints(); - const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); + const ModelFPType * splitPoints = gbtTree.getSplitPoints(); + const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); auto onSplitNodeFunc = [&splitPoints, &splitFeatures, &visitor](size_t iRowInTable, size_t level) -> bool { return visitor.onSplitNode(level, splitFeatures[iRowInTable], splitPoints[iRowInTable]); @@ -83,8 +82,8 @@ void ModelImpl::traverseBF(size_t iTree, algorithms::regression::TreeNodeVisitor const GbtDecisionTree & gbtTree = *at(iTree); - const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints(); - const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); + const ModelFPType * splitPoints = gbtTree.getSplitPoints(); + const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); auto onSplitNodeFunc = [&splitFeatures, &splitPoints, &visitor](size_t iRowInTable, size_t level) -> bool { return visitor.onSplitNode(level, splitFeatures[iRowInTable], splitPoints[iRowInTable]); @@ -108,10 +107,10 @@ void ModelImpl::traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisito const GbtDecisionTree & gbtTree = *at(iTree); - const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints(); - const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); - const int * nodeSamplesCount = getNodeSampleCount(iTree); - const double * imp = getImpVals(iTree); + const ModelFPType * splitPoints = gbtTree.getSplitPoints(); + const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); + const int * nodeSamplesCount = getNodeSampleCount(iTree); + const double * imp = getImpVals(iTree); auto onSplitNodeFunc = [&splitFeatures, &splitPoints, &nodeSamplesCount, &imp, &visitor](size_t iRowInTable, size_t level) -> bool { tree_utils::SplitNodeDescriptor descSplit; @@ -148,10 +147,10 @@ void ModelImpl::traverseDFS(size_t iTree, tree_utils::regression::TreeNodeVisito const GbtDecisionTree & gbtTree = *at(iTree); - const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints(); - const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); - const int * nodeSamplesCount = getNodeSampleCount(iTree); - const double * imp = getImpVals(iTree); + const ModelFPType * splitPoints = gbtTree.getSplitPoints(); + const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); + const int * nodeSamplesCount = getNodeSampleCount(iTree); + const double * imp = getImpVals(iTree); auto onSplitNodeFunc = [&splitFeatures, &splitPoints, &nodeSamplesCount, &imp, &visitor](size_t iRowInTable, size_t level) -> bool { tree_utils::SplitNodeDescriptor descSplit; @@ -226,15 +225,18 @@ void ModelImpl::destroy() super::destroy(); } -bool ModelImpl::nodeIsDummyLeaf(size_t idx, const GbtDecisionTree & gbtTree) +bool ModelImpl::nodeIsDummyLeaf(size_t nodeIndex, const GbtDecisionTree & gbtTree) { - const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints(); - const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); + const size_t childArrayIndex = nodeIndex - 1; + const ModelFPType * splitPoints = gbtTree.getSplitPoints(); + const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit(); - if (idx) + if (childArrayIndex) { - const size_t parent = getIdxOfParent(idx); - return splitPoints[parent] == splitPoints[idx] && splitFeatures[parent] == splitFeatures[idx]; + // check if child node has same split feature and split value as parent + const size_t parent = getIdxOfParent(nodeIndex); + const size_t parentArrayIndex = parent - 1; + return splitPoints[parentArrayIndex] == splitPoints[childArrayIndex] && splitFeatures[parentArrayIndex] == splitFeatures[childArrayIndex]; } return false; } @@ -245,16 +247,16 @@ bool ModelImpl::nodeIsLeaf(size_t idx, const GbtDecisionTree & gbtTree, const si { return true; } - else if (nodeIsDummyLeaf(2 * idx + 1, gbtTree)) // check, that left son is dummy + else if (nodeIsDummyLeaf(2 * idx, gbtTree)) // check, that left son is dummy { return true; } return false; } -size_t ModelImpl::getIdxOfParent(const size_t sonIdx) +size_t ModelImpl::getIdxOfParent(const size_t childIdx) { - return sonIdx ? (sonIdx - 1) / 2 : 0; + return childIdx / 2; } void ModelImpl::decisionTreeToGbtTree(const DecisionTreeTable & tree, GbtDecisionTree & newTree) @@ -270,9 +272,10 @@ void ModelImpl::decisionTreeToGbtTree(const DecisionTreeTable & tree, GbtDecisio NodeType * sons = sonsArr.data(); NodeType * parents = parentsArr.data(); - gbt::prediction::internal::ModelFPType * const spitPoints = newTree.getSplitPoints(); - gbt::prediction::internal::FeatureIndexType * const featureIndexes = newTree.getFeatureIndexesForSplit(); - int * const defaultLeft = newTree.getdefaultLeftForSplit(); + ModelFPType * const splitPoints = newTree.getSplitPoints(); + FeatureIndexType * const featureIndexes = newTree.getFeatureIndexesForSplit(); + ModelFPType * const nodeCoverValues = newTree.getNodeCoverValues(); + int * const defaultLeft = newTree.getDefaultLeftForSplit(); for (size_t i = 0; i < nSourceNodes; ++i) { @@ -293,21 +296,22 @@ void ModelImpl::decisionTreeToGbtTree(const DecisionTreeTable & tree, GbtDecisio if (p->isSplit()) { - sons[nSons++] = arr + p->leftIndexOrClass; - sons[nSons++] = arr + p->leftIndexOrClass + 1; - featureIndexes[idxInTable] = p->featureIndex; - defaultLeft[idxInTable] = p->defaultLeft; + sons[nSons++] = arr + p->leftIndexOrClass; + sons[nSons++] = arr + p->leftIndexOrClass + 1; + featureIndexes[idxInTable] = p->featureIndex; + nodeCoverValues[idxInTable] = p->cover; + defaultLeft[idxInTable] = p->defaultLeft; DAAL_ASSERT(featureIndexes[idxInTable] >= 0); - spitPoints[idxInTable] = p->featureValueOrResponse; + splitPoints[idxInTable] = p->featureValueOrResponse; } else { - sons[nSons++] = p; - sons[nSons++] = p; - featureIndexes[idxInTable] = 0; - defaultLeft[idxInTable] = 0; - DAAL_ASSERT(featureIndexes[idxInTable] >= 0); - spitPoints[idxInTable] = p->featureValueOrResponse; + sons[nSons++] = p; + sons[nSons++] = p; + featureIndexes[idxInTable] = 0; + nodeCoverValues[idxInTable] = p->cover; + defaultLeft[idxInTable] = 0; + splitPoints[idxInTable] = p->featureValueOrResponse; } idxInTable++; @@ -351,7 +355,8 @@ void ModelImpl::getMaxLvl(const dtrees::internal::DecisionTreeNode * const arr, const GbtDecisionTree * ModelImpl::at(const size_t idx) const { - return (const GbtDecisionTree *)(*super::_serializationData)[idx].get(); + auto * const rawTree = (*super::_serializationData)[idx].get(); + return static_cast(rawTree); } } // namespace internal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h index f56b4fc3126..5ae2a8504a0 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h @@ -26,7 +26,6 @@ #include "src/algorithms/dtrees/dtrees_model_impl.h" #include "algorithms/regression/tree_traverse.h" -#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i" #include "algorithms/tree_utils/tree_utils_regression.h" #include "src/algorithms/dtrees/dtrees_model_impl_common.h" #include "src/services/service_arrays.h" @@ -41,12 +40,13 @@ namespace gbt { namespace internal { +typedef uint32_t FeatureIndexType; +typedef float ModelFPType; +typedef services::Collection NodeIdxArray; + static inline size_t getNumberOfNodesByLvls(const size_t nLvls) { - size_t nNodes = 2; // nNodes = pow(2, nLvl+1) - 1 - for (size_t i = 0; i < nLvls; ++i) nNodes *= 2; - nNodes--; - return nNodes; + return (1 << (nLvls + 1)) - 1; } template @@ -61,43 +61,61 @@ class GbtDecisionTree : public SerializationIface { public: DECLARE_SERIALIZABLE(); - using SplitPointType = HomogenNumericTable; - using FeatureIndexesForSplitType = HomogenNumericTable; + using SplitPointType = HomogenNumericTable; + using NodeCoverType = HomogenNumericTable; + using FeatureIndexesForSplitType = HomogenNumericTable; using defaultLeftForSplitType = HomogenNumericTable; - GbtDecisionTree(const size_t nNodes, const size_t maxLvl, const size_t sourceNumOfNodes) + GbtDecisionTree(const size_t nNodes, const size_t maxLvl) : _nNodes(nNodes), _maxLvl(maxLvl), - _sourceNumOfNodes(sourceNumOfNodes), _splitPoints(SplitPointType::create(1, nNodes, NumericTableIface::doAllocate)), _featureIndexes(FeatureIndexesForSplitType::create(1, nNodes, NumericTableIface::doAllocate)), - _defaultLeft(defaultLeftForSplitType::create(1, nNodes, NumericTableIface::doAllocate)) + _nodeCoverValues(NodeCoverType::create(1, nNodes, NumericTableIface::doAllocate)), + _defaultLeft(defaultLeftForSplitType::create(1, nNodes, NumericTableIface::doAllocate)), + nNodeSplitFeature(), + CoverFeature(), + GainFeature() {} - // for serailization only - GbtDecisionTree() : _nNodes(0), _maxLvl(0), _sourceNumOfNodes(0) {} + // for serialization only + GbtDecisionTree() : _nNodes(0), _maxLvl(0) {} + + ModelFPType * getSplitPoints() { return _splitPoints->getArray(); } + + FeatureIndexType * getFeatureIndexesForSplit() { return _featureIndexes->getArray(); } - gbt::prediction::internal::ModelFPType * getSplitPoints() { return _splitPoints->getArray(); } + int * getDefaultLeftForSplit() { return _defaultLeft->getArray(); } - gbt::prediction::internal::FeatureIndexType * getFeatureIndexesForSplit() { return _featureIndexes->getArray(); } + const ModelFPType * getSplitPoints() const { return _splitPoints->getArray(); } - int * getdefaultLeftForSplit() { return _defaultLeft->getArray(); } + const FeatureIndexType * getFeatureIndexesForSplit() const { return _featureIndexes->getArray(); } - const gbt::prediction::internal::ModelFPType * getSplitPoints() const { return _splitPoints->getArray(); } + ModelFPType * getNodeCoverValues() { return _nodeCoverValues->getArray(); } - const gbt::prediction::internal::FeatureIndexType * getFeatureIndexesForSplit() const { return _featureIndexes->getArray(); } + const ModelFPType * getNodeCoverValues() const { return _nodeCoverValues->getArray(); } - const int * getdefaultLeftForSplit() const { return _defaultLeft->getArray(); } + const int * getDefaultLeftForSplit() const { return _defaultLeft->getArray(); } size_t getNumberOfNodes() const { return _nNodes; } size_t * getArrayNumSplitFeature() { return nNodeSplitFeature.data(); } + const size_t * getArrayNumSplitFeature() const { return nNodeSplitFeature.data(); } + size_t * getArrayCoverFeature() { return CoverFeature.data(); } + const size_t * getArrayCoverFeature() const { return CoverFeature.data(); } + + services::Collection getCoverFeature() { return CoverFeature; } + + const services::Collection & getCoverFeature() const { return CoverFeature; } + double * getArrayGainFeature() { return GainFeature.data(); } - gbt::prediction::internal::FeatureIndexType getMaxLvl() const { return _maxLvl; } + const double * getArrayGainFeature() const { return GainFeature.data(); } + + FeatureIndexType getMaxLvl() const { return _maxLvl; } // recursive build of tree (breadth-first) template @@ -113,8 +131,8 @@ class GbtDecisionTree : public SerializationIface int result = 0; - gbt::prediction::internal::ModelFPType * const spitPoints = tree->getSplitPoints(); - gbt::prediction::internal::FeatureIndexType * const featureIndexes = tree->getFeatureIndexesForSplit(); + ModelFPType * const splitPoints = tree->getSplitPoints(); + FeatureIndexType * const featureIndexes = tree->getFeatureIndexesForSplit(); for (size_t i = 0; i < nNodes; ++i) { @@ -163,7 +181,7 @@ class GbtDecisionTree : public SerializationIface DAAL_ASSERT(featureIndexes[idxInTable] >= 0); nNodeSamplesVals[idxInTable] = (int)p->count; impVals[idxInTable] = p->impurity; - spitPoints[idxInTable] = p->featureValue; + splitPoints[idxInTable] = p->featureValue; idxInTable++; } @@ -183,10 +201,10 @@ class GbtDecisionTree : public SerializationIface { arch->set(_nNodes); arch->set(_maxLvl); - arch->set(_sourceNumOfNodes); arch->setSharedPtrObj(_splitPoints); arch->setSharedPtrObj(_featureIndexes); + arch->setSharedPtrObj(_nodeCoverValues); arch->setSharedPtrObj(_defaultLeft); return services::Status(); @@ -194,10 +212,10 @@ class GbtDecisionTree : public SerializationIface protected: size_t _nNodes; - gbt::prediction::internal::FeatureIndexType _maxLvl; - size_t _sourceNumOfNodes; + FeatureIndexType _maxLvl; services::SharedPtr _splitPoints; services::SharedPtr _featureIndexes; + services::SharedPtr _nodeCoverValues; services::SharedPtr _defaultLeft; services::Collection nNodeSplitFeature; services::Collection CoverFeature; @@ -222,7 +240,7 @@ class GbtTreeImpl : public dtrees::internal::TreeImpl getMaxLvl(*super::top(), nLvls, static_cast(-1)); const size_t nNodes = getNumberOfNodesByLvls(nLvls); - *pTbl = new GbtDecisionTree(nNodes, nLvls, super::top()->numChildren() + 1); + *pTbl = new GbtDecisionTree(nNodes, nLvls); *pTblImp = new HomogenNumericTable(1, nNodes, NumericTable::doAllocate); *pTblSmplCnt = new HomogenNumericTable(1, nNodes, NumericTable::doAllocate); @@ -264,6 +282,29 @@ using TreeImpRegression = GbtTreeImpl > > using TreeImpClassification = GbtTreeImpl, Allocator>; +struct DecisionTreeNode +{ + size_t dimension; + size_t leftIndexOrClass; + double cutPointOrDependantVariable; +}; + +class DecisionTreeTable : public data_management::AOSNumericTable +{ +public: + DecisionTreeTable(size_t rowCount, services::Status & st) : data_management::AOSNumericTable(sizeof(DecisionTreeNode), 3, rowCount, st) + { + setFeature(0, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, dimension)); + setFeature(1, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, leftIndexOrClass)); + setFeature(2, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, cutPointOrDependantVariable)); + st |= allocateDataMemory(); + } + DecisionTreeTable(services::Status & st) : DecisionTreeTable(0, st) {} +}; + +typedef services::SharedPtr DecisionTreeTablePtr; +typedef services::SharedPtr DecisionTreeTableConstPtr; + class ModelImpl : protected dtrees::internal::ModelImpl { public: @@ -292,9 +333,33 @@ class ModelImpl : protected dtrees::internal::ModelImpl static services::Status treeToTable(TreeType & t, gbt::internal::GbtDecisionTree ** pTbl, HomogenNumericTable ** pTblImp, HomogenNumericTable ** pTblSmplCnt, size_t nFeature); -protected: + /** + * \brief Returns true if a node is a dummy leaf. A dummy leaf contains the same split feature & value as the parent + * + * \param nodeIndex 1-based index to the node array + * \param gbtTree tree containing nodes + * \param lvl current level in the tree + * \return true if the node is a dummy leaf, false otherwise + */ static bool nodeIsDummyLeaf(size_t idx, const GbtDecisionTree & gbtTree); + + /** + * \brief Return true if a node is leaf + * + * \param idx 1-based index to the node array + * \param gbtTree tree containing nodes + * \param lvl current level in the tree + * \return true if the node is a leaf, false otherwise + */ static bool nodeIsLeaf(size_t idx, const GbtDecisionTree & gbtTree, const size_t lvl); + +protected: + /** + * \brief Return the node index of the provided node's parent + * + * \param childIdx 1-based node index of the child + * \return size_t 1-based node index of the parent + */ static size_t getIdxOfParent(const size_t sonIdx); static void getMaxLvl(const dtrees::internal::DecisionTreeNode * const arr, const size_t idx, size_t & maxLvl, size_t curLvl = 0); @@ -306,21 +371,22 @@ class ModelImpl : protected dtrees::internal::ModelImpl getMaxLvl(arr, 0, nLvls, static_cast(-1)); const size_t nNodes = getNumberOfNodesByLvls(nLvls); - return new GbtDecisionTree(nNodes, nLvls, tree.getNumberOfRows()); + return new GbtDecisionTree(nNodes, nLvls); } template static void traverseGbtDF(size_t level, size_t iRowInTable, const GbtDecisionTree & gbtTree, OnSplitFunctor & visitSplit, OnLeafFunctor & visitLeaf) { - if (!nodeIsLeaf(iRowInTable, gbtTree, level)) + const size_t oneBasedNodeIndex = iRowInTable + 1; + if (!nodeIsLeaf(oneBasedNodeIndex, gbtTree, level)) { if (!visitSplit(iRowInTable, level)) return; //do not continue traversing traverseGbtDF(level + 1, iRowInTable * 2 + 1, gbtTree, visitSplit, visitLeaf); traverseGbtDF(level + 1, iRowInTable * 2 + 2, gbtTree, visitSplit, visitLeaf); } - else if (!nodeIsDummyLeaf(iRowInTable, gbtTree)) + else if (!nodeIsDummyLeaf(oneBasedNodeIndex, gbtTree)) { if (!visitLeaf(iRowInTable, level)) return; //do not continue traversing } @@ -334,14 +400,15 @@ class ModelImpl : protected dtrees::internal::ModelImpl { for (size_t j = 0; j < (level ? 2 : 1); ++j) { - size_t iRowInTable = aCur[i] + j; - if (!nodeIsLeaf(iRowInTable, gbtTree, level)) + const size_t iRowInTable = aCur[i] + j; + const size_t oneBasedNodeIndex = iRowInTable + 1; + if (!nodeIsLeaf(oneBasedNodeIndex, gbtTree, level)) { if (!visitSplit(iRowInTable, level)) return; //do not continue traversing aNext.push_back(iRowInTable * 2 + 1); } - else if (!nodeIsDummyLeaf(iRowInTable, gbtTree)) + else if (!nodeIsDummyLeaf(oneBasedNodeIndex, gbtTree)) { if (!visitLeaf(iRowInTable, level)) return; //do not continue traversing } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i index 18976987b37..fd76f8da721 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i +++ b/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i @@ -25,14 +25,15 @@ #ifndef __GBT_PREDICT_DENSE_DEFAULT_IMPL_I__ #define __GBT_PREDICT_DENSE_DEFAULT_IMPL_I__ +#include "data_management/data/internal/finiteness_checker.h" +#include "src/algorithms/dtrees/dtrees_feature_type_helper.h" #include "src/algorithms/dtrees/dtrees_model_impl.h" -#include "src/algorithms/dtrees/dtrees_train_data_helper.i" #include "src/algorithms/dtrees/dtrees_predict_dense_default_impl.i" -#include "src/algorithms/dtrees/dtrees_feature_type_helper.h" +#include "src/algorithms/dtrees/dtrees_train_data_helper.i" #include "src/algorithms/dtrees/gbt/gbt_internal.h" -#include "src/services/service_environment.h" +#include "src/algorithms/dtrees/gbt/gbt_model_impl.h" #include "src/services/service_defines.h" -#include "data_management/data/internal/finiteness_checker.h" +#include "src/services/service_environment.h" namespace daal { @@ -44,8 +45,8 @@ namespace prediction { namespace internal { -typedef float ModelFPType; -typedef uint32_t FeatureIndexType; +using gbt::internal::ModelFPType; +using gbt::internal::FeatureIndexType; template struct PredictDispatcher @@ -101,7 +102,7 @@ inline void predictForTreeVector(const DecisionTreeType & t, const FeatureTypes { const ModelFPType * const values = t.getSplitPoints() - 1; const FeatureIndexType * const fIndexes = t.getFeatureIndexesForSplit() - 1; - const int * const defaultLeft = t.getdefaultLeftForSplit() - 1; + const int * const defaultLeft = t.getDefaultLeftForSplit() - 1; const FeatureIndexType nFeat = featTypes.getNumberOfFeatures(); FeatureIndexType i[vectorBlockSize]; @@ -135,7 +136,7 @@ inline algorithmFPType predictForTree(const DecisionTreeType & t, const FeatureT { const ModelFPType * const values = (const ModelFPType *)t.getSplitPoints() - 1; const FeatureIndexType * const fIndexes = t.getFeatureIndexesForSplit() - 1; - const int * const defaultLeft = t.getdefaultLeftForSplit() - 1; + const int * const defaultLeft = t.getDefaultLeftForSplit() - 1; const FeatureIndexType maxLvl = t.getMaxLvl(); @@ -167,7 +168,7 @@ struct TileDimensions static constexpr size_t minVectorBlockSizeFactor = 2; static constexpr size_t vectorBlockSizeStep = 16; // optimalBlockSizeFactor is selected from benchmarking - static constexpr size_t optimalBlockSizeFactor = 3; + static constexpr size_t optimalBlockSizeFactor = 5; TileDimensions(const NumericTable & data, size_t nTrees, size_t nNodes) : nTreesTotal(nTrees), nRowsTotal(data.getNumberOfRows()), nCols(data.getNumberOfColumns()) diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD b/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD index 7780e5acaa6..8e85e233059 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD @@ -1,10 +1,12 @@ package(default_visibility = ["//visibility:public"]) load("@onedal//dev/bazel:daal.bzl", "daal_module") +load("@onedal//dev/bazel:dal.bzl", "dal_test_suite") daal_module( name = "kernel", - auto = True, - opencl = True, + auto = False, + hdrs = glob(["**/*.h", "**/*.i", "**/*.cl"]), + srcs = glob(["*.cpp"]), deps = [ "@onedal//cpp/daal:core", "@onedal//cpp/daal:sycl", @@ -12,3 +14,16 @@ daal_module( "@onedal//cpp/daal/src/algorithms/dtrees/gbt:kernel", ], ) + +dal_test_suite( + name = "tests", + compile_as = [ "c++" ], + private = True, + srcs = glob([ + "test/*unit.cpp", + ]), + dal_deps = [ + "@onedal//cpp/daal/src/algorithms/regression:kernel", + "@onedal//cpp/daal/src/algorithms/dtrees/gbt:kernel", + ] +) diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp index 2c2818b3104..0012688dc1e 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp @@ -85,9 +85,20 @@ void ModelImpl::traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisito ImplType::traverseBFS(iTree, visitor); } +void ModelImpl::setPredictionBias(double value) +{ + _predictionBias = value; +} + +double ModelImpl::getPredictionBias() const +{ + return _predictionBias; +} + services::Status ModelImpl::serializeImpl(data_management::InputDataArchive * arch) { auto s = algorithms::regression::Model::serialImpl(arch); + arch->set(this->_predictionBias); s.add(algorithms::regression::internal::ModelInternal::serialImpl(arch)); return s.add(ImplType::serialImpl(arch)); } @@ -95,6 +106,7 @@ services::Status ModelImpl::serializeImpl(data_management::InputDataArchive * ar services::Status ModelImpl::deserializeImpl(const data_management::OutputDataArchive * arch) { auto s = algorithms::regression::Model::serialImpl(arch); + arch->set(this->_predictionBias); s.add(algorithms::regression::internal::ModelInternal::serialImpl(arch)); return s.add(ImplType::serialImpl(arch)); } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp index aefc88d2ab0..7c1a1d10bf2 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp @@ -71,21 +71,21 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, TreeId & resId) return daal::algorithms::dtrees::internal::createTreeInternal(modelImplRef._serializationData, nNodes, resId); } -services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res) +services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res) { gbt::regression::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, response, - res); + cover, res); } services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, - NodeId & res, int defaultLeft) + int defaultLeft, double cover, NodeId & res) { gbt::regression::internal::ModelImpl & modelImplRef = daal::algorithms::dtrees::internal::getModelRef(_model); return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex, - featureValue, res, defaultLeft); + featureValue, defaultLeft, cover, res); } } // namespace interface1 diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h index a711f0d747c..13409fb1790 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h @@ -61,10 +61,17 @@ class ModelImpl : public daal::algorithms::gbt::regression::Model, virtual void traverseDFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE; virtual void traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE; + virtual void setPredictionBias(double value) DAAL_C11_OVERRIDE; + virtual double getPredictionBias() const DAAL_C11_OVERRIDE; + virtual services::Status serializeImpl(data_management::InputDataArchive * arch) DAAL_C11_OVERRIDE; virtual services::Status deserializeImpl(const data_management::OutputDataArchive * arch) DAAL_C11_OVERRIDE; virtual size_t getNumberOfTrees() const DAAL_C11_OVERRIDE; + +private: + /* global bias applied to predictions*/ + double _predictionBias; }; } // namespace internal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h index 8e3a231e1eb..d2e4bd4bdf7 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h @@ -61,8 +61,10 @@ services::Status BatchContainer::compute() const gbt::regression::prediction::Parameter * par = static_cast(_par); daal::services::Environment::env & env = *_env; + const bool predShapContributions = par->resultsToCompute & shapContributions; + const bool predShapInteractions = par->resultsToCompute & shapInteractions; __DAAL_CALL_KERNEL(env, internal::PredictKernel, __DAAL_KERNEL_ARGUMENTS(algorithmFPType, method), compute, - daal::services::internal::hostApp(*input), a, m, r, par->nIterations); + daal::services::internal::hostApp(*input), a, m, r, par->nIterations, predShapContributions, predShapInteractions); } } // namespace prediction diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i index cd2410e1692..7b506fe36b5 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i @@ -26,15 +26,17 @@ #include "algorithms/algorithm.h" #include "data_management/data/numeric_table.h" -#include "src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h" -#include "src/threading/threading.h" #include "services/daal_defines.h" +#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i" #include "src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h" -#include "src/data_management/service_numeric_table.h" +#include "src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h" +#include "src/algorithms/dtrees/gbt/treeshap.h" +#include "src/algorithms/dtrees/regression/dtrees_regression_predict_dense_default_impl.i" #include "src/algorithms/service_error_handling.h" +#include "src/data_management/service_numeric_table.h" #include "src/externals/service_memory.h" -#include "src/algorithms/dtrees/regression/dtrees_regression_predict_dense_default_impl.i" -#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i" +#include "src/threading/threading.h" +#include // DBL_EPSILON using namespace daal::internal; using namespace daal::services::internal; @@ -51,6 +53,7 @@ namespace prediction { namespace internal { +using gbt::internal::FeatureIndexType; ////////////////////////////////////////////////////////////////////////////////////////// // PredictRegressionTask @@ -62,13 +65,16 @@ public: typedef gbt::internal::GbtDecisionTree TreeType; typedef gbt::prediction::internal::TileDimensions DimType; PredictRegressionTask(const NumericTable * x, NumericTable * y) : _data(x), _res(y) {} - services::Status run(const gbt::regression::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp); + + services::Status run(const gbt::regression::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp, + bool predShapContributions, bool predShapInteractions); protected: template using dispatcher_t = gbt::prediction::internal::PredictDispatcher; - services::Status runInternal(services::HostAppIface * pHostApp, NumericTable * result); + services::Status runInternal(services::HostAppIface * pHostApp, NumericTable * result, double predictionBias, bool predShapContributions, + bool predShapInteractions); template algorithmFPType predictByTrees(size_t iFirstTree, size_t nTrees, const algorithmFPType * x, const dispatcher_t & dispatcher); @@ -78,29 +84,29 @@ protected: inline size_t getNumberOfNodes(size_t nTrees) { - size_t nNodesTotal = 0; - for (size_t iTree = 0; iTree < nTrees; ++iTree) + size_t nNodesTotal = 0ul; + for (size_t iTree = 0ul; iTree < nTrees; ++iTree) { - nNodesTotal += this->_aTree[iTree]->getNumberOfNodes(); + nNodesTotal += _aTree[iTree]->getNumberOfNodes(); } return nNodesTotal; } - inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns) + inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns) const { - size_t nLvlTotal = 0; - for (size_t iTree = 0; iTree < nTrees; ++iTree) + size_t nLvlTotal = 0ul; + for (size_t iTree = 0ul; iTree < nTrees; ++iTree) { - nLvlTotal += this->_aTree[iTree]->getMaxLvl(); + nLvlTotal += _aTree[iTree]->getMaxLvl(); } if (nLvlTotal <= nColumns) { - // Checking is compicated. Better to do it during inferense. + // Checking is complicated. Better to do it during inference return true; } else { - for (size_t idx = 0; idx < nRows * nColumns; ++idx) + for (size_t idx = 0ul; idx < nRows * nColumns; ++idx) { if (checkFinitenessByComparison(x[idx])) return true; } @@ -126,7 +132,7 @@ protected: template inline void predict(size_t iTree, size_t nTrees, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * res) { - if (this->_featHelper.hasUnorderedFeatures()) + if (_featHelper.hasUnorderedFeatures()) { predict(iTree, nTrees, nRows, nColumns, x, res); } @@ -150,17 +156,86 @@ protected: } } + template + inline services::Status predictContributions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, algorithmFPType * res, + int condition, FeatureIndexType conditionFeature, const DimType & dim); + + template + inline services::Status predictContributions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, algorithmFPType * res, + int condition, FeatureIndexType conditionFeature, const DimType & dim) + { + if (_featHelper.hasUnorderedFeatures()) + { + return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim); + } + else + { + return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim); + } + } + + // TODO: Add vectorBlockSize templates, similar to predict + // template + inline services::Status predictContributions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, algorithmFPType * res, + int condition, FeatureIndexType conditionFeature, const DimType & dim) + { + const bool hasAnyMissing = checkForMissing(x, nTrees, nRowsData, dim.nCols); + if (hasAnyMissing) + { + return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim); + } + else + { + return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim); + } + } + + template + inline services::Status predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, + const algorithmFPType * nominal, algorithmFPType * res, const DimType & dim); + + template + inline services::Status predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, + const algorithmFPType * nominal, algorithmFPType * res, const DimType & dim) + { + if (_featHelper.hasUnorderedFeatures()) + { + return predictContributionInteractions(iTree, nTrees, nRowsData, x, nominal, res, dim); + } + else + { + return predictContributionInteractions(iTree, nTrees, nRowsData, x, nominal, res, dim); + } + } + + // TODO: Add vectorBlockSize templates, similar to predict + // template + inline services::Status predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, + const algorithmFPType * nominal, algorithmFPType * res, const DimType & dim) + { + const bool hasAnyMissing = checkForMissing(x, nTrees, nRowsData, dim.nCols); + if (hasAnyMissing) + { + return predictContributionInteractions(iTree, nTrees, nRowsData, x, nominal, res, dim); + } + else + { + return predictContributionInteractions(iTree, nTrees, nRowsData, x, nominal, res, dim); + } + } + template struct BooleanConstant { typedef BooleanConstant type; }; - // Recursivelly checking template parameter until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor. + // Recursively checking template parameter until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor. template - inline void predict(size_t iTree, size_t nTrees, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * res, - const DimType & dim, BooleanConstant keepLooking) + inline void predict(size_t iTree, size_t nTrees, size_t nRows, const algorithmFPType * x, algorithmFPType * res, const DimType & dim, + BooleanConstant keepLooking) { + const size_t nColumns = dim.nCols; constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep; if (dim.vectorBlockSizeFactor == vectorBlockSizeFactor) { @@ -168,30 +243,30 @@ protected: } else { - predict(iTree, nTrees, nRows, nColumns, x, res, dim, + predict(iTree, nTrees, nRows, x, res, dim, BooleanConstant()); } } template - inline void predict(size_t iTree, size_t nTrees, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * res, - const DimType & dim, BooleanConstant keepLooking) + inline void predict(size_t iTree, size_t nTrees, size_t nRows, const algorithmFPType * x, algorithmFPType * res, const DimType & dim, + BooleanConstant keepLooking) { constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep; - predict(iTree, nTrees, nRows, nColumns, x, res); + predict(iTree, nTrees, nRows, dim.nCols, x, res); } - inline void predict(size_t iTree, size_t nTrees, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * res, - const DimType & dim) + inline void predict(size_t iTree, size_t nTrees, size_t nRows, const algorithmFPType * x, algorithmFPType * res, const DimType & dim) { + const size_t nColumns = dim.nCols; constexpr size_t maxVectorBlockSizeFactor = DimType::maxVectorBlockSizeFactor; if (maxVectorBlockSizeFactor > 1) { - predict(iTree, nTrees, nRows, nColumns, x, res, dim, BooleanConstant()); + predict(iTree, nTrees, nRows, x, res, dim, BooleanConstant()); } else { - predict(iTree, nTrees, nRows, nColumns, x, res, dim, BooleanConstant()); + predict(iTree, nTrees, nRows, x, res, dim, BooleanConstant()); } } @@ -207,40 +282,187 @@ protected: ////////////////////////////////////////////////////////////////////////////////////////// template services::Status PredictKernel::compute(services::HostAppIface * pHostApp, const NumericTable * x, - const regression::Model * m, NumericTable * r, size_t nIterations) + const regression::Model * m, NumericTable * r, size_t nIterations, + bool predShapContributions, bool predShapInteractions) { const daal::algorithms::gbt::regression::internal::ModelImpl * pModel = static_cast(m); PredictRegressionTask task(x, r); - return task.run(pModel, nIterations, pHostApp); + return task.run(pModel, nIterations, pHostApp, predShapContributions, predShapInteractions); } template services::Status PredictRegressionTask::run(const gbt::regression::internal::ModelImpl * m, size_t nIterations, - services::HostAppIface * pHostApp) + services::HostAppIface * pHostApp, bool predShapContributions, + bool predShapInteractions) { DAAL_ASSERT(nIterations || nIterations <= m->size()); - DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data)); + DAAL_CHECK_MALLOC(_featHelper.init(*_data)); const auto nTreesTotal = (nIterations ? nIterations : m->size()); - this->_aTree.reset(nTreesTotal); - DAAL_CHECK_MALLOC(this->_aTree.get()); - for (size_t i = 0; i < nTreesTotal; ++i) this->_aTree[i] = m->at(i); - return runInternal(pHostApp, this->_res); + _aTree.reset(nTreesTotal); + DAAL_CHECK_MALLOC(_aTree.get()); + + PRAGMA_VECTOR_ALWAYS + for (size_t i = 0ul; i < nTreesTotal; ++i) _aTree[i] = m->at(i); + + return runInternal(pHostApp, this->_res, m->getPredictionBias(), predShapContributions, predShapInteractions); } +/** + * Helper to predict SHAP contribution values + * \param[in] iTree index of start tree for the calculation + * \param[in] nTrees number of trees in block included in calculation + * \param[in] nRowsData number of rows to process + * \param[in] x pointer to the start of observation data + * \param[out] res pointer to the start of memory where results are written to + * \param[in] condition condition the feature to on (1) off (-1) or unconditioned (0) + * \param[in] conditionFeature index of feature that should be conditioned + * \param[in] dim DimType helper +*/ template -services::Status PredictRegressionTask::runInternal(services::HostAppIface * pHostApp, NumericTable * result) +template +services::Status PredictRegressionTask::predictContributions(size_t iTree, size_t nTrees, size_t nRowsData, + const algorithmFPType * x, algorithmFPType * res, int condition, + FeatureIndexType conditionFeature, const DimType & dim) { - const auto nTreesTotal = this->_aTree.size(); + // TODO: Make use of vectorBlockSize, similar to predictByTreesVector + Status st; - DimType dim(*this->_data, nTreesTotal, getNumberOfNodes(nTreesTotal)); - WriteOnlyRows resBD(result, 0, 1); - DAAL_CHECK_BLOCK_STATUS(resBD); - services::internal::service_memset(resBD.get(), 0, dim.nRowsTotal); + const size_t nColumnsData = dim.nCols; + const size_t nColumnsPhi = nColumnsData + 1; + const size_t biasTermIndex = nColumnsPhi - 1; + + for (size_t iRow = 0ul; iRow < nRowsData; ++iRow) + { + const algorithmFPType * const currentX = x + (iRow * nColumnsData); + algorithmFPType * const phi = res + (iRow * nColumnsPhi); + for (size_t currentTreeIndex = iTree; currentTreeIndex < iTree + nTrees; ++currentTreeIndex) + { + const gbt::internal::GbtDecisionTree * currentTree = _aTree[currentTreeIndex]; + st |= gbt::treeshap::treeShap(currentTree, currentX, phi, &_featHelper, + condition, conditionFeature); + } + + if (condition == 0) + { + // find bias term by leveraging bias = nominal - sum_i phi_i + for (int iFeature = 0; iFeature < nColumnsData; ++iFeature) + { + phi[biasTermIndex] -= phi[iFeature]; + } + } + } + + return st; +} + +/** + * Helper to predict SHAP contribution interactions + * \param[in] iTree index of start tree for the calculation + * \param[in] nTrees number of trees in block included in calculation + * \param[in] nRowsData number of rows to process + * \param[in] nColumnsData number of columns in data, i.e. features + * note: 1 SHAP value per feature + bias term + * \param[in] x pointer to the start of observation data + * \param[out] res pointer to the start of memory where results are written to + * \param[in] dim DimType helper +*/ +template +template +services::Status PredictRegressionTask::predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, + const algorithmFPType * x, + const algorithmFPType * nominal, algorithmFPType * res, + const DimType & dim) +{ + Status st; + const size_t nColumnsData = dim.nCols; + const size_t nColumnsPhi = nColumnsData + 1; + const size_t biasTermIndex = nColumnsPhi - 1; + + const size_t interactionMatrixSize = nColumnsPhi * nColumnsPhi; + + // Allocate buffer for 3 matrices for algorithmFPType of size (nRowsData, nColumnsData) + const size_t elementsInMatrix = nRowsData * nColumnsPhi; + const size_t bufferSize = 3 * sizeof(algorithmFPType) * elementsInMatrix; + algorithmFPType * buffer = static_cast(daal_calloc(bufferSize)); + if (!buffer) + { + st.add(ErrorMemoryAllocationFailed); + return st; + } + + // Get pointers into the buffer for our three matrices + algorithmFPType * contribsDiag = buffer + 0 * elementsInMatrix; + algorithmFPType * contribsOff = buffer + 1 * elementsInMatrix; + algorithmFPType * contribsOn = buffer + 2 * elementsInMatrix; + + // Copy nominal values (for bias term) to the condition = 0 buffer + PRAGMA_IVDEP + PRAGMA_VECTOR_ALWAYS + for (size_t i = 0ul; i < nRowsData; ++i) + { + contribsDiag[i * nColumnsPhi + biasTermIndex] = nominal[i]; + } + + predictContributions(iTree, nTrees, nRowsData, x, contribsDiag, 0, 0, dim); + for (size_t i = 0ul; i < nColumnsPhi; ++i) + { + // initialize/reset the on/off buffers + service_memset_seq(contribsOff, algorithmFPType(0), 2 * elementsInMatrix); + + predictContributions(iTree, nTrees, nRowsData, x, contribsOff, -1, i, dim); + predictContributions(iTree, nTrees, nRowsData, x, contribsOn, 1, i, dim); + + for (size_t j = 0ul; j < nRowsData; ++j) + { + const unsigned dataRowOffset = j * interactionMatrixSize + i * nColumnsPhi; + const unsigned columnOffset = j * nColumnsPhi; + res[dataRowOffset + i] = 0; + for (size_t k = 0ul; k < nColumnsPhi; ++k) + { + // fill in the diagonal with additive effects, and off-diagonal with the interactions + if (k == i) + { + res[dataRowOffset + i] += contribsDiag[columnOffset + k]; + } + else + { + constexpr algorithmFPType half(0.5); + res[dataRowOffset + k] = (contribsOn[columnOffset + k] - contribsOff[columnOffset + k]) * half; + res[dataRowOffset + i] -= res[dataRowOffset + k]; + } + } + } + } + + daal_free(buffer); + + return st; +} + +template +services::Status PredictRegressionTask::runInternal(services::HostAppIface * pHostApp, NumericTable * result, + double predictionBias, bool predShapContributions, + bool predShapInteractions) +{ + // assert we're not requesting both contributions and interactions + DAAL_ASSERT(!(predShapContributions && predShapInteractions)); + + const size_t nTreesTotal = _aTree.size(); + const int dataNColumns = _data->getNumberOfColumns(); + const size_t resultNColumns = result->getNumberOfColumns(); + const size_t resultNRows = result->getNumberOfRows(); + + DimType dim(*_data, nTreesTotal, getNumberOfNodes(nTreesTotal)); + WriteOnlyRows resMatrix(result, 0, resultNRows); // select all rows for writing + DAAL_CHECK_BLOCK_STATUS(resMatrix); + services::internal::service_memset(resMatrix.get(), 0, resultNRows * resultNColumns); // set nRows * nCols to 0 SafeStatus safeStat; services::Status s; HostAppHelper host(pHostApp, 100); - for (size_t iTree = 0; iTree < nTreesTotal; iTree += dim.nTreesInBlock) + + const size_t predictionIndex = resultNColumns - 1; + for (size_t iTree = 0ul; iTree < nTreesTotal; iTree += dim.nTreesInBlock) { if (!s || host.isCancelled(s, 1)) return s; size_t nTreesToUse = ((iTree + dim.nTreesInBlock) < nTreesTotal ? dim.nTreesInBlock : (nTreesTotal - iTree)); @@ -248,11 +470,62 @@ services::Status PredictRegressionTask::runInternal(servic daal::threader_for(dim.nDataBlocks, dim.nDataBlocks, [&](size_t iBlock) { const size_t iStartRow = iBlock * dim.nRowsInBlock; const size_t nRowsToProcess = (iBlock == dim.nDataBlocks - 1) ? dim.nRowsTotal - iBlock * dim.nRowsInBlock : dim.nRowsInBlock; - ReadRows xBD(const_cast(this->_data), iStartRow, nRowsToProcess); + ReadRows xBD(const_cast(_data), iStartRow, nRowsToProcess); DAAL_CHECK_BLOCK_STATUS_THR(xBD); - algorithmFPType * res = resBD.get() + iStartRow; - predict(iTree, nTreesTotal, nRowsToProcess, dim.nCols, xBD.get(), res, dim); + if (predShapContributions) + { + // nominal values are required to calculate the correct bias term + TArray nominal(nRowsToProcess); + algorithmFPType * nominalPtr = nominal.get(); + DAAL_CHECK_MALLOC_THR(nominalPtr); + service_memset(nominalPtr, algorithmFPType(predictionBias), nRowsToProcess); + + // bias term: prediction - sum_i phi_i (subtraction in predictContributions) + predict(iTree, nTreesToUse, nRowsToProcess, xBD.get(), nominalPtr, dim); + + // thread-local write rows into global result buffer + WriteOnlyRows resRow(result, iStartRow, nRowsToProcess); + DAAL_CHECK_BLOCK_STATUS_THR(resRow); + + // copy nominal predictions for bias term to correct spot in result array + auto resRowPtr = resRow.get(); + for (size_t i = 0ul; i < nRowsToProcess; ++i) + { + resRowPtr[i * resultNColumns + predictionIndex] = nominalPtr[i]; + } + + // TODO: support tree weights + safeStat |= predictContributions(iTree, nTreesToUse, nRowsToProcess, xBD.get(), resRowPtr, 0, 0, dim); + } + else if (predShapInteractions) + { + // thread-local write rows into global result buffer + WriteOnlyRows resRow(result, iStartRow, nRowsToProcess); + DAAL_CHECK_BLOCK_STATUS_THR(resRow); + + // nominal values are required to calculate the correct bias term + TArray nominal(nRowsToProcess); + algorithmFPType * nominalPtr = nominal.get(); + DAAL_CHECK_MALLOC_THR(nominalPtr); + service_memset(nominalPtr, algorithmFPType(predictionBias), nRowsToProcess); + + predict(iTree, nTreesToUse, nRowsToProcess, xBD.get(), nominalPtr, dim); + + // TODO: support tree weights + safeStat |= predictContributionInteractions(iTree, nTreesToUse, nRowsToProcess, xBD.get(), nominalPtr, resRow.get(), dim); + } + else + { + algorithmFPType * res = resMatrix.get() + iStartRow; + if ((predictionBias < 0 && predictionBias < -DBL_EPSILON) || (0 < predictionBias && DBL_EPSILON < predictionBias)) + { + // memory is already initialized to 0 + // only set it to the bias term if it's != 0 + service_memset(res, algorithmFPType(predictionBias), nRowsToProcess); + } + predict(iTree, nTreesToUse, nRowsToProcess, xBD.get(), res, dim); + } }); s = safeStat.detach(); @@ -268,7 +541,7 @@ algorithmFPType PredictRegressionTask::predictByTrees(size { algorithmFPType val = 0; for (size_t iTree = iFirstTree, iLastTree = iFirstTree + nTrees; iTree < iLastTree; ++iTree) - val += gbt::prediction::internal::predictForTree(*this->_aTree[iTree], this->_featHelper, x, dispatcher); + val += gbt::prediction::internal::predictForTree(*_aTree[iTree], _featHelper, x, dispatcher); return val; } @@ -282,11 +555,14 @@ void PredictRegressionTask::predictByTreesVector(size_t iF for (size_t iTree = iFirstTree, iLastTree = iFirstTree + nTrees; iTree < iLastTree; ++iTree) { gbt::prediction::internal::predictForTreeVector( - *this->_aTree[iTree], this->_featHelper, x, v, dispatcher); + *_aTree[iTree], _featHelper, x, v, dispatcher); PRAGMA_IVDEP PRAGMA_VECTOR_ALWAYS - for (size_t j = 0; j < vectorBlockSize; ++j) res[j] += v[j]; + for (size_t row = 0ul; row < vectorBlockSize; ++row) + { + res[row] += v[row]; + } } } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h index 7feb478f5a2..6c2584f33f9 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h @@ -55,9 +55,11 @@ class PredictKernel : public daal::algorithms::Kernel * \param m[in] gradient boosted trees model obtained on training stage * \param r[out] Prediction results * \param nIterations[in] Number of iterations to predict in gradient boosted trees algorithm parameter + * \param predShapContributions[in] Predict SHAP contributions + * \param predShapInteractions[in] Predict SHAP interactions */ services::Status compute(services::HostAppIface * pHostApp, const NumericTable * a, const regression::Model * m, NumericTable * r, - size_t nIterations); + size_t nIterations, bool predShapContributions, bool predShapInteractions); }; } // namespace internal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_result_fpt.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_result_fpt.cpp index 16dc5c51808..e63ff6e91a3 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_result_fpt.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_result_fpt.cpp @@ -45,8 +45,22 @@ DAAL_EXPORT services::Status Result::allocate(const daal::algorithms::Input * in DAAL_CHECK_EX(dataPtr.get(), ErrorNullInputNumericTable, ArgumentName, dataStr()); services::Status s; const size_t nVectors = dataPtr->getNumberOfRows(); - Argument::set(prediction, - data_management::HomogenNumericTable::create(1, nVectors, data_management::NumericTableIface::doAllocate, &s)); + + size_t nColumnsToAllocate = 1; + const Parameter * regressionParameter = static_cast(par); + if (regressionParameter->resultsToCompute & shapContributions) + { + const size_t nColumns = dataPtr->getNumberOfColumns(); + nColumnsToAllocate = nColumns + 1; + } + else if (regressionParameter->resultsToCompute & shapInteractions) + { + const size_t nColumns = dataPtr->getNumberOfColumns(); + nColumnsToAllocate = (nColumns + 1) * (nColumns + 1); + } + + Argument::set(prediction, data_management::HomogenNumericTable::create(nColumnsToAllocate, nVectors, + data_management::NumericTableIface::doAllocate, &s)); return s; } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp index 1f47216b2da..82e6492e309 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp @@ -107,6 +107,9 @@ services::Status Input::check(const daal::algorithms::Parameter * parameter, int size_t nIterations = pPrm->nIterations; DAAL_CHECK((nIterations == 0) || (nIterations <= maxNIterations), services::ErrorGbtPredictIncorrectNumberOfIterations); + const bool predictContribs = pPrm->resultsToCompute & shapContributions; + const bool predictInteractions = pPrm->resultsToCompute & shapInteractions; + DAAL_CHECK(!(predictContribs && predictInteractions), services::ErrorGbtPredictShapOptions); return s; } @@ -142,7 +145,20 @@ services::Status Result::check(const daal::algorithms::Input * input, const daal { Status s; DAAL_CHECK_STATUS(s, algorithms::regression::prediction::Result::check(input, par, method)); - DAAL_CHECK_EX(get(prediction)->getNumberOfColumns() == 1, ErrorIncorrectNumberOfColumns, ArgumentName, predictionStr()); + const auto inputCast = static_cast(input); + const prediction::Parameter * regressionParameter = static_cast(par); + size_t expectedNColumns = 1; + if (regressionParameter->resultsToCompute & shapContributions) + { + const size_t nColumns = inputCast->get(data)->getNumberOfColumns(); + expectedNColumns = nColumns + 1; + } + else if (regressionParameter->resultsToCompute & shapInteractions) + { + const size_t nColumns = inputCast->get(data)->getNumberOfColumns(); + expectedNColumns = (nColumns + 1) * (nColumns + 1); + } + DAAL_CHECK_EX(get(prediction)->getNumberOfColumns() == expectedNColumns, ErrorIncorrectNumberOfColumns, ArgumentName, predictionStr()); return s; } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_tree_impl.h b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_tree_impl.h index fe48d017a82..7f284b59749 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_tree_impl.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_tree_impl.h @@ -345,8 +345,8 @@ class TreeTableConnector TableRecordType ** sons = sonsArr.data(); TableRecordType ** parents = parentsArr.data(); - gbt::prediction::internal::ModelFPType * const splitPoints = tree->getSplitPoints(); - gbt::prediction::internal::FeatureIndexType * const featureIndexes = tree->getFeatureIndexesForSplit(); + gbt::internal::ModelFPType * const splitPoints = tree->getSplitPoints(); + gbt::internal::FeatureIndexType * const featureIndexes = tree->getFeatureIndexesForSplit(); for (size_t i = 0; i < nNodes; ++i) { diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/oneapi/gbt_regression_train_dense_default_oneapi_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/regression/oneapi/gbt_regression_train_dense_default_oneapi_impl.i index eac996f8796..35f68ee6f58 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/oneapi/gbt_regression_train_dense_default_oneapi_impl.i +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/oneapi/gbt_regression_train_dense_default_oneapi_impl.i @@ -41,7 +41,6 @@ #include "src/services/service_algo_utils.h" #include "services/internal/sycl/types.h" -using namespace daal::algorithms::dtrees::training::internal; using namespace daal::algorithms::gbt::internal; using namespace daal::algorithms::gbt::regression::internal; @@ -1108,10 +1107,9 @@ services::Status RegressionTrainBatchKernelOneAPI::comp connector.getMaxLevel(0, maxLevel); DAAL_ASSERT(maxLevel + 1 <= 63); DAAL_ASSERT(((size_t)1 << (maxLevel + 1)) > 0 && ((size_t)1 << (maxLevel + 1)) < static_cast(UINT_MAX)); - const uint32_t nNodes = ((size_t)1 << (maxLevel + 1)) - 1; - const uint32_t nNodesPresent = connector.getNNodes(0); + const uint32_t nNodes = ((size_t)1 << (maxLevel + 1)) - 1; - gbt::internal::GbtDecisionTree * pTbl = new gbt::internal::GbtDecisionTree(nNodes, maxLevel, nNodesPresent); + gbt::internal::GbtDecisionTree * pTbl = new gbt::internal::GbtDecisionTree(nNodes, maxLevel); DAAL_CHECK_MALLOC(pTbl); HomogenNumericTable * pTblImp = new HomogenNumericTable(1, nNodes, NumericTable::doAllocate); diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/test/gbt_regression_model_builder_unit.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/test/gbt_regression_model_builder_unit.cpp new file mode 100644 index 00000000000..8228a29cace --- /dev/null +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/test/gbt_regression_model_builder_unit.cpp @@ -0,0 +1,152 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/test/engine/common.hpp" +#include "src/algorithms/dtrees/gbt/gbt_model_impl.h" + +namespace daal::algorithms::gbt::internal +{ +GbtDecisionTree prepareThreeNodeTree() +{ + /** + * create a tree with 3 nodes, 1 root (split), 2 leaves + * ROOT (level 1) + * / \ + * L L (level 2) + */ + GbtDecisionTree tree = GbtDecisionTree(3, 2); + + ModelFPType * splitPoints = tree.getSplitPoints(); + FeatureIndexType * splitIndices = tree.getFeatureIndexesForSplit(); + int * defaultLeft = tree.getDefaultLeftForSplit(); + ModelFPType * coverValues = tree.getNodeCoverValues(); + + splitPoints[0] = 1; + splitIndices[0] = 0; + defaultLeft[0] = 1; + coverValues[0] = 1; + + splitPoints[1] = 10; + splitIndices[1] = 0; + defaultLeft[1] = 0; + coverValues[1] = 0.5; + + splitPoints[2] = 11; + splitIndices[2] = 0; + defaultLeft[2] = 0; + coverValues[2] = 0.5; + + return tree; +} + +GbtDecisionTree prepareFiveNodeTree() +{ + /** create a tree with 5 nodes + * ROOT (1) (level 1) + * / \ + * L (2) S (3) (level 2) + * / \ + * L (6) L (7) (level 3) + * (note: on level 3, nodes 4 and 5 do not exist and will be created as "dummy leaf") + */ + GbtDecisionTree tree = GbtDecisionTree(5, 3); + + ModelFPType * splitPoints = tree.getSplitPoints(); + FeatureIndexType * splitIndices = tree.getFeatureIndexesForSplit(); + int * defaultLeft = tree.getDefaultLeftForSplit(); + ModelFPType * coverValues = tree.getNodeCoverValues(); + + // node idx 1 + splitPoints[0] = 1; + splitIndices[0] = 0; + defaultLeft[0] = 1; + coverValues[0] = 10; + + // node idx 2 + // the node with dummy leaf children + splitPoints[1] = 10; + splitIndices[1] = 20; + defaultLeft[1] = 0; + coverValues[1] = 4; + + // node idx 3 + splitPoints[2] = 11; + splitIndices[2] = 0; + defaultLeft[2] = 0; + coverValues[2] = 6; + + // node idx 4 (dummy leaf) + // split point and value equal to parent node + splitPoints[3] = splitPoints[1]; + splitIndices[3] = splitIndices[1]; + defaultLeft[3] = 0; + coverValues[3] = 0; + + // node idx 5 (dummy leaf) + // split point and value equal to parent node + splitPoints[4] = splitPoints[1]; + splitIndices[4] = splitIndices[1]; + defaultLeft[4] = 0; + coverValues[4] = 0; + + // node idx 6 + splitPoints[5] = 12; + splitIndices[5] = 22; + defaultLeft[5] = 0; + coverValues[5] = 4; + + // node idx 7 + splitPoints[6] = 13; + splitIndices[6] = 23; + defaultLeft[6] = 0; + coverValues[6] = 2; + + return tree; +} + +TEST("nodeIsLeafThreeNodes", "[unit]") +{ + GbtDecisionTree tree = prepareThreeNodeTree(); + + REQUIRE(!ModelImpl::nodeIsLeaf(1, tree, 1)); + REQUIRE(ModelImpl::nodeIsLeaf(2, tree, 2)); + REQUIRE(ModelImpl::nodeIsLeaf(3, tree, 2)); +} + +TEST("nodeIsDummyLeafFiveNodes", "[unit]") +{ + GbtDecisionTree tree = prepareFiveNodeTree(); + + REQUIRE(!ModelImpl::nodeIsDummyLeaf(1, tree)); + REQUIRE(!ModelImpl::nodeIsDummyLeaf(2, tree)); + REQUIRE(!ModelImpl::nodeIsDummyLeaf(3, tree)); + REQUIRE(ModelImpl::nodeIsDummyLeaf(4, tree)); + REQUIRE(ModelImpl::nodeIsDummyLeaf(5, tree)); + REQUIRE(!ModelImpl::nodeIsDummyLeaf(6, tree)); + REQUIRE(!ModelImpl::nodeIsDummyLeaf(7, tree)); +} + +TEST("nodeIsLeafFiveNodes", "[unit]") +{ + GbtDecisionTree tree = prepareFiveNodeTree(); + + REQUIRE(!ModelImpl::nodeIsLeaf(1, tree, 1)); + REQUIRE(ModelImpl::nodeIsLeaf(2, tree, 2)); + REQUIRE(!ModelImpl::nodeIsLeaf(3, tree, 2)); + REQUIRE(ModelImpl::nodeIsLeaf(6, tree, 3)); + REQUIRE(ModelImpl::nodeIsLeaf(7, tree, 3)); +} +} // namespace daal::algorithms::gbt::internal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/treeshap.cpp b/cpp/daal/src/algorithms/dtrees/gbt/treeshap.cpp new file mode 100644 index 00000000000..07d8a0df628 --- /dev/null +++ b/cpp/daal/src/algorithms/dtrees/gbt/treeshap.cpp @@ -0,0 +1,218 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "src/algorithms/dtrees/gbt/treeshap.h" + +namespace daal +{ +namespace algorithms +{ +namespace gbt +{ +namespace treeshap +{ + +namespace internal +{ + +namespace v0 +{ + +// extend our decision path with a fraction of one and zero extensions +void extendPath(PathElement * uniquePath, size_t uniqueDepth, float zeroFraction, float oneFraction, int featureIndex) +{ + uniquePath[uniqueDepth].featureIndex = featureIndex; + uniquePath[uniqueDepth].zeroFraction = zeroFraction; + uniquePath[uniqueDepth].oneFraction = oneFraction; + uniquePath[uniqueDepth].partialWeight = (uniqueDepth == 0 ? 1.0f : 0.0f); + + const float constant = 1.0f / static_cast(uniqueDepth + 1); + for (int i = uniqueDepth - 1; i >= 0; --i) + { + uniquePath[i + 1].partialWeight += oneFraction * uniquePath[i].partialWeight * (i + 1) * constant; + uniquePath[i].partialWeight = zeroFraction * uniquePath[i].partialWeight * (uniqueDepth - i) * constant; + } +} + +// undo a previous extension of the decision path +void unwindPath(PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex) +{ + const float oneFraction = uniquePath[pathIndex].oneFraction; + const float zeroFraction = uniquePath[pathIndex].zeroFraction; + float nextOnePortion = uniquePath[uniqueDepth].partialWeight; + + if (oneFraction != 0) + { + for (int i = uniqueDepth - 1; i >= 0; --i) + { + const float tmp = uniquePath[i].partialWeight; + uniquePath[i].partialWeight = nextOnePortion * (uniqueDepth + 1) / static_cast((i + 1) * oneFraction); + nextOnePortion = tmp - uniquePath[i].partialWeight * zeroFraction * (uniqueDepth - i) / static_cast(uniqueDepth + 1); + } + } + else + { + for (int i = 0; i < uniqueDepth; ++i) + { + uniquePath[i].partialWeight = (uniquePath[i].partialWeight * (uniqueDepth + 1)) / static_cast(zeroFraction * (uniqueDepth - i)); + } + } + + for (size_t i = pathIndex; i < uniqueDepth; ++i) + { + uniquePath[i].featureIndex = uniquePath[i + 1].featureIndex; + uniquePath[i].zeroFraction = uniquePath[i + 1].zeroFraction; + uniquePath[i].oneFraction = uniquePath[i + 1].oneFraction; + } +} + +// determine what the total permutation weight would be if we unwound a previous extension in the decision path +float unwoundPathSum(const PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex) +{ + const float oneFraction = uniquePath[pathIndex].oneFraction; + const float zeroFraction = uniquePath[pathIndex].zeroFraction; + + float nextOnePortion = uniquePath[uniqueDepth].partialWeight; + float total = 0; + + if (oneFraction != 0) + { + const float frac = zeroFraction / oneFraction; + for (int i = uniqueDepth - 1; i >= 0; --i) + { + const float tmp = nextOnePortion / (i + 1); + total += tmp; + nextOnePortion = uniquePath[i].partialWeight - tmp * frac * (uniqueDepth - i); + } + total *= (uniqueDepth + 1) / oneFraction; + } + else if (zeroFraction != 0) + { + for (int i = 0; i < uniqueDepth; ++i) + { + total += uniquePath[i].partialWeight / (uniqueDepth - i); + } + total *= (uniqueDepth + 1) / zeroFraction; + } + else + { + for (int i = 0; i < uniqueDepth; ++i) + { + DAAL_ASSERT(uniquePath[i].partialWeight == 0); + } + } + + return total; +} + +} // namespace v0 + +namespace v1 +{ +void extendPath(PathElement * uniquePath, float * partialWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPartialWeights, float zeroFraction, + float oneFraction, int featureIndex) +{ + uniquePath[uniqueDepth].featureIndex = featureIndex; + uniquePath[uniqueDepth].zeroFraction = zeroFraction; + uniquePath[uniqueDepth].oneFraction = oneFraction; + if (oneFraction != 0) + { + // extend partialWeights iff the feature of the last split satisfies the threshold + partialWeights[uniqueDepthPartialWeights] = (uniqueDepthPartialWeights == 0 ? 1.0f : 0.0f); + for (int i = uniqueDepthPartialWeights - 1; i >= 0; i--) + { + partialWeights[i + 1] += partialWeights[i] * (i + 1) / static_cast(uniqueDepth + 1); + partialWeights[i] *= zeroFraction * (uniqueDepth - i) / static_cast(uniqueDepth + 1); + } + } + else + { + for (int i = uniqueDepthPartialWeights - 1; i >= 0; i--) + { + partialWeights[i] *= (uniqueDepth - i) / static_cast(uniqueDepth + 1); + } + } +} + +void unwindPath(PathElement * uniquePath, float * partialWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPartialWeights, uint32_t pathIndex) +{ + const float oneFraction = uniquePath[pathIndex].oneFraction; + const float zeroFraction = uniquePath[pathIndex].zeroFraction; + float nextOnePortion = partialWeights[uniqueDepthPartialWeights]; + + if (oneFraction != 0) + { + // shrink partialWeights iff the feature satisfies the threshold + for (uint32_t i = uniqueDepthPartialWeights - 1;; --i) + { + const float tmp = partialWeights[i]; + partialWeights[i] = nextOnePortion * (uniqueDepth + 1) / static_cast(i + 1); + nextOnePortion = tmp - partialWeights[i] * zeroFraction * (uniqueDepth - i) / static_cast(uniqueDepth + 1); + if (i == 0) break; + } + } + else + { + for (uint32_t i = 0; i <= uniqueDepthPartialWeights; ++i) + { + partialWeights[i] *= (uniqueDepth + 1) / static_cast(uniqueDepth - i); + } + } + + for (uint32_t i = pathIndex; i < uniqueDepth; ++i) + { + uniquePath[i].featureIndex = uniquePath[i + 1].featureIndex; + uniquePath[i].zeroFraction = uniquePath[i + 1].zeroFraction; + uniquePath[i].oneFraction = uniquePath[i + 1].oneFraction; + } +} + +// determine what the total permuation weight would be if +// we unwound a previous extension in the decision path (for feature satisfying the threshold) +float unwoundPathSum(const PathElement * uniquePath, const float * partialWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPartialWeights, + uint32_t pathIndex) +{ + float total = 0; + const float zeroFraction = uniquePath[pathIndex].zeroFraction; + float nextOnePortion = partialWeights[uniqueDepthPartialWeights]; + for (int i = uniqueDepthPartialWeights - 1; i >= 0; --i) + { + const float tmp = nextOnePortion / static_cast(i + 1); + total += tmp; + nextOnePortion = partialWeights[i] - tmp * zeroFraction * (uniqueDepth - i); + } + return total * (uniqueDepth + 1); +} + +float unwoundPathSumZero(const float * partialWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPartialWeights) +{ + float total = 0; + if (uniqueDepth > uniqueDepthPartialWeights) + { + for (uint32_t i = 0; i <= uniqueDepthPartialWeights; ++i) + { + total += partialWeights[i] / static_cast(uniqueDepth - i); + } + } + return total * (uniqueDepth + 1); +} +} // namespace v1 + +} // namespace internal +} // namespace treeshap +} // namespace gbt +} // namespace algorithms +} // namespace daal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/treeshap.h b/cpp/daal/src/algorithms/dtrees/gbt/treeshap.h new file mode 100644 index 00000000000..6a701cd97d8 --- /dev/null +++ b/cpp/daal/src/algorithms/dtrees/gbt/treeshap.h @@ -0,0 +1,470 @@ +/* file: treeshap.h */ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/** + * Original TreeSHAP algorithm by Scott Lundberg, 2018 + * https://arxiv.org/abs/1802.03888 + * Originally contributed to XGBoost in + * - https://github.com/dmlc/xgboost/pull/2438 + * - https://github.com/dmlc/xgboost/pull/3043 + * XGBoost is licensed under Apache-2 (https://github.com/dmlc/xgboost/blob/master/LICENSE) + * + * Fast TreeSHAP algorithm v1 and v2 by Jilei Yang, 2021 + * https://arxiv.org/abs/2109.09847 + * C code available at https://github.com/linkedin/FastTreeSHAP/blob/master/fasttreeshap/cext/_cext.cc + * Fast TreeSHAP is licensed under BSD-2 (https://github.com/linkedin/FastTreeSHAP/blob/master/LICENSE) + */ + +/* +//++ +// Implementation of the treeShap algorithm +//-- +*/ + +#ifndef __TREESHAP_H__ +#define __TREESHAP_H__ + +#include "services/daal_defines.h" +#include "services/error_handling.h" +#include "src/algorithms/dtrees/dtrees_feature_type_helper.h" +#include "src/algorithms/dtrees/gbt/gbt_model_impl.h" +#include "src/services/service_arrays.h" +#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i" +#include // FLT_EPSILON + +namespace daal +{ +namespace algorithms +{ +namespace gbt +{ +namespace treeshap +{ +using gbt::internal::FeatureIndexType; +using gbt::internal::ModelFPType; +using FeatureTypes = algorithms::dtrees::internal::FeatureTypes; + +/** + * Determine the requested version of the TreeSHAP algorithm set in the + * environment variable SHAP_VERSION. + * Returns fallback if SHAP_VERSION is not set. +*/ +uint8_t getRequestedAlgorithmVersion(uint8_t fallback); + +/** + * Decision Path context +*/ +struct PathElement +{ + int featureIndex = 0; + float zeroFraction = 0; + float oneFraction = 0; + float partialWeight = 0; + PathElement() = default; + PathElement(const PathElement &) = default; +}; + +namespace internal +{ + +namespace v0 +{ + +void extendPath(PathElement * uniquePath, size_t uniqueDepth, float zeroFraction, float oneFraction, int featureIndex); +void unwindPath(PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex); +float unwoundPathSum(const PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex); + +/** Recursive treeShap function + * \param nodeIndex the index of the current node in the tree, counted from 1 + * \param depth how deep are we in the tree + * \param uniqueDepth how many unique features are above the current node in the tree + * \param parentUniquePath a vector of statistics about our current path through the tree + * \param parentZeroFraction what fraction of the parent path weight is coming as 0 (integrated) + * \param parentOneFraction what fraction of the parent path weight is coming as 1 (fixed) + * \param parentFeatureIndex what feature the parent node used to split + * \param conditionFraction what fraction of the current weight matches our conditioning feature + */ +template +inline void treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, + const FeatureTypes * featureHelper, size_t nodeIndex, size_t depth, size_t uniqueDepth, PathElement * parentUniquePath, + float parentZeroFraction, float parentOneFraction, int parentFeatureIndex, int condition, FeatureIndexType conditionFeature, + float conditionFraction) +{ + DAAL_ASSERT(parentUniquePath); + + // stop if we have no weight coming down to us + if (conditionFraction < FLT_EPSILON) return; + + const ModelFPType * const splitValues = tree->getSplitPoints() - 1; + const FeatureIndexType * const fIndexes = tree->getFeatureIndexesForSplit() - 1; + const ModelFPType * const nodeCoverValues = tree->getNodeCoverValues() - 1; + const int * const defaultLeft = tree->getDefaultLeftForSplit() - 1; + + PathElement * uniquePath = parentUniquePath + uniqueDepth + 1; + const size_t nBytes = (uniqueDepth + 1) * sizeof(PathElement); + const int copyStatus = daal::services::internal::daal_memcpy_s(uniquePath, nBytes, parentUniquePath, nBytes); + DAAL_ASSERT(copyStatus == 0); + + if (condition == 0 || conditionFeature != static_cast(parentFeatureIndex)) + { + extendPath(uniquePath, uniqueDepth, parentZeroFraction, parentOneFraction, parentFeatureIndex); + } + + const bool isLeaf = gbt::internal::ModelImpl::nodeIsLeaf(nodeIndex, *tree, depth); + + // leaf node + if (isLeaf) + { + for (size_t i = 1; i <= uniqueDepth; ++i) + { + const float w = unwoundPathSum(uniquePath, uniqueDepth, i); + const PathElement & el = uniquePath[i]; + phi[el.featureIndex] += w * (el.oneFraction - el.zeroFraction) * splitValues[nodeIndex] * conditionFraction; + } + + return; + } + + const FeatureIndexType splitIndex = fIndexes[nodeIndex]; + const algorithmFPType dataValue = x[splitIndex]; + + gbt::prediction::internal::PredictDispatcher dispatcher; + size_t hotIndex = updateIndex(nodeIndex, dataValue, splitValues, defaultLeft, *featureHelper, splitIndex, dispatcher); + const size_t coldIndex = 2 * nodeIndex + (hotIndex == (2 * nodeIndex)); + + const float w = nodeCoverValues[nodeIndex]; + DAAL_ASSERT(w > 0); + const float hotZeroFraction = nodeCoverValues[hotIndex] / w; + const float coldZeroFraction = nodeCoverValues[coldIndex] / w; + float incomingZeroFraction = 1.0f; + float incomingOneFraction = 1.0f; + + DAAL_ASSERT(hotZeroFraction < 1.0f); + DAAL_ASSERT(coldZeroFraction < 1.0f); + + // see if we have already split on this feature, + // if so we undo that split so we can redo it for this node + size_t previousSplitPathIndex = 0ul; + for (; previousSplitPathIndex <= uniqueDepth; ++previousSplitPathIndex) + { + const FeatureIndexType castIndex = static_cast(uniquePath[previousSplitPathIndex].featureIndex); + + // It cannot be that a feature that is ignored is in the uniquePath + DAAL_ASSERT((condition == 0) || (castIndex != conditionFeature)); + + if (castIndex == splitIndex) + { + break; + } + } + if (previousSplitPathIndex != uniqueDepth + 1) + { + incomingZeroFraction = uniquePath[previousSplitPathIndex].zeroFraction; + incomingOneFraction = uniquePath[previousSplitPathIndex].oneFraction; + unwindPath(uniquePath, uniqueDepth, previousSplitPathIndex); + uniqueDepth -= 1; + } + + // divide up the conditionFraction among the recursive calls + float hotConditionFraction = conditionFraction; + float coldConditionFraction = conditionFraction; + if (condition > 0 && splitIndex == conditionFeature) + { + coldConditionFraction = 0; + uniqueDepth -= 1; + } + else if (condition < 0 && splitIndex == conditionFeature) + { + hotConditionFraction *= hotZeroFraction; + coldConditionFraction *= coldZeroFraction; + uniqueDepth -= 1; + } + + treeShap(tree, x, phi, featureHelper, hotIndex, depth + 1, uniqueDepth + 1, uniquePath, + hotZeroFraction * incomingZeroFraction, incomingOneFraction, splitIndex, condition, + conditionFeature, hotConditionFraction); + treeShap(tree, x, phi, featureHelper, coldIndex, depth + 1, uniqueDepth + 1, uniquePath, + coldZeroFraction * incomingZeroFraction, 0, splitIndex, condition, + conditionFeature, coldConditionFraction); +} + +/** + * \brief Version 0, i.e. the original TreeSHAP algorithm to compute feature attributions for a single tree + * \param tree current tree + * \param x dense data matrix + * \param phi dense output matrix of feature attributions + * \param featureHelper pointer to a FeatureTypes object (required to traverse tree) + * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) + * \param conditionFeature the index of the feature to fix + */ +template +inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, + const FeatureTypes * featureHelper, int condition, FeatureIndexType conditionFeature) +{ + services::Status st; + const int depth = tree->getMaxLvl() + 2; + const size_t nUniquePath = ((depth * (depth + 1)) / 2); + + TArray uniquePathData(nUniquePath); + DAAL_CHECK_MALLOC(uniquePathData.get()); + + treeShap(tree, x, phi, featureHelper, 1, 0, 0, uniquePathData.get(), 1, 1, -1, condition, + conditionFeature, 1); + + return st; +} + +} // namespace v0 + +namespace v1 +{ + +void extendPath(PathElement * uniquePath, float * pWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPWeights, float zeroFraction, float oneFraction, + int featureIndex); +void unwindPath(PathElement * uniquePath, float * pWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPWeights, uint32_t pathIndex); +float unwoundPathSum(const PathElement * uniquePath, const float * pWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPWeights, uint32_t pathIndex); +float unwoundPathSumZero(const float * pWeights, uint32_t uniqueDepth, uint32_t uniqueDepthPWeights); + +/** + * Recursive Fast TreeSHAP version 1 + * Important: nodeIndex is counted from 0 here! +*/ +template +inline void treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, + const FeatureTypes * featureHelper, size_t nodeIndex, size_t depth, size_t uniqueDepth, size_t uniqueDepthPWeights, + PathElement * parentUniquePath, float * parentPWeights, algorithmFPType pWeightsResidual, float parentZeroFraction, + float parentOneFraction, int parentFeatureIndex, int condition, FeatureIndexType conditionFeature, float conditionFraction) +{ + // stop if we have no weight coming down to us + if (conditionFraction < FLT_EPSILON) return; + + const size_t numOutputs = 1; // currently only support single-output models + const ModelFPType * const splitValues = tree->getSplitPoints() - 1; + const int * const defaultLeft = tree->getDefaultLeftForSplit() - 1; + const FeatureIndexType * const fIndexes = tree->getFeatureIndexesForSplit() - 1; + const ModelFPType * const nodeCoverValues = tree->getNodeCoverValues() - 1; + + // extend the unique path + PathElement * uniquePath = parentUniquePath + uniqueDepth + 1; + size_t nBytes = (uniqueDepth + 1) * sizeof(PathElement); + int copyStatus = daal::services::internal::daal_memcpy_s(uniquePath, nBytes, parentUniquePath, nBytes); + DAAL_ASSERT(copyStatus == 0); + // extend pWeights + float * pWeights = parentPWeights + uniqueDepthPWeights + 1; + nBytes = (uniqueDepthPWeights + 1) * sizeof(float); + copyStatus = daal::services::internal::daal_memcpy_s(pWeights, nBytes, parentPWeights, nBytes); + DAAL_ASSERT(copyStatus == 0); + + if (condition == 0 || conditionFeature != static_cast(parentFeatureIndex)) + { + extendPath(uniquePath, pWeights, uniqueDepth, uniqueDepthPWeights, parentZeroFraction, parentOneFraction, parentFeatureIndex); + // update pWeightsResidual if the feature of the last split does not satisfy the threshold + if (parentOneFraction != 1) + { + pWeightsResidual *= parentZeroFraction; + uniqueDepthPWeights -= 1; + } + } + + const bool isLeaf = gbt::internal::ModelImpl::nodeIsLeaf(nodeIndex, *tree, depth); + + if (isLeaf) + { + const size_t valuesOffset = nodeIndex * numOutputs; + uint32_t valuesNonZeroInd = 0; + uint32_t valuesNonZeroCount = 0; + for (uint32_t j = 0; j < numOutputs; ++j) + { + if (splitValues[valuesOffset + j] != 0) + { + valuesNonZeroInd = j; + valuesNonZeroCount++; + } + } + // pre-calculate wZero for all features not satisfying the thresholds + const algorithmFPType wZero = unwoundPathSumZero(pWeights, uniqueDepth, uniqueDepthPWeights); + const algorithmFPType scaleZero = -wZero * pWeightsResidual * conditionFraction; + algorithmFPType scale; + for (uint32_t i = 1; i <= uniqueDepth; ++i) + { + const PathElement & el = uniquePath[i]; + const uint32_t phiOffset = el.featureIndex * numOutputs; + // update contributions to SHAP values for features satisfying the thresholds and not satisfying the thresholds separately + if (el.oneFraction != 0) + { + const algorithmFPType w = unwoundPathSum(uniquePath, pWeights, uniqueDepth, uniqueDepthPWeights, i); + scale = w * pWeightsResidual * (1 - el.zeroFraction) * conditionFraction; + } + else + { + scale = scaleZero; + } + if (valuesNonZeroCount == 1) + { + phi[phiOffset + valuesNonZeroInd] += scale * splitValues[valuesOffset + valuesNonZeroInd]; + } + else + { + for (uint32_t j = 0; j < numOutputs; ++j) + { + phi[phiOffset + j] += scale * splitValues[valuesOffset + j]; + } + } + } + + return; + } + + const FeatureIndexType splitIndex = fIndexes[nodeIndex]; + const algorithmFPType dataValue = x[splitIndex]; + + gbt::prediction::internal::PredictDispatcher dispatcher; + size_t hotIndex = updateIndex(nodeIndex, dataValue, splitValues, defaultLeft, *featureHelper, splitIndex, dispatcher); + const size_t coldIndex = 2 * nodeIndex + (hotIndex == (2 * nodeIndex)); + + const algorithmFPType w = nodeCoverValues[nodeIndex]; + const algorithmFPType hotZeroFraction = nodeCoverValues[hotIndex] / w; + const algorithmFPType coldZeroFraction = nodeCoverValues[coldIndex] / w; + algorithmFPType incomingZeroFraction = 1; + algorithmFPType incomingOneFraction = 1; + + // see if we have already split on this feature, + // if so we undo that split so we can redo it for this node + uint32_t pathIndex = 0; + for (; pathIndex <= uniqueDepth; ++pathIndex) + { + if (uniquePath[pathIndex].featureIndex == splitIndex) break; + } + if (pathIndex != uniqueDepth + 1) + { + incomingZeroFraction = uniquePath[pathIndex].zeroFraction; + incomingOneFraction = uniquePath[pathIndex].oneFraction; + unwindPath(uniquePath, pWeights, uniqueDepth, uniqueDepthPWeights, pathIndex); + --uniqueDepth; + // update pWeightsResidual iff the duplicated feature does not satisfy the threshold + if (incomingOneFraction != 0.) + { + uniqueDepthPWeights -= 1; + } + else + { + pWeightsResidual /= incomingZeroFraction; + } + } + + // divide up the conditionFraction among the recursive calls + algorithmFPType hotConditionFraction = conditionFraction; + algorithmFPType coldConditionFraction = conditionFraction; + if (condition > 0 && splitIndex == conditionFeature) + { + coldConditionFraction = 0; + --uniqueDepth; + --uniqueDepthPWeights; + } + else if (condition < 0 && splitIndex == conditionFeature) + { + hotConditionFraction *= hotZeroFraction; + coldConditionFraction *= coldZeroFraction; + --uniqueDepth; + --uniqueDepthPWeights; + } + + treeShap( + tree, x, phi, featureHelper, hotIndex, depth + 1, uniqueDepth + 1, uniqueDepthPWeights + 1, uniquePath, pWeights, pWeightsResidual, + hotZeroFraction * incomingZeroFraction, incomingOneFraction, splitIndex, condition, conditionFeature, hotConditionFraction); + + treeShap( + tree, x, phi, featureHelper, coldIndex, depth + 1, uniqueDepth + 1, uniqueDepthPWeights + 1, uniquePath, pWeights, pWeightsResidual, + coldZeroFraction * incomingZeroFraction, 0, splitIndex, condition, conditionFeature, coldConditionFraction); +} + +/** + * \brief Version 1, i.e. first Fast TreeSHAP algorithm + * \param tree current tree + * \param x dense data matrix + * \param phi dense output matrix of feature attributions + * \param featureHelper pointer to a FeatureTypes object (required to traverse tree) + * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) + * \param conditionFeature the index of the feature to fix + */ +template +inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, + const FeatureTypes * featureHelper, int condition, FeatureIndexType conditionFeature) +{ + services::Status st; + + // pre-allocate space for the unique path data and pWeights + const int depth = tree->getMaxLvl() + 2; + const size_t nElements = (depth * (depth + 1)) / 2; + + TArray uniquePathData(nElements); + DAAL_CHECK_MALLOC(uniquePathData.get()); + + TArray pWeights(nElements); + DAAL_CHECK_MALLOC(pWeights.get()); + + treeShap(tree, x, phi, featureHelper, 1, 0, 0, 0, uniquePathData.get(), pWeights.get(), 1, + 1, 1, -1, condition, conditionFeature, 1); + + return st; +} +} // namespace v1 + +} // namespace internal + +enum TreeShapVersion +{ + lundberg = 0, /** https://arxiv.org/abs/1802.03888 */ + fast_v1, /** https://arxiv.org/abs/2109.09847 */ +}; + +/** + * \brief Recursive function that computes the feature attributions for a single tree. + * \param tree current tree + * \param x dense data matrix + * \param phi dense output matrix of feature attributions + * \param featureHelper pointer to a FeatureTypes object (required to traverse tree) + * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) + * \param conditionFeature the index of the feature to fix + */ +template +inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, + const FeatureTypes * featureHelper, int condition, FeatureIndexType conditionFeature, + TreeShapVersion shapVersion = fast_v1) +{ + DAAL_ASSERT(x); + DAAL_ASSERT(phi); + DAAL_ASSERT(featureHelper); + + switch (shapVersion) + { + case lundberg: + return treeshap::internal::v0::treeShap(tree, x, phi, featureHelper, condition, + conditionFeature); + case fast_v1: + return treeshap::internal::v1::treeShap(tree, x, phi, featureHelper, condition, + conditionFeature); + default: return services::Status(ErrorMethodNotImplemented); + } +} + +} // namespace treeshap +} // namespace gbt +} // namespace algorithms +} // namespace daal + +#endif // __TREESHAP_H__ diff --git a/cpp/daal/src/services/error_handling.cpp b/cpp/daal/src/services/error_handling.cpp index e243239fc08..3d0d0d8a42c 100644 --- a/cpp/daal/src/services/error_handling.cpp +++ b/cpp/daal/src/services/error_handling.cpp @@ -930,6 +930,7 @@ void ErrorMessageCollection::parseResourceFile() // GBT error: -30000..-30099 add(ErrorGbtIncorrectNumberOfTrees, "Number of trees in the model is not consistent with the number of classes"); add(ErrorGbtPredictIncorrectNumberOfIterations, "Number of iterations value in GBT parameter is not consistent with the model"); + add(ErrorGbtPredictShapOptions, "Incompatible SHAP options. Can calculate either contributions or interactions, not both"); //Math errors: -90000..-90099 add(ErrorDataSourseNotAvailable, "ErrorDataSourseNotAvailable");