forked from openvinotoolkit/anomalib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add binning capability to AUPRO (openvinotoolkit#1145)
* Add the capability to compute binned AUPRO. * fix linting * use directly binary_roc * update CHANGELOG.md * improve test by doing 2 different ones (aupro and binned aupro) + renamed few variables in the tests * only allow num_thresholds as input + fix tests + add threshold computing utilities * add binning tests * use binary roc directly * remove unused import and rename some * device for thresholds * fix linting * use torch.all * fix linting * fix linting * remove observe time --------- Co-authored-by: Samet Akcay <samet.akcay@intel.com>
- Loading branch information
1 parent
0b5d969
commit 7feee1e
Showing
5 changed files
with
133 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Optional | ||
|
||
from torch import Tensor, linspace | ||
from torch import device as torch_device | ||
|
||
|
||
def thresholds_between_min_and_max( | ||
preds: Tensor, num_thresholds: int = 100, device: Optional[torch_device] = None | ||
) -> Tensor: | ||
return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device) | ||
|
||
|
||
def thresholds_between_0_and_1(num_thresholds: int = 100, device: Optional[torch_device] = None) -> Tensor: | ||
return linspace(start=0, end=1, steps=num_thresholds, device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from torch import Tensor, all as torch_all | ||
|
||
from anomalib.utils.metrics.binning import thresholds_between_min_and_max, thresholds_between_0_and_1 | ||
|
||
|
||
def test_thresholds_between_min_and_max(): | ||
preds = Tensor([1, 10]) | ||
assert torch_all(thresholds_between_min_and_max(preds, 2) == preds) | ||
|
||
|
||
def test_thresholds_between_0_and_1(): | ||
expected = Tensor([0, 1]) | ||
assert torch_all(thresholds_between_0_and_1(2) == expected) |