Skip to content

Commit

Permalink
[fix] handle error when certain classes have no object in dataset due…
Browse files Browse the repository at this point in the history
… to the class imbalance problem.
  • Loading branch information
hglee98 committed Oct 29, 2024
1 parent 8c83b70 commit 6313580
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/netspresso_trainer/metrics/detection/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def average_precisions_per_class(
prediction_confidence: np.ndarray,
prediction_class_ids: np.ndarray,
true_class_ids: np.ndarray,
num_classes: int = 80,
eps: float = 1e-16,
) -> np.ndarray:
"""
Expand All @@ -143,6 +144,7 @@ def average_precisions_per_class(
prediction_confidence (np.ndarray): Objectness value from 0-1.
prediction_class_ids (np.ndarray): Predicted object classes.
true_class_ids (np.ndarray): True object classes.
num_classes (int): The number of classes.
eps (float, optional): Small value to prevent division by zero.
Returns:
Expand All @@ -153,7 +155,6 @@ def average_precisions_per_class(
prediction_class_ids = prediction_class_ids[sorted_indices]

unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
num_classes = unique_classes.shape[0]

average_precisions = np.zeros((num_classes, matches.shape[1]))

Expand All @@ -172,7 +173,7 @@ def average_precisions_per_class(

for iou_level_idx in range(matches.shape[1]):
average_precisions[
class_idx, iou_level_idx
int(class_id), iou_level_idx
] = compute_average_precision(
recall[:, iou_level_idx], precision[:, iou_level_idx]
)
Expand Down Expand Up @@ -233,7 +234,7 @@ def calibrate(self, predictions, targets, **kwargs):
# Compute average precisions if any matches exist
if stats:
concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
average_precisions = average_precisions_per_class(*concatenated_stats)
average_precisions = average_precisions_per_class(*concatenated_stats, num_classes=self.num_classes)

if self.classwise_analysis:
for i, classwise_meter in enumerate(self.classwise_metric_meters):
Expand All @@ -255,7 +256,7 @@ def calibrate(self, predictions, targets, **kwargs):
# Compute average precisions if any matches exist
if stats:
concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
average_precisions = average_precisions_per_class(*concatenated_stats)
average_precisions = average_precisions_per_class(*concatenated_stats, num_classes=self.num_classes)

if self.classwise_analysis:
for i, classwise_meter in enumerate(self.classwise_metric_meters):
Expand All @@ -277,7 +278,7 @@ def calibrate(self, predictions, targets, **kwargs):
# Compute average precisions if any matches exist
if stats:
concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
average_precisions = average_precisions_per_class(*concatenated_stats)
average_precisions = average_precisions_per_class(*concatenated_stats, num_classes=self.num_classes)

if self.classwise_analysis:
for i, classwise_meter in enumerate(self.classwise_metric_meters):
Expand Down

0 comments on commit 6313580

Please sign in to comment.