diff --git a/sklearnex/cluster/k_means.py b/sklearnex/cluster/k_means.py index f1567ad5c7..47ec5ca0d6 100644 --- a/sklearnex/cluster/k_means.py +++ b/sklearnex/cluster/k_means.py @@ -90,8 +90,11 @@ def _initialize_onedal_estimator(self): self._onedal_estimator = onedal_KMeans(**onedal_params) - def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None): + def _onedal_fit_supported( + self, method_name, _is_gpu, X, y=None, sample_weight=None + ): assert method_name == "fit" + assert _is_gpu is not None class_name = self.__class__.__name__ patching_status = PatchingConditionsChain(f"sklearn.cluster.{class_name}.fit") @@ -105,9 +108,12 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None): ) correct_count = self.n_clusters < sample_count - is_data_supported = ( - _is_csr(X) and daal_check_version((2024, "P", 700)) - ) or not issparse(X) + if not _is_gpu: + is_data_supported = ( + _is_csr(X) and daal_check_version((2024, "P", 700)) + ) or not issparse(X) + else: + is_data_supported = not (_is_csr(X) or issparse(X)) _acceptable_sample_weights = self._validate_sample_weight(sample_weight, X) @@ -124,7 +130,7 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None): ), ( is_data_supported, - "Supported data formats: Dense, CSR (oneDAL version >= 2024.7.0).", + "Supported data formats: Dense, CSR (oneDAL version >= 2024.7.0 on CPU).", ), ] ) @@ -295,17 +301,25 @@ def _onedal_predict(self, X, sample_weight=None, queue=None): return self._onedal_estimator.predict(X, queue=queue) - def _onedal_supported(self, method_name, *data): + def _onedal_gpu_supported(self, method_name, *data): + _is_gpu = True if method_name == "fit": - return self._onedal_fit_supported(method_name, *data) + return self._onedal_fit_supported(method_name, _is_gpu, *data) if method_name in ["predict", "score"]: return self._onedal_predict_supported(method_name, *data) raise RuntimeError( f"Unknown method {method_name} in {self.__class__.__name__}" ) - _onedal_gpu_supported = _onedal_supported - _onedal_cpu_supported = _onedal_supported + def _onedal_cpu_supported(self, method_name, *data): + _is_gpu = False + if method_name == "fit": + return self._onedal_fit_supported(method_name, _is_gpu, *data) + if method_name in ["predict", "score"]: + return self._onedal_predict_supported(method_name, *data) + raise RuntimeError( + f"Unknown method {method_name} in {self.__class__.__name__}" + ) @wrap_output_data def fit_transform(self, X, y=None, sample_weight=None):