Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Jul 25, 2024
1 parent 36cfdc6 commit bdd0fe4
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ Complete list of metrics
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metric_group.MetricGroup
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
MutualInformation
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metric_group import MetricGroup
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
from ignite.metrics.mutual_information import MutualInformation
Expand All @@ -41,6 +42,7 @@
"Metric",
"Accuracy",
"Loss",
"MetricGroup",
"MetricsLambda",
"MeanAbsoluteError",
"MeanPairwiseDistance",
Expand Down
35 changes: 32 additions & 3 deletions ignite/metrics/metric_group.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
from typing import Any, Dict
from typing import Any, Callable, Dict

from ignite.metrics import Metric


class MetricGroup(Metric):
def __init__(self, metrics: Dict[str, Metric]):
"""
A class for grouping metrics so that user could manage them easier.
Args:
metrics: a dictionary of names to metric instances.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. `output_transform` of each metric in the group is also
called upon its update.
Examples:
We construct a group of metrics, attach them to the engine at once and retrieve their result.
.. code-block:: python
metric_group = {'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}
metric_group.attach(default_evaluator, "eval_metrics")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
# Metrics individually available in `state.metrics`
state.metrics["acc"], state.metrics["precision"], state.metrics["loss"]
# And also altogether
state.metrics["eval_metrics]
"""

_state_dict_all_req_keys = ("metrics",)

def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
self.metrics = metrics
super(MetricGroup, self).__init__()
super(MetricGroup, self).__init__(output_transform=output_transform)

def reset(self):
for m in self.metrics.values():
Expand Down
118 changes: 118 additions & 0 deletions tests/ignite/metrics/test_metric_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest
import torch

from ignite import distributed as idist
from ignite.engine import Engine
from ignite.metrics import Accuracy, MetricGroup, Precision

torch.manual_seed(41)


def test_update():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_output_transform():
def drop_first(output):
y_pred, y = output
return (y_pred[1:], y[1:])

precision = Precision(output_transform=drop_first)
accuracy = Accuracy(output_transform=drop_first)

group = MetricGroup(
{"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)}
)

y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update(drop_first(drop_first((y_pred, y))))
accuracy.update(drop_first(drop_first((y_pred, y))))
group.update(drop_first((y_pred, y)))

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


def test_compute():
precision = Precision()
accuracy = Accuracy()

group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})

for _ in range(3):
y_pred = torch.randint(0, 2, (100,))
y = torch.randint(0, 2, (100,))

precision.update((y_pred, y))
accuracy.update((y_pred, y))
group.update((y_pred, y))

assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()}

precision.reset()
accuracy.reset()
group.reset()

assert precision.state_dict() == group.metrics["precision"].state_dict()
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
rank = idist.get_rank()
torch.manual_seed(12 + rank)

n_epochs = 3
n_iters = 5
batch_size = 10
device = idist.device()

y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device)
y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device)

def update(_, i):
return (
y_pred[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

precision = Precision()
precision.attach(engine, "precision")

accuracy = Accuracy()
accuracy.attach(engine, "accuracy")

group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()})
group.attach(engine, "eval_metrics")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

assert "eval_metrics" in engine.state.metrics
assert "eval_metrics.accuracy" in engine.state.metrics
assert "eval_metrics.precision" in engine.state.metrics

assert engine.state.metrics["eval_metrics"] == {
"eval_metrics.accuracy": engine.state.metrics["accuracy"],
"eval_metrics.precision": engine.state.metrics["precision"],
}
assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"]
assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]

0 comments on commit bdd0fe4

Please sign in to comment.