diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index 8b925392bdc..cc2c81de4f8 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Sequence + +import torch from ignite.metrics import Metric @@ -18,7 +20,7 @@ class MetricGroup(Metric): We construct a group of metrics, attach them to the engine at once and retrieve their result. .. code-block:: python - metric_group = {'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())} + metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) metric_group.attach(default_evaluator, "eval_metrics") y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) @@ -41,7 +43,7 @@ def reset(self): for m in self.metrics.values(): m.reset() - def update(self, output): + def update(self, output: Sequence[torch.Tensor]): for m in self.metrics.values(): m.update(m._output_transform(output))