Skip to content

Commit

Permalink
Add binning capability to AUPRO (openvinotoolkit#1145)
Browse files Browse the repository at this point in the history
* Add the capability to compute binned AUPRO.

* fix linting

* use directly binary_roc

* update CHANGELOG.md

* improve test by doing 2 different ones (aupro and binned aupro) + renamed few variables in the tests

* only allow num_thresholds as input + fix tests + add threshold computing utilities

* add binning tests

* use binary roc directly

* remove unused import and rename some

* device for thresholds

* fix linting

* use torch.all

* fix linting

* fix linting

* remove observe time

---------

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
Yann-CV and samet-akcay committed Aug 21, 2023
1 parent 0b5d969 commit 7feee1e
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- AUPRO binning capability by @yann-cv
- Add support for receiving dataset paths as a list by @harimkang in https://github.com/openvinotoolkit/anomalib/pull/1265

### Changed
Expand Down
41 changes: 38 additions & 3 deletions src/anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import auc, roc
from torchmetrics.functional import auc
from torchmetrics.functional.classification import binary_roc
from torchmetrics.utilities.data import dim_zero_cat

from anomalib.utils.metrics.pro import (
connected_components_cpu,
connected_components_gpu,
)

from .binning import thresholds_between_0_and_1, thresholds_between_min_and_max
from .plotting_utils import plot_figure


Expand All @@ -30,6 +32,13 @@ class AUPRO(Metric):
full_state_update: bool = False
preds: list[Tensor]
target: list[Tensor]
# When not None, the computation is performed in constant-memory by computing the roc curve
# for fixed thresholds buckets/thresholds.
# Warning: The thresholds are evenly distributed between the min and max predictions
# if all predictions are inside [0, 1]. Otherwise, the thresholds are evenly distributed between 0 and 1.
# This warning can be removed when https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed
# and the roc curve is computed with deactivated formatting
num_thresholds: int | None

def __init__(
self,
Expand All @@ -38,6 +47,7 @@ def __init__(
process_group: Any | None = None,
dist_sync_fn: Callable | None = None,
fpr_limit: float = 0.3,
num_thresholds: int | None = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -49,6 +59,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.register_buffer("fpr_limit", torch.tensor(fpr_limit))
self.num_thresholds = num_thresholds

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with new values.
Expand Down Expand Up @@ -96,9 +107,29 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso
Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""
if self.num_thresholds is not None:
# binary_roc is applying a sigmoid on the predictions before computing the roc curve
# when some predictions are out of [0, 1], the binning between min and max predictions
# cannot be applied in that case. This can be removed when
# https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed and
# the roc curve is computed with deactivated formatting.

if torch.all((0 <= preds) * (preds <= 1)):
thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device)
else:
thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device)

else:
thresholds = None

# compute the global fpr-size
fpr: Tensor = roc(preds, target)[0] # only need fpr
fpr: Tensor = binary_roc(
preds=preds,
target=target,
thresholds=thresholds,
)[
0
] # only need fpr
output_size = torch.where(fpr <= self.fpr_limit)[0].size(0)

# compute the PRO curve by aggregating per-region tpr/fpr curves/values.
Expand All @@ -120,7 +151,11 @@ def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tenso
mask = cca == label
# Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other
# label in labels as FPs. We also don't need to return the thresholds
_fpr, _tpr = roc(preds[background | mask], mask[background | mask])[:-1]
_fpr, _tpr = binary_roc(
preds=preds[background | mask],
target=mask[background | mask],
thresholds=thresholds,
)[:-1]

# catch edge-case where ROC only has fpr vals > self.fpr_limit
if _fpr[_fpr <= self.fpr_limit].max() == 0:
Expand Down
14 changes: 14 additions & 0 deletions src/anomalib/utils/metrics/binning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional

from torch import Tensor, linspace
from torch import device as torch_device


def thresholds_between_min_and_max(
preds: Tensor, num_thresholds: int = 100, device: Optional[torch_device] = None
) -> Tensor:
return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device)


def thresholds_between_0_and_1(num_thresholds: int = 100, device: Optional[torch_device] = None) -> Tensor:
return linspace(start=0, end=1, steps=num_thresholds, device=device)
102 changes: 67 additions & 35 deletions tests/pre_merge/utils/metrics/test_aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,85 @@


