diff --git a/docs/source/defaults.rst b/docs/source/defaults.rst index 0a8409e9127..ace7108dd7f 100644 --- a/docs/source/defaults.rst +++ b/docs/source/defaults.rst @@ -12,6 +12,7 @@ from ignite.engine import * from ignite.handlers import * from ignite.metrics import * + from ignite.metrics.clustering import * from ignite.metrics.regression import * from ignite.utils import * diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 29943e98343..0bf58ace047 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -382,6 +382,9 @@ Complete list of metrics regression.KendallRankCorrelation regression.R2Score regression.WaveHedgesDistance + clustering.SilhouetteScore + clustering.DaviesBouldinScore + clustering.CalinskiHarabaszScore .. note:: diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 9f2c2303bc8..f705faa41fd 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -1,3 +1,4 @@ +import ignite.metrics.clustering import ignite.metrics.regression from ignite.metrics.accumulation import Average, GeometricAverage, VariableAccumulation @@ -82,6 +83,7 @@ "RougeN", "RougeL", "regression", + "clustering", "AveragePrecision", "CohenKappa", "GpuInfo", diff --git a/ignite/metrics/clustering/__init__.py b/ignite/metrics/clustering/__init__.py new file mode 100644 index 00000000000..8ca86613501 --- /dev/null +++ b/ignite/metrics/clustering/__init__.py @@ -0,0 +1,3 @@ +from ignite.metrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore +from ignite.metrics.clustering.davies_bouldin_score import DaviesBouldinScore +from ignite.metrics.clustering.silhouette_score import SilhouetteScore diff --git a/ignite/metrics/clustering/_base.py b/ignite/metrics/clustering/_base.py new file mode 100644 index 00000000000..0789855a471 --- /dev/null +++ b/ignite/metrics/clustering/_base.py @@ -0,0 +1,42 @@ +from typing import Tuple + +from torch import Tensor + +from ignite.exceptions import NotComputableError +from ignite.metrics.epoch_metric import EpochMetric + + +class _ClusteringMetricBase(EpochMetric): + required_output_keys = ("features", "labels") + + def _check_shape(self, output: Tuple[Tensor, Tensor]) -> None: + features, labels = output + if features.ndimension() != 2: + raise ValueError("Features should be of shape (batch_size, n_targets).") + + if labels.ndimension() != 1: + raise ValueError("Labels should be of shape (batch_size, ).") + + def _check_type(self, output: Tuple[Tensor, Tensor]) -> None: + features, labels = output + if len(self._predictions) < 1: + return + dtype_preds = self._predictions[-1].dtype + if dtype_preds != features.dtype: + raise ValueError( + f"Incoherent types between input features and stored features: {dtype_preds} vs {features.dtype}" + ) + + dtype_targets = self._targets[-1].dtype + if dtype_targets != labels.dtype: + raise ValueError( + f"Incoherent types between input labels and stored labels: {dtype_targets} vs {labels.dtype}" + ) + + def compute(self) -> float: + if len(self._predictions) < 1 or len(self._targets) < 1: + raise NotComputableError( + f"{self.__class__.__name__} must have at least one example before it can be computed." + ) + + return super().compute() diff --git a/ignite/metrics/clustering/calinski_harabasz_score.py b/ignite/metrics/clustering/calinski_harabasz_score.py new file mode 100644 index 00000000000..fe58ac46151 --- /dev/null +++ b/ignite/metrics/clustering/calinski_harabasz_score.py @@ -0,0 +1,106 @@ +from typing import Any, Callable, Union + +import torch +from torch import Tensor + +from ignite.metrics.clustering._base import _ClusteringMetricBase + +__all__ = ["CalinskiHarabaszScore"] + + +def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float: + from sklearn.metrics import calinski_harabasz_score + + np_features = features.numpy() + np_labels = labels.numpy() + score = calinski_harabasz_score(np_features, np_labels) + return score + + +class CalinskiHarabaszScore(_ClusteringMetricBase): + r"""Calculates the + `Calinski-Harabasz score `_. + + The Calinski-Harabasz score evaluates the quality of clustering results. + + More details can be found + `here `_. + + A higher Calinski-Harabasz score indicates that + the clustering result is good (i.e., clusters are well-separated). + + The computation of this metric is implemented with + `sklearn.metrics.calinski_harabasz_score + `_. + + - ``update`` must receive output of the form ``(features, labels)`` + or ``{'features': features, 'labels': labels}``. + - `features` and `labels` must be of same shape `(B, D)` and `(B,)`. + + Parameters are inherited from ``EpochMetric.__init__``. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(features, labels)`` + or ``{'features': features, 'labels': labels}``. + check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no + issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. + Default, True. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(features, labels)`` or ``{'features': features, 'labels': labels, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = CalinskiHarabaszScore() + metric.attach(default_evaluator, "calinski_harabasz_score") + X = torch.tensor([ + [-1.04, -0.71, -1.42, -0.28, -0.43], + [0.47, 0.96, -0.43, 1.57, -2.24], + [-0.62, -0.29, 0.10, -0.72, -1.69], + [0.96, -0.77, 0.60, -0.89, 0.49], + [-1.33, -1.53, 0.25, -1.60, -2.0], + [-0.63, -0.55, -1.03, -0.89, -0.77], + [-0.26, -1.67, -0.24, -1.33, -0.40], + [-0.20, -1.34, -0.52, -1.55, -1.50], + [2.68, 1.13, 2.51, 0.80, 0.92], + [0.33, 2.88, 1.35, -0.56, 1.71] + ]) + Y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2]) + state = default_evaluator.run([{"features": X, "labels": Y}]) + print(state.metrics["calinski_harabasz_score"]) + + .. testoutput:: + + 5.733935121807529 + + .. versionadded:: 0.5.2 + """ + + def __init__( + self, + output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + try: + from sklearn.metrics import calinski_harabasz_score # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scikit-learn to be installed.") + + super().__init__(_calinski_harabasz_score, output_transform, check_compute_fn, device, skip_unrolling) diff --git a/ignite/metrics/clustering/davies_bouldin_score.py b/ignite/metrics/clustering/davies_bouldin_score.py new file mode 100644 index 00000000000..b34ec69f51a --- /dev/null +++ b/ignite/metrics/clustering/davies_bouldin_score.py @@ -0,0 +1,106 @@ +from typing import Any, Callable, Union + +import torch +from torch import Tensor + +from ignite.metrics.clustering._base import _ClusteringMetricBase + +__all__ = ["DaviesBouldinScore"] + + +def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float: + from sklearn.metrics import davies_bouldin_score + + np_features = features.numpy() + np_labels = labels.numpy() + score = davies_bouldin_score(np_features, np_labels) + return score + + +class DaviesBouldinScore(_ClusteringMetricBase): + r"""Calculates the + `Davies-Bouldin score `_. + + The Davies-Bouldin score evaluates the quality of clustering results. + + More details can be found + `here `_. + + The Davies-Bouldin score is non-negative, + where values closer to zero indicate that the clustering result is good (i.e., clusters are well-separated). + + The computation of this metric is implemented with + `sklearn.metrics.davies_bouldin_score + `_. + + - ``update`` must receive output of the form ``(features, labels)`` + or ``{'features': features, 'labels': labels}``. + - `features` and `labels` must be of same shape `(B, D)` and `(B,)`. + + Parameters are inherited from ``EpochMetric.__init__``. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(features, labels)`` + or ``{'features': features, 'labels': labels}``. + check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no + issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. + Default, True. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(features, labels)`` or ``{'features': features, 'labels': labels, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = DaviesBouldinScore() + metric.attach(default_evaluator, "davies_bouldin_score") + X = torch.tensor([ + [-1.04, -0.71, -1.42, -0.28, -0.43], + [0.47, 0.96, -0.43, 1.57, -2.24], + [-0.62, -0.29, 0.10, -0.72, -1.69], + [0.96, -0.77, 0.60, -0.89, 0.49], + [-1.33, -1.53, 0.25, -1.60, -2.0], + [-0.63, -0.55, -1.03, -0.89, -0.77], + [-0.26, -1.67, -0.24, -1.33, -0.40], + [-0.20, -1.34, -0.52, -1.55, -1.50], + [2.68, 1.13, 2.51, 0.80, 0.92], + [0.33, 2.88, 1.35, -0.56, 1.71] + ]) + Y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2]) + state = default_evaluator.run([{"features": X, "labels": Y}]) + print(state.metrics["davies_bouldin_score"]) + + .. testoutput:: + + 1.3838673743829881 + + .. versionadded:: 0.5.2 + """ + + def __init__( + self, + output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + try: + from sklearn.metrics import davies_bouldin_score # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scikit-learn to be installed.") + + super().__init__(_davies_bouldin_score, output_transform, check_compute_fn, device, skip_unrolling) diff --git a/ignite/metrics/clustering/silhouette_score.py b/ignite/metrics/clustering/silhouette_score.py new file mode 100644 index 00000000000..39b28c5d040 --- /dev/null +++ b/ignite/metrics/clustering/silhouette_score.py @@ -0,0 +1,117 @@ +from typing import Any, Callable, Optional, Union + +import torch +from torch import Tensor + +from ignite.metrics.clustering._base import _ClusteringMetricBase + +__all__ = ["SilhouetteScore"] + + +class SilhouetteScore(_ClusteringMetricBase): + r"""Calculates the + `silhouette score `_. + + The silhouette score evaluates the quality of clustering results. + + .. math:: + s = \frac{b-a}{\max(a,b)} + + where: + + - :math:`a` is the mean distance between a sample and all other points in the same cluster. + - :math:`b` is the mean distance between a sample and all other points in the next nearest cluster. + + More details can be found + `here `_. + + The silhouette score ranges from -1 to +1, + where the score becomes close to +1 when the clustering result is good (i.e., clusters are well-separated). + + The computation of this metric is implemented with + `sklearn.metrics.silhouette_score + `_. + + - ``update`` must receive output of the form ``(features, labels)`` + or ``{'features': features, 'labels': labels}``. + - `features` and `labels` must be of same shape `(B, D)` and `(B,)`. + + Parameters are inherited from ``EpochMetric.__init__``. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(features, labels)`` + or ``{'features': features, 'labels': labels}``. + check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no + issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. + Default, True. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. + silhouette_kwargs: additional arguments passed to ``sklearn.metrics.silhouette_score``. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(features, labels)`` or ``{'features': features, 'labels': labels, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = SilhouetteScore() + metric.attach(default_evaluator, "silhouette_score") + X = torch.tensor([ + [-1.04, -0.71, -1.42, -0.28, -0.43], + [0.47, 0.96, -0.43, 1.57, -2.24], + [-0.62, -0.29, 0.10, -0.72, -1.69], + [0.96, -0.77, 0.60, -0.89, 0.49], + [-1.33, -1.53, 0.25, -1.60, -2.0], + [-0.63, -0.55, -1.03, -0.89, -0.77], + [-0.26, -1.67, -0.24, -1.33, -0.40], + [-0.20, -1.34, -0.52, -1.55, -1.50], + [2.68, 1.13, 2.51, 0.80, 0.92], + [0.33, 2.88, 1.35, -0.56, 1.71] + ]) + Y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2]) + state = default_evaluator.run([{"features": X, "labels": Y}]) + print(state.metrics["silhouette_score"]) + + .. testoutput:: + + 0.12607366 + + .. versionadded:: 0.5.2 + """ + + def __init__( + self, + output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + silhouette_kwargs: Optional[dict] = None, + ) -> None: + try: + from sklearn.metrics import silhouette_score # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scikit-learn to be installed.") + + self._silhouette_kwargs = {} if silhouette_kwargs is None else silhouette_kwargs + + super().__init__(self._silhouette_score, output_transform, check_compute_fn, device, skip_unrolling) + + def _silhouette_score(self, features: Tensor, labels: Tensor) -> float: + from sklearn.metrics import silhouette_score + + np_features = features.numpy() + np_labels = labels.numpy() + score = silhouette_score(np_features, np_labels, **self._silhouette_kwargs) + return score diff --git a/tests/ignite/metrics/clustering/__init__.py b/tests/ignite/metrics/clustering/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/ignite/metrics/clustering/test_calinski_harabasz_score.py b/tests/ignite/metrics/clustering/test_calinski_harabasz_score.py new file mode 100644 index 00000000000..5846707c870 --- /dev/null +++ b/tests/ignite/metrics/clustering/test_calinski_harabasz_score.py @@ -0,0 +1,195 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from sklearn.metrics import calinski_harabasz_score +from torch import Tensor + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.clustering import CalinskiHarabaszScore + + +def test_zero_sample(): + with pytest.raises( + NotComputableError, match="CalinskiHarabaszScore must have at least one example before it can be computed" + ): + metric = CalinskiHarabaszScore() + metric.compute() + + +def test_wrong_output_shape(): + wrong_features = torch.zeros(4, dtype=torch.float) + correct_features = torch.zeros(4, 3, dtype=torch.float) + wrong_labels = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]], dtype=torch.long) + correct_labels = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + with pytest.raises(ValueError, match=r"Features should be of shape \(batch_size, n_targets\)"): + metric = CalinskiHarabaszScore() + metric.update((wrong_features, correct_labels)) + + with pytest.raises(ValueError, match=r"Labels should be of shape \(batch_size, \)"): + metric = CalinskiHarabaszScore() + metric.update((correct_features, wrong_labels)) + + +def test_wrong_output_dtype(): + wrong_features = torch.zeros(4, 3, dtype=torch.long) + correct_features = torch.zeros(4, 3, dtype=torch.float) + wrong_labels = torch.tensor([0, 0, 1, 1], dtype=torch.float) + correct_labels = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + with pytest.raises(ValueError, match=r"Incoherent types between input features and stored features"): + metric = CalinskiHarabaszScore() + metric.update((correct_features, correct_labels)) + metric.update((wrong_features, correct_labels)) + + with pytest.raises(ValueError, match=r"Incoherent types between input labels and stored labels"): + metric = CalinskiHarabaszScore() + metric.update((correct_features, correct_labels)) + metric.update((correct_features, wrong_labels)) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + N = 100 + NDIM = 10 + BS = 10 + + # well-clustered case + random_order = torch.from_numpy(np.random.permutation(N * 3)) + x1 = torch.cat( + [ + torch.normal(-5.0, 1.0, size=(N, NDIM)), + torch.normal(5.0, 1.0, size=(N, NDIM)), + torch.normal(0.0, 1.0, size=(N, NDIM)), + ] + ).float()[random_order] + y1 = torch.tensor([0] * N + [1] * N + [2] * N, dtype=torch.long)[random_order] + + # poorly-clustered case + x2 = torch.cat( + [ + torch.normal(-1.0, 1.0, size=(N, NDIM)), + torch.normal(0.0, 1.0, size=(N, NDIM)), + torch.normal(1.0, 1.0, size=(N, NDIM)), + ] + ).float() + y2 = torch.from_numpy(np.random.choice(3, size=N * 3)).long() + + return [ + (x1, y1, BS), + (x2, y2, BS), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_integration(n_times: int, test_case: Tuple[Tensor, Tensor, Tensor]): + features, labels, batch_size = test_case + + np_features = features.numpy() + np_labels = labels.numpy() + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + feature_batch = np_features[idx : idx + batch_size] + label_batch = np_labels[idx : idx + batch_size] + return torch.from_numpy(feature_batch), torch.from_numpy(label_batch) + + engine = Engine(update_fn) + + m = CalinskiHarabaszScore() + m.attach(engine, "calinski_harabasz_score") + + data = list(range(np_features.shape[0] // batch_size)) + s = engine.run(data, max_epochs=1).metrics["calinski_harabasz_score"] + + np_ans = calinski_harabasz_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=1e-5) == s + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_compute(self): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = CalinskiHarabaszScore(device=metric_device) + + random_order = torch.from_numpy(np.random.permutation(200)) + features = torch.cat([torch.normal(-1.0, 1.0, size=(100, 10)), torch.normal(1.0, 1.0, size=(100, 10))]).to( + device + )[random_order] + labels = torch.tensor([0] * 100 + [1] * 100, dtype=torch.long, device=device)[random_order] + + m.update((features, labels)) + + features = idist.all_gather(features) + labels = idist.all_gather(labels) + + np_features = features.cpu().numpy() + np_labels = labels.cpu().numpy() + + np_ans = calinski_harabasz_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=1e-5) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs: int): + tol = 1e-5 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + torch.manual_seed(12 + rank) + + cluster_size = n_iters * batch_size // 2 + random_order = torch.from_numpy(np.random.permutation(n_iters * batch_size)) + features = torch.cat( + [torch.normal(-1.0, 1.0, size=(cluster_size, 10)), torch.normal(1.0, 1.0, size=(cluster_size, 10))] + ).to(device)[random_order] + labels = torch.tensor([0] * cluster_size + [1] * cluster_size, dtype=torch.long, device=device)[ + random_order + ] + + engine = Engine( + lambda e, i: ( + features[i * batch_size : (i + 1) * batch_size], + labels[i * batch_size : (i + 1) * batch_size], + ) + ) + + chs = CalinskiHarabaszScore(device=metric_device) + chs.attach(engine, "calinski_harabasz_score") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + features = idist.all_gather(features) + labels = idist.all_gather(labels) + + assert "calinski_harabasz_score" in engine.state.metrics + + res = engine.state.metrics["calinski_harabasz_score"] + + np_labels = labels.cpu().numpy() + np_features = features.cpu().numpy() + + np_ans = calinski_harabasz_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=tol) == res diff --git a/tests/ignite/metrics/clustering/test_davies_bouldin_score.py b/tests/ignite/metrics/clustering/test_davies_bouldin_score.py new file mode 100644 index 00000000000..407355ee95b --- /dev/null +++ b/tests/ignite/metrics/clustering/test_davies_bouldin_score.py @@ -0,0 +1,195 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from sklearn.metrics import davies_bouldin_score +from torch import Tensor + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.clustering import DaviesBouldinScore + + +def test_zero_sample(): + with pytest.raises( + NotComputableError, match="DaviesBouldinScore must have at least one example before it can be computed" + ): + metric = DaviesBouldinScore() + metric.compute() + + +def test_wrong_output_shape(): + wrong_features = torch.zeros(4, dtype=torch.float) + correct_features = torch.zeros(4, 3, dtype=torch.float) + wrong_labels = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]], dtype=torch.long) + correct_labels = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + with pytest.raises(ValueError, match=r"Features should be of shape \(batch_size, n_targets\)"): + metric = DaviesBouldinScore() + metric.update((wrong_features, correct_labels)) + + with pytest.raises(ValueError, match=r"Labels should be of shape \(batch_size, \)"): + metric = DaviesBouldinScore() + metric.update((correct_features, wrong_labels)) + + +def test_wrong_output_dtype(): + wrong_features = torch.zeros(4, 3, dtype=torch.long) + correct_features = torch.zeros(4, 3, dtype=torch.float) + wrong_labels = torch.tensor([0, 0, 1, 1], dtype=torch.float) + correct_labels = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + with pytest.raises(ValueError, match=r"Incoherent types between input features and stored features"): + metric = DaviesBouldinScore() + metric.update((correct_features, correct_labels)) + metric.update((wrong_features, correct_labels)) + + with pytest.raises(ValueError, match=r"Incoherent types between input labels and stored labels"): + metric = DaviesBouldinScore() + metric.update((correct_features, correct_labels)) + metric.update((correct_features, wrong_labels)) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + N = 100 + NDIM = 10 + BS = 10 + + # well-clustered case + random_order = torch.from_numpy(np.random.permutation(N * 3)) + x1 = torch.cat( + [ + torch.normal(-5.0, 1.0, size=(N, NDIM)), + torch.normal(5.0, 1.0, size=(N, NDIM)), + torch.normal(0.0, 1.0, size=(N, NDIM)), + ] + ).float()[random_order] + y1 = torch.tensor([0] * N + [1] * N + [2] * N, dtype=torch.long)[random_order] + + # poorly-clustered case + x2 = torch.cat( + [ + torch.normal(-1.0, 1.0, size=(N, NDIM)), + torch.normal(0.0, 1.0, size=(N, NDIM)), + torch.normal(1.0, 1.0, size=(N, NDIM)), + ] + ).float() + y2 = torch.from_numpy(np.random.choice(3, size=N * 3)).long() + + return [ + (x1, y1, BS), + (x2, y2, BS), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_integration(n_times: int, test_case: Tuple[Tensor, Tensor, Tensor]): + features, labels, batch_size = test_case + + np_features = features.numpy() + np_labels = labels.numpy() + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + feature_batch = np_features[idx : idx + batch_size] + label_batch = np_labels[idx : idx + batch_size] + return torch.from_numpy(feature_batch), torch.from_numpy(label_batch) + + engine = Engine(update_fn) + + m = DaviesBouldinScore() + m.attach(engine, "davies_bouldin_score") + + data = list(range(np_features.shape[0] // batch_size)) + s = engine.run(data, max_epochs=1).metrics["davies_bouldin_score"] + + np_ans = davies_bouldin_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=1e-5) == s + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_compute(self): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = DaviesBouldinScore(device=metric_device) + + random_order = torch.from_numpy(np.random.permutation(200)) + features = torch.cat([torch.normal(-1.0, 1.0, size=(100, 10)), torch.normal(1.0, 1.0, size=(100, 10))]).to( + device + )[random_order] + labels = torch.tensor([0] * 100 + [1] * 100, dtype=torch.long, device=device)[random_order] + + m.update((features, labels)) + + features = idist.all_gather(features) + labels = idist.all_gather(labels) + + np_features = features.cpu().numpy() + np_labels = labels.cpu().numpy() + + np_ans = davies_bouldin_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=1e-5) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs: int): + tol = 1e-5 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + torch.manual_seed(12 + rank) + + cluster_size = n_iters * batch_size // 2 + random_order = torch.from_numpy(np.random.permutation(n_iters * batch_size)) + features = torch.cat( + [torch.normal(-1.0, 1.0, size=(cluster_size, 10)), torch.normal(1.0, 1.0, size=(cluster_size, 10))] + ).to(device)[random_order] + labels = torch.tensor([0] * cluster_size + [1] * cluster_size, dtype=torch.long, device=device)[ + random_order + ] + + engine = Engine( + lambda e, i: ( + features[i * batch_size : (i + 1) * batch_size], + labels[i * batch_size : (i + 1) * batch_size], + ) + ) + + dbs = DaviesBouldinScore(device=metric_device) + dbs.attach(engine, "davies_bouldin_score") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + features = idist.all_gather(features) + labels = idist.all_gather(labels) + + assert "davies_bouldin_score" in engine.state.metrics + + res = engine.state.metrics["davies_bouldin_score"] + + np_labels = labels.cpu().numpy() + np_features = features.cpu().numpy() + + np_ans = davies_bouldin_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=tol) == res diff --git a/tests/ignite/metrics/clustering/test_silhouette_score.py b/tests/ignite/metrics/clustering/test_silhouette_score.py new file mode 100644 index 00000000000..436843955fe --- /dev/null +++ b/tests/ignite/metrics/clustering/test_silhouette_score.py @@ -0,0 +1,195 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from sklearn.metrics import silhouette_score +from torch import Tensor + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.clustering import SilhouetteScore + + +def test_zero_sample(): + with pytest.raises( + NotComputableError, match="SilhouetteScore must have at least one example before it can be computed" + ): + metric = SilhouetteScore() + metric.compute() + + +def test_wrong_output_shape(): + wrong_features = torch.zeros(4, dtype=torch.float) + correct_features = torch.zeros(4, 3, dtype=torch.float) + wrong_labels = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]], dtype=torch.long) + correct_labels = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + with pytest.raises(ValueError, match=r"Features should be of shape \(batch_size, n_targets\)"): + metric = SilhouetteScore() + metric.update((wrong_features, correct_labels)) + + with pytest.raises(ValueError, match=r"Labels should be of shape \(batch_size, \)"): + metric = SilhouetteScore() + metric.update((correct_features, wrong_labels)) + + +def test_wrong_output_dtype(): + wrong_features = torch.zeros(4, 3, dtype=torch.long) + correct_features = torch.zeros(4, 3, dtype=torch.float) + wrong_labels = torch.tensor([0, 0, 1, 1], dtype=torch.float) + correct_labels = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + with pytest.raises(ValueError, match=r"Incoherent types between input features and stored features"): + metric = SilhouetteScore() + metric.update((correct_features, correct_labels)) + metric.update((wrong_features, correct_labels)) + + with pytest.raises(ValueError, match=r"Incoherent types between input labels and stored labels"): + metric = SilhouetteScore() + metric.update((correct_features, correct_labels)) + metric.update((correct_features, wrong_labels)) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + N = 100 + NDIM = 10 + BS = 10 + + # well-clustered case + random_order = torch.from_numpy(np.random.permutation(N * 3)) + x1 = torch.cat( + [ + torch.normal(-5.0, 1.0, size=(N, NDIM)), + torch.normal(5.0, 1.0, size=(N, NDIM)), + torch.normal(0.0, 1.0, size=(N, NDIM)), + ] + ).float()[random_order] + y1 = torch.tensor([0] * N + [1] * N + [2] * N, dtype=torch.long)[random_order] + + # poorly-clustered case + x2 = torch.cat( + [ + torch.normal(-1.0, 1.0, size=(N, NDIM)), + torch.normal(0.0, 1.0, size=(N, NDIM)), + torch.normal(1.0, 1.0, size=(N, NDIM)), + ] + ).float() + y2 = torch.from_numpy(np.random.choice(3, size=N * 3)).long() + + return [ + (x1, y1, BS), + (x2, y2, BS), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_integration(n_times: int, test_case: Tuple[Tensor, Tensor, Tensor]): + features, labels, batch_size = test_case + + np_features = features.numpy() + np_labels = labels.numpy() + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + feature_batch = np_features[idx : idx + batch_size] + label_batch = np_labels[idx : idx + batch_size] + return torch.from_numpy(feature_batch), torch.from_numpy(label_batch) + + engine = Engine(update_fn) + + m = SilhouetteScore() + m.attach(engine, "silhouette") + + data = list(range(np_features.shape[0] // batch_size)) + s = engine.run(data, max_epochs=1).metrics["silhouette"] + + np_ans = silhouette_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=1e-5) == s + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_compute(self): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = SilhouetteScore(device=metric_device) + + random_order = torch.from_numpy(np.random.permutation(200)) + features = torch.cat([torch.normal(-1.0, 1.0, size=(100, 10)), torch.normal(1.0, 1.0, size=(100, 10))]).to( + device + )[random_order] + labels = torch.tensor([0] * 100 + [1] * 100, dtype=torch.long, device=device)[random_order] + + m.update((features, labels)) + + features = idist.all_gather(features) + labels = idist.all_gather(labels) + + np_features = features.cpu().numpy() + np_labels = labels.cpu().numpy() + + np_ans = silhouette_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=1e-5) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs: int): + tol = 1e-5 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + torch.manual_seed(12 + rank) + + cluster_size = n_iters * batch_size // 2 + random_order = torch.from_numpy(np.random.permutation(n_iters * batch_size)) + features = torch.cat( + [torch.normal(-1.0, 1.0, size=(cluster_size, 10)), torch.normal(1.0, 1.0, size=(cluster_size, 10))] + ).to(device)[random_order] + labels = torch.tensor([0] * cluster_size + [1] * cluster_size, dtype=torch.long, device=device)[ + random_order + ] + + engine = Engine( + lambda e, i: ( + features[i * batch_size : (i + 1) * batch_size], + labels[i * batch_size : (i + 1) * batch_size], + ) + ) + + silhouette = SilhouetteScore(device=metric_device) + silhouette.attach(engine, "silhouette") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + features = idist.all_gather(features) + labels = idist.all_gather(labels) + + assert "silhouette" in engine.state.metrics + + res = engine.state.metrics["silhouette"] + + np_labels = labels.cpu().numpy() + np_features = features.cpu().numpy() + + np_ans = silhouette_score(np_features, np_labels) + + assert pytest.approx(np_ans, rel=tol) == res