diff --git a/src/netspresso_trainer/postprocessors/detection.py b/src/netspresso_trainer/postprocessors/detection.py index 8d8abc14..604fc699 100644 --- a/src/netspresso_trainer/postprocessors/detection.py +++ b/src/netspresso_trainer/postprocessors/detection.py @@ -33,7 +33,7 @@ def rtdetr_decode(pred, original_shape, num_top_queries=300, score_thresh=0.0): num_classes = logits.shape[-1] h, w = original_shape[1], original_shape[2] - boxes = transform_bbox(boxes, "cxcywhn -> xyxy", img_size=(w, h)) + boxes = transform_bbox(boxes, "cxcywhn -> xyxy", image_shape=(h, w)) scores = torch.sigmoid(logits) scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) diff --git a/src/netspresso_trainer/utils/bbox_utils.py b/src/netspresso_trainer/utils/bbox_utils.py index a7b73f15..d76fa0b0 100644 --- a/src/netspresso_trainer/utils/bbox_utils.py +++ b/src/netspresso_trainer/utils/bbox_utils.py @@ -10,7 +10,7 @@ def transform_bbox(bboxes: Union[Tensor, Proxy], indicator="xywh -> xyxy", - img_size: Optional[Union[int, Tuple[int, int]]]=None): + image_shape: Optional[Union[int, Tuple[int, int]]]=None): def is_normalized(fmt: str) -> bool: return fmt.endswith('n') @@ -21,16 +21,16 @@ def is_normalized(fmt: str) -> bool: assert out_type in VALID_OUT_TYPE, f"Invalid out_type: '{out_type}'. Must be one of {VALID_OUT_TYPE}." if is_normalized(in_type): - assert img_size is not None, f"img_size is required for normalized conversion: {indicator}" - if isinstance(img_size, int): - img_width = img_height = img_size + assert image_shape is not None, f"image_shape is required for normalized conversion: {indicator}" + if isinstance(image_shape, int): + img_height = img_width = image_shape else: - img_width, img_height = img_size - assert isinstance(img_width, int) and isinstance(img_height, int), \ - f"Invalid type: (width: {type(img_width)}, height: {type(img_height)}. Must be (int, int))" + img_height, img_width = image_shape + assert isinstance(img_height, int) and isinstance(img_width, int), \ + f"Invalid type: (height: {type(img_height)}, width: {type(img_width)}. Must be (int, int))" in_type = in_type[:-1] else: - img_width = img_height = 1.0 + img_height = img_width = 1.0 if in_type == "xyxy": x_min, y_min, x_max, y_max = bboxes.unbind(-1) @@ -53,16 +53,16 @@ def is_normalized(fmt: str) -> bool: assert (y_max >= y_min).all(), "Invalid box: y_max < y_min" if is_normalized(out_type): - assert img_size is not None, f"img_size is required for normalized conversion: {indicator}" - if isinstance(img_size, int): - img_width = img_height = img_size + assert image_shape is not None, f"img_size is required for normalized conversion: {indicator}" + if isinstance(image_shape, int): + img_height = img_width = image_shape else: - img_width, img_height = img_size - assert isinstance(img_width, int) and isinstance(img_height, int), \ - f"Invalid type: (width: {type(img_width)}, height: {type(img_height)}. Must be (int, int))" + img_height, img_width = image_shape + assert isinstance(img_height, int) and isinstance(img_width, int), \ + f"Invalid type: (height: {type(img_height)}, width: {type(img_width)}. Must be (int, int))" out_type = out_type[:-1] else: - img_width = img_height = 1.0 + img_height = img_width = 1.0 x_min /= img_width y_min /= img_height