-
Notifications
You must be signed in to change notification settings - Fork 4
/
training.py
41 lines (33 loc) · 1.55 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from keras import backend as K
from keras.optimizers import Adadelta, SGD
from keras.callbacks import EarlyStopping, ModelCheckpoint
from Image_Generator import TextImageGenerator
from Model import get_Model
from parameter import *
K.set_learning_phase(0)
# # Model description and training
model = get_Model(training=True)
try:
model.load_weights('LSTM+BN5--26--1.144.hdf5')
print("...Previous weight data...")
except:
print("...New weight data...")
pass
train_file_path = './DB/train/'
tiger_train = TextImageGenerator(train_file_path, img_w, img_h, batch_size, downsample_factor)
tiger_train.build_data()
valid_file_path = './DB/test/'
tiger_val = TextImageGenerator(valid_file_path, img_w, img_h, val_batch_size, downsample_factor)
tiger_val.build_data()
ada = Adadelta()
early_stop = EarlyStopping(monitor='loss', min_delta=0.001, patience=4, mode='min', verbose=1)
checkpoint = ModelCheckpoint(filepath='LSTM+BN5--{epoch:02d}--{val_loss:.3f}.hdf5', monitor='loss', verbose=1, mode='min', period=1)
# the loss calc occurs elsewhere, so use a dummy lambda func for the loss
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=ada)
# captures output of softmax so we can decode the output during visualization
model.fit_generator(generator=tiger_train.next_batch(),
steps_per_epoch=int(tiger_train.n / batch_size),
epochs=30,
callbacks=[checkpoint],
validation_data=tiger_val.next_batch(),
validation_steps=int(tiger_val.n / val_batch_size))