-
Notifications
You must be signed in to change notification settings - Fork 0
/
yolo_engine.py
82 lines (65 loc) · 3.46 KB
/
yolo_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from datetime import datetime
import torch
from ultralytics import YOLO
import cv2
import math
from django.core.cache import cache
loaded_models = {}
model_class_names = {
"yolov8n.pt": ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "tank", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "sofa", "pottedplant", "bed",
"diningtable", "toilet", "tvmonitor", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "military vehicle"
],
"best3.pt": ["tank","0","1"],
"models/best_IDF_tank.pt": ["Tank", "cars", "cake", "fish", "horse", "sheep", "cow", "elephant", "zebra", "giraffe"],
"models/tankNZ.pt": ["tank"],
# Add more models and their corresponding class names here
}
def load_model(model_name):
if model_name not in loaded_models:
model = YOLO(model_name).to(torch.device('cpu')) # Force the model to run on CPU
loaded_models[model_name] = model
return loaded_models[model_name]
def video_detection(frame, model_name):
# Load the selected YOLO model
model = load_model(model_name)
# Get class names for the selected model
classNames = model_class_names.get(model_name, [])
# Perform YOLO detection on the given frame
results = model(frame, stream=True)
detections = cache.get('detections', [])
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(frame.shape[1], x2), min(frame.shape[0], y2)
cls = int(box.cls[0].item())
if cls >= len(classNames):
print(f"Warning: class index {cls} is out of range for classNames")
continue
conf = box.conf[0].item() + 0
if conf >= 1.0:
conf = 0.99
class_name = classNames[cls]
if conf >= 0.7:
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 2)
label = f'{class_name} {conf:.2f}'
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1)[0]
c2 = (x1 + t_size[0], y1 - t_size[1] - 3)
cv2.rectangle(frame, (x1, y1), c2, (255, 0, 255), -1)
cv2.putText(frame, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), thickness=1)
if class_name == 'tank':
detections.append({'class': 'tank', 'confidence': round(conf,2),
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
cache.set('detections', detections, timeout=300)
return frame
cv2.destroyAllWindows()