diff --git a/src/gbt_model_builder.h b/src/gbt_model_builder.h index 0f7335da65..00e418417e 100644 --- a/src/gbt_model_builder.h +++ b/src/gbt_model_builder.h @@ -22,10 +22,12 @@ #include #include "onedal/version.hpp" -#if (((MAJOR_VERSION == 2023) && (MINOR_VERSION >= 2)) || (MAJOR_VERSION > 2023)) - #define _gbt_inference_has_missing_values_support 1 +#if (((MAJOR_VERSION == 2024) && (MINOR_VERSION >= 1)) || (MAJOR_VERSION > 2024)) + #define _gbt_inference_api_version 2 +#elif (((MAJOR_VERSION == 2023) && (MINOR_VERSION >= 2)) || (MAJOR_VERSION > 2023)) + #define _gbt_inference_api_version 1 #else - #define _gbt_inference_has_missing_values_support 0 + #define _gbt_inference_api_version 0 #endif typedef daal::algorithms::gbt::classification::ModelBuilder c_gbt_classification_model_builder; @@ -49,9 +51,11 @@ static daal::algorithms::gbt::regression::ModelPtr * get_gbt_regression_model_bu return RAW()(obj_->getModel()); } -c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft) +c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft) { -#if _gbt_inference_has_missing_values_support +#if (_gbt_inference_api_version == 2) + return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover, defaultLeft); +#elif (_gbt_inference_api_version == 1) return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft); #else return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue); @@ -60,10 +64,12 @@ c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft) { -#if _gbt_inference_has_missing_values_support +#if (_gbt_inference_api_version == 2) return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover, defaultLeft); +#elif (_gbt_inference_api_version == 1) + return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, defaultLeft); #else - return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue, cover); + return c_ptr->addSplitNode(treeId, parentId, position, featureIndex, featureValue); #endif } diff --git a/src/gbt_model_builder.pyx b/src/gbt_model_builder.pyx index 0ea68a44f4..9f1fd4acb6 100644 --- a/src/gbt_model_builder.pyx +++ b/src/gbt_model_builder.pyx @@ -33,14 +33,14 @@ cdef extern from "gbt_model_builder.h": cdef cppclass c_gbt_classification_model_builder: c_gbt_classification_model_builder(size_t nFeatures, size_t nIterations, size_t nClasses) except + c_gbt_clf_tree_id createTree(size_t nNodes, size_t classLabel) - c_gbt_clf_node_id addLeafNode(c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, double response) + c_gbt_clf_node_id addLeafNode(c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, double response, double cover) cdef cppclass c_gbt_regression_model_builder: c_gbt_regression_model_builder(size_t nFeatures, size_t nIterations) except + c_gbt_reg_tree_id createTree(size_t nNodes) c_gbt_reg_node_id addLeafNode(c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, double response, double cover) - cdef c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, int defaultLeft) + cdef c_gbt_clf_node_id clfAddSplitNodeWrapper(c_gbt_classification_model_builder * c_ptr, c_gbt_clf_tree_id treeId, c_gbt_clf_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft) cdef c_gbt_reg_node_id regAddSplitNodeWrapper(c_gbt_regression_model_builder * c_ptr, c_gbt_reg_tree_id treeId, c_gbt_reg_node_id parentId, size_t position, size_t featureIndex, double featureValue, double cover, int defaultLeft) cdef class gbt_classification_model_builder: @@ -76,8 +76,7 @@ cdef class gbt_classification_model_builder: :param double cover: cover (sum_hess) of the leaf node :rtype: node identifier ''' - # TODO: Forward cover to oneDAL - return self.c_ptr.addLeafNode(tree_id, parent_id, position, response) + return self.c_ptr.addLeafNode(tree_id, parent_id, position, response, cover) def add_split(self, c_gbt_clf_tree_id tree_id, size_t feature_index, double feature_value, double cover, int default_left, c_gbt_clf_node_id parent_id=c_gbt_clf_no_parent, size_t position=0): ''' @@ -92,8 +91,7 @@ cdef class gbt_classification_model_builder: :param int default_left: default behaviour in case of missing value :rtype: node identifier ''' - # TODO: Forward cover to oneDAL - return clfAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, default_left) + return clfAddSplitNodeWrapper(self.c_ptr, tree_id, parent_id, position, feature_index, feature_value, cover, default_left) def model(self): '''