Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MetricGroup feature #3266

Merged
merged 9 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ Complete list of metrics
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metric_group.MetricGroup
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
MutualInformation
Expand Down
4 changes: 2 additions & 2 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
_check_signature(process_function, "process_function", self, None)

# generator provided by self._internal_run_as_gen
self._internal_run_generator: Optional[Generator] = None
self._internal_run_generator: Optional[Generator[Any, None, State]] = None

def register_events(
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
Expand Down Expand Up @@ -951,7 +951,7 @@ def _internal_run(self) -> State:
self._internal_run_generator = None
return out.value

def _internal_run_as_gen(self) -> Generator:
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
self._init_timers(self.state)
try:
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metric_group import MetricGroup
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
from ignite.metrics.mutual_information import MutualInformation
Expand All @@ -41,6 +42,7 @@
"Metric",
"Accuracy",
"Loss",
"MetricGroup",
"MetricsLambda",
"MeanAbsoluteError",
"MeanPairwiseDistance",
Expand Down
54 changes: 54 additions & 0 deletions ignite/metrics/metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Callable, Dict, Sequence

import torch

from ignite.metrics import Metric


class MetricGroup(Metric):
"""
A class for grouping metrics so that user could manage them easier.

Args:
metrics: a dictionary of names to metric instances.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. `output_transform` of each metric in the group is also
called upon its update.

Examples:
We construct a group of metrics, attach them to the engine at once and retrieve their result.

.. code-block:: python

import torch

metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())})
metric_group.attach(default_evaluator, "eval_metrics")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])

# Metrics individually available in `state.metrics`
state.metrics["acc"], state.metrics["precision"], state.metrics["loss"]

# And also altogether
state.metrics["eval_metrics"]
"""

_state_dict_all_req_keys = ("metrics",)

def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
self.metrics = metrics
super(MetricGroup, self).__init__(output_transform=output_transform)

def reset(self) -> None:
for m in self.metrics.values():
m.reset()

def update(self, output: Sequence[torch.Tensor]) -> None:
for m in self.metrics.values():
m.update(m._output_transform(output))

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.metrics.items()}
118 changes: 118 additions & 0 deletions tests/ignite/metrics/test_metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest
import torch

from ignite import distributed as idist
from ignite.engine import Engine
from ignite.metrics import Accuracy, MetricGroup, Precision

torch.manual_seed(41)


def test_update():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_output_transform():
def drop_first(output):
y_pred, y = output
return (y_pred[1:], y[1:])

precision = Precision(output_transform=drop_first)
accuracy = Accuracy(output_transform=drop_first)

group = MetricGroup(
{"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)}
)

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update(drop_first(drop_first((y_pred, y))))
accuracy.update(drop_first(drop_first((y_pred, y))))
group.update(drop_first((y_pred, y)))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_compute():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

for _ in range(3):
y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()}

precision.reset()
accuracy.reset()
group.reset()

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
rank = idist.get_rank()
torch.manual_seed(12 + rank)

n_epochs = 3
n_iters = 5
batch_size = 10
device = idist.device()

y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device)
y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device)

def update(_, i):
return (
y_pred[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

precision = Precision()
precision.attach(engine, "precision")

accuracy = Accuracy()
accuracy.attach(engine, "accuracy")

group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()})
group.attach(engine, "eval_metrics")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

assert "eval_metrics" in engine.state.metrics
assert "eval_metrics.accuracy" in engine.state.metrics
assert "eval_metrics.precision" in engine.state.metrics

assert engine.state.metrics["eval_metrics"] == {
"eval_metrics.accuracy": engine.state.metrics["accuracy"],
"eval_metrics.precision": engine.state.metrics["precision"],
}
assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"]
assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]
Loading