-
There is an issue with batched_nms when there is only one detection. In line and the following line, the |
Beta Was this translation helpful? Give feedback.
Answered by
SunHao-AI
Dec 16, 2024
Replies: 1 comment 1 reply
-
I encountered the same problem, change the code to the following, and it can be run: def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_threshold: float = 0.5):
"""
Apply non-maximum suppression to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
predictions: (tensor) The location preds for the image
along with the class predscores, Shape: [num_boxes,5].
match_metric: (str) IOU or IOS
match_threshold: (float) The overlap thresh for
match metric.
Returns:
A list of filtered indexes, Shape: [ ,]
"""
# scores = predictions[:, 4].squeeze()
# category_ids = predictions[:, 5].squeeze()
scores = predictions[:, 4]
category_ids = predictions[:, 5]
keep_mask = torch.zeros_like(category_ids, dtype=torch.bool)
for category_id in torch.unique(category_ids):
curr_indices = torch.where(category_ids == category_id)[0]
curr_keep_indices = nms(predictions[curr_indices], match_metric, match_threshold)
keep_mask[curr_indices[curr_keep_indices]] = True
keep_indices = torch.where(keep_mask)[0]
# sort selected indices by their scores
keep_indices = keep_indices[scores[keep_indices].sort(descending=True)[1]].tolist()
return keep_indices |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
fcakyon
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I encountered the same problem, change the code to the following, and it can be run: