forked from yeyupiaoling/Pytorch-MobileFaceNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
148 lines (133 loc) · 5.91 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
import functools
import os
import time
import cv2
import numpy as np
import torch
from PIL import ImageDraw, ImageFont, Image
from detection.face_detect import MTCNN
from utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('image_path', str, 'dataset/test.jpg', '预测图片路径')
add_arg('face_db_path', str, 'face_db', '人脸库路径')
add_arg('threshold', float, 0.6, '判断相识度的阈值')
add_arg('mobilefacenet_model_path', str, 'save_model/mobilefacenet.pth', 'MobileFaceNet预测模型的路径')
add_arg('mtcnn_model_path', str, 'save_model/mtcnn', 'MTCNN预测模型的路径')
args = parser.parse_args()
print_arguments(args)
class Predictor:
def __init__(self, mtcnn_model_path, mobilefacenet_model_path, face_db_path, threshold=0.7):
self.threshold = threshold
self.mtcnn = MTCNN(model_path=mtcnn_model_path)
self.device = torch.device("cuda")
# 加载模型
self.model = torch.jit.load(mobilefacenet_model_path)
self.model.to(self.device)
self.model.eval()
self.faces_db = self.load_face_db(face_db_path)
def load_face_db(self, face_db_path):
faces_db = {}
for path in os.listdir(face_db_path):
name = os.path.basename(path).split('.')[0]
image_path = os.path.join(face_db_path, path)
img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), -1)
imgs, _ = self.mtcnn.infer_image(img)
if imgs is None or len(imgs) > 1:
print('人脸库中的 %s 图片包含不是1张人脸,自动跳过该图片' % image_path)
continue
imgs = self.process(imgs)
feature = self.infer(imgs[0])
faces_db[name] = feature[0][0]
return faces_db
@staticmethod
def process(imgs):
imgs1 = []
for img in imgs:
img = img.transpose((2, 0, 1))
img = (img - 127.5) / 127.5
imgs1.append(img)
return imgs1
# 预测图片
def infer(self, imgs):
assert len(imgs.shape) == 3 or len(imgs.shape) == 4
if len(imgs.shape) == 3:
imgs = imgs[np.newaxis, :]
# TODO 不知为何不支持多张图片预测
'''
imgs = torch.tensor(imgs, dtype=torch.float32, device=self.device)
features = self.model(img)
features = features.detach().cpu().numpy()
'''
features = []
for i in range(imgs.shape[0]):
img = imgs[i][np.newaxis, :]
img = torch.tensor(img, dtype=torch.float32, device=self.device)
# 执行预测
feature = self.model(img)
feature = feature.detach().cpu().numpy()
features.append(feature)
return features
def recognition(self, image_path):
img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), -1)
s = time.time()
imgs, boxes = self.mtcnn.infer_image(img)
print('人脸检测时间:%dms' % int((time.time() - s) * 1000))
if imgs is None:
return None, None
imgs = self.process(imgs)
imgs = np.array(imgs, dtype='float32')
s = time.time()
features = self.infer(imgs)
print('人脸识别时间:%dms' % int((time.time() - s) * 1000))
names = []
probs = []
for i in range(len(features)):
feature = features[i][0]
results_dict = {}
for name in self.faces_db.keys():
feature1 = self.faces_db[name]
prob = np.dot(feature, feature1) / (np.linalg.norm(feature) * np.linalg.norm(feature1))
results_dict[name] = prob
results = sorted(results_dict.items(), key=lambda d: d[1], reverse=True)
print('人脸对比结果:', results)
result = results[0]
prob = float(result[1])
probs.append(prob)
if prob > self.threshold:
name = result[0]
names.append(name)
else:
names.append('unknow')
return boxes, names
def add_text(self, img, text, left, top, color=(0, 0, 0), size=20):
if isinstance(img, np.ndarray):
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img)
font = ImageFont.truetype('simfang.ttf', size)
draw.text((left, top), text, color, font=font)
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# 画出人脸框和关键点
def draw_face(self, image_path, boxes_c, names):
img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), -1)
if boxes_c is not None:
for i in range(boxes_c.shape[0]):
bbox = boxes_c[i, :4]
name = names[i]
corpbbox = [int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])]
# 画人脸框
cv2.rectangle(img, (corpbbox[0], corpbbox[1]),
(corpbbox[2], corpbbox[3]), (255, 0, 0), 1)
# 判别为人脸的名字
img = self.add_text(img, name, corpbbox[0], corpbbox[1] -15, color=(0, 0, 255), size=12)
cv2.imshow("result", img)
cv2.waitKey(0)
if __name__ == '__main__':
predictor = Predictor(args.mtcnn_model_path, args.mobilefacenet_model_path, args.face_db_path, threshold=args.threshold)
start = time.time()
boxes, names = predictor.recognition(args.image_path)
print('预测的人脸位置:', boxes.astype(np.int_).tolist())
print('识别的人脸名称:', names)
print('总识别时间:%dms' % int((time.time() - start) * 1000))
predictor.draw_face(args.image_path, boxes, names)