diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a910dfb1f..508e465e7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- AUPRO binning capability by @yann-cv - Add support for receiving dataset paths as a list by @harimkang in https://github.com/openvinotoolkit/anomalib/pull/1265 ### Changed diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index ba8608fd3d..ad2a615e5c 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -11,7 +11,8 @@ from matplotlib.figure import Figure from torch import Tensor from torchmetrics import Metric -from torchmetrics.functional import auc, roc +from torchmetrics.functional import auc +from torchmetrics.functional.classification import binary_roc from torchmetrics.utilities.data import dim_zero_cat from anomalib.utils.metrics.pro import ( @@ -19,6 +20,7 @@ connected_components_gpu, ) +from .binning import thresholds_between_0_and_1, thresholds_between_min_and_max from .plotting_utils import plot_figure @@ -30,6 +32,13 @@ class AUPRO(Metric): full_state_update: bool = False preds: list[Tensor] target: list[Tensor] + # When not None, the computation is performed in constant-memory by computing the roc curve + # for fixed thresholds buckets/thresholds. + # Warning: The thresholds are evenly distributed between the min and max predictions + # if all predictions are inside [0, 1]. Otherwise, the thresholds are evenly distributed between 0 and 1. + # This warning can be removed when https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed + # and the roc curve is computed with deactivated formatting + num_thresholds: int | None def __init__( self, @@ -38,6 +47,7 @@ def __init__( process_group: Any | None = None, dist_sync_fn: Callable | None = None, fpr_limit: float = 0.3, + num_thresholds: int | None = None, ) -> None: super().__init__( compute_on_step=compute_on_step, @@ -49,6 +59,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable self.register_buffer("fpr_limit", torch.tensor(fpr_limit)) + self.num_thresholds = num_thresholds def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with new values. @@ -96,9 +107,29 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso Returns: tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. """ + if self.num_thresholds is not None: + # binary_roc is applying a sigmoid on the predictions before computing the roc curve + # when some predictions are out of [0, 1], the binning between min and max predictions + # cannot be applied in that case. This can be removed when + # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed and + # the roc curve is computed with deactivated formatting. + + if torch.all((0 <= preds) * (preds <= 1)): + thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device) + else: + thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device) + + else: + thresholds = None # compute the global fpr-size - fpr: Tensor = roc(preds, target)[0] # only need fpr + fpr: Tensor = binary_roc( + preds=preds, + target=target, + thresholds=thresholds, + )[ + 0 + ] # only need fpr output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) # compute the PRO curve by aggregating per-region tpr/fpr curves/values. @@ -120,7 +151,11 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso mask = cca == label # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other # label in labels as FPs. We also don't need to return the thresholds - _fpr, _tpr = roc(preds[background | mask], mask[background | mask])[:-1] + _fpr, _tpr = binary_roc( + preds=preds[background | mask], + target=mask[background | mask], + thresholds=thresholds, + )[:-1] # catch edge-case where ROC only has fpr vals > self.fpr_limit if _fpr[_fpr <= self.fpr_limit].max() == 0: diff --git a/src/anomalib/utils/metrics/binning.py b/src/anomalib/utils/metrics/binning.py new file mode 100644 index 0000000000..f92e0b20ab --- /dev/null +++ b/src/anomalib/utils/metrics/binning.py @@ -0,0 +1,14 @@ +from typing import Optional + +from torch import Tensor, linspace +from torch import device as torch_device + + +def thresholds_between_min_and_max( + preds: Tensor, num_thresholds: int = 100, device: Optional[torch_device] = None +) -> Tensor: + return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device) + + +def thresholds_between_0_and_1(num_thresholds: int = 100, device: Optional[torch_device] = None) -> Tensor: + return linspace(start=0, end=1, steps=num_thresholds, device=device) diff --git a/tests/pre_merge/utils/metrics/test_aupro.py b/tests/pre_merge/utils/metrics/test_aupro.py index 88466d0566..5e17394926 100644 --- a/tests/pre_merge/utils/metrics/test_aupro.py +++ b/tests/pre_merge/utils/metrics/test_aupro.py @@ -8,53 +8,85 @@ def pytest_generate_tests(metafunc): - if metafunc.function is test_pro: - labels = [ - torch.tensor( + labels = [ + torch.tensor( + [ [ - [ - [0, 0, 0, 1, 0, 0, 0], - ] - * 400, + [0, 0, 0, 1, 0, 0, 0], ] - ), - torch.tensor( + * 400, + ] + ), + torch.tensor( + [ [ - [ - [0, 1, 0, 1, 0, 1, 0], - ] - * 400, + [0, 1, 0, 1, 0, 1, 0], ] - ), - ] - preds = torch.arange(2800) / 2800.0 - preds = preds.view(1, 1, 400, 7) + * 400, + ] + ), + ] + preds = torch.arange(2800) / 2800.0 + preds = preds.view(1, 1, 400, 7) - preds = [preds, preds] + preds = [preds, preds] - fpr_limit = [1 / 3, 1 / 3] - aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)] + fpr_limit = [1 / 3, 1 / 3] + expected_aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)] - # Also test that per-region aupros are averaged - labels.append(torch.cat(labels)) - preds.append(torch.cat(preds)) - fpr_limit.append(float(np.mean(fpr_limit))) - aupro.append(torch.tensor(np.mean(aupro))) + # Also test that per-region aupros are averaged + labels.append(torch.cat(labels)) + preds.append(torch.cat(preds)) + fpr_limit.append(float(np.mean(fpr_limit))) + expected_aupro.append(torch.tensor(np.mean(expected_aupro))) - vals = list(zip(labels, preds, fpr_limit, aupro)) - metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "aupro"), argvalues=vals) + threshold_count = [ + 200, + 200, + 200, + ] + if metafunc.function is test_aupro: + vals = list(zip(labels, preds, fpr_limit, expected_aupro)) + metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "expected_aupro"), argvalues=vals) + elif metafunc.function is test_binned_aupro: + vals = list(zip(labels, preds, threshold_count)) + metafunc.parametrize(argnames=("labels", "preds", "threshold_count"), argvalues=vals) -def test_pro(labels, preds, fpr_limit, aupro): - pro = AUPRO(fpr_limit=fpr_limit) - pro.update(preds, labels) - computed_aupro = pro.compute() + +def test_aupro(labels, preds, fpr_limit, expected_aupro): + aupro = AUPRO(fpr_limit=fpr_limit) + aupro.update(preds, labels) + computed_aupro = aupro.compute() tmp_labels = [label.squeeze().numpy() for label in labels] tmp_preds = [pred.squeeze().numpy() for pred in preds] - ref_pro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float) + ref_aupro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float) + + TOL = 0.001 + assert torch.allclose(computed_aupro, expected_aupro, atol=TOL) + assert torch.allclose(computed_aupro, ref_aupro, atol=TOL) + + +def test_binned_aupro(labels, preds, threshold_count): + aupro = AUPRO() + computed_not_binned_aupro = aupro(preds, labels) + + binned_pro = AUPRO(num_thresholds=threshold_count) + computed_binned_aupro = binned_pro(preds, labels) TOL = 0.001 - assert torch.allclose(computed_aupro, aupro, atol=TOL) - assert torch.allclose(computed_aupro, ref_pro, atol=TOL) - assert torch.allclose(aupro, ref_pro, atol=TOL) + # with threshold binning the roc curve computed within the metric is more memory efficient + # but a bit less accurate. So we check the difference in order to validate the binning effect. + assert computed_binned_aupro != computed_not_binned_aupro + assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL) + + # test with prediction higher than 1 + preds = preds * 2 + computed_binned_aupro = binned_pro(preds, labels) + computed_not_binned_aupro = aupro(preds, labels) + + # with threshold binning the roc curve computed within the metric is more memory efficient + # but a bit less accurate. So we check the difference in order to validate the binning effect. + assert computed_binned_aupro != computed_not_binned_aupro + assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL) diff --git a/tests/pre_merge/utils/metrics/test_binning.py b/tests/pre_merge/utils/metrics/test_binning.py new file mode 100644 index 0000000000..d256d95a7a --- /dev/null +++ b/tests/pre_merge/utils/metrics/test_binning.py @@ -0,0 +1,13 @@ +from torch import Tensor, all as torch_all + +from anomalib.utils.metrics.binning import thresholds_between_min_and_max, thresholds_between_0_and_1 + + +def test_thresholds_between_min_and_max(): + preds = Tensor([1, 10]) + assert torch_all(thresholds_between_min_and_max(preds, 2) == preds) + + +def test_thresholds_between_0_and_1(): + expected = Tensor([0, 1]) + assert torch_all(thresholds_between_0_and_1(2) == expected)