From 62feff20098bdd5f6ad92eb324ac5838cf670120 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 19:26:02 +0330 Subject: [PATCH 1/7] Initial commit --- ignite/metrics/metric_group.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 ignite/metrics/metric_group.py diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py new file mode 100644 index 00000000000..fdbba17ad5c --- /dev/null +++ b/ignite/metrics/metric_group.py @@ -0,0 +1,20 @@ +from typing import Any, Dict + +from ignite.metrics import Metric + + +class MetricGroup(Metric): + def __init__(self, metrics: Dict[str, Metric]): + self.metrics = metrics + super(MetricGroup, self).__init__() + + def reset(self): + for m in self.metrics.values(): + m.reset() + + def update(self, output): + for m in self.metrics.values(): + m.update(m._output_transform(output)) + + def compute(self) -> Dict[str, Any]: + return {k: m.compute() for k, m in self.metrics.items()} From 5ed70687147195435b94b1e5610bf07234f301c6 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Thu, 25 Jul 2024 22:22:39 +0330 Subject: [PATCH 2/7] Add tests --- docs/source/metrics.rst | 1 + ignite/metrics/__init__.py | 2 + ignite/metrics/metric_group.py | 35 ++++++- tests/ignite/metrics/test_metric_group.py | 118 ++++++++++++++++++++++ 4 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 tests/ignite/metrics/test_metric_group.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ef125031481..0e4979f82a1 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -335,6 +335,7 @@ Complete list of metrics MeanPairwiseDistance MeanSquaredError metric.Metric + metric_group.MetricGroup metrics_lambda.MetricsLambda MultiLabelConfusionMatrix MutualInformation diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index e4f4e24337c..142a13e5934 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -22,6 +22,7 @@ from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage +from ignite.metrics.metric_group import MetricGroup from ignite.metrics.metrics_lambda import MetricsLambda from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix from ignite.metrics.mutual_information import MutualInformation @@ -41,6 +42,7 @@ "Metric", "Accuracy", "Loss", + "MetricGroup", "MetricsLambda", "MeanAbsoluteError", "MeanPairwiseDistance", diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index fdbba17ad5c..8b925392bdc 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -1,12 +1,41 @@ -from typing import Any, Dict +from typing import Any, Callable, Dict from ignite.metrics import Metric class MetricGroup(Metric): - def __init__(self, metrics: Dict[str, Metric]): + """ + A class for grouping metrics so that user could manage them easier. + + Args: + metrics: a dictionary of names to metric instances. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. `output_transform` of each metric in the group is also + called upon its update. + + Examples: + We construct a group of metrics, attach them to the engine at once and retrieve their result. + + .. code-block:: python + metric_group = {'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())} + metric_group.attach(default_evaluator, "eval_metrics") + y_true = torch.tensor([1, 0, 1, 1, 0, 1]) + y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) + state = default_evaluator.run([[y_pred, y_true]]) + + # Metrics individually available in `state.metrics` + state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] + + # And also altogether + state.metrics["eval_metrics] + """ + + _state_dict_all_req_keys = ("metrics",) + + def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x): self.metrics = metrics - super(MetricGroup, self).__init__() + super(MetricGroup, self).__init__(output_transform=output_transform) def reset(self): for m in self.metrics.values(): diff --git a/tests/ignite/metrics/test_metric_group.py b/tests/ignite/metrics/test_metric_group.py new file mode 100644 index 00000000000..237df966e05 --- /dev/null +++ b/tests/ignite/metrics/test_metric_group.py @@ -0,0 +1,118 @@ +import pytest +import torch + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.metrics import Accuracy, MetricGroup, Precision + +torch.manual_seed(41) + + +def test_update(): + precision = Precision() + accuracy = Accuracy() + + group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()}) + + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update((y_pred, y)) + accuracy.update((y_pred, y)) + group.update((y_pred, y)) + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +def test_output_transform(): + def drop_first(output): + y_pred, y = output + return (y_pred[1:], y[1:]) + + precision = Precision(output_transform=drop_first) + accuracy = Accuracy(output_transform=drop_first) + + group = MetricGroup( + {"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)} + ) + + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update(drop_first(drop_first((y_pred, y)))) + accuracy.update(drop_first(drop_first((y_pred, y)))) + group.update(drop_first((y_pred, y))) + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +def test_compute(): + precision = Precision() + accuracy = Accuracy() + + group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()}) + + for _ in range(3): + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update((y_pred, y)) + accuracy.update((y_pred, y)) + group.update((y_pred, y)) + + assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()} + + precision.reset() + accuracy.reset() + group.reset() + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + n_epochs = 3 + n_iters = 5 + batch_size = 10 + device = idist.device() + + y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device) + y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device) + + def update(_, i): + return ( + y_pred[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + + engine = Engine(update) + + precision = Precision() + precision.attach(engine, "precision") + + accuracy = Accuracy() + accuracy.attach(engine, "accuracy") + + group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()}) + group.attach(engine, "eval_metrics") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + assert "eval_metrics" in engine.state.metrics + assert "eval_metrics.accuracy" in engine.state.metrics + assert "eval_metrics.precision" in engine.state.metrics + + assert engine.state.metrics["eval_metrics"] == { + "eval_metrics.accuracy": engine.state.metrics["accuracy"], + "eval_metrics.precision": engine.state.metrics["precision"], + } + assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"] + assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"] From c8b368fcdb191c6236178208861aca850d94ccc4 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 31 Jul 2024 00:01:10 +0330 Subject: [PATCH 3/7] Fix two typos --- ignite/metrics/metric_group.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index 8b925392bdc..cc2c81de4f8 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Sequence + +import torch from ignite.metrics import Metric @@ -18,7 +20,7 @@ class MetricGroup(Metric): We construct a group of metrics, attach them to the engine at once and retrieve their result. .. code-block:: python - metric_group = {'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())} + metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) metric_group.attach(default_evaluator, "eval_metrics") y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) @@ -41,7 +43,7 @@ def reset(self): for m in self.metrics.values(): m.reset() - def update(self, output): + def update(self, output: Sequence[torch.Tensor]): for m in self.metrics.values(): m.update(m._output_transform(output)) From 44f3558eb4430926eb933ef831cdb2228b0026b7 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 31 Jul 2024 20:28:21 +0330 Subject: [PATCH 4/7] Fix Mypy --- ignite/metrics/metric_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index cc2c81de4f8..35743c42d21 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -39,11 +39,11 @@ def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lamb self.metrics = metrics super(MetricGroup, self).__init__(output_transform=output_transform) - def reset(self): + def reset(self) -> None: for m in self.metrics.values(): m.reset() - def update(self, output: Sequence[torch.Tensor]): + def update(self, output: Sequence[torch.Tensor]) -> None: for m in self.metrics.values(): m.update(m._output_transform(output)) From bd8e88b53353859343699670fa18ccc20ea59324 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 31 Jul 2024 20:48:01 +0330 Subject: [PATCH 5/7] Fix engine mypy issue --- ignite/engine/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 865218af359..27a949cacca 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -157,7 +157,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): _check_signature(process_function, "process_function", self, None) # generator provided by self._internal_run_as_gen - self._internal_run_generator: Optional[Generator] = None + self._internal_run_generator: Optional[Generator[Any, None, State]] = None def register_events( self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None @@ -951,7 +951,7 @@ def _internal_run(self) -> State: self._internal_run_generator = None return out.value - def _internal_run_as_gen(self) -> Generator: + def _internal_run_as_gen(self) -> Generator[Any, None, State]: self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False self._init_timers(self.state) try: From 8471c4c8b38fba2cce1d8a4aeff20f7f78f8099c Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 31 Jul 2024 21:01:03 +0330 Subject: [PATCH 6/7] Fix docstring --- ignite/metrics/metric_group.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index 35743c42d21..c1e87b5771e 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -20,17 +20,19 @@ class MetricGroup(Metric): We construct a group of metrics, attach them to the engine at once and retrieve their result. .. code-block:: python - metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) - metric_group.attach(default_evaluator, "eval_metrics") - y_true = torch.tensor([1, 0, 1, 1, 0, 1]) - y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) - state = default_evaluator.run([[y_pred, y_true]]) + import torch - # Metrics individually available in `state.metrics` - state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] + metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) + metric_group.attach(default_evaluator, "eval_metrics") + y_true = torch.tensor([1, 0, 1, 1, 0, 1]) + y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) + state = default_evaluator.run([[y_pred, y_true]]) - # And also altogether - state.metrics["eval_metrics] + # Metrics individually available in `state.metrics` + state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] + + # And also altogether + state.metrics["eval_metrics] """ _state_dict_all_req_keys = ("metrics",) From 83019164e7fa3cb8a540233403af8a515687fc9f Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 31 Jul 2024 21:14:41 +0330 Subject: [PATCH 7/7] Fix another problem in docstring --- ignite/metrics/metric_group.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index c1e87b5771e..58a52f658ae 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -20,6 +20,7 @@ class MetricGroup(Metric): We construct a group of metrics, attach them to the engine at once and retrieve their result. .. code-block:: python + import torch metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) @@ -32,7 +33,7 @@ class MetricGroup(Metric): state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] # And also altogether - state.metrics["eval_metrics] + state.metrics["eval_metrics"] """ _state_dict_all_req_keys = ("metrics",)