This repository has been archived by the owner on Apr 2, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
run.py
58 lines (51 loc) · 2.06 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pickle
import time
import numpy as np
from modules import UNetModel
from modules import ImageGenerator
if __name__ == '__main__':
img_width = -1
while(True):
print('Smaller images allows training to be perfomed faster with a slight loss in accuracy\n')
choice = input('Select the image scale factor\n(a) 1\n(b) 1/2\n(c) 1/4\n(d) 1/8\n')
choice = choice.strip()
if(choice == ''):
print('Please try again')
elif(choice.isalpha()):
if(choice == 'a'):
img_width = 1280
break
elif(choice == 'b'):
img_width = 1280 // 2
break
elif(choice == 'c'):
img_width = 1280 // 4
break
elif(choice == 'd'):
img_width = 1280 // 8
break
else:
print('Please try again')
print('The image width is', img_width)
trn_gen, val_gen, tst_gen = ImageGenerator.get_generators(img_width)
model, callbacks = UNetModel.get_unet_model(img_width)
num_epochs = 10
print('Training initialized\n')
start = time.time()
history = model.fit_generator(trn_gen,
steps_per_epoch = 2035,
epochs = num_epochs,
validation_data = val_gen,
validation_steps = 252,
callbacks = callbacks)
stop = time.time()
print('Training complete\nSaving model')
model.save('model.h5')
pickle.save(history.history, open('history.p', 'wb'))
trn_acc = history.history.get('dice_coef')
val_acc = history.history.get('val_dice_coef')
tst_acc = [model.evaluate_generator(tst_gen, steps = 252)[1] for _ in range(num_epochs)]
print('Training Time:', stop - start, 'seconds')
print('Average Training Accuracy: ', np.mean(trn_acc))
print('Average Validation Accuracy: ', np.mean(val_acc))
print('Average Testing Accuracy: ', np.mean(tst_acc))