diff --git a/.ci/pipeline/ci.yml b/.ci/pipeline/ci.yml
index 811a046f10b..5b1412635fc 100755
--- a/.ci/pipeline/ci.yml
+++ b/.ci/pipeline/ci.yml
@@ -249,6 +249,10 @@ jobs:
--test_thread_mode=par
displayName: 'cpp-examples-thread-release-dynamic'
+ - script: |
+ bazel test //cpp/daal:tests
+ displayName: 'daal-tests-algorithms'
+
- script: |
bazel test //cpp/oneapi/dal:tests \
--config=host \
diff --git a/cpp/daal/BUILD b/cpp/daal/BUILD
index 5e7f640c782..cea370c9b2d 100644
--- a/cpp/daal/BUILD
+++ b/cpp/daal/BUILD
@@ -1,4 +1,8 @@
package(default_visibility = ["//visibility:public"])
+load("@onedal//dev/bazel:dal.bzl",
+ "dal_test_suite",
+ "dal_collect_test_suites",
+)
load("@onedal//dev/bazel:daal.bzl",
"daal_module",
"daal_static_lib",
@@ -28,7 +32,7 @@ daal_module(
deps = select({
"@config//:backend_ref": [ "@openblas//:openblas",
],
- "//conditions:default": [ "@micromkl//:mkl_thr",
+ "//conditions:default": [ "@micromkl//:mkl_thr",
],
}),
)
@@ -54,7 +58,7 @@ daal_module(
"DAAL_HIDE_DEPRECATED",
],
deps = select({
- "@config//:backend_ref": [
+ "@config//:backend_ref": [
":public_includes",
"@openblas//:headers",
],
@@ -123,11 +127,11 @@ daal_module(
hdrs = glob(["src/sycl/**/*.h", "src/sycl/**/*.cl"]),
srcs = glob(["src/sycl/**/*.cpp"]),
deps = select({
- "@config//:backend_ref": [
+ "@config//:backend_ref": [
":services",
"@onedal//cpp/daal/src/algorithms/engines:kernel",
],
- "//conditions:default": [
+ "//conditions:default": [
":services",
"@onedal//cpp/daal/src/algorithms/engines:kernel",
"@micromkl_dpc//:headers",
@@ -146,13 +150,13 @@ daal_module(
"TBB_USE_ASSERT=0",
],
deps = select({
- "@config//:backend_ref": [
+ "@config//:backend_ref": [
":threading_headers",
":mathbackend_thread",
"@tbb//:tbb",
"@tbb//:tbbmalloc",
],
- "//conditions:default": [
+ "//conditions:default": [
":threading_headers",
":mathbackend_thread",
"@tbb//:tbb",
@@ -269,7 +273,7 @@ daal_dynamic_lib(
],
def_file = select({
"@config//:backend_ref": "src/threading/export_lnx32e.ref.def",
- "//conditions:default": "src/threading/export_lnx32e.mkl.def",
+ "//conditions:default": "src/threading/export_lnx32e.mkl.def",
}),
)
@@ -316,3 +320,22 @@ filegroup(
":thread_static",
],
)
+
+dal_test_suite(
+ name = "unit_tests",
+ framework = "catch2",
+ srcs = glob([
+ "test/*.cpp",
+ ]),
+)
+
+dal_collect_test_suites(
+ name = "tests",
+ root = "@onedal//cpp/daal/src/algorithms",
+ modules = [
+ "dtrees/gbt/regression"
+ ],
+ tests = [
+ ":unit_tests",
+ ],
+)
diff --git a/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h b/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h
index cef25464418..46b5aaa804b 100644
--- a/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h
+++ b/cpp/daal/include/algorithms/decision_forest/decision_forest_classification_model_builder.h
@@ -107,32 +107,50 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] classLabel Class label to be predicted
+ * \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
- NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel)
+ NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, const double cover)
{
NodeId resId;
- _status |= addLeafNodeInternal(treeId, parentId, position, classLabel, resId);
+ _status |= addLeafNodeInternal(treeId, parentId, position, classLabel, cover, resId);
services::throwIfPossible(_status);
return resId;
}
+ /**
+ * \DAAL_DEPRECATED
+ */
+ DAAL_DEPRECATED NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel)
+ {
+ return addLeafNode(treeId, parentId, position, classLabel, 0);
+ }
+
/**
* Create Leaf node and add it to certain tree
* \param[in] treeId Tree to which new node is added
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] proba Array with probability values for each class
+ * \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
- NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba)
+ NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba, const double cover)
{
NodeId resId;
- _status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, resId);
+ _status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, cover, resId);
services::throwIfPossible(_status);
return resId;
}
+ /**
+ * \DAAL_DEPRECATED
+ */
+ DAAL_DEPRECATED NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba)
+ {
+ return addLeafNodeByProba(treeId, parentId, position, proba, 0);
+ }
+
/**
* Create Split node and add it to certain tree
* \param[in] treeId Tree to which new node is added
@@ -140,16 +158,28 @@ class DAAL_EXPORT ModelBuilder
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
+ * \param[in] defaultLeft Behaviour in case of missing values
+ * \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
- NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue)
+ NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue,
+ const int defaultLeft, const double cover)
{
NodeId resId;
- _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId);
+ _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}
+ /**
+ * \DAAL_DEPRECATED
+ */
+ DAAL_DEPRECATED NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
+ const double featureValue)
+ {
+ return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0);
+ }
+
void setNFeatures(size_t nFeatures)
{
if (!_model.get())
@@ -184,11 +214,12 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(const size_t nClasses, const size_t nTrees);
services::Status createTreeInternal(const size_t nNodes, TreeId & resId);
- services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, NodeId & res);
+ services::Status addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel,
+ const double cover, NodeId & res);
services::Status addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba,
- NodeId & res);
+ const double cover, NodeId & res);
services::Status addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
- const double featureValue, NodeId & res);
+ const double featureValue, const int defaultLeft, const double cover, NodeId & res);
private:
size_t _nClasses;
diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h
index 69f348c2baf..f6e0a7e8188 100644
--- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h
+++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model.h
@@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public classifier::Model
*/
virtual size_t getNumberOfTrees() const = 0;
+ /**
+ * \brief Set the Prediction Bias term
+ *
+ * \param value global prediction bias
+ */
+ virtual void setPredictionBias(double value) = 0;
+
+ /**
+ * \brief Get the Prediction Bias term
+ *
+ * \return double prediction bias
+ */
+ virtual double getPredictionBias() const = 0;
+
protected:
Model() : classifier::Model() {}
};
diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h
index d72cff6e9f4..b8d1c5da47d 100644
--- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h
+++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_model_builder.h
@@ -109,16 +109,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] response Response value for leaf node to be predicted
+ * \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
- NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
+ NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response, double cover)
{
NodeId resId;
- _status |= addLeafNodeInternal(treeId, parentId, position, response, resId);
+ _status |= addLeafNodeInternal(treeId, parentId, position, response, cover, resId);
services::throwIfPossible(_status);
return resId;
}
+ /**
+ * \DAAL_DEPRECATED
+ */
+ DAAL_DEPRECATED NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
+ {
+ return addLeafNode(treeId, parentId, position, response, 0);
+ }
+
/**
* Create Split node and add it to certain tree
* \param[in] treeId Tree to which new node is added
@@ -127,16 +136,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
+ * \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
- NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft = 0)
+ NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
NodeId resId;
- _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId, defaultLeft);
+ _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}
+ /**
+ * \DAAL_DEPRECATED
+ */
+ DAAL_DEPRECATED NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue)
+ {
+ return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0);
+ }
+
/**
* Get built model
* \return Model pointer
@@ -159,9 +177,9 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(size_t nFeatures, size_t nIterations, size_t nClasses);
services::Status createTreeInternal(size_t nNodes, size_t classLabel, TreeId & resId);
- services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res);
- services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, NodeId & res,
- int defaultLeft);
+ services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, const double cover, NodeId & res);
+ services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft,
+ const double cover, NodeId & res);
services::Status convertModelInternal();
size_t _nClasses;
size_t _nIterations;
diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h
index ec6fe54ca9d..58aeea58a3a 100755
--- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h
+++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h
@@ -56,6 +56,17 @@ enum Method
defaultDense = 0 /*!< Default method */
};
+/**
+ *
+ * Available identifiers to specify the result to compute - results are mutually exclusive
+ */
+enum ResultToComputeId
+{
+ predictionResult = (1 << 0), /*!< Compute the regular prediction */
+ shapContributions = (1 << 1), /*!< Compute SHAP contribution values */
+ shapInteractions = (1 << 2) /*!< Compute SHAP interaction values */
+};
+
/**
* \brief Contains version 2.0 of the Intel(R) oneAPI Data Analytics Library interface.
*/
@@ -70,9 +81,12 @@ namespace interface2
/* [Parameter source code] */
struct DAAL_EXPORT Parameter : public daal::algorithms::classifier::Parameter
{
- Parameter(size_t nClasses = 2) : daal::algorithms::classifier::Parameter(nClasses), nIterations(0) {}
- Parameter(const Parameter & o) : daal::algorithms::classifier::Parameter(o), nIterations(o.nIterations) {}
- size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */
+ typedef daal::algorithms::classifier::Parameter super;
+
+ Parameter(size_t nClasses = 2) : super(nClasses), nIterations(0), resultsToCompute(predictionResult) {}
+ Parameter(const Parameter & o) : super(o), nIterations(o.nIterations), resultsToCompute(o.resultsToCompute) {}
+ size_t nIterations; /*!< Number of iterations of the trained model to be used for prediction */
+ DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */
};
/* [Parameter source code] */
} // namespace interface2
diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h
index 99607a6d16f..6e5e2b33027 100644
--- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h
+++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model.h
@@ -123,6 +123,20 @@ class DAAL_EXPORT Model : public algorithms::regression::Model
*/
virtual size_t getNumberOfTrees() const = 0;
+ /**
+ * \brief Set the Prediction Bias term
+ *
+ * \param value global prediction bias
+ */
+ virtual void setPredictionBias(double value) = 0;
+
+ /**
+ * \brief Get the Prediction Bias term
+ *
+ * \return double prediction bias
+ */
+ virtual double getPredictionBias() const = 0;
+
protected:
Model();
};
diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h
index 2a2dd4bcfe9..c9343f03d8d 100644
--- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h
+++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_model_builder.h
@@ -107,16 +107,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] response Response value for leaf node to be predicted
+ * \param[in] cover Cover of the node (sum_hess)
* \return Node identifier
*/
- NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
+ NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response, double cover)
{
NodeId resId;
- _status |= addLeafNodeInternal(treeId, parentId, position, response, resId);
+ _status |= addLeafNodeInternal(treeId, parentId, position, response, cover, resId);
services::throwIfPossible(_status);
return resId;
}
+ /**
+ * \DAAL_DEPRECATED
+ */
+ DAAL_DEPRECATED NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
+ {
+ return addLeafNode(treeId, parentId, position, response, 0);
+ }
+
/**
* Create Split node and add it to certain tree
* \param[in] treeId Tree to which new node is added
@@ -125,16 +134,25 @@ class DAAL_EXPORT ModelBuilder
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
+ * \param[in] cover Cover of the node (sum_hess)
* \return Node identifier
*/
- NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft = 0)
+ NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
NodeId resId;
- _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId, defaultLeft);
+ _status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}
+ /**
+ * \DAAL_DEPRECATED
+ */
+ DAAL_DEPRECATED NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue)
+ {
+ return addSplitNode(treeId, parentId, position, featureIndex, featureValue, 0, 0);
+ }
+
/**
* Get built model
* \return Model pointer
@@ -157,9 +175,9 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(size_t nFeatures, size_t nIterations);
services::Status createTreeInternal(size_t nNodes, TreeId & resId);
- services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res);
- services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, NodeId & res,
- int defaultLeft);
+ services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res);
+ services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft,
+ double cover, NodeId & res);
services::Status convertModelInternal();
};
/** @} */
diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h
index 8a9cc5da552..f69591a9a6e 100644
--- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h
+++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_regression_predict_types.h
@@ -90,6 +90,17 @@ enum ResultId
lastResultId = prediction
};
+/**
+ *
+ * Available identifiers to specify the result to compute - results are mutually exclusive
+ */
+enum ResultToComputeId
+{
+ predictionResult = (1 << 0), /*!< Compute the regular prediction */
+ shapContributions = (1 << 1), /*!< Compute SHAP contribution values */
+ shapInteractions = (1 << 2) /*!< Compute SHAP interaction values */
+};
+
/**
* \brief Contains version 1.0 of the Intel(R) oneAPI Data Analytics Library interface
*/
@@ -104,9 +115,12 @@ namespace interface1
/* [Parameter source code] */
struct DAAL_EXPORT Parameter : public daal::algorithms::Parameter
{
- Parameter() : daal::algorithms::Parameter(), nIterations(0) {}
- Parameter(const Parameter & o) : daal::algorithms::Parameter(o), nIterations(o.nIterations) {}
- size_t nIterations; /*!< Number of iterations of the trained model to be uses for prediction*/
+ typedef daal::algorithms::Parameter super;
+
+ Parameter() : super(), nIterations(0), resultsToCompute(predictionResult) {}
+ Parameter(const Parameter & o) : super(o), nIterations(o.nIterations), resultsToCompute(o.resultsToCompute) {}
+ size_t nIterations; /*!< Number of iterations of the trained model to be uses for prediction*/
+ DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */
};
/* [Parameter source code] */
diff --git a/cpp/daal/include/algorithms/tree_utils/tree_utils.h b/cpp/daal/include/algorithms/tree_utils/tree_utils.h
index c1e98513b5b..d98ef4779cc 100644
--- a/cpp/daal/include/algorithms/tree_utils/tree_utils.h
+++ b/cpp/daal/include/algorithms/tree_utils/tree_utils.h
@@ -59,6 +59,7 @@ struct DAAL_EXPORT SplitNodeDescriptor : public NodeDescriptor
{
size_t featureIndex; /*!< Feature used for splitting the node */
double featureValue; /*!< Threshold value at the node */
+ double coverValue; /*!< Cover, a sum of the Hessian values of the loss function evaluated at the points flowing through the node */
};
/**
diff --git a/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h b/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h
index a776a5412d9..f8d13c56a9f 100644
--- a/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h
+++ b/cpp/daal/include/algorithms/tree_utils/tree_utils_classification.h
@@ -50,8 +50,9 @@ namespace interface2
*/
struct DAAL_EXPORT LeafNodeDescriptor : public NodeDescriptor
{
- size_t label; /*!< Label to be predicted when reaching the leaf */
- const double * prob; /*!< Probabilities estimation for the leaf */
+ size_t label; /*!< Label to be predicted when reaching the leaf */
+ const double * prob; /*!< Probabilities estimation for the leaf */
+ const double * cover; /*!< Cover (sum_hess) for the leaf */
};
typedef daal::algorithms::tree_utils::TreeNodeVisitor TreeNodeVisitor;
diff --git a/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h b/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h
index 9890556ea12..8e587f984ad 100644
--- a/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h
+++ b/cpp/daal/include/algorithms/tree_utils/tree_utils_regression.h
@@ -50,7 +50,8 @@ namespace interface1
*/
struct DAAL_EXPORT LeafNodeDescriptor : public NodeDescriptor
{
- double response; /*!< Value to be predicted when reaching the leaf */
+ double response; /*!< Value to be predicted when reaching the leaf */
+ double coverValue; /*!< Cover (sum_hess) for the leaf */
};
typedef daal::algorithms::tree_utils::TreeNodeVisitor TreeNodeVisitor;
diff --git a/cpp/daal/include/services/error_indexes.h b/cpp/daal/include/services/error_indexes.h
index 8d5ca7c79e7..9c9c634cd08 100644
--- a/cpp/daal/include/services/error_indexes.h
+++ b/cpp/daal/include/services/error_indexes.h
@@ -378,6 +378,7 @@ enum ErrorID
// GBT error: -30000..-30099
ErrorGbtIncorrectNumberOfTrees = -30000, /*!< Number of trees in the model is not consistent with the number of classes */
ErrorGbtPredictIncorrectNumberOfIterations = -30001, /*!< Number of iterations value in GBT parameter is not consistent with the model */
+ ErrorGbtPredictShapOptions = -30002, /*< For SHAP values, calculate either contributions or interactions, not both */
// Data management errors: -80001..
ErrorUserAllocatedMemory = -80001, /*!< Couldn't free memory allocated by user */
diff --git a/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp b/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp
index 37a06dfe111..32a298dc34b 100644
--- a/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp
+++ b/cpp/daal/src/algorithms/dtrees/dtrees_model.cpp
@@ -180,17 +180,19 @@ services::Status createTreeInternal(data_management::DataCollectionPtr & seriali
return s;
}
-void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel)
+void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel, double cover)
{
node.featureIndex = featureIndex;
node.leftIndexOrClass = classLabel;
+ node.cover = cover;
node.featureValueOrResponse = 0;
}
-void setNode(DecisionTreeNode & node, int featureIndex, double response)
+void setNode(DecisionTreeNode & node, int featureIndex, double response, double cover)
{
node.featureIndex = featureIndex;
node.leftIndexOrClass = 0;
+ node.cover = cover;
node.featureValueOrResponse = response;
}
@@ -222,7 +224,7 @@ void setProbabilities(const size_t treeId, const size_t nodeId, const size_t res
}
services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position,
- size_t featureIndex, double featureValue, size_t & res, int defaultLeft)
+ size_t featureIndex, double featureValue, int defaultLeft, double cover, size_t & res)
{
const size_t noParent = static_cast(-1);
services::Status s;
@@ -243,6 +245,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria
aNode[0].defaultLeft = defaultLeft;
aNode[0].leftIndexOrClass = 0;
aNode[0].featureValueOrResponse = featureValue;
+ aNode[0].cover = cover;
nodeId = 0;
}
else if (aNode[parentId].featureIndex < 0)
@@ -262,6 +265,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria
aNode[nodeId].defaultLeft = defaultLeft;
aNode[nodeId].leftIndexOrClass = 0;
aNode[nodeId].featureValueOrResponse = featureValue;
+ aNode[nodeId].cover = cover;
}
}
if ((aNode[parentId].leftIndexOrClass > 0) && (position == 0))
@@ -274,6 +278,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria
aNode[nodeId].defaultLeft = defaultLeft;
aNode[nodeId].leftIndexOrClass = 0;
aNode[nodeId].featureValueOrResponse = featureValue;
+ aNode[nodeId].cover = cover;
}
}
if ((aNode[parentId].leftIndexOrClass == 0) && (position == 0))
@@ -296,6 +301,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria
aNode[nodeId].defaultLeft = defaultLeft;
aNode[nodeId].leftIndexOrClass = 0;
aNode[nodeId].featureValueOrResponse = featureValue;
+ aNode[nodeId].cover = cover;
aNode[parentId].leftIndexOrClass = nodeId;
if (((nodeId + 1) < nRows) && (aNode[nodeId + 1].featureIndex == __NODE_FREE_ID))
{
@@ -332,6 +338,7 @@ services::Status addSplitNodeInternal(data_management::DataCollectionPtr & seria
aNode[nodeId].defaultLeft = defaultLeft;
aNode[nodeId].leftIndexOrClass = 0;
aNode[nodeId].featureValueOrResponse = featureValue;
+ aNode[nodeId].cover = cover;
}
else
{
diff --git a/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h b/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h
index 0eee0e757f3..08c04d677a8 100644
--- a/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h
+++ b/cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h
@@ -53,6 +53,7 @@ struct DecisionTreeNode
ClassIndexType leftIndexOrClass; //split: left node index, classification leaf: class index
ModelFPType featureValueOrResponse; //split: feature value, regression tree leaf: response
int defaultLeft; //split: if 1: go to the yes branch for missing value
+ ModelFPType cover; //split: cover (sum_hess) of the node
DAAL_FORCEINLINE bool isSplit() const { return featureIndex != -1; }
ModelFPType featureValue() const { return featureValueOrResponse; }
};
@@ -60,12 +61,13 @@ struct DecisionTreeNode
class DecisionTreeTable : public data_management::AOSNumericTable
{
public:
- DecisionTreeTable(size_t rowCount = 0) : data_management::AOSNumericTable(sizeof(DecisionTreeNode), 4, rowCount)
+ DecisionTreeTable(size_t rowCount = 0) : data_management::AOSNumericTable(sizeof(DecisionTreeNode), 5, rowCount)
{
setFeature(0, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, featureIndex));
setFeature(1, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, leftIndexOrClass));
setFeature(2, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, featureValueOrResponse));
setFeature(3, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, defaultLeft));
+ setFeature(4, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, cover));
allocateDataMemory();
}
};
@@ -339,19 +341,19 @@ ModelImplType & getModelRef(ModelTypePtr & modelPtr)
services::Status createTreeInternal(data_management::DataCollectionPtr & serializationData, size_t nNodes, size_t & resId);
-void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel);
+void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel, double cover);
-void setNode(DecisionTreeNode & node, int featureIndex, double response);
+void setNode(DecisionTreeNode & node, int featureIndex, double response, double cover);
services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position,
- size_t featureIndex, double featureValue, size_t & res, int defaultLeft = 0);
+ size_t featureIndex, double featureValue, int defaultLeft, double cover, size_t & res);
void setProbabilities(const size_t treeId, const size_t nodeId, const size_t response, const data_management::DataCollectionPtr probTbl,
const double * const prob);
template
static services::Status addLeafNodeInternal(const data_management::DataCollectionPtr & serializationData, const size_t treeId, const size_t parentId,
- const size_t position, ClassOrResponseType response, size_t & res,
+ const size_t position, ClassOrResponseType response, double cover, size_t & res,
const data_management::DataCollectionPtr probTbl = data_management::DataCollectionPtr(),
const double * const prob = nullptr, const size_t nClasses = 0)
{
@@ -373,7 +375,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio
size_t nodeId = 0;
if (parentId == noParent)
{
- setNode(aNode[0], -1, response);
+ setNode(aNode[0], -1, response, cover);
setProbabilities(treeId, 0, response, probTbl, prob);
nodeId = 0;
}
@@ -390,7 +392,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio
nodeId = reservedId;
if (aNode[reservedId].featureIndex == __NODE_RESERVED_ID)
{
- setNode(aNode[nodeId], -1, response);
+ setNode(aNode[nodeId], -1, response, cover);
setProbabilities(treeId, nodeId, response, probTbl, prob);
}
}
@@ -400,7 +402,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio
nodeId = reservedId;
if (aNode[reservedId].featureIndex == __NODE_RESERVED_ID)
{
- setNode(aNode[nodeId], -1, response);
+ setNode(aNode[nodeId], -1, response, cover);
setProbabilities(treeId, nodeId, response, probTbl, prob);
}
}
@@ -420,7 +422,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio
{
return services::Status(services::ErrorID::ErrorIncorrectParameter);
}
- setNode(aNode[nodeId], -1, response);
+ setNode(aNode[nodeId], -1, response, cover);
setProbabilities(treeId, nodeId, response, probTbl, prob);
aNode[parentId].leftIndexOrClass = nodeId;
if (((nodeId + 1) < nRows) && (aNode[nodeId + 1].featureIndex == __NODE_FREE_ID))
@@ -454,7 +456,7 @@ static services::Status addLeafNodeInternal(const data_management::DataCollectio
nodeId = leftEmptyId + 1;
if (nodeId < nRows)
{
- setNode(aNode[nodeId], -1, response);
+ setNode(aNode[nodeId], -1, response, cover);
setProbabilities(treeId, nodeId, response, probTbl, prob);
}
else
diff --git a/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp b/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp
index c801da2c4cb..ad53f087081 100644
--- a/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp
+++ b/cpp/daal/src/algorithms/dtrees/forest/classification/df_classification_model_builder.cpp
@@ -72,7 +72,7 @@ services::Status ModelBuilder::createTreeInternal(const size_t nNodes, TreeId &
}
services::Status ModelBuilder::addLeafNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel,
- NodeId & res)
+ const double cover, NodeId & res)
{
decision_forest::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
@@ -81,7 +81,7 @@ services::Status ModelBuilder::addLeafNodeInternal(const TreeId treeId, const No
return services::Status(services::ErrorID::ErrorIncorrectParameter);
}
return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, classLabel,
- res, modelImplRef._probTbl);
+ cover, res, modelImplRef._probTbl);
}
bool checkProba(const double * const proba, const size_t nClasses)
@@ -104,7 +104,7 @@ bool checkProba(const double * const proba, const size_t nClasses)
}
services::Status ModelBuilder::addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position,
- const double * const proba, NodeId & res)
+ const double * const proba, const double cover, NodeId & res)
{
decision_forest::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
@@ -112,17 +112,17 @@ services::Status ModelBuilder::addLeafNodeByProbaInternal(const TreeId treeId, c
{
return services::Status(services::ErrorID::ErrorIncorrectParameter);
}
- return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, 0, res,
+ return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, 0, cover, res,
modelImplRef._probTbl, proba, _nClasses);
}
services::Status ModelBuilder::addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
- const double featureValue, NodeId & res)
+ const double featureValue, const int defaultLeft, const double cover, NodeId & res)
{
decision_forest::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
- featureValue, res);
+ featureValue, defaultLeft, cover, res);
}
} // namespace interface2
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/BUILD b/cpp/daal/src/algorithms/dtrees/gbt/BUILD
index 712d778e8ae..9d717dae310 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/BUILD
+++ b/cpp/daal/src/algorithms/dtrees/gbt/BUILD
@@ -1,4 +1,7 @@
package(default_visibility = ["//visibility:public"])
+load("@onedal//dev/bazel:dal.bzl",
+ "dal_collect_test_suites",
+)
load("@onedal//dev/bazel:daal.bzl", "daal_module")
daal_module(
@@ -11,3 +14,11 @@ daal_module(
"@onedal//cpp/daal/src/algorithms/dtrees:kernel",
],
)
+
+dal_collect_test_suites(
+ name = "tests",
+ root = "@onedal//cpp/oneapi/dal/algo",
+ modules = [
+ "regression"
+ ],
+)
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp
index 1e4c6b71234..afb01125ef6 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp
+++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model.cpp
@@ -84,10 +84,21 @@ void ModelImpl::traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisito
ImplType::traverseBFS(iTree, visitor);
}
+void ModelImpl::setPredictionBias(double value)
+{
+ _predictionBias = value;
+}
+
+double ModelImpl::getPredictionBias() const
+{
+ return _predictionBias;
+}
+
services::Status ModelImpl::serializeImpl(data_management::InputDataArchive * arch)
{
auto s = algorithms::classifier::Model::serialImpl(arch);
arch->set(this->_nFeatures); //algorithms::classifier::internal::ModelInternal
+ arch->set(this->_predictionBias);
return s.add(ImplType::serialImpl(arch));
}
@@ -95,6 +106,7 @@ services::Status ModelImpl::deserializeImpl(const data_management::OutputDataArc
{
auto s = algorithms::classifier::Model::serialImpl(arch);
arch->set(this->_nFeatures); //algorithms::classifier::internal::ModelInternal
+ arch->set(this->_predictionBias);
return s.add(ImplType::serialImpl(
arch, COMPUTE_DAAL_VERSION(arch->getMajorVersion(), arch->getMinorVersion(), arch->getUpdateVersion())));
}
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp
index 1c0462dd7a1..c4c195e400f 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp
+++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_builder.cpp
@@ -68,7 +68,7 @@ services::Status ModelBuilder::initialize(size_t nFeatures, size_t nIterations,
return s;
}
-services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t clasLabel, TreeId & resId)
+services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t classLabel, TreeId & resId)
{
gbt::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
@@ -83,19 +83,22 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t clasLabe
{
return Status(ErrorID::ErrorIncorrectParameter);
}
- if (clasLabel > (_nClasses - 1))
+ if (_nClasses <= classLabel)
{
return Status(ErrorID::ErrorIncorrectParameter);
}
- TreeId treeId = clasLabel * _nIterations;
+ TreeId treeId = classLabel * _nIterations;
const SerializationIface * isEmptyTreeTable = (*(modelImplRef._serializationData))[treeId].get();
- const size_t nTrees = (clasLabel + 1) * _nIterations;
+ const size_t nTrees = (classLabel + 1) * _nIterations;
while (isEmptyTreeTable && treeId < nTrees)
{
treeId++;
isEmptyTreeTable = (*(modelImplRef._serializationData))[treeId].get();
}
- if (treeId == nTrees) return Status(ErrorID::ErrorIncorrectParameter);
+ if (treeId == nTrees)
+ {
+ return Status(ErrorID::ErrorIncorrectParameter);
+ }
services::SharedPtr treeTablePtr(
new DecisionTreeTable(nNodes)); //DecisionTreeTable* const treeTable = new DecisionTreeTable(nNodes);
@@ -120,22 +123,21 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t clasLabe
}
}
-services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res)
+services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res)
{
gbt::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, response,
- res);
- ;
+ cover, res);
}
services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue,
- NodeId & res, int defaultLeft)
+ int defaultLeft, const double cover, NodeId & res)
{
gbt::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
- featureValue, res, defaultLeft);
+ featureValue, defaultLeft, cover, res);
}
} // namespace interface1
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h
index 8e9fbbc216e..931780a5f1c 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h
+++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h
@@ -61,10 +61,17 @@ class ModelImpl : public daal::algorithms::gbt::classification::Model,
virtual void traverseDFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE;
virtual void traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE;
+ virtual void setPredictionBias(double value) DAAL_C11_OVERRIDE;
+ virtual double getPredictionBias() const DAAL_C11_OVERRIDE;
+
virtual services::Status serializeImpl(data_management::InputDataArchive * arch) DAAL_C11_OVERRIDE;
virtual services::Status deserializeImpl(const data_management::OutputDataArchive * arch) DAAL_C11_OVERRIDE;
virtual size_t getNumberOfTrees() const DAAL_C11_OVERRIDE;
+
+private:
+ /* global bias applied to predictions*/
+ double _predictionBias;
};
} // namespace internal
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i
index 0a79b9d040b..00a9ee5884d 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i
+++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i
@@ -38,6 +38,7 @@
#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i"
#include "src/algorithms/objective_function/cross_entropy_loss/cross_entropy_loss_dense_default_batch_kernel.h"
#include "src/services/service_algo_utils.h"
+#include
using namespace daal::internal;
using namespace daal::services::internal;
@@ -56,111 +57,49 @@ namespace internal
{
//////////////////////////////////////////////////////////////////////////////////////////
-// PredictBinaryClassificationTask
+// PredictBinaryClassificationTask - declaration
//////////////////////////////////////////////////////////////////////////////////////////
template
class PredictBinaryClassificationTask : public gbt::regression::prediction::internal::PredictRegressionTask
{
public:
typedef gbt::regression::prediction::internal::PredictRegressionTask super;
+
+public:
+ /**
+ * \brief Construct a new Predict Binary Classification Task object
+ *
+ * \param x NumericTable observation data
+ * \param y NumericTable prediction data
+ * \param prob NumericTable probability data
+ */
PredictBinaryClassificationTask(const NumericTable * x, NumericTable * y, NumericTable * prob) : super(x, y), _prob(prob) {}
- services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp)
- {
- DAAL_ASSERT(!nIterations || nIterations <= m->size());
- DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data));
- const auto nTreesTotal = (nIterations ? nIterations : m->size());
- this->_aTree.reset(nTreesTotal);
- DAAL_CHECK_MALLOC(this->_aTree.get());
- for (size_t i = 0; i < nTreesTotal; ++i) this->_aTree[i] = m->at(i);
- const auto nRows = this->_data->getNumberOfRows();
- services::Status s;
- DAAL_OVERFLOW_CHECK_BY_MULTIPLICATION(size_t, nRows, sizeof(algorithmFPType));
- //compute raw boosted values
- if (this->_res && _prob)
- {
- WriteOnlyRows resBD(this->_res, 0, nRows);
- DAAL_CHECK_BLOCK_STATUS(resBD);
- const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) };
- algorithmFPType * res = resBD.get();
- WriteOnlyRows probBD(_prob, 0, nRows);
- DAAL_CHECK_BLOCK_STATUS(probBD);
- algorithmFPType * prob_pred = probBD.get();
- TArray expValPtr(nRows);
- algorithmFPType * expVal = expValPtr.get();
- DAAL_CHECK_MALLOC(expVal);
- s = super::runInternal(pHostApp, this->_res);
- if (!s) return s;
-
- auto nBlocks = daal::threader_get_threads_number();
- const size_t blockSize = nRows / nBlocks;
- nBlocks += (nBlocks * blockSize != nRows);
-
- daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) {
- const size_t startRow = iBlock * blockSize;
- const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize);
- daal::internal::MathInst::vExp(finishRow - startRow, res + startRow, expVal + startRow);
-
- PRAGMA_IVDEP
- PRAGMA_VECTOR_ALWAYS
- for (size_t iRow = startRow; iRow < finishRow; ++iRow)
- {
- res[iRow] = label[services::internal::SignBit::get(res[iRow])];
- prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]);
- prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1];
- }
- });
- }
- else if ((!this->_res) && _prob)
- {
- WriteOnlyRows probBD(_prob, 0, nRows);
- DAAL_CHECK_BLOCK_STATUS(probBD);
- algorithmFPType * prob_pred = probBD.get();
- TArray expValPtr(nRows);
- algorithmFPType * expVal = expValPtr.get();
- NumericTablePtr expNT = HomogenNumericTableCPU::create(expVal, 1, nRows, &s);
- DAAL_CHECK_MALLOC(expVal);
- s = super::runInternal(pHostApp, expNT.get());
- if (!s) return s;
-
- auto nBlocks = daal::threader_get_threads_number();
- const size_t blockSize = nRows / nBlocks;
- nBlocks += (nBlocks * blockSize != nRows);
- daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) {
- const size_t startRow = iBlock * blockSize;
- const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize);
- daal::internal::MathInst::vExp(finishRow - startRow, expVal + startRow, expVal + startRow);
- for (size_t iRow = startRow; iRow < finishRow; ++iRow)
- {
- prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]);
- prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1];
- }
- });
- }
- else if (this->_res && (!_prob))
- {
- WriteOnlyRows resBD(this->_res, 0, nRows);
- DAAL_CHECK_BLOCK_STATUS(resBD);
- const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) };
- algorithmFPType * res = resBD.get();
- s = super::runInternal(pHostApp, this->_res);
- if (!s) return s;
-
- for (size_t iRow = 0; iRow < nRows; ++iRow)
- {
- //probablity is a sigmoid(f) hence sign(f) can be checked
- res[iRow] = label[services::internal::SignBit::get(res[iRow])];
- }
- }
- return s;
- }
+ /**
+ * \brief Run prediction for the given model
+ *
+ * \param m The model for which to run prediction
+ * \param nIterations Number of iterations
+ * \param pHostApp HostAppInterface
+ * \return services::Status
+ */
+ services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp);
+
+protected:
+ /**
+ * \brief Convert the model bias to a margin, considering the softmax activation
+ *
+ * \param bias Bias in class score units
+ * \return algorithmFPType Bias in softmax offset units
+ */
+ algorithmFPType getMarginFromModelBias(algorithmFPType bias) const;
protected:
NumericTable * _prob;
};
//////////////////////////////////////////////////////////////////////////////////////////
-// PredictMulticlassTask
+// PredictMulticlassTask - declaration
//////////////////////////////////////////////////////////////////////////////////////////
template
class PredictMulticlassTask
@@ -171,200 +110,603 @@ public:
typedef daal::tls ClassesRawBoostedTlsBase;
typedef daal::TlsMem ClassesRawBoostedTls;
+ /**
+ * \brief Construct a new Predict Multiclass Task object
+ *
+ * \param x NumericTable observation data
+ * \param y NumericTable prediction data
+ * \param prob NumericTable probability data
+ */
PredictMulticlassTask(const NumericTable * x, NumericTable * y, NumericTable * prob) : _data(x), _res(y), _prob(prob) {}
+
+ /**
+ * \brief Run prediction for the given model
+ *
+ * \param m The model for which to run prediction
+ * \param nClasses Number of data classes
+ * \param nIterations Number of iterations
+ * \param pHostApp HostAppInterface
+ * \return services::Status
+ */
services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nClasses, size_t nIterations, services::HostAppIface * pHostApp);
protected:
- services::Status predictByAllTrees(size_t nTreesTotal, size_t nClasses, const DimType & dim);
-
+ /** Dispatcher type for template dispatching */
template
using dispatcher_t = gbt::prediction::internal::PredictDispatcher;
+
+ /**
+ * \brief Helper boolean constant to populate template dispatcher
+ *
+ * \param val A boolean value, known at compile time
+ */
+ template
+ struct BooleanConstant
+ {
+ typedef BooleanConstant type;
+ };
+
+ /**
+ * \brief Run prediction for all trees
+ *
+ * \param nTreesTotal Total number of trees in model
+ * \param nClasses Number of data classes
+ * \param bias Global prediction bias (e.g. base_score in XGBoost)
+ * \param dim DimType helper
+ * \return services::Status
+ */
+ services::Status predictByAllTrees(size_t nTreesTotal, size_t nClasses, double bias, const DimType & dim);
+
+ /**
+ * \brief Make prediction for a number of trees
+ *
+ * \param hasUnorderedFeatures Data has unordered features yes/no
+ * \param hasAnyMissing Data has missing values yes/no
+ * \param val Output prediction
+ * \param iFirstTree Index of first ree
+ * \param nTrees Number of trees included in prediction
+ * \param nClasses Number of data classes
+ * \param x Input observation data
+ * \param dispatcher Template dispatcher helper
+ */
template
void predictByTrees(algorithmFPType * val, size_t iFirstTree, size_t nTrees, size_t nClasses, const algorithmFPType * x,
const dispatcher_t & dispatcher);
+
+ /**
+ * \brief Make prediction for a number of trees leveraging vector instructions
+ *
+ * \param hasUnorderedFeatures Data has unordered features yes/no
+ * \param hasAnyMissing Data has missing values yes/no
+ * \param vectorBlockSize Vector instruction block size
+ * \param val Output prediction
+ * \param iFirstTree Index of first ree
+ * \param nTrees Number of trees included in prediction
+ * \param nClasses Number of data classes
+ * \param x Input observation data
+ * \param dispatcher Template dispatcher helper
+ */
template
void predictByTreesVector(algorithmFPType * val, size_t iFirstTree, size_t nTrees, size_t nClasses, const algorithmFPType * x,
const dispatcher_t & dispatcher);
- template
- struct BooleanConstant
- {
- typedef BooleanConstant type;
- };
+ /**
+ * \brief Assign a class index to the result
+ *
+ * \param res Pointer to result array
+ * \param val Value of current prediction
+ * \param iRow
+ * \param i
+ * \param nClasses Number of data classes
+ * \param dispatcher Template dispatcher helper
+ */
+ inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher);
+
+ /**
+ * \brief Empty function if results assigning is not required.
+ *
+ * \param res Pointer to result array
+ * \param val Value of current prediction
+ * \param iRow
+ * \param i
+ * \param nClasses Number of data classes
+ * \param dispatcher Template dispatcher helper
+ */
+ inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher);
+
+ /**
+ * \brief Prepare buff pointer for the next using. All steps reuse the same memory.
+ *
+ * \param buff Pointer to a buffer
+ * \param buf_shift
+ * \param buf_size
+ * \param dispatcher
+ * \return algorithmFPType* Pointer to the input buffer
+ */
+ inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher);
+
+ /**
+ * \brief Prepare buff pointer for the next using. Steps have own memory.
+ *
+ * \param buff
+ * \param buf_shift
+ * \param buf_size
+ * \param dispatcher
+ * \return algorithmFPType*
+ */
+ inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher);
+
+ /**
+ * \brief Get the total number of nodes in all trees for tree number [1, 2, ... nTrees]
+ *
+ * \param nTrees Number of trees that contribute to the sum
+ * \return size_t Number of nodes in all contributing trees
+ */
+ inline size_t getNumberOfNodes(size_t nTrees);
+
+ /**
+ * \brief Check for missing data
+ *
+ * \param x Input observation data
+ * \param nTrees Number of contributing trees
+ * \param nRows Number of rows in input observation data to be considered
+ * \param nColumns Number of columns in input observation data to be considered
+ * \return true If runtime check for missing is required
+ * \return false If runtime check for missing is not required
+ */
+ inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns);
+
+ /**
+ * \brief Traverse a number of trees to get prediction results
+ *
+ * \param hasUnorderedFeatures Data has unordered features yes/no
+ * \param hasAnyMissing Data has missing values yes/no
+ * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no)
+ * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no)
+ * \param vectorBlockSize Vector instruction block size
+ * \param nTrees Number of trees contributing to prediction
+ * \param nClasses Number of data classes
+ * \param nRows Number of rows in observation data for which prediction is run
+ * \param nColumns Number of columns in observation data
+ * \param x Input observation data
+ * \param buff A pre-allocated buffer for computations
+ * \param[out] res Output prediction result
+ */
+ template
+ inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
+ algorithmFPType * res);
+
+ /**
+ * \brief Traverse a number of trees to get prediction results
+ *
+ * \param hasAnyMissing Data has missing values yes/no
+ * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no)
+ * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no)
+ * \param vectorBlockSize Vector instruction block size
+ * \param nTrees Number of trees contributing to prediction
+ * \param nClasses Number of data classes
+ * \param nRows Number of rows in observation data for which prediction is run
+ * \param nColumns Number of columns in observation data
+ * \param x Input observation data
+ * \param buff A pre-allocated buffer for computations
+ * \param[out] res Output prediction result
+ */
+ template
+ inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
+ algorithmFPType * res);
+
+ /**
+ * \brief Traverse a number of trees to get prediction results
+ *
+ * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no)
+ * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no)
+ * \param vectorBlockSize Vector instruction block size
+ * \param nTrees Number of trees contributing to prediction
+ * \param nClasses Number of data classes
+ * \param nRows Number of rows in observation data for which prediction is run
+ * \param nColumns Number of columns in observation data
+ * \param x Input observation data
+ * \param buff A pre-allocated buffer for computations
+ * \param[out] res Output prediction result
+ */
+ template
+ inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
+ algorithmFPType * res);
+
+ /**
+ * \brief Traverse a number of trees to get prediction results
+ *
+ * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no)
+ * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no)
+ * \param vectorBlockSizeFactor Vector instruction block size - recursively decremented until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor
+ * \param nTrees Number of trees contributing to prediction
+ * \param nClasses Number of data classes
+ * \param nRows Number of rows in observation data for which prediction is run
+ * \param nColumns Number of columns in observation data
+ * \param x Input observation data
+ * \param buff A pre-allocated buffer for computations
+ * \param[out] res Output prediction result
+ */
+ template
+ inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
+ algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking);
+
+ /**
+ * \brief Traverse a number of trees to get prediction results
+ *
+ * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no)
+ * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no)
+ * \param vectorBlockSizeFactor Vector instruction block size - recursively decremented until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor
+ * \param nTrees Number of trees contributing to prediction
+ * \param nClasses Number of data classes
+ * \param nRows Number of rows in observation data for which prediction is run
+ * \param nColumns Number of columns in observation data
+ * \param x Input observation data
+ * \param buff A pre-allocated buffer for computations
+ * \param[out] res Output prediction result
+ * \param dim DimType helper
+ * \param keepLooking
+ */
+ template
+ inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
+ algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking);
+
+ /**
+ * \brief Traverse a number of trees to get prediction results
+ *
+ * \param isResValidPtr Result pointer is valid yes/no (write result to the pointer if yes, skip if no)
+ * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no)
+ * \param nTrees Number of trees contributing to prediction
+ * \param nClasses Number of data classes
+ * \param nRows Number of rows in observation data for which prediction is run
+ * \param nColumns Number of columns in observation data
+ * \param x Input observation data
+ * \param buff A pre-allocated buffer for computations
+ * \param[out] res Output prediction result
+ * \param dim DimType helper
+ */
+ template
+ inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
+ algorithmFPType * res, const DimType & dim);
+
+ /**
+ * \brief Traverse a number of trees to get prediction results
+ *
+ * \param reuseBuffer Re-use buffer yes/no (will fill buffer with zero if yes, shift buff pointer if no)
+ * \param nTrees Number of trees contributing to prediction
+ * \param nClasses Number of data classes
+ * \param nRows Number of rows in observation data for which prediction is run
+ * \param nColumns Number of columns in observation data
+ * \param x Input observation data
+ * \param buff A pre-allocated buffer for computations
+ * \param[out] res Output prediction result
+ * \param dim DimType helper
+ */
+ template
+ inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
+ algorithmFPType * res, const DimType & dim);
- inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher)
- {
- res[iRow + i] = getMaxClass(val + i * nClasses, nClasses);
- }
+ /**
+ * \brief Get index of element with maximum value / activation
+ *
+ * \param val Pointer to values
+ * \param nClasses Number of columns per value
+ * \return size_t The column index with the maximum value
+ */
+ size_t getMaxClass(const algorithmFPType * val, size_t nClasses) const;
- inline void updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses, BooleanConstant dispatcher)
- {}
+protected:
+ const NumericTable * _data;
+ NumericTable * _res;
+ NumericTable * _prob;
+ dtrees::internal::FeatureTypes _featHelper;
+ TArray _aTree;
+};
- inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher)
- {
- services::internal::service_memset_seq(buff, algorithmFPType(0), buf_size);
- return buff;
- }
+//////////////////////////////////////////////////////////////////////////////////////////
+// PredictBinaryClassificationTask - implementation
+//////////////////////////////////////////////////////////////////////////////////////////
+template
+services::Status PredictBinaryClassificationTask::run(const gbt::classification::internal::ModelImpl * m, size_t nIterations,
+ services::HostAppIface * pHostApp)
+{
+ DAAL_ASSERT(!nIterations || nIterations <= m->size());
+ DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data));
+ const auto nTreesTotal = (nIterations ? nIterations : m->size());
+ this->_aTree.reset(nTreesTotal);
+ DAAL_CHECK_MALLOC(this->_aTree.get());
+ for (size_t i = 0; i < nTreesTotal; ++i) this->_aTree[i] = m->at(i);
- inline algorithmFPType * updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size, BooleanConstant dispatcher)
- {
- return buff + buf_shift;
- }
+ const auto nRows = this->_data->getNumberOfRows();
+ services::Status s;
+ DAAL_OVERFLOW_CHECK_BY_MULTIPLICATION(size_t, nRows, sizeof(algorithmFPType));
- inline size_t getNumberOfNodes(size_t nTrees)
+ // we convert the bias to a margin if it's > 0
+ // otherwise the margin is 0
+ algorithmFPType margin(0);
+ if (m->getPredictionBias() > FLT_EPSILON)
{
- size_t nNodesTotal = 0;
- for (size_t iTree = 0; iTree < nTrees; ++iTree)
- {
- nNodesTotal += this->_aTree[iTree]->getNumberOfNodes();
- }
- return nNodesTotal;
+ margin = getMarginFromModelBias(m->getPredictionBias());
}
- inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns)
+ // compute raw boosted values
+ if (this->_res && _prob)
{
- size_t nLvlTotal = 0;
- for (size_t iTree = 0; iTree < nTrees; ++iTree)
- {
- nLvlTotal += this->_aTree[iTree]->getMaxLvl();
- }
- if (nLvlTotal <= nColumns)
- {
- // Checking is compicated. Better to do it during inferense.
- return true;
- }
- else
- {
- for (size_t idx = 0; idx < nRows * nColumns; ++idx)
+ WriteOnlyRows resBD(this->_res, 0, nRows);
+ DAAL_CHECK_BLOCK_STATUS(resBD);
+ const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) };
+ algorithmFPType * res = resBD.get();
+ WriteOnlyRows probBD(_prob, 0, nRows);
+ DAAL_CHECK_BLOCK_STATUS(probBD);
+ algorithmFPType * prob_pred = probBD.get();
+ TArray expValPtr(nRows);
+ algorithmFPType * expVal = expValPtr.get();
+ DAAL_CHECK_MALLOC(expVal);
+ s = super::runInternal(pHostApp, this->_res, margin, false, false);
+ if (!s) return s;
+
+ auto nBlocks = daal::threader_get_threads_number();
+ const size_t blockSize = nRows / nBlocks;
+ nBlocks += (nBlocks * blockSize != nRows);
+
+ daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) {
+ const size_t startRow = iBlock * blockSize;
+ const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize);
+ daal::internal::MathInst::vExp(finishRow - startRow, res + startRow, expVal + startRow);
+
+ PRAGMA_IVDEP
+ PRAGMA_VECTOR_ALWAYS
+ for (size_t iRow = startRow; iRow < finishRow; ++iRow)
{
- if (checkFinitenessByComparison(x[idx])) return true;
+ res[iRow] = label[services::internal::SignBit::get(res[iRow])];
+ prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]);
+ prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1];
}
- }
- return false;
+ });
}
- template
- inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
- algorithmFPType * res)
+ else if ((!this->_res) && _prob)
{
- dispatcher_t dispatcher;
- size_t iRow = 0;
- for (; iRow + vectorBlockSize <= nRows; iRow += vectorBlockSize)
- {
- algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses * vectorBlockSize, BooleanConstant());
- predictByTreesVector(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher);
- for (size_t i = 0; i < vectorBlockSize; ++i)
+ WriteOnlyRows probBD(_prob, 0, nRows);
+ DAAL_CHECK_BLOCK_STATUS(probBD);
+ algorithmFPType * prob_pred = probBD.get();
+ TArray expValPtr(nRows);
+ algorithmFPType * expVal = expValPtr.get();
+ NumericTablePtr expNT = HomogenNumericTableCPU::create(expVal, 1, nRows, &s);
+ DAAL_CHECK_MALLOC(expVal);
+ s = super::runInternal(pHostApp, expNT.get(), margin, false, false);
+ if (!s) return s;
+
+ auto nBlocks = daal::threader_get_threads_number();
+ const size_t blockSize = nRows / nBlocks;
+ nBlocks += (nBlocks * blockSize != nRows);
+ daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) {
+ const size_t startRow = iBlock * blockSize;
+ const size_t finishRow = (((iBlock + 1) == nBlocks) ? nRows : (iBlock + 1) * blockSize);
+ daal::internal::MathInst::vExp(finishRow - startRow, expVal + startRow, expVal + startRow);
+ for (size_t iRow = startRow; iRow < finishRow; ++iRow)
{
- updateResult(res, val, iRow, i, nClasses, BooleanConstant());
+ prob_pred[2 * iRow + 1] = expVal[iRow] / (algorithmFPType(1.) + expVal[iRow]);
+ prob_pred[2 * iRow] = algorithmFPType(1.) - prob_pred[2 * iRow + 1];
}
- }
- for (; iRow < nRows; ++iRow)
+ });
+ }
+ else if (this->_res && (!_prob))
+ {
+ WriteOnlyRows resBD(this->_res, 0, nRows);
+ DAAL_CHECK_BLOCK_STATUS(resBD);
+ const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) };
+ algorithmFPType * res = resBD.get();
+ s = super::runInternal(pHostApp, this->_res, margin, false, false);
+ if (!s) return s;
+
+ typedef services::internal::SignBit SignBit;
+
+ PRAGMA_IVDEP
+ for (size_t iRow = 0; iRow < nRows; ++iRow)
{
- algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses, BooleanConstant());
- predictByTrees(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher);
- updateResult(res, val, iRow, 0, nClasses, BooleanConstant());
+ // probability is a sigmoid(f) hence sign(f) can be checked
+ const algorithmFPType initial = res[iRow];
+ const int sign = SignBit::get(initial);
+ res[iRow] = label[sign];
}
}
+ return s;
+}
- template
- inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
- algorithmFPType * res)
+template
+algorithmFPType PredictBinaryClassificationTask::getMarginFromModelBias(algorithmFPType bias) const
+{
+ DAAL_ASSERT((0.0 < bias) && (bias < 1.0));
+ constexpr algorithmFPType one(1);
+ // convert bias to margin
+ return -one * daal::internal::MathInst::sLog(one / bias - one);
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// PredictMulticlassTask - implementation
+//////////////////////////////////////////////////////////////////////////////////////////
+
+template
+void PredictMulticlassTask::updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses,
+ BooleanConstant dispatcher)
+{
+ res[iRow + i] = getMaxClass(val + i * nClasses, nClasses);
+}
+
+template
+void PredictMulticlassTask::updateResult(algorithmFPType * res, algorithmFPType * val, size_t iRow, size_t i, size_t nClasses,
+ BooleanConstant dispatcher)
+{}
+
+template
+algorithmFPType * PredictMulticlassTask::updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size,
+ BooleanConstant dispatcher)
+{
+ services::internal::service_memset_seq(buff, algorithmFPType(0), buf_size);
+ return buff;
+}
+
+template
+algorithmFPType * PredictMulticlassTask::updateBuffer(algorithmFPType * buff, size_t buf_shift, size_t buf_size,
+ BooleanConstant dispatcher)
+{
+ return buff + buf_shift;
+}
+
+template
+size_t PredictMulticlassTask::getNumberOfNodes(size_t nTrees)
+{
+ size_t nNodesTotal = 0;
+ for (size_t iTree = 0; iTree < nTrees; ++iTree)
{
- if (this->_featHelper.hasUnorderedFeatures())
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
- }
- else
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
- }
+ nNodesTotal += this->_aTree[iTree]->getNumberOfNodes();
}
+ return nNodesTotal;
+}
- template
- inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
- algorithmFPType * res)
+template
+bool PredictMulticlassTask::checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns)
+{
+ size_t nLvlTotal = 0;
+ for (size_t iTree = 0; iTree < nTrees; ++iTree)
{
- const bool hasAnyMissing = checkForMissing(x, nTrees, nRows, nColumns);
- if (hasAnyMissing)
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
- }
- else
+ nLvlTotal += this->_aTree[iTree]->getMaxLvl();
+ }
+ if (nLvlTotal <= nColumns)
+ {
+ // Checking is complicated. Better to do it during inference.
+ return true;
+ }
+ else
+ {
+ for (size_t idx = 0; idx < nRows * nColumns; ++idx)
{
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
+ if (checkFinitenessByComparison(x[idx])) return true;
}
}
+ return false;
+}
- // Recursivelly checking template parameter until it becomes equal to dim.vectorBlockSizeFactor or equal to DimType::minVectorBlockSizeFactor.
- template
- inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
- algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking)
+template
+template
+void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x,
+ algorithmFPType * buff, algorithmFPType * res)
+{
+ dispatcher_t dispatcher;
+ size_t iRow = 0;
+ for (; iRow + vectorBlockSize <= nRows; iRow += vectorBlockSize)
{
- constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep;
- if (dim.vectorBlockSizeFactor == vectorBlockSizeFactor)
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
- }
- else
+ algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses * vectorBlockSize, BooleanConstant());
+ predictByTreesVector(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher);
+ for (size_t i = 0; i < vectorBlockSize; ++i)
{
- predict(
- nTrees, nClasses, nRows, nColumns, x, buff, res, dim, BooleanConstant());
+ updateResult(res, val, iRow, i, nClasses, BooleanConstant());
}
}
+ for (; iRow < nRows; ++iRow)
+ {
+ algorithmFPType * val = updateBuffer(buff, iRow * nClasses, nClasses, BooleanConstant());
+ predictByTrees(val, 0, nTrees, nClasses, x + iRow * nColumns, dispatcher);
+ updateResult(res, val, iRow, 0, nClasses, BooleanConstant());
+ }
+}
- template
- inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
- algorithmFPType * res, const DimType & dim, BooleanConstant keepLooking)
+template
+template
+void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x,
+ algorithmFPType * buff, algorithmFPType * res)
+{
+ if (this->_featHelper.hasUnorderedFeatures())
{
- constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep;
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
+ }
+ else
+ {
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
}
+}
- template
- inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
- algorithmFPType * res, const DimType & dim)
+template
+template
+void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x,
+ algorithmFPType * buff, algorithmFPType * res)
+{
+ const bool hasAnyMissing = checkForMissing(x, nTrees, nRows, nColumns);
+ if (hasAnyMissing)
{
- constexpr size_t maxVectorBlockSizeFactor = DimType::maxVectorBlockSizeFactor;
- if (maxVectorBlockSizeFactor > 1)
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim,
- BooleanConstant());
- }
- else
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim,
- BooleanConstant());
- }
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
+ }
+ else
+ {
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
}
+}
- template
- inline void predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * buff,
- algorithmFPType * res, const DimType & dim)
+template
+template
+void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x,
+ algorithmFPType * buff, algorithmFPType * res, const DimType & dim,
+ BooleanConstant keepLooking)
+{
+ constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep;
+ if (dim.vectorBlockSizeFactor == vectorBlockSizeFactor)
{
- if (res)
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim);
- }
- else
- {
- predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim);
- }
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
+ }
+ else
+ {
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim,
+ BooleanConstant());
}
+}
+
+template
+template
+void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x,
+ algorithmFPType * buff, algorithmFPType * res, const DimType & dim,
+ BooleanConstant keepLooking)
+{
+ constexpr size_t vectorBlockSizeStep = DimType::vectorBlockSizeStep;
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res);
+}
- void softmax(algorithmFPType * Input, algorithmFPType * Output, size_t nRows, size_t nCols);
+template
+template
+void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x,
+ algorithmFPType * buff, algorithmFPType * res, const DimType & dim)
+{
+ constexpr size_t maxVectorBlockSizeFactor = DimType::maxVectorBlockSizeFactor;
+ if (maxVectorBlockSizeFactor > 1)
+ {
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim, BooleanConstant());
+ }
+ else
+ {
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim, BooleanConstant());
+ }
+}
- size_t getMaxClass(const algorithmFPType * val, size_t nClasses) const
+template
+template
+void PredictMulticlassTask::predict(size_t nTrees, size_t nClasses, size_t nRows, size_t nColumns, const algorithmFPType * x,
+ algorithmFPType * buff, algorithmFPType * res, const DimType & dim)
+{
+ if (res)
+ {
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim);
+ }
+ else
{
- return services::internal::getMaxElementIndex(val, nClasses);
+ predict(nTrees, nClasses, nRows, nColumns, x, buff, res, dim);
}
+}
-protected:
- const NumericTable * _data;
- NumericTable * _res;
- NumericTable * _prob;
- dtrees::internal::FeatureTypes _featHelper;
- TArray _aTree;
-};
+template
+size_t PredictMulticlassTask::getMaxClass(const algorithmFPType * val, size_t nClasses) const
+{
+ return services::internal::getMaxElementIndex(val, nClasses);
+}
//////////////////////////////////////////////////////////////////////////////////////////
// PredictKernel
@@ -398,7 +740,7 @@ services::Status PredictMulticlassTask::run(const gbt::cla
DimType dim(*_data, nTreesTotal, getNumberOfNodes(nTreesTotal));
- return predictByAllTrees(nTreesTotal, nClasses, dim);
+ return predictByAllTrees(nTreesTotal, nClasses, m->getPredictionBias(), dim);
}
template
@@ -433,7 +775,7 @@ void PredictMulticlassTask::predictByTreesVector(algorithm
}
template
-services::Status PredictMulticlassTask::predictByAllTrees(size_t nTreesTotal, size_t nClasses, const DimType & dim)
+services::Status PredictMulticlassTask::predictByAllTrees(size_t nTreesTotal, size_t nClasses, double bias, const DimType & dim)
{
WriteOnlyRows resBD(_res, 0, dim.nRowsTotal);
DAAL_CHECK_BLOCK_STATUS(resBD);
@@ -449,7 +791,7 @@ services::Status PredictMulticlassTask::predictByAllTrees(
DAAL_OVERFLOW_CHECK_BY_MULTIPLICATION(size_t, nRows * nClasses, sizeof(algorithmFPType));
TArray valPtr(nRows * nClasses);
algorithmFPType * valFull = valPtr.get();
- services::internal::service_memset(valFull, algorithmFPType(0), nRows * nClasses);
+ services::internal::service_memset(valFull, algorithmFPType(bias), nRows * nClasses);
daal::threader_for(dim.nDataBlocks, dim.nDataBlocks, [&](size_t iBlock) {
const size_t iStartRow = iBlock * dim.nRowsInBlock;
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h
index 1a32131f8b8..0efa3707206 100755
--- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h
+++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h
@@ -51,9 +51,10 @@ class PredictKernel : public daal::algorithms::Kernel
/**
* \brief Compute gradient boosted trees prediction results.
*
- * \param a[in] Matrix of input variables X
- * \param m[in] Gradient boosted trees model obtained on training stage
- * \param r[out] Prediction results
+ * \param a[in] Matrix of input variables X
+ * \param m[in] Gradient boosted trees model obtained on training stage
+ * \param r[out] Prediction results
+ * \param prob[out] Prediction class probabilities
* \param nClasses[in] Number of classes in gradient boosted trees algorithm parameter
* \param nIterations[in] Number of iterations to predict in gradient boosted trees algorithm parameter
*/
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp
index f3ac56aa5f5..58e3da1a52b 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp
+++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp
@@ -96,20 +96,27 @@ services::Status Input::check(const daal::algorithms::Parameter * parameter, int
size_t nClasses = 0, nIterations = 0;
- const gbt::classification::prediction::interface2::Parameter * pPrm2 =
+ const gbt::classification::prediction::interface2::Parameter * pPrm =
dynamic_cast(parameter);
- if (pPrm2)
+ if (pPrm)
{
- nClasses = pPrm2->nClasses;
- nIterations = pPrm2->nIterations;
+ nClasses = pPrm->nClasses;
+ nIterations = pPrm->nIterations;
}
else
+ {
return services::ErrorNullParameterNotSupported;
+ }
auto maxNIterations = pModel->getNumberOfTrees();
if (nClasses > 2) maxNIterations /= nClasses;
DAAL_CHECK((nClasses < 3) || (pModel->getNumberOfTrees() % nClasses == 0), services::ErrorGbtIncorrectNumberOfTrees);
DAAL_CHECK((nIterations == 0) || (nIterations <= maxNIterations), services::ErrorGbtPredictIncorrectNumberOfIterations);
+
+ const bool predictContribs = pPrm->resultsToCompute & shapContributions;
+ const bool predictInteractions = pPrm->resultsToCompute & shapInteractions;
+ DAAL_CHECK(!(predictContribs || predictInteractions), services::ErrorMethodNotImplemented);
+
return s;
}
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp
index 05974244d41..391a267f62c 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp
+++ b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model.cpp
@@ -23,7 +23,6 @@
#include "services/daal_defines.h"
#include "src/algorithms/dtrees/gbt/gbt_model_impl.h"
-#include "src/algorithms/dtrees/dtrees_model_impl_common.h"
using namespace daal::data_management;
using namespace daal::services;
@@ -63,8 +62,8 @@ void ModelImpl::traverseDF(size_t iTree, algorithms::regression::TreeNodeVisitor
const GbtDecisionTree & gbtTree = *at(iTree);
- const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints();
- const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
+ const ModelFPType * splitPoints = gbtTree.getSplitPoints();
+ const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
auto onSplitNodeFunc = [&splitPoints, &splitFeatures, &visitor](size_t iRowInTable, size_t level) -> bool {
return visitor.onSplitNode(level, splitFeatures[iRowInTable], splitPoints[iRowInTable]);
@@ -83,8 +82,8 @@ void ModelImpl::traverseBF(size_t iTree, algorithms::regression::TreeNodeVisitor
const GbtDecisionTree & gbtTree = *at(iTree);
- const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints();
- const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
+ const ModelFPType * splitPoints = gbtTree.getSplitPoints();
+ const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
auto onSplitNodeFunc = [&splitFeatures, &splitPoints, &visitor](size_t iRowInTable, size_t level) -> bool {
return visitor.onSplitNode(level, splitFeatures[iRowInTable], splitPoints[iRowInTable]);
@@ -108,10 +107,10 @@ void ModelImpl::traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisito
const GbtDecisionTree & gbtTree = *at(iTree);
- const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints();
- const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
- const int * nodeSamplesCount = getNodeSampleCount(iTree);
- const double * imp = getImpVals(iTree);
+ const ModelFPType * splitPoints = gbtTree.getSplitPoints();
+ const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
+ const int * nodeSamplesCount = getNodeSampleCount(iTree);
+ const double * imp = getImpVals(iTree);
auto onSplitNodeFunc = [&splitFeatures, &splitPoints, &nodeSamplesCount, &imp, &visitor](size_t iRowInTable, size_t level) -> bool {
tree_utils::SplitNodeDescriptor descSplit;
@@ -148,10 +147,10 @@ void ModelImpl::traverseDFS(size_t iTree, tree_utils::regression::TreeNodeVisito
const GbtDecisionTree & gbtTree = *at(iTree);
- const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints();
- const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
- const int * nodeSamplesCount = getNodeSampleCount(iTree);
- const double * imp = getImpVals(iTree);
+ const ModelFPType * splitPoints = gbtTree.getSplitPoints();
+ const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
+ const int * nodeSamplesCount = getNodeSampleCount(iTree);
+ const double * imp = getImpVals(iTree);
auto onSplitNodeFunc = [&splitFeatures, &splitPoints, &nodeSamplesCount, &imp, &visitor](size_t iRowInTable, size_t level) -> bool {
tree_utils::SplitNodeDescriptor descSplit;
@@ -226,15 +225,18 @@ void ModelImpl::destroy()
super::destroy();
}
-bool ModelImpl::nodeIsDummyLeaf(size_t idx, const GbtDecisionTree & gbtTree)
+bool ModelImpl::nodeIsDummyLeaf(size_t nodeIndex, const GbtDecisionTree & gbtTree)
{
- const gbt::prediction::internal::ModelFPType * splitPoints = gbtTree.getSplitPoints();
- const gbt::prediction::internal::FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
+ const size_t childArrayIndex = nodeIndex - 1;
+ const ModelFPType * splitPoints = gbtTree.getSplitPoints();
+ const FeatureIndexType * splitFeatures = gbtTree.getFeatureIndexesForSplit();
- if (idx)
+ if (childArrayIndex)
{
- const size_t parent = getIdxOfParent(idx);
- return splitPoints[parent] == splitPoints[idx] && splitFeatures[parent] == splitFeatures[idx];
+ // check if child node has same split feature and split value as parent
+ const size_t parent = getIdxOfParent(nodeIndex);
+ const size_t parentArrayIndex = parent - 1;
+ return splitPoints[parentArrayIndex] == splitPoints[childArrayIndex] && splitFeatures[parentArrayIndex] == splitFeatures[childArrayIndex];
}
return false;
}
@@ -245,16 +247,16 @@ bool ModelImpl::nodeIsLeaf(size_t idx, const GbtDecisionTree & gbtTree, const si
{
return true;
}
- else if (nodeIsDummyLeaf(2 * idx + 1, gbtTree)) // check, that left son is dummy
+ else if (nodeIsDummyLeaf(2 * idx, gbtTree)) // check, that left son is dummy
{
return true;
}
return false;
}
-size_t ModelImpl::getIdxOfParent(const size_t sonIdx)
+size_t ModelImpl::getIdxOfParent(const size_t childIdx)
{
- return sonIdx ? (sonIdx - 1) / 2 : 0;
+ return childIdx / 2;
}
void ModelImpl::decisionTreeToGbtTree(const DecisionTreeTable & tree, GbtDecisionTree & newTree)
@@ -270,9 +272,10 @@ void ModelImpl::decisionTreeToGbtTree(const DecisionTreeTable & tree, GbtDecisio
NodeType * sons = sonsArr.data();
NodeType * parents = parentsArr.data();
- gbt::prediction::internal::ModelFPType * const spitPoints = newTree.getSplitPoints();
- gbt::prediction::internal::FeatureIndexType * const featureIndexes = newTree.getFeatureIndexesForSplit();
- int * const defaultLeft = newTree.getdefaultLeftForSplit();
+ ModelFPType * const splitPoints = newTree.getSplitPoints();
+ FeatureIndexType * const featureIndexes = newTree.getFeatureIndexesForSplit();
+ ModelFPType * const nodeCoverValues = newTree.getNodeCoverValues();
+ int * const defaultLeft = newTree.getDefaultLeftForSplit();
for (size_t i = 0; i < nSourceNodes; ++i)
{
@@ -293,21 +296,22 @@ void ModelImpl::decisionTreeToGbtTree(const DecisionTreeTable & tree, GbtDecisio
if (p->isSplit())
{
- sons[nSons++] = arr + p->leftIndexOrClass;
- sons[nSons++] = arr + p->leftIndexOrClass + 1;
- featureIndexes[idxInTable] = p->featureIndex;
- defaultLeft[idxInTable] = p->defaultLeft;
+ sons[nSons++] = arr + p->leftIndexOrClass;
+ sons[nSons++] = arr + p->leftIndexOrClass + 1;
+ featureIndexes[idxInTable] = p->featureIndex;
+ nodeCoverValues[idxInTable] = p->cover;
+ defaultLeft[idxInTable] = p->defaultLeft;
DAAL_ASSERT(featureIndexes[idxInTable] >= 0);
- spitPoints[idxInTable] = p->featureValueOrResponse;
+ splitPoints[idxInTable] = p->featureValueOrResponse;
}
else
{
- sons[nSons++] = p;
- sons[nSons++] = p;
- featureIndexes[idxInTable] = 0;
- defaultLeft[idxInTable] = 0;
- DAAL_ASSERT(featureIndexes[idxInTable] >= 0);
- spitPoints[idxInTable] = p->featureValueOrResponse;
+ sons[nSons++] = p;
+ sons[nSons++] = p;
+ featureIndexes[idxInTable] = 0;
+ nodeCoverValues[idxInTable] = p->cover;
+ defaultLeft[idxInTable] = 0;
+ splitPoints[idxInTable] = p->featureValueOrResponse;
}
idxInTable++;
@@ -351,7 +355,8 @@ void ModelImpl::getMaxLvl(const dtrees::internal::DecisionTreeNode * const arr,
const GbtDecisionTree * ModelImpl::at(const size_t idx) const
{
- return (const GbtDecisionTree *)(*super::_serializationData)[idx].get();
+ auto * const rawTree = (*super::_serializationData)[idx].get();
+ return static_cast(rawTree);
}
} // namespace internal
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h
index f56b4fc3126..5ae2a8504a0 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h
+++ b/cpp/daal/src/algorithms/dtrees/gbt/gbt_model_impl.h
@@ -26,7 +26,6 @@
#include "src/algorithms/dtrees/dtrees_model_impl.h"
#include "algorithms/regression/tree_traverse.h"
-#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i"
#include "algorithms/tree_utils/tree_utils_regression.h"
#include "src/algorithms/dtrees/dtrees_model_impl_common.h"
#include "src/services/service_arrays.h"
@@ -41,12 +40,13 @@ namespace gbt
{
namespace internal
{
+typedef uint32_t FeatureIndexType;
+typedef float ModelFPType;
+typedef services::Collection NodeIdxArray;
+
static inline size_t getNumberOfNodesByLvls(const size_t nLvls)
{
- size_t nNodes = 2; // nNodes = pow(2, nLvl+1) - 1
- for (size_t i = 0; i < nLvls; ++i) nNodes *= 2;
- nNodes--;
- return nNodes;
+ return (1 << (nLvls + 1)) - 1;
}
template
@@ -61,43 +61,61 @@ class GbtDecisionTree : public SerializationIface
{
public:
DECLARE_SERIALIZABLE();
- using SplitPointType = HomogenNumericTable;
- using FeatureIndexesForSplitType = HomogenNumericTable;
+ using SplitPointType = HomogenNumericTable;
+ using NodeCoverType = HomogenNumericTable;
+ using FeatureIndexesForSplitType = HomogenNumericTable;
using defaultLeftForSplitType = HomogenNumericTable;
- GbtDecisionTree(const size_t nNodes, const size_t maxLvl, const size_t sourceNumOfNodes)
+ GbtDecisionTree(const size_t nNodes, const size_t maxLvl)
: _nNodes(nNodes),
_maxLvl(maxLvl),
- _sourceNumOfNodes(sourceNumOfNodes),
_splitPoints(SplitPointType::create(1, nNodes, NumericTableIface::doAllocate)),
_featureIndexes(FeatureIndexesForSplitType::create(1, nNodes, NumericTableIface::doAllocate)),
- _defaultLeft(defaultLeftForSplitType::create(1, nNodes, NumericTableIface::doAllocate))
+ _nodeCoverValues(NodeCoverType::create(1, nNodes, NumericTableIface::doAllocate)),
+ _defaultLeft(defaultLeftForSplitType::create(1, nNodes, NumericTableIface::doAllocate)),
+ nNodeSplitFeature(),
+ CoverFeature(),
+ GainFeature()
{}
- // for serailization only
- GbtDecisionTree() : _nNodes(0), _maxLvl(0), _sourceNumOfNodes(0) {}
+ // for serialization only
+ GbtDecisionTree() : _nNodes(0), _maxLvl(0) {}
+
+ ModelFPType * getSplitPoints() { return _splitPoints->getArray(); }
+
+ FeatureIndexType * getFeatureIndexesForSplit() { return _featureIndexes->getArray(); }
- gbt::prediction::internal::ModelFPType * getSplitPoints() { return _splitPoints->getArray(); }
+ int * getDefaultLeftForSplit() { return _defaultLeft->getArray(); }
- gbt::prediction::internal::FeatureIndexType * getFeatureIndexesForSplit() { return _featureIndexes->getArray(); }
+ const ModelFPType * getSplitPoints() const { return _splitPoints->getArray(); }
- int * getdefaultLeftForSplit() { return _defaultLeft->getArray(); }
+ const FeatureIndexType * getFeatureIndexesForSplit() const { return _featureIndexes->getArray(); }
- const gbt::prediction::internal::ModelFPType * getSplitPoints() const { return _splitPoints->getArray(); }
+ ModelFPType * getNodeCoverValues() { return _nodeCoverValues->getArray(); }
- const gbt::prediction::internal::FeatureIndexType * getFeatureIndexesForSplit() const { return _featureIndexes->getArray(); }
+ const ModelFPType * getNodeCoverValues() const { return _nodeCoverValues->getArray(); }
- const int * getdefaultLeftForSplit() const { return _defaultLeft->getArray(); }
+ const int * getDefaultLeftForSplit() const { return _defaultLeft->getArray(); }
size_t getNumberOfNodes() const { return _nNodes; }
size_t * getArrayNumSplitFeature() { return nNodeSplitFeature.data(); }
+ const size_t * getArrayNumSplitFeature() const { return nNodeSplitFeature.data(); }
+
size_t * getArrayCoverFeature() { return CoverFeature.data(); }
+ const size_t * getArrayCoverFeature() const { return CoverFeature.data(); }
+
+ services::Collection getCoverFeature() { return CoverFeature; }
+
+ const services::Collection & getCoverFeature() const { return CoverFeature; }
+
double * getArrayGainFeature() { return GainFeature.data(); }
- gbt::prediction::internal::FeatureIndexType getMaxLvl() const { return _maxLvl; }
+ const double * getArrayGainFeature() const { return GainFeature.data(); }
+
+ FeatureIndexType getMaxLvl() const { return _maxLvl; }
// recursive build of tree (breadth-first)
template
@@ -113,8 +131,8 @@ class GbtDecisionTree : public SerializationIface
int result = 0;
- gbt::prediction::internal::ModelFPType * const spitPoints = tree->getSplitPoints();
- gbt::prediction::internal::FeatureIndexType * const featureIndexes = tree->getFeatureIndexesForSplit();
+ ModelFPType * const splitPoints = tree->getSplitPoints();
+ FeatureIndexType * const featureIndexes = tree->getFeatureIndexesForSplit();
for (size_t i = 0; i < nNodes; ++i)
{
@@ -163,7 +181,7 @@ class GbtDecisionTree : public SerializationIface
DAAL_ASSERT(featureIndexes[idxInTable] >= 0);
nNodeSamplesVals[idxInTable] = (int)p->count;
impVals[idxInTable] = p->impurity;
- spitPoints[idxInTable] = p->featureValue;
+ splitPoints[idxInTable] = p->featureValue;
idxInTable++;
}
@@ -183,10 +201,10 @@ class GbtDecisionTree : public SerializationIface
{
arch->set(_nNodes);
arch->set(_maxLvl);
- arch->set(_sourceNumOfNodes);
arch->setSharedPtrObj(_splitPoints);
arch->setSharedPtrObj(_featureIndexes);
+ arch->setSharedPtrObj(_nodeCoverValues);
arch->setSharedPtrObj(_defaultLeft);
return services::Status();
@@ -194,10 +212,10 @@ class GbtDecisionTree : public SerializationIface
protected:
size_t _nNodes;
- gbt::prediction::internal::FeatureIndexType _maxLvl;
- size_t _sourceNumOfNodes;
+ FeatureIndexType _maxLvl;
services::SharedPtr _splitPoints;
services::SharedPtr _featureIndexes;
+ services::SharedPtr _nodeCoverValues;
services::SharedPtr _defaultLeft;
services::Collection nNodeSplitFeature;
services::Collection CoverFeature;
@@ -222,7 +240,7 @@ class GbtTreeImpl : public dtrees::internal::TreeImpl
getMaxLvl(*super::top(), nLvls, static_cast(-1));
const size_t nNodes = getNumberOfNodesByLvls(nLvls);
- *pTbl = new GbtDecisionTree(nNodes, nLvls, super::top()->numChildren() + 1);
+ *pTbl = new GbtDecisionTree(nNodes, nLvls);
*pTblImp = new HomogenNumericTable(1, nNodes, NumericTable::doAllocate);
*pTblSmplCnt = new HomogenNumericTable(1, nNodes, NumericTable::doAllocate);
@@ -264,6 +282,29 @@ using TreeImpRegression = GbtTreeImpl > >
using TreeImpClassification = GbtTreeImpl, Allocator>;
+struct DecisionTreeNode
+{
+ size_t dimension;
+ size_t leftIndexOrClass;
+ double cutPointOrDependantVariable;
+};
+
+class DecisionTreeTable : public data_management::AOSNumericTable
+{
+public:
+ DecisionTreeTable(size_t rowCount, services::Status & st) : data_management::AOSNumericTable(sizeof(DecisionTreeNode), 3, rowCount, st)
+ {
+ setFeature(0, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, dimension));
+ setFeature(1, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, leftIndexOrClass));
+ setFeature(2, DAAL_STRUCT_MEMBER_OFFSET(DecisionTreeNode, cutPointOrDependantVariable));
+ st |= allocateDataMemory();
+ }
+ DecisionTreeTable(services::Status & st) : DecisionTreeTable(0, st) {}
+};
+
+typedef services::SharedPtr DecisionTreeTablePtr;
+typedef services::SharedPtr DecisionTreeTableConstPtr;
+
class ModelImpl : protected dtrees::internal::ModelImpl
{
public:
@@ -292,9 +333,33 @@ class ModelImpl : protected dtrees::internal::ModelImpl
static services::Status treeToTable(TreeType & t, gbt::internal::GbtDecisionTree ** pTbl, HomogenNumericTable ** pTblImp,
HomogenNumericTable ** pTblSmplCnt, size_t nFeature);
-protected:
+ /**
+ * \brief Returns true if a node is a dummy leaf. A dummy leaf contains the same split feature & value as the parent
+ *
+ * \param nodeIndex 1-based index to the node array
+ * \param gbtTree tree containing nodes
+ * \param lvl current level in the tree
+ * \return true if the node is a dummy leaf, false otherwise
+ */
static bool nodeIsDummyLeaf(size_t idx, const GbtDecisionTree & gbtTree);
+
+ /**
+ * \brief Return true if a node is leaf
+ *
+ * \param idx 1-based index to the node array
+ * \param gbtTree tree containing nodes
+ * \param lvl current level in the tree
+ * \return true if the node is a leaf, false otherwise
+ */
static bool nodeIsLeaf(size_t idx, const GbtDecisionTree & gbtTree, const size_t lvl);
+
+protected:
+ /**
+ * \brief Return the node index of the provided node's parent
+ *
+ * \param childIdx 1-based node index of the child
+ * \return size_t 1-based node index of the parent
+ */
static size_t getIdxOfParent(const size_t sonIdx);
static void getMaxLvl(const dtrees::internal::DecisionTreeNode * const arr, const size_t idx, size_t & maxLvl, size_t curLvl = 0);
@@ -306,21 +371,22 @@ class ModelImpl : protected dtrees::internal::ModelImpl
getMaxLvl(arr, 0, nLvls, static_cast(-1));
const size_t nNodes = getNumberOfNodesByLvls(nLvls);
- return new GbtDecisionTree(nNodes, nLvls, tree.getNumberOfRows());
+ return new GbtDecisionTree(nNodes, nLvls);
}
template
static void traverseGbtDF(size_t level, size_t iRowInTable, const GbtDecisionTree & gbtTree, OnSplitFunctor & visitSplit,
OnLeafFunctor & visitLeaf)
{
- if (!nodeIsLeaf(iRowInTable, gbtTree, level))
+ const size_t oneBasedNodeIndex = iRowInTable + 1;
+ if (!nodeIsLeaf(oneBasedNodeIndex, gbtTree, level))
{
if (!visitSplit(iRowInTable, level)) return; //do not continue traversing
traverseGbtDF(level + 1, iRowInTable * 2 + 1, gbtTree, visitSplit, visitLeaf);
traverseGbtDF(level + 1, iRowInTable * 2 + 2, gbtTree, visitSplit, visitLeaf);
}
- else if (!nodeIsDummyLeaf(iRowInTable, gbtTree))
+ else if (!nodeIsDummyLeaf(oneBasedNodeIndex, gbtTree))
{
if (!visitLeaf(iRowInTable, level)) return; //do not continue traversing
}
@@ -334,14 +400,15 @@ class ModelImpl : protected dtrees::internal::ModelImpl
{
for (size_t j = 0; j < (level ? 2 : 1); ++j)
{
- size_t iRowInTable = aCur[i] + j;
- if (!nodeIsLeaf(iRowInTable, gbtTree, level))
+ const size_t iRowInTable = aCur[i] + j;
+ const size_t oneBasedNodeIndex = iRowInTable + 1;
+ if (!nodeIsLeaf(oneBasedNodeIndex, gbtTree, level))
{
if (!visitSplit(iRowInTable, level)) return; //do not continue traversing
aNext.push_back(iRowInTable * 2 + 1);
}
- else if (!nodeIsDummyLeaf(iRowInTable, gbtTree))
+ else if (!nodeIsDummyLeaf(oneBasedNodeIndex, gbtTree))
{
if (!visitLeaf(iRowInTable, level)) return; //do not continue traversing
}
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i
index 18976987b37..fd76f8da721 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i
+++ b/cpp/daal/src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i
@@ -25,14 +25,15 @@
#ifndef __GBT_PREDICT_DENSE_DEFAULT_IMPL_I__
#define __GBT_PREDICT_DENSE_DEFAULT_IMPL_I__
+#include "data_management/data/internal/finiteness_checker.h"
+#include "src/algorithms/dtrees/dtrees_feature_type_helper.h"
#include "src/algorithms/dtrees/dtrees_model_impl.h"
-#include "src/algorithms/dtrees/dtrees_train_data_helper.i"
#include "src/algorithms/dtrees/dtrees_predict_dense_default_impl.i"
-#include "src/algorithms/dtrees/dtrees_feature_type_helper.h"
+#include "src/algorithms/dtrees/dtrees_train_data_helper.i"
#include "src/algorithms/dtrees/gbt/gbt_internal.h"
-#include "src/services/service_environment.h"
+#include "src/algorithms/dtrees/gbt/gbt_model_impl.h"
#include "src/services/service_defines.h"
-#include "data_management/data/internal/finiteness_checker.h"
+#include "src/services/service_environment.h"
namespace daal
{
@@ -44,8 +45,8 @@ namespace prediction
{
namespace internal
{
-typedef float ModelFPType;
-typedef uint32_t FeatureIndexType;
+using gbt::internal::ModelFPType;
+using gbt::internal::FeatureIndexType;
template
struct PredictDispatcher
@@ -101,7 +102,7 @@ inline void predictForTreeVector(const DecisionTreeType & t, const FeatureTypes
{
const ModelFPType * const values = t.getSplitPoints() - 1;
const FeatureIndexType * const fIndexes = t.getFeatureIndexesForSplit() - 1;
- const int * const defaultLeft = t.getdefaultLeftForSplit() - 1;
+ const int * const defaultLeft = t.getDefaultLeftForSplit() - 1;
const FeatureIndexType nFeat = featTypes.getNumberOfFeatures();
FeatureIndexType i[vectorBlockSize];
@@ -135,7 +136,7 @@ inline algorithmFPType predictForTree(const DecisionTreeType & t, const FeatureT
{
const ModelFPType * const values = (const ModelFPType *)t.getSplitPoints() - 1;
const FeatureIndexType * const fIndexes = t.getFeatureIndexesForSplit() - 1;
- const int * const defaultLeft = t.getdefaultLeftForSplit() - 1;
+ const int * const defaultLeft = t.getDefaultLeftForSplit() - 1;
const FeatureIndexType maxLvl = t.getMaxLvl();
@@ -167,7 +168,7 @@ struct TileDimensions
static constexpr size_t minVectorBlockSizeFactor = 2;
static constexpr size_t vectorBlockSizeStep = 16;
// optimalBlockSizeFactor is selected from benchmarking
- static constexpr size_t optimalBlockSizeFactor = 3;
+ static constexpr size_t optimalBlockSizeFactor = 5;
TileDimensions(const NumericTable & data, size_t nTrees, size_t nNodes)
: nTreesTotal(nTrees), nRowsTotal(data.getNumberOfRows()), nCols(data.getNumberOfColumns())
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD b/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD
index 7780e5acaa6..8e85e233059 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD
+++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/BUILD
@@ -1,10 +1,12 @@
package(default_visibility = ["//visibility:public"])
load("@onedal//dev/bazel:daal.bzl", "daal_module")
+load("@onedal//dev/bazel:dal.bzl", "dal_test_suite")
daal_module(
name = "kernel",
- auto = True,
- opencl = True,
+ auto = False,
+ hdrs = glob(["**/*.h", "**/*.i", "**/*.cl"]),
+ srcs = glob(["*.cpp"]),
deps = [
"@onedal//cpp/daal:core",
"@onedal//cpp/daal:sycl",
@@ -12,3 +14,16 @@ daal_module(
"@onedal//cpp/daal/src/algorithms/dtrees/gbt:kernel",
],
)
+
+dal_test_suite(
+ name = "tests",
+ compile_as = [ "c++" ],
+ private = True,
+ srcs = glob([
+ "test/*unit.cpp",
+ ]),
+ dal_deps = [
+ "@onedal//cpp/daal/src/algorithms/regression:kernel",
+ "@onedal//cpp/daal/src/algorithms/dtrees/gbt:kernel",
+ ]
+)
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp
index 2c2818b3104..0012688dc1e 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp
+++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model.cpp
@@ -85,9 +85,20 @@ void ModelImpl::traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisito
ImplType::traverseBFS(iTree, visitor);
}
+void ModelImpl::setPredictionBias(double value)
+{
+ _predictionBias = value;
+}
+
+double ModelImpl::getPredictionBias() const
+{
+ return _predictionBias;
+}
+
services::Status ModelImpl::serializeImpl(data_management::InputDataArchive * arch)
{
auto s = algorithms::regression::Model::serialImpl(arch);
+ arch->set(this->_predictionBias);
s.add(algorithms::regression::internal::ModelInternal::serialImpl(arch));
return s.add(ImplType::serialImpl(arch));
}
@@ -95,6 +106,7 @@ services::Status ModelImpl::serializeImpl(data_management::InputDataArchive * ar
services::Status ModelImpl::deserializeImpl(const data_management::OutputDataArchive * arch)
{
auto s = algorithms::regression::Model::serialImpl(arch);
+ arch->set(this->_predictionBias);
s.add(algorithms::regression::internal::ModelInternal::serialImpl(arch));
return s.add(ImplType::serialImpl(arch));
}
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp
index aefc88d2ab0..7c1a1d10bf2 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp
+++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_builder.cpp
@@ -71,21 +71,21 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, TreeId & resId)
return daal::algorithms::dtrees::internal::createTreeInternal(modelImplRef._serializationData, nNodes, resId);
}
-services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res)
+services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res)
{
gbt::regression::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
return daal::algorithms::dtrees::internal::addLeafNodeInternal(modelImplRef._serializationData, treeId, parentId, position, response,
- res);
+ cover, res);
}
services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue,
- NodeId & res, int defaultLeft)
+ int defaultLeft, double cover, NodeId & res)
{
gbt::regression::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
- featureValue, res, defaultLeft);
+ featureValue, defaultLeft, cover, res);
}
} // namespace interface1
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h
index a711f0d747c..13409fb1790 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h
+++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h
@@ -61,10 +61,17 @@ class ModelImpl : public daal::algorithms::gbt::regression::Model,
virtual void traverseDFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE;
virtual void traverseBFS(size_t iTree, tree_utils::regression::TreeNodeVisitor & visitor) const DAAL_C11_OVERRIDE;
+ virtual void setPredictionBias(double value) DAAL_C11_OVERRIDE;
+ virtual double getPredictionBias() const DAAL_C11_OVERRIDE;
+
virtual services::Status serializeImpl(data_management::InputDataArchive * arch) DAAL_C11_OVERRIDE;
virtual services::Status deserializeImpl(const data_management::OutputDataArchive * arch) DAAL_C11_OVERRIDE;
virtual size_t getNumberOfTrees() const DAAL_C11_OVERRIDE;
+
+private:
+ /* global bias applied to predictions*/
+ double _predictionBias;
};
} // namespace internal
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h
index 8e3a231e1eb..d2e4bd4bdf7 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h
+++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_container.h
@@ -61,8 +61,10 @@ services::Status BatchContainer::compute()
const gbt::regression::prediction::Parameter * par = static_cast(_par);
daal::services::Environment::env & env = *_env;
+ const bool predShapContributions = par->resultsToCompute & shapContributions;
+ const bool predShapInteractions = par->resultsToCompute & shapInteractions;
__DAAL_CALL_KERNEL(env, internal::PredictKernel, __DAAL_KERNEL_ARGUMENTS(algorithmFPType, method), compute,
- daal::services::internal::hostApp(*input), a, m, r, par->nIterations);
+ daal::services::internal::hostApp(*input), a, m, r, par->nIterations, predShapContributions, predShapInteractions);
}
} // namespace prediction
diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i
index cd2410e1692..7b506fe36b5 100644
--- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i
+++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i
@@ -26,15 +26,17 @@
#include "algorithms/algorithm.h"
#include "data_management/data/numeric_table.h"
-#include "src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h"
-#include "src/threading/threading.h"
#include "services/daal_defines.h"
+#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i"
#include "src/algorithms/dtrees/gbt/regression/gbt_regression_model_impl.h"
-#include "src/data_management/service_numeric_table.h"
+#include "src/algorithms/dtrees/gbt/regression/gbt_regression_predict_kernel.h"
+#include "src/algorithms/dtrees/gbt/treeshap.h"
+#include "src/algorithms/dtrees/regression/dtrees_regression_predict_dense_default_impl.i"
#include "src/algorithms/service_error_handling.h"
+#include "src/data_management/service_numeric_table.h"
#include "src/externals/service_memory.h"
-#include "src/algorithms/dtrees/regression/dtrees_regression_predict_dense_default_impl.i"
-#include "src/algorithms/dtrees/gbt/gbt_predict_dense_default_impl.i"
+#include "src/threading/threading.h"
+#include // DBL_EPSILON
using namespace daal::internal;
using namespace daal::services::internal;
@@ -51,6 +53,7 @@ namespace prediction
{
namespace internal
{
+using gbt::internal::FeatureIndexType;
//////////////////////////////////////////////////////////////////////////////////////////
// PredictRegressionTask
@@ -62,13 +65,16 @@ public:
typedef gbt::internal::GbtDecisionTree TreeType;
typedef gbt::prediction::internal::TileDimensions DimType;
PredictRegressionTask(const NumericTable * x, NumericTable * y) : _data(x), _res(y) {}
- services::Status run(const gbt::regression::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp);
+
+ services::Status run(const gbt::regression::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp,
+ bool predShapContributions, bool predShapInteractions);
protected:
template
using dispatcher_t = gbt::prediction::internal::PredictDispatcher;
- services::Status runInternal(services::HostAppIface * pHostApp, NumericTable * result);
+ services::Status runInternal(services::HostAppIface * pHostApp, NumericTable * result, double predictionBias, bool predShapContributions,
+ bool predShapInteractions);
template
algorithmFPType predictByTrees(size_t iFirstTree, size_t nTrees, const algorithmFPType * x,
const dispatcher_t & dispatcher);
@@ -78,29 +84,29 @@ protected:
inline size_t getNumberOfNodes(size_t nTrees)
{
- size_t nNodesTotal = 0;
- for (size_t iTree = 0; iTree < nTrees; ++iTree)
+ size_t nNodesTotal = 0ul;
+ for (size_t iTree = 0ul; iTree < nTrees; ++iTree)
{
- nNodesTotal += this->_aTree[iTree]->getNumberOfNodes();
+ nNodesTotal += _aTree[iTree]->getNumberOfNodes();
}
return nNodesTotal;
}
- inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns)
+ inline bool checkForMissing(const algorithmFPType * x, size_t nTrees, size_t nRows, size_t nColumns) const
{
- size_t nLvlTotal = 0;
- for (size_t iTree = 0; iTree < nTrees; ++iTree)
+ size_t nLvlTotal = 0ul;
+ for (size_t iTree = 0ul; iTree < nTrees; ++iTree)
{
- nLvlTotal += this->_aTree[iTree]->getMaxLvl();
+ nLvlTotal += _aTree[iTree]->getMaxLvl();
}
if (nLvlTotal <= nColumns)
{
- // Checking is compicated. Better to do it during inferense.
+ // Checking is complicated. Better to do it during inference
return true;
}
else
{
- for (size_t idx = 0; idx < nRows * nColumns; ++idx)
+ for (size_t idx = 0ul; idx < nRows * nColumns; ++idx)
{
if (checkFinitenessByComparison(x[idx])) return true;
}
@@ -126,7 +132,7 @@ protected:
template
inline void predict(size_t iTree, size_t nTrees, size_t nRows, size_t nColumns, const algorithmFPType * x, algorithmFPType * res)
{
- if (this->_featHelper.hasUnorderedFeatures())
+ if (_featHelper.hasUnorderedFeatures())
{
predict(iTree, nTrees, nRows, nColumns, x, res);
}
@@ -150,17 +156,86 @@ protected:
}
}
+ template
+ inline services::Status predictContributions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, algorithmFPType * res,
+ int condition, FeatureIndexType conditionFeature, const DimType & dim);
+
+ template
+ inline services::Status predictContributions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, algorithmFPType * res,
+ int condition, FeatureIndexType conditionFeature, const DimType & dim)
+ {
+ if (_featHelper.hasUnorderedFeatures())
+ {
+ return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim);
+ }
+ else
+ {
+ return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim);
+ }
+ }
+
+ // TODO: Add vectorBlockSize templates, similar to predict
+ // template
+ inline services::Status predictContributions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x, algorithmFPType * res,
+ int condition, FeatureIndexType conditionFeature, const DimType & dim)
+ {
+ const bool hasAnyMissing = checkForMissing(x, nTrees, nRowsData, dim.nCols);
+ if (hasAnyMissing)
+ {
+ return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim);
+ }
+ else
+ {
+ return predictContributions(iTree, nTrees, nRowsData, x, res, condition, conditionFeature, dim);
+ }
+ }
+
+ template
+ inline services::Status predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x,
+ const algorithmFPType * nominal, algorithmFPType * res, const DimType & dim);
+
+ template