forked from yu4u/age-gender-estimation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
82 lines (68 loc) · 3.17 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import cv2
import dlib
import numpy as np
import argparse
from wide_resnet import WideResNet
import glob
def get_args():
parser = argparse.ArgumentParser(description="This script detects faces from web cam input, "
"and estimates age and gender for the detected faces.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--weight_file", type=str, default=None,
help="path to weight file (e.g. weights.18-4.06.hdf5)")
parser.add_argument("--depth", type=int, default=16,
help="depth of network")
parser.add_argument("--width", type=int, default=8,
help="width of network")
args = parser.parse_args()
return args
def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=1, thickness=2):
size = cv2.getTextSize(label, font, font_scale, thickness)[0]
x, y = point
cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED)
cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness)
def main():
args = get_args()
depth = args.depth
k = args.width
weight_file = args.weight_file
if not weight_file:
weight_file = os.path.join("pretrained_models", "weights.18-4.06.hdf5")
# for face detection
detector = dlib.get_frontal_face_detector()
# load model and weights
img_size = 64
model = WideResNet(img_size, depth=depth, k=k)()
model.load_weights(weight_file)
for x in glob.glob("test/*.jpg"):
img = cv2.imread(x, 3)
# img = imutils.resize(img, width=400)
input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_h, img_w, _ = np.shape(input_img)
# detect faces using dlib detector
detected = detector(input_img, 1)
faces = np.empty((len(detected), img_size, img_size, 3))
for i, d in enumerate(detected):
x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height()
xw1 = max(int(x1 - 0.4 * w), 0)
yw1 = max(int(y1 - 0.4 * h), 0)
xw2 = min(int(x2 + 0.4 * w), img_w - 1)
yw2 = min(int(y2 + 0.4 * h), img_h - 1)
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
# cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)
faces[i, :, :, :] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1, :], (img_size, img_size))
if len(detected) == 1:
# predict ages and genders of the detected faces
results = model.predict(faces)
predicted_genders = "F" if results[0][0][0]>0.5 else "M"
ages = np.arange(0, 101).reshape(101, 1)
predicted_ages = str(int(results[1][0].dot(ages).flatten()[0]))
print(x + " est:" + predicted_ages + " " + predicted_genders + " debug:" + str(results[0][0][0])+ " " + str(results[0][0][1]))
elif len(detected) > 1:
print("multiple faces detected... not supported yet")
else:
print("no faces detected... skipping")
if __name__ == '__main__':
main()