Skip to content

Commit

Permalink
ensure segmentation has type boolean for now (#507)
Browse files Browse the repository at this point in the history
* ensure that segmentation arrays are always of type boolean

* fix metric with new restriction

* more precise check on values of segmentation

* Trigger Build
  • Loading branch information
ismael-mendoza authored Jul 1, 2024
1 parent 46d9adf commit 7795588
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
20 changes: 15 additions & 5 deletions btk/blend_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _validate_catalog(self, catalog: Table):
)
return catalog

def _validate_segmentation(self, segmentation):
def _validate_segmentation(self, segmentation: Optional[np.ndarray]):
if segmentation is not None:
if self.image_size is None or self.n_bands is None:
raise ValueError("`image_size` must be specified if segmentation is provided")
Expand All @@ -233,14 +233,17 @@ def _validate_segmentation(self, segmentation):
"The predicted segmentation of at least one of your deblended images "
"has the wrong shape. It should be `(max_n_sources, image_size, image_size)`."
)
if segmentation.min() < 0 or segmentation.max() > 1:
cond1 = np.any(np.greater(segmentation, 0) & np.less(segmentation, 1))
cond2 = np.any(np.less(segmentation, 0) | np.greater(segmentation, 1))
if cond1 or cond2:
raise ValueError(
"The predicted segmentation of at least one of your deblended images "
"has values outside the range [0, 1]."
"has values different than 0 or 1."
)
return segmentation.astype(bool)
return segmentation

def _validate_deblended_images(self, deblended_images):
def _validate_deblended_images(self, deblended_images: Optional[np.ndarray]):
if deblended_images is not None:
if self.image_size is None or self.n_bands is None:
raise ValueError(
Expand Down Expand Up @@ -335,7 +338,14 @@ def _validate_segmentation(self, segmentation: Optional[np.ndarray] = None) -> n
self.image_size,
self.image_size,
)
assert segmentation.min() >= 0 and segmentation.max() <= 1
cond1 = np.any(np.greater(segmentation, 0) & np.less(segmentation, 1))
cond2 = np.any(np.less(segmentation, 0) | np.greater(segmentation, 1))
if cond1 or cond2:
raise ValueError(
"The predicted segmentation of at least one of your deblended images "
"has values different than 0 or 1."
)
return segmentation.astype(bool)
return segmentation

def _validate_deblended_images(
Expand Down
10 changes: 9 additions & 1 deletion btk/metrics/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def _get_data(


class IoU(SegmentationMetric):
"""Intersection-over-Union class metric."""
"""Intersection-over-Union class metric.
Note that this metric assumes that the input segmentation arrays are booleans. An error is
raised if that condition is not met.
"""

def _get_data(
self,
Expand All @@ -31,6 +35,10 @@ def _get_data(
) -> Dict[str, np.ndarray]:
assert seg1.shape == seg2.shape
assert seg1.ndim == 4 # batch, max_n_sources, x, y

if not (seg1.dtype == "bool" and seg2.dtype == "bool"):
raise TypeError("Segmentation arrays for the IoU metric should be of boolean type.")

ious = np.full((self.batch_size, seg1.shape[1]), fill_value=np.nan)
for ii in range(self.batch_size):
n_sources1 = np.sum(np.sum(seg1[ii], axis=(-1, -2)) > 0)
Expand Down
3 changes: 2 additions & 1 deletion btk/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def mse(images1: np.ndarray, images2: np.ndarray) -> np.ndarray:
def iou(seg1: np.ndarray, seg2: np.ndarray) -> np.ndarray:
"""Calculates intersection-over-union (IoU) given two semgentation arrays.
The segmentation arrays should each have values of 1 or 0s only.
This metric assumes that the input arrays are boolean. Otherwise the arrays are
casted to boolean arrays before the computation.
Args:
seg1: Array of shape `NHW` containing `N` segmentation maps each of
Expand Down
6 changes: 3 additions & 3 deletions notebooks/01-advanced-metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fd5229bcf70>"
"<matplotlib.image.AxesImage at 0x7fa7ea42f550>"
]
},
"execution_count": null,
Expand Down Expand Up @@ -297,7 +297,7 @@
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fd522a6bd30>"
"<matplotlib.image.AxesImage at 0x7fa7ea42f0d0>"
]
},
"execution_count": null,
Expand Down Expand Up @@ -378,7 +378,7 @@
"from btk.metrics.segmentation import IoU\n",
"\n",
"iou = IoU(match.batch_size)\n",
"iou(segs1, segs2) # automatically ignores unmached segmentation"
"iou(segs1.astype(bool), segs2.astype(bool)) # automatically ignores unmached segmentation"
]
},
{
Expand Down

0 comments on commit 7795588

Please sign in to comment.