From b8b7e3c57bfda06ad96017e94c763e3a1e4d0d6a Mon Sep 17 00:00:00 2001 From: KulikovNikita Date: Fri, 7 Jul 2023 11:06:55 +0100 Subject: [PATCH] KMeans fix (#1351) * Unified with stock sklearn * Minor changes in checks --- onedal/cluster/kmeans.py | 48 ++++++++++++++++------------------------ 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index ad87fdb472..0bbf04a1c2 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -115,48 +115,37 @@ def _check_params_vs_input( self._n_init = self.n_init if self._n_init == "warn": warnings.warn( - "The default value of `n_init` will change from " - f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`" - " explicitly to suppress the warning", + ( + "The default value of `n_init` will change from " + f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`" + " explicitly to suppress the warning" + ), FutureWarning, + stacklevel=2, ) self._n_init = default_n_init if self._n_init == "auto": - if self.init == "k-means++": + if isinstance(self.init, str) and self.init == "k-means++": self._n_init = 1 - else: + elif isinstance(self.init, str) and self.init == "random": + self._n_init = default_n_init + elif callable(self.init): self._n_init = default_n_init + else: # array-like + self._n_init = 1 if _is_arraylike_not_scalar(self.init) and self._n_init != 1: - warnings.warn( - "Explicit initial center position passed: performing only" - f" one init in {self.__class__.__name__} instead of " - f"n_init={self._n_init}.", - RuntimeWarning, - stacklevel=2, - ) - self._n_init = 1 - - self._algorithm = self.algorithm - if self._algorithm in ("auto", "full"): warnings.warn( ( - f"algorithm='{self._algorithm}' is deprecated, it will be " - "removed in 1.3. Using 'lloyd' instead." - ), - FutureWarning, - ) - self._algorithm = "lloyd" - if self._algorithm == "elkan" and self.n_clusters == 1: - warnings.warn( - ( - "algorithm='elkan' doesn't make sense for a single " - "cluster. Using 'lloyd' instead." + "Explicit initial center position passed: performing only" + f" one init in {self.__class__.__name__} instead of " + f"n_init={self._n_init}." ), RuntimeWarning, + stacklevel=2, ) - self._algorithm = "lloyd" - assert self._algorithm == "lloyd" + self._n_init = 1 + assert self.algorithm == "lloyd" def _get_policy(self, queue, *data): return _get_policy(queue, *data) @@ -413,6 +402,7 @@ def __init__( self.copy_x = copy_x self.algorithm = algorithm + assert self.algorithm == "lloyd" def fit(self, X, queue=None): return super()._fit(X, _backend.kmeans.clustering, queue)