Skip to content

Commit

Permalink
[fix] from img_size (w, h) to image_shape (h, w)
Browse files Browse the repository at this point in the history
  • Loading branch information
hglee98 committed Nov 15, 2024
1 parent af737d7 commit 9d20bb6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/netspresso_trainer/postprocessors/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def rtdetr_decode(pred, original_shape, num_top_queries=300, score_thresh=0.0):

num_classes = logits.shape[-1]
h, w = original_shape[1], original_shape[2]
boxes = transform_bbox(boxes, "cxcywhn -> xyxy", img_size=(w, h))
boxes = transform_bbox(boxes, "cxcywhn -> xyxy", image_shape=(h, w))

scores = torch.sigmoid(logits)
scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
Expand Down
30 changes: 15 additions & 15 deletions src/netspresso_trainer/utils/bbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def transform_bbox(bboxes: Union[Tensor, Proxy],
indicator="xywh -> xyxy",
img_size: Optional[Union[int, Tuple[int, int]]]=None):
image_shape: Optional[Union[int, Tuple[int, int]]]=None):
def is_normalized(fmt: str) -> bool:
return fmt.endswith('n')

Expand All @@ -21,16 +21,16 @@ def is_normalized(fmt: str) -> bool:
assert out_type in VALID_OUT_TYPE, f"Invalid out_type: '{out_type}'. Must be one of {VALID_OUT_TYPE}."

if is_normalized(in_type):
assert img_size is not None, f"img_size is required for normalized conversion: {indicator}"
if isinstance(img_size, int):
img_width = img_height = img_size
assert image_shape is not None, f"image_shape is required for normalized conversion: {indicator}"
if isinstance(image_shape, int):
img_height = img_width = image_shape
else:
img_width, img_height = img_size
assert isinstance(img_width, int) and isinstance(img_height, int), \
f"Invalid type: (width: {type(img_width)}, height: {type(img_height)}. Must be (int, int))"
img_height, img_width = image_shape
assert isinstance(img_height, int) and isinstance(img_width, int), \
f"Invalid type: (height: {type(img_height)}, width: {type(img_width)}. Must be (int, int))"
in_type = in_type[:-1]
else:
img_width = img_height = 1.0
img_height = img_width = 1.0

if in_type == "xyxy":
x_min, y_min, x_max, y_max = bboxes.unbind(-1)
Expand All @@ -53,16 +53,16 @@ def is_normalized(fmt: str) -> bool:
assert (y_max >= y_min).all(), "Invalid box: y_max < y_min"

if is_normalized(out_type):
assert img_size is not None, f"img_size is required for normalized conversion: {indicator}"
if isinstance(img_size, int):
img_width = img_height = img_size
assert image_shape is not None, f"img_size is required for normalized conversion: {indicator}"
if isinstance(image_shape, int):
img_height = img_width = image_shape
else:
img_width, img_height = img_size
assert isinstance(img_width, int) and isinstance(img_height, int), \
f"Invalid type: (width: {type(img_width)}, height: {type(img_height)}. Must be (int, int))"
img_height, img_width = image_shape
assert isinstance(img_height, int) and isinstance(img_width, int), \
f"Invalid type: (height: {type(img_height)}, width: {type(img_width)}. Must be (int, int))"
out_type = out_type[:-1]
else:
img_width = img_height = 1.0
img_height = img_width = 1.0

x_min /= img_width
y_min /= img_height
Expand Down

0 comments on commit 9d20bb6

Please sign in to comment.