Skip to content

Commit

Permalink
clean up inference APIs and versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Sep 27, 2023
1 parent 0ea9d49 commit 598adb2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
20 changes: 13 additions & 7 deletions src/gbt_model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
#include <daal.h>
#include "onedal/version.hpp"

#if (((MAJOR_VERSION == 2023) && (MINOR_VERSION >= 2)) || (MAJOR_VERSION > 2023))
#define _gbt_inference_has_missing_values_support 1
#if (((MAJOR_VERSION == 2024) && (MINOR_VERSION >= 1)) || (MAJOR_VERSION > 2024))
#define _gbt_inference_api_version 2
#elif (((MAJOR_VERSION == 2023) && (MINOR_VERSION >= 2)) || (MAJOR_VERSION > 2023))
#define _gbt_inference_api_version 1
#else
#define _gbt_inference_has_missing_values_support 0
#define _gbt_inference_api_version 0
#endif

typedef daal::algorithms::gbt::classification::ModelBuilder c_gbt_classification_model_builder;
Expand All @@ -49,9 +51,11 @@ static daal::algorithms::gbt::regression::ModelPtr * get_gbt_regression_model_bu
return RAW<daal::algorithms::gbt::regression::ModelPtr>()(obj_->getModel());
}

c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft)
c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)
{
#if _gbt_inference_has_missing_values_support
#if (_gbt_inference_api_version == 2)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover, defaultLeft);
#elif (_gbt_inference_api_version == 1)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft);
#else
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue);
Expand All @@ -60,10 +64,12 @@ c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_

c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)
{
#if _gbt_inference_has_missing_values_support
#if (_gbt_inference_api_version == 2)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover, defaultLeft);
#elif (_gbt_inference_api_version == 1)
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft);
#else
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover);
return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue);
#endif
}

Expand Down
10 changes: 4 additions & 6 deletions src/gbt_model_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ cdef extern from "gbt_model_builder.h":
cdef cppclass c_gbt_classification_model_builder:
c_gbt_classification_model_builder(size_t nFeatures, size_t nIterations, size_t nClasses) except +
c_gbt_clf_tree_id createTree(size_t nNodes, size_t classLabel)
c_gbt_clf_node_id addLeafNode(c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, double response)
c_gbt_clf_node_id addLeafNode(c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, double response, double cover)

cdef cppclass c_gbt_regression_model_builder:
c_gbt_regression_model_builder(size_t nFeatures, size_t nIterations) except +
c_gbt_reg_tree_id createTree(size_t nNodes)
c_gbt_reg_node_id addLeafNode(c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, double response, double cover)

cdef c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft)
cdef c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)
cdef c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft)

cdef class gbt_classification_model_builder:
Expand Down Expand Up @@ -76,8 +76,7 @@ cdef class gbt_classification_model_builder:
:param double cover: cover (sum_hess) of the leaf node
:rtype: node identifier
'''
# TODO: Forward cover to oneDAL
return self.c_ptr.addLeafNode(tree_id, parent_id, position, response)
return self.c_ptr.addLeafNode(tree_id, parent_id, position, response, cover)

def add_split(self, c_gbt_clf_tree_id tree_id, size_t feature_index, double feature_value, double cover, int default_left, c_gbt_clf_node_id parent_id=c_gbt_clf_no_parent, size_t position=0):
'''
Expand All @@ -92,8 +91,7 @@ cdef class gbt_classification_model_builder:
:param int default_left: default behaviour in case of missing value
:rtype: node identifier
'''
# TODO: Forward cover to oneDAL
return clfAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, default_left)
return clfAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, cover, default_left)

def model(self):
'''
Expand Down

0 comments on commit 598adb2

Please sign in to comment.