From 346de6fd51f0ce70d61dffab0cbe894b1ff86dbe Mon Sep 17 00:00:00 2001 From: fcakyon Date: Mon, 16 Dec 2024 17:55:02 +0300 Subject: [PATCH] Add YOLO11 model predictions to test suite - Introduced tests for YOLO11 model predictions in `test_predict.py`, including both full-sized and sliced predictions. - Implemented model initialization and image preparation for testing. - Added assertions to verify the detection results for different object categories (person, truck, car). - Enhanced the test suite to ensure compatibility with the new Ultralytics model utilities. These additions improve the coverage of the testing framework for the Ultralytics YOLO11 model and ensure accurate prediction functionality. --- tests/test_predict.py | 153 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/tests/test_predict.py b/tests/test_predict.py index eb108039a..f7c123c2f 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -8,6 +8,7 @@ import numpy as np from sahi.utils.cv import read_image +from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model MODEL_DEVICE = "cpu" CONFIDENCE_THRESHOLD = 0.5 @@ -294,10 +295,123 @@ def test_get_sliced_prediction_yolov5(self): num_car += 1 self.assertEqual(num_car, 11) + def test_get_prediction_yolo11(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + from sahi.predict import get_prediction + + # init model + download_yolo11n_model() + + yolo11_detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model.load_model() + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # get full sized prediction + prediction_result = get_prediction( + image=image, + detection_model=yolo11_detection_model, + shift_amount=[0, 0], + full_shape=None, + postprocess=None + ) + object_prediction_list = prediction_result.object_prediction_list + + # compare + self.assertGreater(len(object_prediction_list), 0) + num_person = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "person": + num_person += 1 + self.assertEqual(num_person, 0) + num_truck = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "truck": + num_truck += 1 + self.assertEqual(num_truck, 0) + num_car = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "car": + num_car += 1 + self.assertGreater(num_car, 0) + + def test_get_sliced_prediction_yolo11(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + from sahi.predict import get_sliced_prediction + + # init model + download_yolo11n_model() + + yolo11_detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model.load_model() + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + + slice_height = 512 + slice_width = 512 + overlap_height_ratio = 0.1 + overlap_width_ratio = 0.2 + postprocess_type = "GREEDYNMM" + match_metric = "IOS" + match_threshold = 0.5 + class_agnostic = True + + # get sliced prediction + prediction_result = get_sliced_prediction( + image=image_path, + detection_model=yolo11_detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=False, + postprocess_type=postprocess_type, + postprocess_match_threshold=match_threshold, + postprocess_match_metric=match_metric, + postprocess_class_agnostic=class_agnostic, + ) + object_prediction_list = prediction_result.object_prediction_list + + # compare + self.assertGreater(len(object_prediction_list), 0) + num_person = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "person": + num_person += 1 + self.assertEqual(num_person, 0) + num_truck = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "truck": + num_truck += 1 + self.assertEqual(num_truck, 0) + num_car = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "car": + num_car += 1 + self.assertGreater(num_car, 0) + def test_coco_json_prediction(self): from sahi.predict import predict from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_yolox_tiny_model from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model + from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model # init model download_mmdet_yolox_tiny_model() @@ -382,6 +496,45 @@ def test_coco_json_prediction(self): verbose=1, ) + # init model + download_yolo11n_model() + + # prepare paths + dataset_json_path = "tests/data/coco_utils/terrain_all_coco.json" + source = "tests/data/coco_utils/" + project_dir = "tests/data/predict_result" + + # get sliced prediction + if os.path.isdir(project_dir): + shutil.rmtree(project_dir, ignore_errors=True) + predict( + model_type="ultralytics", + model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH, + model_config_path=None, + model_confidence_threshold=CONFIDENCE_THRESHOLD, + model_device=MODEL_DEVICE, + model_category_mapping=None, + model_category_remapping=None, + source=source, + no_sliced_prediction=False, + no_standard_prediction=True, + slice_height=512, + slice_width=512, + overlap_height_ratio=0.2, + overlap_width_ratio=0.2, + postprocess_type=postprocess_type, + postprocess_match_metric=match_metric, + postprocess_match_threshold=match_threshold, + postprocess_class_agnostic=class_agnostic, + novisual=True, + export_pickle=False, + export_crop=False, + dataset_json_path=dataset_json_path, + project=project_dir, + name="exp", + verbose=1, + ) + def test_video_prediction(self): from os import path