-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_model.py
68 lines (62 loc) · 2.37 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from data_creation import key_check
import cv2
import numpy as np
from PIL import ImageGrab
import time
import keyboardPress as kp
import sys
from alexnet import alexnet
import argparse
stopped = True
WIDTH = 120
HEIGHT = 90
LR = 1e-3
EPOCHS = 9
NAME = './models/gtasa-drive-{}-{}-video.model'.format(LR, EPOCHS)
model = alexnet(WIDTH, HEIGHT, LR)
model.load(NAME)
fw_threshold = 0.42
right_threshold = 0.30
left_threshold = 0.97
def main():
global stopped
print("OH BOY")
while(True):
if 'P' in key_check():
stopped = not stopped
print('Stopped' if stopped else 'Resumed')
kp.full_stop()
time.sleep(1)
if not stopped:
screen = np.array(ImageGrab.grab(bbox=(50,50,800,650)))
screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
#cv2.imshow('window',screen)
screen = cv2.resize(screen, (WIDTH, HEIGHT))
prediction = model.predict([screen.reshape(WIDTH, HEIGHT, 1)])[0]
print(prediction)
if prediction[1] > fw_threshold:
kp.forward()
elif prediction[2] > left_threshold:
kp.turn_left_f()
elif prediction[0] > right_threshold:
kp.turn_right_f()
if cv2.waitKey(25) & 0xFF == ord('q'):
cv2.destroyAllWindows()
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test model')
parser.add_argument('-w', '--width', metavar='Width')
parser.add_argument('--ht', '--heigth', metavar='Height')
parser.add_argument('--lr', metavar='Learning rate')
parser.add_argument('-e', '--epochs', metavar='Epochs')
parser.add_argument('--fw', '--fw-th', metavar='Forward threshold, should be between 1 and 0')
parser.add_argument('-s', '--s-th', metavar='Sides threshold, should be between 1 and 0')
args = parser.parse_args()
WIDTH = int(args.width) if args.width is not None else WIDTH
HEIGHT = int(args.ht) if args.ht is not None else HEIGHT
LR = float(args.lr) if args.lr is not None else LR
EPOCHS = int(args.epochs) if args.epochs is not None else EPOCHS
fw_threshold = float(args.fw) if args.fw is not None else fw_threshold
#sides_threshold = float(args.s_th) if args.s_th is not None else sides_threshold
NAME = './models/gtasa-drive-0.001-9-video.model'
main()