-
-
Notifications
You must be signed in to change notification settings - Fork 615
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add silhouette score metric * add Davies Bouldin score metric * add Calinski-Harabasz score metric * small modification on docstring * update docstring * remove extra kwargs for calinski_harabasz_score * simplify imports for sklearn.metrics functions * add import of ignite.metrics.clustering * add __all__ * update compute_fn style * fix type hint * fix formatting --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
- Loading branch information
Showing
12 changed files
with
965 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ignite.metrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore | ||
from ignite.metrics.clustering.davies_bouldin_score import DaviesBouldinScore | ||
from ignite.metrics.clustering.silhouette_score import SilhouetteScore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Tuple | ||
|
||
from torch import Tensor | ||
|
||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.epoch_metric import EpochMetric | ||
|
||
|
||
class _ClusteringMetricBase(EpochMetric): | ||
required_output_keys = ("features", "labels") | ||
|
||
def _check_shape(self, output: Tuple[Tensor, Tensor]) -> None: | ||
features, labels = output | ||
if features.ndimension() != 2: | ||
raise ValueError("Features should be of shape (batch_size, n_targets).") | ||
|
||
if labels.ndimension() != 1: | ||
raise ValueError("Labels should be of shape (batch_size, ).") | ||
|
||
def _check_type(self, output: Tuple[Tensor, Tensor]) -> None: | ||
features, labels = output | ||
if len(self._predictions) < 1: | ||
return | ||
dtype_preds = self._predictions[-1].dtype | ||
if dtype_preds != features.dtype: | ||
raise ValueError( | ||
f"Incoherent types between input features and stored features: {dtype_preds} vs {features.dtype}" | ||
) | ||
|
||
dtype_targets = self._targets[-1].dtype | ||
if dtype_targets != labels.dtype: | ||
raise ValueError( | ||
f"Incoherent types between input labels and stored labels: {dtype_targets} vs {labels.dtype}" | ||
) | ||
|
||
def compute(self) -> float: | ||
if len(self._predictions) < 1 or len(self._targets) < 1: | ||
raise NotComputableError( | ||
f"{self.__class__.__name__} must have at least one example before it can be computed." | ||
) | ||
|
||
return super().compute() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from typing import Any, Callable, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from ignite.metrics.clustering._base import _ClusteringMetricBase | ||
|
||
__all__ = ["CalinskiHarabaszScore"] | ||
|
||
|
||
def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float: | ||
from sklearn.metrics import calinski_harabasz_score | ||
|
||
np_features = features.numpy() | ||
np_labels = labels.numpy() | ||
score = calinski_harabasz_score(np_features, np_labels) | ||
return score | ||
|
||
|
||
class CalinskiHarabaszScore(_ClusteringMetricBase): | ||
r"""Calculates the | ||
`Calinski-Harabasz score <https://en.wikipedia.org/wiki/Calinski%E2%80%93Harabasz_index>`_. | ||
The Calinski-Harabasz score evaluates the quality of clustering results. | ||
More details can be found | ||
`here <https://scikit-learn.org/stable/modules/clustering.html#calinski-harabasz-index>`_. | ||
A higher Calinski-Harabasz score indicates that | ||
the clustering result is good (i.e., clusters are well-separated). | ||
The computation of this metric is implemented with | ||
`sklearn.metrics.calinski_harabasz_score | ||
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.calinski_harabasz_score.html>`_. | ||
- ``update`` must receive output of the form ``(features, labels)`` | ||
or ``{'features': features, 'labels': labels}``. | ||
- `features` and `labels` must be of same shape `(B, D)` and `(B,)`. | ||
Parameters are inherited from ``EpochMetric.__init__``. | ||
Args: | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
By default, metrics require the output as ``(features, labels)`` | ||
or ``{'features': features, 'labels': labels}``. | ||
check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no | ||
issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. | ||
Default, True. | ||
device: specifies which device updates are accumulated on. Setting the | ||
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
non-blocking. By default, CPU. | ||
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be | ||
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` | ||
Alternatively, ``output_transform`` can be used to handle this. | ||
Examples: | ||
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
The output of the engine's ``process_function`` needs to be in format of | ||
``(features, labels)`` or ``{'features': features, 'labels': labels, ...}``. | ||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
.. testcode:: | ||
metric = CalinskiHarabaszScore() | ||
metric.attach(default_evaluator, "calinski_harabasz_score") | ||
X = torch.tensor([ | ||
[-1.04, -0.71, -1.42, -0.28, -0.43], | ||
[0.47, 0.96, -0.43, 1.57, -2.24], | ||
[-0.62, -0.29, 0.10, -0.72, -1.69], | ||
[0.96, -0.77, 0.60, -0.89, 0.49], | ||
[-1.33, -1.53, 0.25, -1.60, -2.0], | ||
[-0.63, -0.55, -1.03, -0.89, -0.77], | ||
[-0.26, -1.67, -0.24, -1.33, -0.40], | ||
[-0.20, -1.34, -0.52, -1.55, -1.50], | ||
[2.68, 1.13, 2.51, 0.80, 0.92], | ||
[0.33, 2.88, 1.35, -0.56, 1.71] | ||
]) | ||
Y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2]) | ||
state = default_evaluator.run([{"features": X, "labels": Y}]) | ||
print(state.metrics["calinski_harabasz_score"]) | ||
.. testoutput:: | ||
5.733935121807529 | ||
.. versionadded:: 0.5.2 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
output_transform: Callable[..., Any] = lambda x: x, | ||
check_compute_fn: bool = True, | ||
device: Union[str, torch.device] = torch.device("cpu"), | ||
skip_unrolling: bool = False, | ||
) -> None: | ||
try: | ||
from sklearn.metrics import calinski_harabasz_score # noqa: F401 | ||
except ImportError: | ||
raise ModuleNotFoundError("This module requires scikit-learn to be installed.") | ||
|
||
super().__init__(_calinski_harabasz_score, output_transform, check_compute_fn, device, skip_unrolling) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from typing import Any, Callable, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from ignite.metrics.clustering._base import _ClusteringMetricBase | ||
|
||
__all__ = ["DaviesBouldinScore"] | ||
|
||
|
||
def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float: | ||
from sklearn.metrics import davies_bouldin_score | ||
|
||
np_features = features.numpy() | ||
np_labels = labels.numpy() | ||
score = davies_bouldin_score(np_features, np_labels) | ||
return score | ||
|
||
|
||
class DaviesBouldinScore(_ClusteringMetricBase): | ||
r"""Calculates the | ||
`Davies-Bouldin score <https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index>`_. | ||
The Davies-Bouldin score evaluates the quality of clustering results. | ||
More details can be found | ||
`here <https://scikit-learn.org/1.5/modules/clustering.html#davies-bouldin-index>`_. | ||
The Davies-Bouldin score is non-negative, | ||
where values closer to zero indicate that the clustering result is good (i.e., clusters are well-separated). | ||
The computation of this metric is implemented with | ||
`sklearn.metrics.davies_bouldin_score | ||
<https://scikit-learn.org/1.5/modules/generated/sklearn.metrics.davies_bouldin_score.html>`_. | ||
- ``update`` must receive output of the form ``(features, labels)`` | ||
or ``{'features': features, 'labels': labels}``. | ||
- `features` and `labels` must be of same shape `(B, D)` and `(B,)`. | ||
Parameters are inherited from ``EpochMetric.__init__``. | ||
Args: | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
By default, metrics require the output as ``(features, labels)`` | ||
or ``{'features': features, 'labels': labels}``. | ||
check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no | ||
issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. | ||
Default, True. | ||
device: specifies which device updates are accumulated on. Setting the | ||
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
non-blocking. By default, CPU. | ||
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be | ||
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` | ||
Alternatively, ``output_transform`` can be used to handle this. | ||
Examples: | ||
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
The output of the engine's ``process_function`` needs to be in format of | ||
``(features, labels)`` or ``{'features': features, 'labels': labels, ...}``. | ||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
.. testcode:: | ||
metric = DaviesBouldinScore() | ||
metric.attach(default_evaluator, "davies_bouldin_score") | ||
X = torch.tensor([ | ||
[-1.04, -0.71, -1.42, -0.28, -0.43], | ||
[0.47, 0.96, -0.43, 1.57, -2.24], | ||
[-0.62, -0.29, 0.10, -0.72, -1.69], | ||
[0.96, -0.77, 0.60, -0.89, 0.49], | ||
[-1.33, -1.53, 0.25, -1.60, -2.0], | ||
[-0.63, -0.55, -1.03, -0.89, -0.77], | ||
[-0.26, -1.67, -0.24, -1.33, -0.40], | ||
[-0.20, -1.34, -0.52, -1.55, -1.50], | ||
[2.68, 1.13, 2.51, 0.80, 0.92], | ||
[0.33, 2.88, 1.35, -0.56, 1.71] | ||
]) | ||
Y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2]) | ||
state = default_evaluator.run([{"features": X, "labels": Y}]) | ||
print(state.metrics["davies_bouldin_score"]) | ||
.. testoutput:: | ||
1.3838673743829881 | ||
.. versionadded:: 0.5.2 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
output_transform: Callable[..., Any] = lambda x: x, | ||
check_compute_fn: bool = True, | ||
device: Union[str, torch.device] = torch.device("cpu"), | ||
skip_unrolling: bool = False, | ||
) -> None: | ||
try: | ||
from sklearn.metrics import davies_bouldin_score # noqa: F401 | ||
except ImportError: | ||
raise ModuleNotFoundError("This module requires scikit-learn to be installed.") | ||
|
||
super().__init__(_davies_bouldin_score, output_transform, check_compute_fn, device, skip_unrolling) |
Oops, something went wrong.