Skip to content

Commit

Permalink
Add YOLO11 model predictions to test suite
Browse files Browse the repository at this point in the history
- Introduced tests for YOLO11 model predictions in `test_predict.py`, including both full-sized and sliced predictions.
- Implemented model initialization and image preparation for testing.
- Added assertions to verify the detection results for different object categories (person, truck, car).
- Enhanced the test suite to ensure compatibility with the new Ultralytics model utilities.

These additions improve the coverage of the testing framework for the Ultralytics YOLO11 model and ensure accurate prediction functionality.
  • Loading branch information
fcakyon committed Dec 16, 2024
1 parent 59ae6e3 commit 346de6f
Showing 1 changed file with 153 additions and 0 deletions.
153 changes: 153 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from sahi.utils.cv import read_image
from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model

MODEL_DEVICE = "cpu"
CONFIDENCE_THRESHOLD = 0.5
Expand Down Expand Up @@ -294,10 +295,123 @@ def test_get_sliced_prediction_yolov5(self):
num_car += 1
self.assertEqual(num_car, 11)

def test_get_prediction_yolo11(self):
from sahi.models.ultralytics import UltralyticsDetectionModel
from sahi.predict import get_prediction

# init model
download_yolo11n_model()

yolo11_detection_model = UltralyticsDetectionModel(
model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=False,
image_size=IMAGE_SIZE,
)
yolo11_detection_model.load_model()

# prepare image
image_path = "tests/data/small-vehicles1.jpeg"
image = read_image(image_path)

# get full sized prediction
prediction_result = get_prediction(
image=image,
detection_model=yolo11_detection_model,
shift_amount=[0, 0],
full_shape=None,
postprocess=None
)
object_prediction_list = prediction_result.object_prediction_list

# compare
self.assertGreater(len(object_prediction_list), 0)
num_person = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "person":
num_person += 1
self.assertEqual(num_person, 0)
num_truck = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "truck":
num_truck += 1
self.assertEqual(num_truck, 0)
num_car = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "car":
num_car += 1
self.assertGreater(num_car, 0)

def test_get_sliced_prediction_yolo11(self):
from sahi.models.ultralytics import UltralyticsDetectionModel
from sahi.predict import get_sliced_prediction

# init model
download_yolo11n_model()

yolo11_detection_model = UltralyticsDetectionModel(
model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=False,
image_size=IMAGE_SIZE,
)
yolo11_detection_model.load_model()

# prepare image
image_path = "tests/data/small-vehicles1.jpeg"

slice_height = 512
slice_width = 512
overlap_height_ratio = 0.1
overlap_width_ratio = 0.2
postprocess_type = "GREEDYNMM"
match_metric = "IOS"
match_threshold = 0.5
class_agnostic = True

# get sliced prediction
prediction_result = get_sliced_prediction(
image=image_path,
detection_model=yolo11_detection_model,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
perform_standard_pred=False,
postprocess_type=postprocess_type,
postprocess_match_threshold=match_threshold,
postprocess_match_metric=match_metric,
postprocess_class_agnostic=class_agnostic,
)
object_prediction_list = prediction_result.object_prediction_list

# compare
self.assertGreater(len(object_prediction_list), 0)
num_person = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "person":
num_person += 1
self.assertEqual(num_person, 0)
num_truck = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "truck":
num_truck += 1
self.assertEqual(num_truck, 0)
num_car = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "car":
num_car += 1
self.assertGreater(num_car, 0)

def test_coco_json_prediction(self):
from sahi.predict import predict
from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_yolox_tiny_model
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model
from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model

# init model
download_mmdet_yolox_tiny_model()
Expand Down Expand Up @@ -382,6 +496,45 @@ def test_coco_json_prediction(self):
verbose=1,
)

# init model
download_yolo11n_model()

# prepare paths
dataset_json_path = "tests/data/coco_utils/terrain_all_coco.json"
source = "tests/data/coco_utils/"
project_dir = "tests/data/predict_result"

# get sliced prediction
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
predict(
model_type="ultralytics",
model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH,
model_config_path=None,
model_confidence_threshold=CONFIDENCE_THRESHOLD,
model_device=MODEL_DEVICE,
model_category_mapping=None,
model_category_remapping=None,
source=source,
no_sliced_prediction=False,
no_standard_prediction=True,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
postprocess_type=postprocess_type,
postprocess_match_metric=match_metric,
postprocess_match_threshold=match_threshold,
postprocess_class_agnostic=class_agnostic,
novisual=True,
export_pickle=False,
export_crop=False,
dataset_json_path=dataset_json_path,
project=project_dir,
name="exp",
verbose=1,
)

def test_video_prediction(self):
from os import path

Expand Down

0 comments on commit 346de6f

Please sign in to comment.