Skip to content

Commit

Permalink
add testing script
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Aug 22, 2019
1 parent b570421 commit 7da3f9c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
18 changes: 18 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import keras.backend as K
from model import SiameseNet
from data_loader import SiameseImageLoader
import matplotlib.pyplot as plt
from keras import optimizers
from keras.models import Model, load_model
import albumentations as A
import cv2
import time


model = SiameseNet('configs/road_signs.yml')
model.load_model('{}best_model_4.h5'.format(model.weights_save_path))
model.load_encodings('{}encodings.pkl'.format(model.encodings_path))


model_accuracy = model.calculate_prediction_accuracy()
print('Model accuracy on validation set: {}'.format(model_accuracy))
13 changes: 5 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint


n_epochs = 5
n_steps_per_epoch = 6
batch_size = 1
n_epochs = 1000
n_steps_per_epoch = 600
batch_size = 8
val_steps = 100

# model = SiameseNet('configs/road_signs.yml')
model = SiameseNet('configs/plates.yml')

model = SiameseNet('configs/road_signs.yml')

initial_lr = 1e-4
decay_factor = 0.99
Expand All @@ -23,7 +21,7 @@
decay_factor ** np.floor(x/step_size)),
EarlyStopping(patience=50, verbose=1),
TensorBoard(log_dir=model.tensorboard_log_path),
ModelCheckpoint(filepath=os.path.join(model.weights_save_path, 'best_model.h5'),
ModelCheckpoint(filepath=os.path.join(model.weights_save_path, 'best_model_4.h5'),
verbose=1, monitor='val_loss', save_best_only=True)
]

Expand All @@ -32,7 +30,6 @@


model.generate_encodings()
# model.load_encodings('encodings/encodings.pkl')

model_accuracy = model.calculate_prediction_accuracy()
print('Model accuracy on validation set: {}'.format(model_accuracy))

0 comments on commit 7da3f9c

Please sign in to comment.