From 63077beab316798766e9a22d6db3e66170b6e75e Mon Sep 17 00:00:00 2001 From: fcakyon Date: Mon, 16 Dec 2024 19:48:44 +0300 Subject: [PATCH] Refactor and enhance Ultralytics model utilities for improved OBB handling - Reintroduced the import of `ObjectPrediction` in `base.py` for better modularity. - Improved code readability in `ultralytics.py` by restructuring the concatenation of OBB data and resizing masks. - Cleaned up whitespace in `cv.py` to enhance code clarity in the OBB to COCO conversion function. These changes contribute to a more maintainable and user-friendly codebase for the SAHI library's Ultralytics model integration. --- sahi/models/base.py | 2 +- sahi/models/ultralytics.py | 20 ++++++++++++-------- sahi/utils/cv.py | 6 +++--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/sahi/models/base.py b/sahi/models/base.py index 2ea8bad2..a434728d 100644 --- a/sahi/models/base.py +++ b/sahi/models/base.py @@ -5,10 +5,10 @@ import numpy as np +from sahi.prediction import ObjectPrediction from sahi.utils.import_utils import is_available from sahi.utils.torch import select_device as select_torch_device -from sahi.prediction import ObjectPrediction class DetectionModel: def __init__( diff --git a/sahi/models/ultralytics.py b/sahi/models/ultralytics.py index a0646753..dd421f5c 100644 --- a/sahi/models/ultralytics.py +++ b/sahi/models/ultralytics.py @@ -89,13 +89,18 @@ def perform_inference(self, image: np.ndarray): prediction_result = [ ( # Get OBB data: xyxy, conf, cls - torch.cat([ - result.obb.xyxy, # box coordinates - result.obb.conf.unsqueeze(-1), # confidence scores - result.obb.cls.unsqueeze(-1), # class ids - ], dim=1) if result.obb is not None else torch.empty((0, 6), device=self.model.device), + torch.cat( + [ + result.obb.xyxy, # box coordinates + result.obb.conf.unsqueeze(-1), # confidence scores + result.obb.cls.unsqueeze(-1), # class ids + ], + dim=1, + ) + if result.obb is not None + else torch.empty((0, 6), device=self.model.device), # Get OBB points in (N, 4, 2) format - result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=self.model.device) + result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=self.model.device), ) for result in prediction_result ] @@ -197,8 +202,7 @@ def _create_object_prediction_list_from_original_predictions( bool_mask = masks_or_points[pred_ind] # Resize mask to original image size bool_mask = cv2.resize( - bool_mask.astype(np.uint8), - (self._original_shape[1], self._original_shape[0]) + bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0]) ) segmentation = get_coco_segmentation_from_bool_mask(bool_mask) else: # is_obb diff --git a/sahi/utils/cv.py b/sahi/utils/cv.py index 80ea1b49..693097aa 100644 --- a/sahi/utils/cv.py +++ b/sahi/utils/cv.py @@ -691,7 +691,7 @@ def get_bbox_from_coco_segmentation(coco_segmentation): def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[float]]: """ Convert OBB (Oriented Bounding Box) points to COCO polygon format. - + Args: obb_points: np.ndarray OBB points tensor from ultralytics.engine.results.OBB @@ -702,13 +702,13 @@ def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[f """ # Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format points = obb_points.reshape(-1).tolist() - + # Create polygon from points and close it by repeating first point polygons = [] # Add first point to end to close polygon closed_polygon = points + [points[0], points[1]] polygons.append(closed_polygon) - + return polygons