Skip to content

Commit

Permalink
[fix] yolov9 postprocessor logic for correct outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
hglee98 committed Dec 6, 2024
1 parent 6c6a358 commit 86d4149
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/netspresso_trainer/postprocessors/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def yolo_head_decode(pred, original_shape, score_thresh=0.7):
pred = pred['pred']
if isinstance(pred, dict):
pred = pred['outputs']
pred[0][0].type()
h, w = original_shape[1], original_shape[2]
device = pred[0][0].device
stage_strides= [original_shape[-1] // bbox_reg.shape[-1] for bbox_reg, _, _ in pred]
Expand All @@ -221,11 +220,11 @@ def yolo_head_decode(pred, original_shape, score_thresh=0.7):
for layer_output in pred:
bbox_reg, _, class_logits = layer_output
b, c, h, w = bbox_reg.shape
reg = bbox_reg.view(b, c, -1).permute(0, 2, 1)
reg = bbox_reg.permute(0, 2, 3, 1).view(b, h*w, c)
pred_bbox_reg.append(reg)

b, c, h, w = class_logits.shape
logits = class_logits.view(b, c, -1).permute(0, 2, 1)
logits = class_logits.permute(0, 2, 3, 1).view(b, h*w, c)
pred_class_logits.append(logits)

pred_bbox_reg = torch.concat(pred_bbox_reg, dim=1)
Expand Down

0 comments on commit 86d4149

Please sign in to comment.