Skip to content

Commit

Permalink
Add MetricGroup feature (#3266)
Browse files Browse the repository at this point in the history
* Initial commit

* Add tests

* Fix two typos

* Fix Mypy

* Fix engine mypy issue

* Fix docstring

* Fix another problem in docstring

---------

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
sadra-barikbin and vfdev-5 authored Aug 1, 2024
1 parent 65352ad commit 4c93282
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ Complete list of metrics
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metric_group.MetricGroup
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
MutualInformation
Expand Down
4 changes: 2 additions & 2 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
_check_signature(process_function, "process_function", self, None)

# generator provided by self._internal_run_as_gen
self._internal_run_generator: Optional[Generator] = None
self._internal_run_generator: Optional[Generator[Any, None, State]] = None

def register_events(
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
Expand Down Expand Up @@ -951,7 +951,7 @@ def _internal_run(self) -> State:
self._internal_run_generator = None
return out.value

def _internal_run_as_gen(self) -> Generator:
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
self._init_timers(self.state)
try:
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metric_group import MetricGroup
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
from ignite.metrics.mutual_information import MutualInformation
Expand All @@ -41,6 +42,7 @@
"Metric",
"Accuracy",
"Loss",
"MetricGroup",
"MetricsLambda",
"MeanAbsoluteError",
"MeanPairwiseDistance",
Expand Down
54 changes: 54 additions & 0 deletions ignite/metrics/metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Callable, Dict, Sequence

import torch

from ignite.metrics import Metric


class MetricGroup(Metric):
"""
A class for grouping metrics so that user could manage them easier.
Args:
metrics: a dictionary of names to metric instances.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. `output_transform` of each metric in the group is also
called upon its update.
Examples:
We construct a group of metrics, attach them to the engine at once and retrieve their result.
.. code-block:: python
import torch
metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())})
metric_group.attach(default_evaluator, "eval_metrics")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
# Metrics individually available in `state.metrics`
state.metrics["acc"], state.metrics["precision"], state.metrics["loss"]
# And also altogether
state.metrics["eval_metrics"]
"""

_state_dict_all_req_keys = ("metrics",)

def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
self.metrics = metrics
super(MetricGroup, self).__init__(output_transform=output_transform)

def reset(self) -> None:
for m in self.metrics.values():
m.reset()

def update(self, output: Sequence[torch.Tensor]) -> None:
for m in self.metrics.values():
m.update(m._output_transform(output))

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.metrics.items()}
118 changes: 118 additions & 0 deletions tests/ignite/metrics/test_metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest
import torch

from ignite import distributed as idist
from ignite.engine import Engine
from ignite.metrics import Accuracy, MetricGroup, Precision

torch.manual_seed(41)


def test_update():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_output_transform():
def drop_first(output):
y_pred, y = output
return (y_pred[1:], y[1:])

precision = Precision(output_transform=drop_first)
accuracy = Accuracy(output_transform=drop_first)

group = MetricGroup(
{"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)}
)

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update(drop_first(drop_first((y_pred, y))))
accuracy.update(drop_first(drop_first((y_pred, y))))
group.update(drop_first((y_pred, y)))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_compute():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

for _ in range(3):
y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()}

precision.reset()
accuracy.reset()
group.reset()

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
rank = idist.get_rank()
torch.manual_seed(12 + rank)

n_epochs = 3
n_iters = 5
batch_size = 10
device = idist.device()

y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device)
y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device)

def update(_, i):
return (
y_pred[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

precision = Precision()
precision.attach(engine, "precision")

accuracy = Accuracy()
accuracy.attach(engine, "accuracy")

group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()})
group.attach(engine, "eval_metrics")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

assert "eval_metrics" in engine.state.metrics
assert "eval_metrics.accuracy" in engine.state.metrics
assert "eval_metrics.precision" in engine.state.metrics

assert engine.state.metrics["eval_metrics"] == {
"eval_metrics.accuracy": engine.state.metrics["accuracy"],
"eval_metrics.precision": engine.state.metrics["precision"],
}
assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"]
assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]

0 comments on commit 4c93282

Please sign in to comment.