Skip to content

Commit

Permalink
add ability to use net without config file
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Aug 29, 2019
1 parent e3fd702 commit 5c448ac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 42 deletions.
71 changes: 36 additions & 35 deletions siamese_net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,41 @@ class SiameseNet:
mode = 'triplet' -> Triplen network
"""

def __init__(self, cfg_file):

params = parse_net_params(cfg_file)
self.input_shape = params['input_shape']
self.encodings_len = params['encodings_len']
self.backbone = params['backbone']
self.backbone_weights = params['backbone_weights']
self.distance_type = params['distance_type']
self.mode = params['mode']
self.project_name = params['project_name']
self.optimizer = params['optimizer']
self.freeze_backbone = params['freeze_backbone']
self.data_loader = params['loader']

self.model = []
self.base_model = []
self.l_model = []

self.encodings_path = params['encodings_path']
self.plots_path = params['plots_path']
self.tensorboard_log_path = params['tensorboard_log_path']
self.weights_save_path = params['weights_save_path']
self.model_save_name = params['model_save_name']

os.makedirs(self.encodings_path, exist_ok=True)
os.makedirs(self.plots_path, exist_ok=True)
os.makedirs(self.tensorboard_log_path, exist_ok=True)
os.makedirs(self.weights_save_path, exist_ok=True)

if self.mode == 'siamese':
self._create_model_siamese()
elif self.mode == 'triplet':
self._create_model_triplet()

self.encoded_training_data = {}
def __init__(self, cfg_file=None):
if cfg_file:
params = parse_net_params(cfg_file)
self.input_shape = params['input_shape']
self.encodings_len = params['encodings_len']
self.backbone = params['backbone']
self.backbone_weights = params['backbone_weights']
self.distance_type = params['distance_type']
self.mode = params['mode']
self.project_name = params['project_name']
self.optimizer = params['optimizer']
self.freeze_backbone = params['freeze_backbone']
self.data_loader = params['loader']
self.model = []
self.base_model = []
self.l_model = []

self.encodings_path = params['encodings_path']
self.plots_path = params['plots_path']
self.tensorboard_log_path = params['tensorboard_log_path']
self.weights_save_path = params['weights_save_path']
self.model_save_name = params['model_save_name']

os.makedirs(self.encodings_path, exist_ok=True)
os.makedirs(self.plots_path, exist_ok=True)
os.makedirs(self.tensorboard_log_path, exist_ok=True)
os.makedirs(self.weights_save_path, exist_ok=True)

if self.mode == 'siamese':
self._create_model_siamese()
elif self.mode == 'triplet':
self._create_model_triplet()
self.encoded_training_data = {}


def _create_base_model(self):
Expand Down Expand Up @@ -219,6 +219,7 @@ def load_model(self,file_path):
custom_objects={'contrastive_loss': lac.contrastive_loss,
'accuracy': lac.accuracy,
'triplet_loss': lac.triplet_loss})
self.input_shape = list(self.model.inputs[0].shape[1:])
self.base_model = Model(inputs=[self.model.layers[3].get_input_at(0)],
outputs=[self.model.layers[3].layers[-1].output])
self.base_model._make_predict_function()
Expand Down
14 changes: 7 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from model import SiameseNet
from siamese_net.model import SiameseNet

model = SiameseNet('SiameseNet/configs/road_signs.yml')
model.load_model('{}best_model_4.h5'.format(model.weights_save_path))
model.load_encodings('{}encodings.pkl'.format(model.encodings_path))
model = SiameseNet()
model.load_model('weights/road_signs/best_model_4.h5')
model.load_encodings('encodings/road_signs/encodings.pkl')


model_accuracy = model.calculate_prediction_accuracy()
print('Model accuracy on validation set: {}'.format(model_accuracy))
image_path = '/home/rauf/datasets/road_signs/road_signs_separated/val/1_1/rtsd-r1_train_006470.png'
model_prediction = model.predict(image_path)
print('Model prediction: {}'.format(model_prediction))

0 comments on commit 5c448ac

Please sign in to comment.