Skip to content

Commit

Permalink
KMeans fix (#1351)
Browse files Browse the repository at this point in the history
* Unified with stock sklearn

* Minor changes in checks
  • Loading branch information
KulikovNikita authored Jul 7, 2023
1 parent efe8bda commit b8b7e3c
Showing 1 changed file with 19 additions and 29 deletions.
48 changes: 19 additions & 29 deletions onedal/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,48 +115,37 @@ def _check_params_vs_input(
self._n_init = self.n_init
if self._n_init == "warn":
warnings.warn(
"The default value of `n_init` will change from "
f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`"
" explicitly to suppress the warning",
(
"The default value of `n_init` will change from "
f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`"
" explicitly to suppress the warning"
),
FutureWarning,
stacklevel=2,
)
self._n_init = default_n_init
if self._n_init == "auto":
if self.init == "k-means++":
if isinstance(self.init, str) and self.init == "k-means++":
self._n_init = 1
else:
elif isinstance(self.init, str) and self.init == "random":
self._n_init = default_n_init
elif callable(self.init):
self._n_init = default_n_init
else: # array-like
self._n_init = 1

if _is_arraylike_not_scalar(self.init) and self._n_init != 1:
warnings.warn(
"Explicit initial center position passed: performing only"
f" one init in {self.__class__.__name__} instead of "
f"n_init={self._n_init}.",
RuntimeWarning,
stacklevel=2,
)
self._n_init = 1

self._algorithm = self.algorithm
if self._algorithm in ("auto", "full"):
warnings.warn(
(
f"algorithm='{self._algorithm}' is deprecated, it will be "
"removed in 1.3. Using 'lloyd' instead."
),
FutureWarning,
)
self._algorithm = "lloyd"
if self._algorithm == "elkan" and self.n_clusters == 1:
warnings.warn(
(
"algorithm='elkan' doesn't make sense for a single "
"cluster. Using 'lloyd' instead."
"Explicit initial center position passed: performing only"
f" one init in {self.__class__.__name__} instead of "
f"n_init={self._n_init}."
),
RuntimeWarning,
stacklevel=2,
)
self._algorithm = "lloyd"
assert self._algorithm == "lloyd"
self._n_init = 1
assert self.algorithm == "lloyd"

def _get_policy(self, queue, *data):
return _get_policy(queue, *data)
Expand Down Expand Up @@ -413,6 +402,7 @@ def __init__(

self.copy_x = copy_x
self.algorithm = algorithm
assert self.algorithm == "lloyd"

def fit(self, X, queue=None):
return super()._fit(X, _backend.kmeans.clustering, queue)
Expand Down

0 comments on commit b8b7e3c

Please sign in to comment.