Skip to content

Commit

Permalink
align tree clf/reg APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Sep 27, 2023
1 parent 72bdfb3 commit 4a6a507
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/gbt_model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,21 @@ static daal::algorithms::gbt::regression::ModelPtr * get_gbt_regression_model_bu
return RAW<daal::algorithms::gbt::regression::ModelPtr>()(obj_->getModel());
}

c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)
c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
#if (_gbt_inference_api_version == 2)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover, defaultLeft);
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover);
#elif (_gbt_inference_api_version == 1)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft);
#else
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue);
#endif
}

c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)
c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
{
#if (_gbt_inference_api_version == 2)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover, defaultLeft);
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft, cover);
#elif (_gbt_inference_api_version == 1)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft);
#else
Expand Down
14 changes: 7 additions & 7 deletions src/gbt_model_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ cdef extern from "gbt_model_builder.h":
c_gbt_reg_tree_id createTree(size_t nNodes)
c_gbt_reg_node_id addLeafNode(c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, double response, double cover)

cdef c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)
cdef c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)
cdef c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)
cdef c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft, double cover)

cdef class gbt_classification_model_builder:
'''
Expand Down Expand Up @@ -78,7 +78,7 @@ cdef class gbt_classification_model_builder:
'''
return self.c_ptr.addLeafNode(tree_id, parent_id, position, response, cover)

def add_split(self, c_gbt_clf_tree_id tree_id, size_t feature_index, double feature_value, double cover, int default_left, c_gbt_clf_node_id parent_id=c_gbt_clf_no_parent, size_t position=0):
def add_split(self, c_gbt_clf_tree_id tree_id, size_t feature_index, double feature_value, int default_left, double cover, c_gbt_clf_node_id parent_id=c_gbt_clf_no_parent, size_t position=0):
'''
Create Split node and add it to certain tree.
Expand All @@ -87,11 +87,11 @@ cdef class gbt_classification_model_builder:
:param size_t position: position in parent (e.g. 0 for left and 1 for right child in a binary tree)
:param size_t feature_index: feature index for spliting
:param double feature_value: feature value for spliting
:param double cover: cover (sum_hess) of the solit node
:param int default_left: default behaviour in case of missing value
:param double cover: cover (sum_hess) of the solit node
:rtype: node identifier
'''
return clfAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, cover, default_left)
return clfAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, default_left, cover)

def model(self):
'''
Expand Down Expand Up @@ -138,7 +138,7 @@ cdef class gbt_regression_model_builder:
'''
return self.c_ptr.addLeafNode(tree_id, parent_id, position, response, cover)

def add_split(self, c_gbt_reg_tree_id tree_id, size_t feature_index, double feature_value, double cover, int default_left, c_gbt_reg_node_id parent_id=c_gbt_reg_no_parent, size_t position=0):
def add_split(self, c_gbt_reg_tree_id tree_id, size_t feature_index, double feature_value, int default_left, double cover, c_gbt_reg_node_id parent_id=c_gbt_reg_no_parent, size_t position=0):
'''
Create Split node and add it to certain tree.
Expand All @@ -151,7 +151,7 @@ cdef class gbt_regression_model_builder:
:param int default_left: default behaviour in case of missing value
:rtype: node identifier
'''
return regAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, cover, default_left)
return regAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, default_left, cover)

def model(self):
'''
Expand Down

0 comments on commit 4a6a507

Please sign in to comment.