Skip to content

Commit

Permalink
Consistently add cover to daaal APIs, add output parameters to end of…
Browse files Browse the repository at this point in the history
… function arguments
  • Loading branch information
ahuber21 committed Sep 27, 2023
1 parent 7ba11f4 commit 88b770b
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 528 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ class DAAL_EXPORT ModelBuilder
* \param[in] classLabel Class label to be predicted
* \return Node identifier
*/
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel)
NodeId addLeafNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t classLabel, const double cover)
{
NodeId resId;
const double cover = 0.0; // TODO: Add cover
_status |= addLeafNodeInternal(treeId, parentId, position, classLabel, cover, resId);
services::throwIfPossible(_status);
return resId;
Expand All @@ -126,10 +125,9 @@ class DAAL_EXPORT ModelBuilder
* \param[in] proba Array with probability values for each class
* \return Node identifier
*/
NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba)
NodeId addLeafNodeByProba(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba, const double cover)
{
NodeId resId;
const double cover = 0.0; // TODO: Add cover
_status |= addLeafNodeByProbaInternal(treeId, parentId, position, proba, cover, resId);
services::throwIfPossible(_status);
return resId;
Expand All @@ -144,10 +142,10 @@ class DAAL_EXPORT ModelBuilder
* \param[in] featureValue Feature value for splitting
* \return Node identifier
*/
NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue)
NodeId addSplitNode(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex, const double featureValue,
const double cover)
{
NodeId resId;
const double cover = 0.0; // TODO: Add cover
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, cover, resId);
services::throwIfPossible(_status);
return resId;
Expand Down Expand Up @@ -192,7 +190,7 @@ class DAAL_EXPORT ModelBuilder
services::Status addLeafNodeByProbaInternal(const TreeId treeId, const NodeId parentId, const size_t position, const double * const proba,
const double cover, NodeId & res);
services::Status addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue, const double cover, NodeId & res);
const double featureValue, const double cover, const int defaultLeft, NodeId & res);

private:
size_t _nClasses;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ class DAAL_EXPORT ModelBuilder
* \param[in] parentId Parent node to which new node is added (use noParent for root node)
* \param[in] position Position in parent (e.g. 0 for left and 1 for right child in a binary tree)
* \param[in] response Response value for leaf node to be predicted
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response)
NodeId addLeafNode(TreeId treeId, NodeId parentId, size_t position, double response, double cover)
{
NodeId resId;
_status |= addLeafNodeInternal(treeId, parentId, position, response, resId);
_status |= addLeafNodeInternal(treeId, parentId, position, response, cover, resId);
services::throwIfPossible(_status);
return resId;
}
Expand All @@ -127,12 +128,13 @@ class DAAL_EXPORT ModelBuilder
* \param[in] featureIndex Feature index for splitting
* \param[in] featureValue Feature value for splitting
* \param[in] defaultLeft Behaviour in case of missing values
* \param[in] cover Cover (Hessian sum) of the node
* \return Node identifier
*/
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft = 0)
NodeId addSplitNode(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
NodeId resId;
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, resId, defaultLeft);
_status |= addSplitNodeInternal(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover, resId);
services::throwIfPossible(_status);
return resId;
}
Expand All @@ -159,9 +161,9 @@ class DAAL_EXPORT ModelBuilder
services::Status _status;
services::Status initialize(size_t nFeatures, size_t nIterations, size_t nClasses);
services::Status createTreeInternal(size_t nNodes, size_t classLabel, TreeId & resId);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, NodeId & res,
int defaultLeft);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, const double cover, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft,
const double cover, NodeId & res);
services::Status convertModelInternal();
size_t _nClasses;
size_t _nIterations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class DAAL_EXPORT ModelBuilder
services::Status createTreeInternal(size_t nNodes, TreeId & resId);
services::Status addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res);
services::Status addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue, double cover,
NodeId & res, int defaultLeft);
int defaultLeft, NodeId & res);
services::Status convertModelInternal();
};
/** @} */
Expand Down
2 changes: 1 addition & 1 deletion cpp/daal/src/algorithms/dtrees/dtrees_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ void setProbabilities(const size_t treeId, const size_t nodeId, const size_t res
}

services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position,
size_t featureIndex, double featureValue, double cover, size_t & res, int defaultLeft)
size_t featureIndex, double featureValue, double cover, int defaultLeft, size_t & res)
{
const size_t noParent = static_cast<size_t>(-1);
services::Status s;
Expand Down
2 changes: 1 addition & 1 deletion cpp/daal/src/algorithms/dtrees/dtrees_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ void setNode(DecisionTreeNode & node, int featureIndex, size_t classLabel, doubl
void setNode(DecisionTreeNode & node, int featureIndex, double response, double cover);

services::Status addSplitNodeInternal(data_management::DataCollectionPtr & serializationData, size_t treeId, size_t parentId, size_t position,
size_t featureIndex, double featureValue, double cover, size_t & res, int defaultLeft = 0);
size_t featureIndex, double featureValue, double cover, int defaultLeft, size_t & res);

void setProbabilities(const size_t treeId, const size_t nodeId, const size_t response, const data_management::DataCollectionPtr probTbl,
const double * const prob);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ services::Status ModelBuilder::addLeafNodeByProbaInternal(const TreeId treeId, c
}

services::Status ModelBuilder::addSplitNodeInternal(const TreeId treeId, const NodeId parentId, const size_t position, const size_t featureIndex,
const double featureValue, const double cover, NodeId & res)
const double featureValue, const double cover, const int defaultLeft, NodeId & res)
{
decision_forest::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<decision_forest::classification::internal::ModelImpl, ModelPtr>(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
featureValue, cover, res);
featureValue, cover, defaultLeft, res);
}

} // namespace interface2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,21 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, size_t clasLabe
}
}