def pytest_generate_tests(metafunc):
if metafunc.function is test_pro:
labels = [
torch.tensor(
labels = [
torch.tensor(
[
[
[
[0, 0, 0, 1, 0, 0, 0],
]
* 400,
[0, 0, 0, 1, 0, 0, 0],
]
),
torch.tensor(
* 400,
]
),
torch.tensor(
[
[
[
[0, 1, 0, 1, 0, 1, 0],
]
* 400,
[0, 1, 0, 1, 0, 1, 0],
]
),
]
preds = torch.arange(2800) / 2800.0
preds = preds.view(1, 1, 400, 7)
* 400,
]
),
]
preds = torch.arange(2800) / 2800.0
preds = preds.view(1, 1, 400, 7)

preds = [preds, preds]
preds = [preds, preds]

fpr_limit = [1 / 3, 1 / 3]
aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)]
fpr_limit = [1 / 3, 1 / 3]
expected_aupro = [torch.tensor(1 / 6), torch.tensor(1 / 6)]

# Also test that per-region aupros are averaged
labels.append(torch.cat(labels))
preds.append(torch.cat(preds))
fpr_limit.append(float(np.mean(fpr_limit)))
aupro.append(torch.tensor(np.mean(aupro)))
# Also test that per-region aupros are averaged
labels.append(torch.cat(labels))
preds.append(torch.cat(preds))
fpr_limit.append(float(np.mean(fpr_limit)))
expected_aupro.append(torch.tensor(np.mean(expected_aupro)))

vals = list(zip(labels, preds, fpr_limit, aupro))
metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "aupro"), argvalues=vals)
threshold_count = [
200,
200,
200,
]

if metafunc.function is test_aupro:
vals = list(zip(labels, preds, fpr_limit, expected_aupro))
metafunc.parametrize(argnames=("labels", "preds", "fpr_limit", "expected_aupro"), argvalues=vals)
elif metafunc.function is test_binned_aupro:
vals = list(zip(labels, preds, threshold_count))
metafunc.parametrize(argnames=("labels", "preds", "threshold_count"), argvalues=vals)

def test_pro(labels, preds, fpr_limit, aupro):
pro = AUPRO(fpr_limit=fpr_limit)
pro.update(preds, labels)
computed_aupro = pro.compute()

def test_aupro(labels, preds, fpr_limit, expected_aupro):
aupro = AUPRO(fpr_limit=fpr_limit)
aupro.update(preds, labels)
computed_aupro = aupro.compute()

tmp_labels = [label.squeeze().numpy() for label in labels]
tmp_preds = [pred.squeeze().numpy() for pred in preds]
ref_pro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float)
ref_aupro = torch.tensor(calculate_au_pro(tmp_labels, tmp_preds, integration_limit=fpr_limit)[0], dtype=torch.float)

TOL = 0.001
assert torch.allclose(computed_aupro, expected_aupro, atol=TOL)
assert torch.allclose(computed_aupro, ref_aupro, atol=TOL)


def test_binned_aupro(labels, preds, threshold_count):
aupro = AUPRO()
computed_not_binned_aupro = aupro(preds, labels)

binned_pro = AUPRO(num_thresholds=threshold_count)
computed_binned_aupro = binned_pro(preds, labels)

TOL = 0.001
assert torch.allclose(computed_aupro, aupro, atol=TOL)
assert torch.allclose(computed_aupro, ref_pro, atol=TOL)
assert torch.allclose(aupro, ref_pro, atol=TOL)
# with threshold binning the roc curve computed within the metric is more memory efficient
# but a bit less accurate. So we check the difference in order to validate the binning effect.
assert computed_binned_aupro != computed_not_binned_aupro
assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL)

# test with prediction higher than 1
preds = preds * 2
computed_binned_aupro = binned_pro(preds, labels)
computed_not_binned_aupro = aupro(preds, labels)

# with threshold binning the roc curve computed within the metric is more memory efficient
# but a bit less accurate. So we check the difference in order to validate the binning effect.
assert computed_binned_aupro != computed_not_binned_aupro
assert torch.allclose(computed_not_binned_aupro, computed_binned_aupro, atol=TOL)
13 changes: 13 additions & 0 deletions tests/pre_merge/utils/metrics/test_binning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch import Tensor, all as torch_all

from anomalib.utils.metrics.binning import thresholds_between_min_and_max, thresholds_between_0_and_1


def test_thresholds_between_min_and_max():
preds = Tensor([1, 10])
assert torch_all(thresholds_between_min_and_max(preds, 2) == preds)


def test_thresholds_between_0_and_1():
expected = Tensor([0, 1])
assert torch_all(thresholds_between_0_and_1(2) == expected)

0 comments on commit 7feee1e

Please sign in to comment.