diff --git a/onedal/ensemble/forest.cpp b/onedal/ensemble/forest.cpp index 939a1e23fd..51c8dd7dca 100644 --- a/onedal/ensemble/forest.cpp +++ b/onedal/ensemble/forest.cpp @@ -220,6 +220,17 @@ void init_train_ops(py::module_& m) { train_ops ops(policy, input_t{ data, responses, weights }, params2desc{}); return fptype2t{ method2t{ Task{}, ops } }(params); }); + m.def("train", + [](const Policy& policy, + const py::dict& params, + const table& data, + const table& responses) { + using namespace decision_forest; + using input_t = train_input; + + train_ops ops(policy, input_t{ data, responses}, params2desc{}); + return fptype2t{ method2t{ Task{}, ops } }(params); + }); } template diff --git a/onedal/ensemble/forest.py b/onedal/ensemble/forest.py index 173710b3d0..9e7b37eed2 100644 --- a/onedal/ensemble/forest.py +++ b/onedal/ensemble/forest.py @@ -331,48 +331,22 @@ def _validate_targets(self, y, dtype): self.classes_ = None return _column_or_1d(y, warn=True).astype(dtype, copy=False) - def _get_sample_weight(self, X, y, sample_weight): - n_samples = X.shape[0] - dtype = X.dtype - if n_samples == 1: - raise ValueError("n_samples=1") - - sample_weight = np.asarray( - [] if sample_weight is None else sample_weight, dtype=dtype + def _get_sample_weight(self, sample_weight, X): + sample_weight = np.asarray(sample_weight, dtype=X.dtype).ravel() + + sample_weight = _check_array( + sample_weight, accept_sparse=False, ensure_2d=False, dtype=X.dtype, order="C" ) - sample_weight = sample_weight.ravel() - sample_weight_count = sample_weight.shape[0] - if sample_weight_count != 0 and sample_weight_count != n_samples: + if sample_weight.size != X.shape[0]: raise ValueError( "sample_weight and X have incompatible shapes: " "%r vs %r\n" "Note: Sparse matrices cannot be indexed w/" "boolean masks (use `indices=True` in CV)." - % (len(sample_weight), X.shape) + % (sample_weight.shape, X.shape) ) - if sample_weight_count == 0: - sample_weight = np.ones(n_samples, dtype=dtype) - elif isinstance(sample_weight, Number): - sample_weight = np.full(n_samples, sample_weight, dtype=dtype) - else: - sample_weight = _check_array( - sample_weight, - accept_sparse=False, - ensure_2d=False, - dtype=dtype, - order="C", - ) - if sample_weight.ndim != 1: - raise ValueError("Sample weights must be 1D array or scalar") - - if sample_weight.shape != (n_samples,): - raise ValueError( - "sample_weight.shape == {}, expected {}!".format( - sample_weight.shape, (n_samples,) - ) - ) return sample_weight def _get_policy(self, queue, *data): @@ -387,16 +361,21 @@ def _fit(self, X, y, sample_weight, module, queue): accept_sparse="csr", ) y = self._validate_targets(y, X.dtype) - sample_weight = self._get_sample_weight(X, y, sample_weight) self.n_features_in_ = X.shape[1] if not sklearn_check_version("1.0"): self.n_features_ = self.n_features_in_ - policy = self._get_policy(queue, X, y, sample_weight) - X, y, sample_weight = _convert_to_supported(policy, X, y, sample_weight) - params = self._get_onedal_params(X) - train_result = module.train(policy, params, *to_table(X, y, sample_weight)) + if sample_weight is not None and len(sample_weight) > 0: + sample_weight = self._get_sample_weight(sample_weight, X) + data = (X, y, sample_weight) + else: + data = (X, y) + policy = self._get_policy(queue, *data) + data = _convert_to_supported(policy, *data) + params = self._get_onedal_params(data[0]) + train_result = module.train(policy, params, *to_table(*data)) + self._onedal_model = train_result.model if self.oob_score: diff --git a/onedal/primitives/tree_visitor.cpp b/onedal/primitives/tree_visitor.cpp index 0f147e1e8e..3ced21967e 100644 --- a/onedal/primitives/tree_visitor.cpp +++ b/onedal/primitives/tree_visitor.cpp @@ -82,7 +82,6 @@ struct tree_state { std::size_t leaf_count; std::size_t class_count; }; - // Declaration and implementation. template class node_count_visitor { @@ -100,6 +99,7 @@ class node_count_visitor { return true; } + std::size_t n_nodes; std::size_t depth; std::size_t n_leaf_nodes; @@ -133,6 +133,8 @@ class to_sklearn_tree_object_visitor : public tree_state { std::size_t _max_n_classes); bool call(const df::leaf_node_info& info); bool call(const df::split_node_info& info); + double* value_ar_ptr; + skl_tree_node* node_ar_ptr; protected: std::size_t node_id; @@ -163,42 +165,49 @@ to_sklearn_tree_object_visitor::to_sklearn_tree_object_visitor(std::size_t auto value_ar_strides = py::array::StridesContainer( { this->class_count * sizeof(double), this->class_count * sizeof(double), sizeof(double) }); - skl_tree_node* node_ar_ptr = new skl_tree_node[this->node_count]; + OVERFLOW_CHECK_BY_MULTIPLICATION(std::size_t, this->node_count, this->class_count); + this->node_ar_ptr = new skl_tree_node[this->node_count]; + this->value_ar_ptr = new double[this->node_count*this->class_count](); - OVERFLOW_CHECK_BY_MULTIPLICATION(std::size_t, this->node_count, this->class_count); - double* value_ar_ptr = - new double[this->node_count * 1 * - this->class_count](); // oneDAL only supports scalar responses for now + // array_t doesn't initialize the underlying memory with the object's constructor + // so the values will not match what is defined above, must be done on C++ side - this->node_ar = py::array_t(node_ar_shape, node_ar_strides, node_ar_ptr); - this->value_ar = py::array_t(value_ar_shape, value_ar_strides, value_ar_ptr); + py::capsule free_value_ar(this->value_ar_ptr, [](void* f){ + double *value_ar_ptr = reinterpret_cast(f); + delete[] value_ar_ptr; + }); + + py::capsule free_node_ar(this->node_ar_ptr, [](void* f){ + skl_tree_node *node_ar_ptr = reinterpret_cast(f); + delete[] node_ar_ptr; + }); + + this->node_ar = py::array_t(node_ar_shape, node_ar_strides, this->node_ar_ptr, free_node_ar); + this->value_ar = py::array_t(value_ar_shape, value_ar_strides, this->value_ar_ptr, free_value_ar); } template bool to_sklearn_tree_object_visitor::call(const df::split_node_info& info) { - py::buffer_info node_ar_buf = this->node_ar.request(); - - skl_tree_node* node_ar_ptr = static_cast(node_ar_buf.ptr); - if (info.get_level() > 0) { // has parents Py_ssize_t parent = parents[info.get_level() - 1]; - if (node_ar_ptr[parent].left_child > 0) { - assert(node_ar_ptr[node_id].right_child < 0); - node_ar_ptr[parent].right_child = node_id; + if (this->node_ar_ptr[parent].left_child > 0) { + assert(this->node_ar_ptr[node_id].right_child < 0); + this->node_ar_ptr[parent].right_child = node_id; } else { - node_ar_ptr[parent].left_child = node_id; + this->node_ar_ptr[parent].left_child = node_id; } } + parents[info.get_level()] = node_id; - node_ar_ptr[node_id].feature = info.get_feature_index(); - node_ar_ptr[node_id].threshold = info.get_feature_value(); - node_ar_ptr[node_id].impurity = info.get_impurity(); - node_ar_ptr[node_id].n_node_samples = info.get_sample_count(); - node_ar_ptr[node_id].weighted_n_node_samples = info.get_sample_count(); - node_ar_ptr[node_id].missing_go_to_left = false; + this->node_ar_ptr[node_id].feature = info.get_feature_index(); + this->node_ar_ptr[node_id].threshold = info.get_feature_value(); + this->node_ar_ptr[node_id].impurity = info.get_impurity(); + this->node_ar_ptr[node_id].n_node_samples = info.get_sample_count(); + this->node_ar_ptr[node_id].weighted_n_node_samples = info.get_sample_count(); + this->node_ar_ptr[node_id].missing_go_to_left = false; // wrap-up ++node_id; @@ -208,25 +217,21 @@ bool to_sklearn_tree_object_visitor::call(const df::split_node_info& // stuff that is done for all leaf node types template void to_sklearn_tree_object_visitor::_onLeafNode(const df::leaf_node_info& info) { - py::buffer_info node_ar_buf = this->node_ar.request(); - - skl_tree_node* node_ar_ptr = static_cast(node_ar_buf.ptr); - if (info.get_level()) { Py_ssize_t parent = parents[info.get_level() - 1]; - if (node_ar_ptr[parent].left_child > 0) { - assert(node_ar_ptr[node_id].right_child < 0); - node_ar_ptr[parent].right_child = node_id; + if (this->node_ar_ptr[parent].left_child > 0) { + assert(this->node_ar_ptr[node_id].right_child < 0); + this->node_ar_ptr[parent].right_child = node_id; } else { - node_ar_ptr[parent].left_child = node_id; + this->node_ar_ptr[parent].left_child = node_id; } } - node_ar_ptr[node_id].impurity = info.get_impurity(); - node_ar_ptr[node_id].n_node_samples = info.get_sample_count(); - node_ar_ptr[node_id].weighted_n_node_samples = info.get_sample_count(); - node_ar_ptr[node_id].missing_go_to_left = false; + this->node_ar_ptr[node_id].impurity = info.get_impurity(); + this->node_ar_ptr[node_id].n_node_samples = info.get_sample_count(); + this->node_ar_ptr[node_id].weighted_n_node_samples = info.get_sample_count(); + this->node_ar_ptr[node_id].missing_go_to_left = false; } template <> @@ -235,10 +240,7 @@ bool to_sklearn_tree_object_visitor::call( _onLeafNode(info); OVERFLOW_CHECK_BY_MULTIPLICATION(std::size_t, node_id, class_count); - py::buffer_info value_ar_buf = this->value_ar.request(); - double* value_ar_ptr = static_cast(value_ar_buf.ptr); - - value_ar_ptr[node_id * 1 * this->class_count] = info.get_response(); + this->value_ar_ptr[node_id * this->class_count] = info.get_response(); // wrap-up ++node_id; @@ -248,26 +250,19 @@ bool to_sklearn_tree_object_visitor::call( template <> bool to_sklearn_tree_object_visitor::call( const df::leaf_node_info& info) { - py::buffer_info value_ar_buf = this->value_ar.request(); - double* value_ar_ptr = static_cast(value_ar_buf.ptr); - - if (info.get_level() > 0) { - std::size_t depth = static_cast(info.get_level()) - 1; - while (depth >= 0) { - const std::size_t id = parents[depth]; - OVERFLOW_CHECK_BY_MULTIPLICATION(std::size_t, id, this->class_count); - const auto row = id * 1 * this->class_count; - OVERFLOW_CHECK_BY_ADDING(std::size_t, row, info.get_response()); - value_ar_ptr[row + info.get_response()] += info.get_sample_count(); - if (depth == 0) { - break; - } - --depth; - } + + std::size_t depth = static_cast(info.get_level()); + const std::size_t label = info.get_response(); // these may be a slow accesses due to oneDAL abstraction + const double nNodeSampleCount = info.get_sample_count(); // do them only once + + while(depth--) + { + const std::size_t id = parents[depth]; + const std::size_t row = id * this->class_count; + this->value_ar_ptr[row + label] += nNodeSampleCount; } _onLeafNode(info); - OVERFLOW_CHECK_BY_ADDING(std::size_t, node_id * 1 * this->class_count, info.get_response()); - value_ar_ptr[node_id * 1 * this->class_count + info.get_response()] += info.get_sample_count(); + this->value_ar_ptr[node_id * this->class_count + label] += nNodeSampleCount; // wrap-up ++node_id; diff --git a/sklearnex/preview/ensemble/extra_trees.py b/sklearnex/preview/ensemble/extra_trees.py index 9487882746..4cf2bc7fc8 100644 --- a/sklearnex/preview/ensemble/extra_trees.py +++ b/sklearnex/preview/ensemble/extra_trees.py @@ -191,6 +191,22 @@ def check_sample_weight(self, sample_weight, X, dtype=None): ) return sample_weight + @property + def estimators_(self): + if hasattr(self, "_cached_estimators_"): + if self._cached_estimators_ is None and self._onedal_model: + self._estimators_() + return self._cached_estimators_ + else: + raise AttributeError( + f"'{self.__class__.__name__}' has no attribute 'estimators_'" + ) + + @estimators_.setter + def estimators_(self, estimators): + # Needed to allow for proper sklearn operation in fallback mode + self._cached_estimators_ = estimators + class ExtraTreesClassifier(sklearn_ExtraTreesClassifier, BaseTree): __doc__ = sklearn_ExtraTreesClassifier.__doc__ @@ -541,17 +557,13 @@ def predict_proba(self, X): def n_features_(self): return self.n_features_in_ - @property def _estimators_(self): - if hasattr(self, "_cached_estimators_"): - if self._cached_estimators_: - return self._cached_estimators_ - if sklearn_check_version("0.22"): - check_is_fitted(self) - else: - check_is_fitted(self, "_onedal_model") + # _estimators_ should only be called if _onedal_model exists + check_is_fitted(self, "_onedal_model") classes_ = self.classes_[0] - n_classes_ = self.n_classes_[0] + n_classes_ = ( + self.n_classes_ if isinstance(self.n_classes_, int) else self.n_classes_[0] + ) # convert model to estimators params = { "criterion": self.criterion, @@ -570,7 +582,9 @@ def _estimators_(self): # we need to set est.tree_ field with Trees constructed from Intel(R) # oneAPI Data Analytics Library solution estimators_ = [] + random_state_checked = check_random_state(self.random_state) + for i in range(self.n_estimators): est_i = clone(est) est_i.set_params( @@ -599,7 +613,6 @@ def _estimators_(self): estimators_.append(est_i) self._cached_estimators_ = estimators_ - return estimators_ def _onedal_cpu_supported(self, method_name, *data): class_name = self.__class__.__name__ @@ -713,7 +726,7 @@ def _onedal_gpu_supported(self, method_name, *data): (self.warm_start is False, "Warm start is not supported."), ( daal_check_version((2023, "P", 100)), - "ExtraTrees only supported starting from oneDAL version 2023.1", + "ExtraTrees supported starting from oneDAL version 2023.1", ), ] ) @@ -823,6 +836,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): onedal_params["min_impurity_split"] = self.min_impurity_split else: onedal_params["min_impurity_split"] = None + + # Lazy evaluation of estimators_ self._cached_estimators_ = None # Compute @@ -832,7 +847,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() if sklearn_check_version("1.2"): self._estimator = ExtraTreeClassifier() - self.estimators_ = self._estimators_ + # Decapsulate classes_ attributes self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] @@ -840,7 +855,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): def _onedal_predict(self, X, queue=None): X = check_array(X, dtype=[np.float32, np.float64]) - check_is_fitted(self) + check_is_fitted(self, "_onedal_model") + if sklearn_check_version("1.0"): self._check_feature_names(X, reset=False) @@ -849,7 +865,8 @@ def _onedal_predict(self, X, queue=None): def _onedal_predict_proba(self, X, queue=None): X = check_array(X, dtype=[np.float64, np.float32]) - check_is_fitted(self) + check_is_fitted(self, "_onedal_model") + if sklearn_check_version("0.23"): self._check_n_features(X, reset=False) if sklearn_check_version("1.0"): @@ -967,15 +984,9 @@ def __init__( self.max_bins = max_bins self.min_bin_size = min_bin_size - @property def _estimators_(self): - if hasattr(self, "_cached_estimators_"): - if self._cached_estimators_: - return self._cached_estimators_ - if sklearn_check_version("0.22"): - check_is_fitted(self) - else: - check_is_fitted(self, "_onedal_model") + # _estimators_ should only be called if _onedal_model exists + check_is_fitted(self, "_onedal_model") # convert model to estimators params = { "criterion": self.criterion, @@ -995,6 +1006,7 @@ def _estimators_(self): # oneAPI Data Analytics Library solution estimators_ = [] random_state_checked = check_random_state(self.random_state) + for i in range(self.n_estimators): est_i = clone(est) est_i.set_params( @@ -1020,7 +1032,7 @@ def _estimators_(self): est_i.tree_.__setstate__(tree_i_state_dict) estimators_.append(est_i) - return estimators_ + self._cached_estimators_ = estimators_ def _onedal_fit_ready(self, patching_status, X, y, sample_weight): if sp.issparse(y): @@ -1319,20 +1331,26 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): } if daal_check_version((2023, "P", 101)): onedal_params["splitter_mode"] = "random" + + # Lazy evaluation of estimators_ self._cached_estimators_ = None + self._onedal_estimator = self._onedal_regressor(**onedal_params) self._onedal_estimator.fit(X, y, sample_weight, queue=queue) self._save_attributes() if sklearn_check_version("1.2"): self._estimator = ExtraTreeRegressor() - self.estimators_ = self._estimators_ + return self def _onedal_predict(self, X, queue=None): + X = check_array(X, dtype=[np.float32, np.float64]) + check_is_fitted(self, "_onedal_model") + if sklearn_check_version("1.0"): self._check_feature_names(X, reset=False) - X = self._validate_X_predict(X) + return self._onedal_estimator.predict(X, queue=queue) def fit(self, X, y, sample_weight=None): diff --git a/sklearnex/preview/ensemble/forest.py b/sklearnex/preview/ensemble/forest.py index 8b75b104cb..64f4558108 100755 --- a/sklearnex/preview/ensemble/forest.py +++ b/sklearnex/preview/ensemble/forest.py @@ -194,6 +194,22 @@ def check_sample_weight(self, sample_weight, X, dtype=None): ) return sample_weight + @property + def estimators_(self): + if hasattr(self, "_cached_estimators_"): + if self._cached_estimators_ is None and self._onedal_model: + self._estimators_() + return self._cached_estimators_ + else: + raise AttributeError( + f"'{self.__class__.__name__}' has no attribute 'estimators_'" + ) + + @estimators_.setter + def estimators_(self, estimators): + # Needed to allow for proper sklearn operation in fallback mode + self._cached_estimators_ = estimators + class RandomForestClassifier(sklearn_RandomForestClassifier, BaseRandomForest): __doc__ = sklearn_RandomForestClassifier.__doc__ @@ -430,9 +446,7 @@ def _onedal_ready(self, X, y, sample_weight): correct_ccp_alpha = self.ccp_alpha == 0.0 correct_criterion = self.criterion == "gini" correct_warm_start = self.warm_start is False - correct_monotonic_cst = ( - sklearn_check_version("1.4") and self.monotonic_cst is None - ) + correct_monotonic_cst = getattr(self, "monotonic_cst", None) is None if correct_sparsity and sklearn_check_version("1.4"): try: _assert_all_finite(X) @@ -456,6 +470,7 @@ def _onedal_ready(self, X, y, sample_weight): correct_warm_start, correct_monotonic_cst, correct_finiteness, + self.class_weight != "balanced_subsample", ] ) if ready: @@ -576,17 +591,12 @@ def predict_proba(self, X): def n_features_(self): return self.n_features_in_ - @property def _estimators_(self): - if hasattr(self, "_cached_estimators_"): - if self._cached_estimators_: - return self._cached_estimators_ - if sklearn_check_version("0.22"): - check_is_fitted(self) - else: - check_is_fitted(self, "_onedal_model") + check_is_fitted(self, "_onedal_model") classes_ = self.classes_[0] - n_classes_ = self.n_classes_[0] + n_classes_ = ( + self.n_classes_ if isinstance(self.n_classes_, int) else self.n_classes_[0] + ) # convert model to estimators params = { "criterion": self.criterion, @@ -606,6 +616,7 @@ def _estimators_(self): # oneAPI Data Analytics Library solution estimators_ = [] random_state_checked = check_random_state(self.random_state) + for i in range(self.n_estimators): est_i = clone(est) est_i.set_params( @@ -634,7 +645,6 @@ def _estimators_(self): estimators_.append(est_i) self._cached_estimators_ = estimators_ - return estimators_ def _onedal_cpu_supported(self, method_name, *data): if method_name == "fit": @@ -814,6 +824,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): } if daal_check_version((2023, "P", 101)): onedal_params["splitter_mode"] = self.splitter_mode + + # Lazy evaluation of estimators_ self._cached_estimators_ = None # Compute @@ -823,7 +835,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._save_attributes() if sklearn_check_version("1.2"): self._estimator = DecisionTreeClassifier() - self.estimators_ = self._estimators_ + # Decapsulate classes_ attributes self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] @@ -831,7 +843,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): def _onedal_predict(self, X, queue=None): X = check_array(X, dtype=[np.float32, np.float64]) - check_is_fitted(self) + check_is_fitted(self, "_onedal_model") + if sklearn_check_version("1.0"): self._check_feature_names(X, reset=False) @@ -840,7 +853,8 @@ def _onedal_predict(self, X, queue=None): def _onedal_predict_proba(self, X, queue=None): X = check_array(X, dtype=[np.float64, np.float32]) - check_is_fitted(self) + check_is_fitted(self, "_onedal_model") + if sklearn_check_version("0.23"): self._check_n_features(X, reset=False) if sklearn_check_version("1.0"): @@ -1019,15 +1033,8 @@ def __init__( self.min_impurity_split = None self.splitter_mode = splitter_mode - @property def _estimators_(self): - if hasattr(self, "_cached_estimators_"): - if self._cached_estimators_: - return self._cached_estimators_ - if sklearn_check_version("0.22"): - check_is_fitted(self) - else: - check_is_fitted(self, "_onedal_model") + check_is_fitted(self, "_onedal_model") # convert model to estimators params = { "criterion": self.criterion, @@ -1047,6 +1054,7 @@ def _estimators_(self): # oneAPI Data Analytics Library solution estimators_ = [] random_state_checked = check_random_state(self.random_state) + for i in range(self.n_estimators): est_i = clone(est) est_i.set_params( @@ -1071,8 +1079,7 @@ def _estimators_(self): ) est_i.tree_.__setstate__(tree_i_state_dict) estimators_.append(est_i) - - return estimators_ + self._cached_estimators_ = estimators_ def _onedal_ready(self, X, y, sample_weight): # TODO: @@ -1254,20 +1261,27 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): } if daal_check_version((2023, "P", 101)): onedal_params["splitter_mode"] = self.splitter_mode + + # Lazy evaluation of estimators_ self._cached_estimators_ = None + + # Compute self._onedal_estimator = self._onedal_regressor(**onedal_params) self._onedal_estimator.fit(X, y, sample_weight, queue=queue) self._save_attributes() if sklearn_check_version("1.2"): self._estimator = DecisionTreeRegressor() - self.estimators_ = self._estimators_ + return self def _onedal_predict(self, X, queue=None): + X = check_array(X, dtype=[np.float32, np.float64]) + check_is_fitted(self, "_onedal_model") + if sklearn_check_version("1.0"): self._check_feature_names(X, reset=False) - X = self._validate_X_predict(X) + return self._onedal_estimator.predict(X, queue=queue) def fit(self, X, y, sample_weight=None):