services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, NodeId & res)
services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res)
{
gbt::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<daal::algorithms::gbt::classification::internal::ModelImpl, ModelPtr>(_model);
const double cover = 0.0; // TODO: Add cover
return daal::algorithms::dtrees::internal::addLeafNodeInternal<double>(modelImplRef._serializationData, treeId, parentId, position, response,
cover, res);
}

services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue,
NodeId & res, int defaultLeft)
int defaultLeft, const double cover, NodeId & res)
{
gbt::classification::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<daal::algorithms::gbt::classification::internal::ModelImpl, ModelPtr>(_model);
const double cover = 0.0; // TODO: Add cover
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
featureValue, cover, res, defaultLeft);
featureValue, cover, defaultLeft, res);
}

} // namespace interface1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,20 @@ services::Status ModelBuilder::createTreeInternal(size_t nNodes, TreeId & resId)
}

services::Status ModelBuilder::addLeafNodeInternal(TreeId treeId, NodeId parentId, size_t position, double response, double cover, NodeId & res)
{
{
gbt::regression::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<daal::algorithms::gbt::regression::internal::ModelImpl, ModelPtr>(_model);
return daal::algorithms::dtrees::internal::addLeafNodeInternal<double>(modelImplRef._serializationData, treeId, parentId, position, response,
cover, res);
}

services::Status ModelBuilder::addSplitNodeInternal(TreeId treeId, NodeId parentId, size_t position, size_t featureIndex, double featureValue,
double cover, NodeId & res, int defaultLeft)
double cover, int defaultLeft, NodeId & res)
{
gbt::regression::internal::ModelImpl & modelImplRef =
daal::algorithms::dtrees::internal::getModelRef<daal::algorithms::gbt::regression::internal::ModelImpl, ModelPtr>(_model);
return daal::algorithms::dtrees::internal::addSplitNodeInternal(modelImplRef._serializationData, treeId, parentId, position, featureIndex,
featureValue, cover, res, defaultLeft);
featureValue, cover, defaultLeft, res);
}

} // namespace interface1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,19 +337,6 @@ services::Status PredictRegressionTask<algorithmFPType, cpu>::predictContributio
const size_t nColumnsPhi = nColumnsData + 1;
const size_t biasTermIndex = nColumnsPhi - 1;

// some model details (populated only for Fast TreeSHAP v2)
gbt::treeshap::ModelDetails<algorithmFPType> modelDetails(_aTree.get(), iTree, nTrees);
if (modelDetails.requiresPrecompute)
{
for (size_t currentTreeIndex = iTree; currentTreeIndex < iTree + nTrees; ++currentTreeIndex)
{
// regression model builder tree 0 contains only the base_score and must be skipped
if (currentTreeIndex == 0) continue;
const gbt::internal::GbtDecisionTree * currentTree = _aTree[currentTreeIndex];
gbt::treeshap::computeCombinationSum(currentTree, currentTreeIndex, modelDetails);
}
}

for (size_t iRow = 0; iRow < nRowsData; ++iRow)
{
const algorithmFPType * currentX = x + (iRow * nColumnsData);
Expand All @@ -360,8 +347,8 @@ services::Status PredictRegressionTask<algorithmFPType, cpu>::predictContributio
if (currentTreeIndex == 0) continue;

const gbt::internal::GbtDecisionTree * currentTree = _aTree[currentTreeIndex];
st |= gbt::treeshap::treeShap<algorithmFPType, hasUnorderedFeatures, hasAnyMissing>(
currentTree, currentTreeIndex, currentX, phi, &_featHelper, condition, conditionFeature, modelDetails);
st |= gbt::treeshap::treeShap<algorithmFPType, hasUnorderedFeatures, hasAnyMissing>(currentTree, currentX, phi, &_featHelper, condition,
conditionFeature);
}

if (condition == 0)
Expand Down
2 changes: 0 additions & 2 deletions cpp/daal/src/algorithms/dtrees/gbt/treeshap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,6 @@ float unwoundPathSumZero(const float * partialWeights, unsigned uniqueDepth, uns
}
} // namespace v1

namespace v2
{}
} // namespace internal
} // namespace treeshap
} // namespace gbt
Expand Down
Loading

0 comments on commit 88b770b

Please sign in to comment.