-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
89 lines (80 loc) · 3.23 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import time
import cv2
import numpy as np
from PIL import Image
from yolo import YOLO
if __name__ == "__main__":
yolo = YOLO()
mode = "predict"
video_path = 0 #为零时表示检测摄像头
video_save_path = ""
video_fps = 25.0
test_interval = 100 #test_interval用于指定测量fps的时候,图片检测的次数
dir_origin_path = "img/"
dir_save_path = "img_out/"
# predict针对单张图片
if mode == "predict":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = yolo.detect_image(image)
r_image.show()
# video针对一段视频
elif mode == "video":
capture = cv2.VideoCapture(video_path)
if video_save_path != "":
fourcc = cv2.VideoWriter_fourcc(*'XVID') #确定保存格式
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
fps = 0.0
while (True):
t1 = time.time()
# 读取某一帧
ref, frame = capture.read()
# 格式转变,BGRtoRGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(yolo.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
fps = (fps + (1. / (time.time() - t1))) / 2
print("fps= %.2f" % (fps))
frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video", frame)
c = cv2.waitKey(1) & 0xff
if video_save_path != "":
out.write(frame)
if c == 27:
capture.release()
break
capture.release()
out.release()
cv2.destroyAllWindows()
# fps用来测试检测图片的fps需要
elif mode == "fps":
img = Image.open('img/street.jpg')
tact_time = yolo.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
# 对一个文件夹中所有文件进行检测,并输出。
elif mode == "dir_predict":
import os
from tqdm import tqdm
img_names = os.listdir(dir_origin_path)
for img_name in tqdm(img_names):
if img_name.lower().endswith(
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = yolo.detect_image(image)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name))
else:
raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")