From 3dd503c255825226da883314d3bc3a87c3de0480 Mon Sep 17 00:00:00 2001 From: cho-96 Date: Fri, 28 Jun 2024 15:39:36 +0900 Subject: [PATCH] input validation --- app.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index 6fddcbaa1..6c095681d 100644 --- a/app.py +++ b/app.py @@ -7,10 +7,12 @@ def yolov10_inference(image, video, model_id, image_size, conf_threshold): model = YOLOv10.from_pretrained(f'jameslahm/{model_id}') if image: + assert image is not None, "Image input is required." results = model.predict(source=image, imgsz=image_size, conf=conf_threshold) annotated_image = results[0].plot() return annotated_image[:, :, ::-1], None else: + assert video is not None, "Video input is required." video_path = tempfile.mktemp(suffix=".webm") with open(video_path, "wb") as f: with open(video, "rb") as g: @@ -90,7 +92,7 @@ def app(): def update_visibility(input_type): image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False) video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True) - output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False) + output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible(False)) output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True) return image, video, output_image, output_video @@ -102,6 +104,10 @@ def update_visibility(input_type): ) def run_inference(image, video, model_id, image_size, conf_threshold, input_type): + assert model_id in ["yolov10n", "yolov10s", "yolov10m", "yolov10b", "yolov10l", "yolov10x"], "Invalid model ID." + assert 320 <= image_size <= 1280, "Image size out of range." + assert 0.0 <= conf_threshold <= 1.0, "Confidence threshold out of range." + if input_type == "Image": return yolov10_inference(image, None, model_id, image_size, conf_threshold) else: