diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ef125031481..0e4979f82a1 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -335,6 +335,7 @@ Complete list of metrics MeanPairwiseDistance MeanSquaredError metric.Metric + metric_group.MetricGroup metrics_lambda.MetricsLambda MultiLabelConfusionMatrix MutualInformation diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 865218af359..27a949cacca 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -157,7 +157,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): _check_signature(process_function, "process_function", self, None) # generator provided by self._internal_run_as_gen - self._internal_run_generator: Optional[Generator] = None + self._internal_run_generator: Optional[Generator[Any, None, State]] = None def register_events( self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None @@ -951,7 +951,7 @@ def _internal_run(self) -> State: self._internal_run_generator = None return out.value - def _internal_run_as_gen(self) -> Generator: + def _internal_run_as_gen(self) -> Generator[Any, None, State]: self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False self._init_timers(self.state) try: diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index e4f4e24337c..142a13e5934 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -22,6 +22,7 @@ from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage +from ignite.metrics.metric_group import MetricGroup from ignite.metrics.metrics_lambda import MetricsLambda from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix from ignite.metrics.mutual_information import MutualInformation @@ -41,6 +42,7 @@ "Metric", "Accuracy", "Loss", + "MetricGroup", "MetricsLambda", "MeanAbsoluteError", "MeanPairwiseDistance", diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py new file mode 100644 index 00000000000..58a52f658ae --- /dev/null +++ b/ignite/metrics/metric_group.py @@ -0,0 +1,54 @@ +from typing import Any, Callable, Dict, Sequence + +import torch + +from ignite.metrics import Metric + + +class MetricGroup(Metric): + """ + A class for grouping metrics so that user could manage them easier. + + Args: + metrics: a dictionary of names to metric instances. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. `output_transform` of each metric in the group is also + called upon its update. + + Examples: + We construct a group of metrics, attach them to the engine at once and retrieve their result. + + .. code-block:: python + + import torch + + metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) + metric_group.attach(default_evaluator, "eval_metrics") + y_true = torch.tensor([1, 0, 1, 1, 0, 1]) + y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) + state = default_evaluator.run([[y_pred, y_true]]) + + # Metrics individually available in `state.metrics` + state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] + + # And also altogether + state.metrics["eval_metrics"] + """ + + _state_dict_all_req_keys = ("metrics",) + + def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x): + self.metrics = metrics + super(MetricGroup, self).__init__(output_transform=output_transform) + + def reset(self) -> None: + for m in self.metrics.values(): + m.reset() + + def update(self, output: Sequence[torch.Tensor]) -> None: + for m in self.metrics.values(): + m.update(m._output_transform(output)) + + def compute(self) -> Dict[str, Any]: + return {k: m.compute() for k, m in self.metrics.items()} diff --git a/tests/ignite/metrics/test_metric_group.py b/tests/ignite/metrics/test_metric_group.py new file mode 100644 index 00000000000..237df966e05 --- /dev/null +++ b/tests/ignite/metrics/test_metric_group.py @@ -0,0 +1,118 @@ +import pytest +import torch + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.metrics import Accuracy, MetricGroup, Precision + +torch.manual_seed(41) + + +def test_update(): + precision = Precision() + accuracy = Accuracy() + + group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()}) + + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update((y_pred, y)) + accuracy.update((y_pred, y)) + group.update((y_pred, y)) + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +def test_output_transform(): + def drop_first(output): + y_pred, y = output + return (y_pred[1:], y[1:]) + + precision = Precision(output_transform=drop_first) + accuracy = Accuracy(output_transform=drop_first) + + group = MetricGroup( + {"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)} + ) + + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update(drop_first(drop_first((y_pred, y)))) + accuracy.update(drop_first(drop_first((y_pred, y)))) + group.update(drop_first((y_pred, y))) + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +def test_compute(): + precision = Precision() + accuracy = Accuracy() + + group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()}) + + for _ in range(3): + y_pred = torch.randint(0, 2, (100,)) + y = torch.randint(0, 2, (100,)) + + precision.update((y_pred, y)) + accuracy.update((y_pred, y)) + group.update((y_pred, y)) + + assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()} + + precision.reset() + accuracy.reset() + group.reset() + + assert precision.state_dict() == group.metrics["precision"].state_dict() + assert accuracy.state_dict() == group.metrics["accuracy"].state_dict() + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + n_epochs = 3 + n_iters = 5 + batch_size = 10 + device = idist.device() + + y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device) + y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device) + + def update(_, i): + return ( + y_pred[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + + engine = Engine(update) + + precision = Precision() + precision.attach(engine, "precision") + + accuracy = Accuracy() + accuracy.attach(engine, "accuracy") + + group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()}) + group.attach(engine, "eval_metrics") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + assert "eval_metrics" in engine.state.metrics + assert "eval_metrics.accuracy" in engine.state.metrics + assert "eval_metrics.precision" in engine.state.metrics + + assert engine.state.metrics["eval_metrics"] == { + "eval_metrics.accuracy": engine.state.metrics["accuracy"], + "eval_metrics.precision": engine.state.metrics["precision"], + } + assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"] + assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]