Skip to content

Commit

Permalink
Refactor and enhance Ultralytics model utilities for improved OBB han…
Browse files Browse the repository at this point in the history
…dling

- Reintroduced the import of `ObjectPrediction` in `base.py` for better modularity.
- Improved code readability in `ultralytics.py` by restructuring the concatenation of OBB data and resizing masks.
- Cleaned up whitespace in `cv.py` to enhance code clarity in the OBB to COCO conversion function.

These changes contribute to a more maintainable and user-friendly codebase for the SAHI library's Ultralytics model integration.
  • Loading branch information
fcakyon committed Dec 16, 2024
1 parent 0271ba3 commit 63077be
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sahi/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import numpy as np

from sahi.prediction import ObjectPrediction
from sahi.utils.import_utils import is_available
from sahi.utils.torch import select_device as select_torch_device

from sahi.prediction import ObjectPrediction

class DetectionModel:
def __init__(
Expand Down
20 changes: 12 additions & 8 deletions sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,18 @@ def perform_inference(self, image: np.ndarray):
prediction_result = [
(
# Get OBB data: xyxy, conf, cls
torch.cat([
result.obb.xyxy, # box coordinates
result.obb.conf.unsqueeze(-1), # confidence scores
result.obb.cls.unsqueeze(-1), # class ids
], dim=1) if result.obb is not None else torch.empty((0, 6), device=self.model.device),
torch.cat(
[
result.obb.xyxy, # box coordinates
result.obb.conf.unsqueeze(-1), # confidence scores
result.obb.cls.unsqueeze(-1), # class ids
],
dim=1,
)
if result.obb is not None
else torch.empty((0, 6), device=self.model.device),
# Get OBB points in (N, 4, 2) format
result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=self.model.device)
result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=self.model.device),
)
for result in prediction_result
]
Expand Down Expand Up @@ -197,8 +202,7 @@ def _create_object_prediction_list_from_original_predictions(
bool_mask = masks_or_points[pred_ind]
# Resize mask to original image size
bool_mask = cv2.resize(
bool_mask.astype(np.uint8),
(self._original_shape[1], self._original_shape[0])
bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0])
)
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
else: # is_obb
Expand Down
6 changes: 3 additions & 3 deletions sahi/utils/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def get_bbox_from_coco_segmentation(coco_segmentation):
def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[float]]:
"""
Convert OBB (Oriented Bounding Box) points to COCO polygon format.
Args:
obb_points: np.ndarray
OBB points tensor from ultralytics.engine.results.OBB
Expand All @@ -702,13 +702,13 @@ def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[f
"""
# Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format
points = obb_points.reshape(-1).tolist()

# Create polygon from points and close it by repeating first point
polygons = []
# Add first point to end to close polygon
closed_polygon = points + [points[0], points[1]]
polygons.append(closed_polygon)

return polygons


Expand Down

0 comments on commit 63077be

Please sign in to comment.