Skip to content

Commit

Permalink
[feat] Implement yolov9 detection post-processor
Browse files Browse the repository at this point in the history
  • Loading branch information
hglee98 committed Nov 28, 2024
1 parent 6cf0e0d commit 5b41cfd
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
44 changes: 43 additions & 1 deletion src/netspresso_trainer/postprocessors/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchvision.models.detection._utils import BoxCoder, _topk_min
from torchvision.ops import boxes as box_ops

from netspresso_trainer.utils.bbox_utils import transform_bbox
from netspresso_trainer.utils.bbox_utils import generate_anchors, transform_bbox

from ..models.utils import ModelOutput

Expand Down Expand Up @@ -205,6 +205,45 @@ def yolo_fastest_head_decode(pred, original_shape, score_thresh=0.7, anchors=Non

return detections

def yolo_head_decode(pred, original_shape, score_thresh=0.7):
pred = pred['pred']
if isinstance(pred, dict):
pred = pred['outputs']
pred[0][0].type()
h, w = original_shape[1], original_shape[2]
device = pred[0][0].device
stage_strides= [original_shape[-1] // bbox_reg.shape[-1] for bbox_reg, _, _ in pred]
offset, scaler = generate_anchors((h, w), stage_strides)
offset = offset.to(device)
scaler = scaler.to(device)

pred_bbox_reg, pred_class_logits = [], []
for layer_output in pred:
bbox_reg, _, class_logits = layer_output
b, c, h, w = bbox_reg.shape
reg = bbox_reg.view(b, c, -1).permute(0, 2, 1)
pred_bbox_reg.append(reg)

b, c, h, w = class_logits.shape
logits = class_logits.view(b, c, -1).permute(0, 2, 1)
pred_class_logits.append(logits)

pred_bbox_reg = torch.concat(pred_bbox_reg, dim=1)
pred_class_logits = torch.concat(pred_class_logits, dim=1).sigmoid()

pred_xyxy = pred_bbox_reg * scaler.view(1, -1, 1)
lt, rb = pred_xyxy.chunk(2, dim=-1)
pred_bbox_reg = torch.cat([offset - lt, offset + rb], dim=-1)

detections = []
for bbox, cls_logits in zip(pred_bbox_reg, pred_class_logits):
class_conf, class_pred = torch.max(cls_logits, 1, keepdim=True)
conf_mask = (class_conf.squeeze() >= score_thresh).squeeze()

detections.append(
torch.cat((bbox, torch.ones_like(class_pred), class_conf, class_pred.float()), 1)[conf_mask]
)
return detections

def nms(prediction, nms_thresh=0.45, class_agnostic=False):
output = [torch.zeros(0, 7).to(prediction[0].device) for i in range(len(prediction))]
Expand Down Expand Up @@ -247,6 +286,9 @@ def __init__(self, conf_model):
elif head_name == 'yolo_fastest_head_v2':
self.decode_outputs = partial(yolo_fastest_head_decode, score_thresh=params.score_thresh, anchors=params.anchors)
self.postprocess = partial(nms, nms_thresh=params.nms_thresh, class_agnostic=params.class_agnostic)
elif head_name == 'yolo_detection_head':
self.decode_outputs = partial(yolo_head_decode, score_thresh=params.score_thresh)
self.postprocess = partial(nms, nms_thresh=params.nms_thresh, class_agnostic=params.class_agnostic)
elif head_name == 'rtdetr_head':
self.decode_outputs = partial(rtdetr_decode, num_top_queries=params.num_top_queries, score_thresh=params.score_thresh)
self.postprocess = None
Expand Down
1 change: 1 addition & 0 deletions src/netspresso_trainer/postprocessors/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'pidnet': SegmentationPostprocessor,
'anchor_decoupled_head': DetectionPostprocessor,
'yolo_fastest_head_v2': DetectionPostprocessor,
'yolo_detection_head': DetectionPostprocessor,
'rtmcc': PoseEstimationPostprocessor,
'rtdetr_head': DetectionPostprocessor,
}

0 comments on commit 5b41cfd

Please sign in to comment.