Skip to content

Commit

Permalink
add knn for prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Sep 6, 2019
1 parent b8f6106 commit 5b0f3c5
Show file tree
Hide file tree
Showing 8 changed files with 10,820 additions and 142 deletions.
1 change: 1 addition & 0 deletions configs/road_signs_resnext50_merged_dataset.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.5
mode : 'triplet'
distance_type : 'l1'
backbone : 'resnext50'
Expand Down
3 changes: 2 additions & 1 deletion configs/road_signs_simple2_merged_dataset.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
input_shape : [48, 48, 3]
encodings_len: 1024
encodings_len: 256
margin: 0.5
mode : 'triplet'
distance_type : 'l1'
backbone : 'simple2'
Expand Down
7 changes: 7 additions & 0 deletions siamese_net/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def get_batch_triplets_mining(self,
triplet_positives.append(all_images[anchor_positive[1]])
triplet_negatives.append(all_images[hard_negative])
targets.append(1)

if len(triplet_anchors) == 0:
triplet_anchors.append(all_images[anchor_positive[0]])
triplet_positives.append(all_images[anchor_positive[1]])
triplet_negatives.append(all_images[negative_indices[0]])
targets.append(1)

triplet_anchors = np.array(triplet_anchors)
triplet_positives = np.array(triplet_positives)
triplet_negatives = np.array(triplet_negatives)
Expand Down
29 changes: 15 additions & 14 deletions siamese_net/losses_and_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def contrastive_loss(y_true, y_pred):
margin_square = K.square(K.maximum(margin - y_pred, 0))
return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)

def triplet_loss(y_true, y_pred, alpha = 0.5):
def triplet_loss(margin = 0.5):
"""
Implementation of the triplet loss function
Arguments:
Expand All @@ -22,24 +22,25 @@ def triplet_loss(y_true, y_pred, alpha = 0.5):
Returns:
loss -- real number, value of the loss
"""

total_lenght = y_pred.shape.as_list()[-1]
def loss_function(y_true, y_pred):
total_lenght = y_pred.shape.as_list()[-1]

anchor = y_pred[:,0:int(total_lenght*1/3)]
positive = y_pred[:,int(total_lenght*1/3):int(total_lenght*2/3)]
negative = y_pred[:,int(total_lenght*2/3):int(total_lenght*3/3)]
anchor = y_pred[:,0:int(total_lenght*1/3)]
positive = y_pred[:,int(total_lenght*1/3):int(total_lenght*2/3)]
negative = y_pred[:,int(total_lenght*2/3):int(total_lenght*3/3)]

# distance between the anchor and the positive
pos_dist = K.sum(K.square(anchor-positive),axis=1)
# distance between the anchor and the positive
pos_dist = K.sum(K.square(anchor-positive),axis=1)

# distance between the anchor and the negative
neg_dist = K.sum(K.square(anchor-negative),axis=1)
# distance between the anchor and the negative
neg_dist = K.sum(K.square(anchor-negative),axis=1)

# compute loss
basic_loss = pos_dist-neg_dist+alpha
loss = K.maximum(basic_loss,0.0)
# compute loss
basic_loss = pos_dist-neg_dist+margin
loss = K.maximum(basic_loss,0.0)
return loss

return loss
return loss_function

def accuracy(y_true, y_pred):
'''Compute classification accuracy with a fixed threshold on distances.
Expand Down
38 changes: 24 additions & 14 deletions siamese_net/model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import os
import glob
import numpy as np
import keras.backend as K
import tensorflow as tf
import cv2
import random
from keras.models import Model, load_model
from keras import optimizers
from keras.layers import Dense, Input, Lambda, Dropout, Flatten
from keras.layers import Conv2D, MaxPool2D, BatchNormalization, concatenate
from classification_models import Classifiers
from keras.layers import Dense, Input, Lambda, concatenate
import pickle
from .utils import parse_net_params, load_encodings
from .backbones import get_backbone
from . import losses_and_accuracies as lac
import matplotlib.pyplot as plt

from sklearn.neighbors import KNeighborsClassifier

class SiameseNet:
"""
Expand All @@ -41,7 +37,8 @@ def __init__(self, cfg_file=None):
self.freeze_backbone = params['freeze_backbone']
self.data_loader = params['loader']
self.embeddings_normalization = params['embeddings_normalization']

self.margin = params['margin']

self.model = []
self.base_model = []
self.l_model = []
Expand Down Expand Up @@ -82,7 +79,7 @@ def _create_model_siamese(self):

self._create_base_model()
self.base_model._make_predict_function()

image_encoding_1 = self.base_model(input_image_1)
image_encoding_2 = self.base_model(input_image_2)

Expand Down Expand Up @@ -138,7 +135,7 @@ def _create_model_triplet(self):
print('Whole model summary')
self.model.summary()

self.model.compile(loss=lac.triplet_loss, optimizer=self.optimizer)
self.model.compile(loss=lac.triplet_loss(self.margin), optimizer=self.optimizer)


def train_on_batch(self, batch_size=8, s="train"):
Expand Down Expand Up @@ -177,8 +174,9 @@ def train_generator_mining(self,
verbose=1):

train_generator = self.data_loader.generate_mining(self.base_model, n_classes, n_samples, negative_selection_mode=negative_selection_mode, s="train")
val_generator = self.data_loader.generate_mining(self.base_model, n_classes, n_samples, negative_selection_mode=negative_selection_mode, s="val")

# val_generator = self.data_loader.generate_mining(self.base_model, n_classes, n_samples, negative_selection_mode=negative_selection_mode, s="val")
val_generator = self.data_loader.generate(8, mode=self.mode, s="val")

history = self.model.fit_generator(train_generator, steps_per_epoch=steps_per_epoch, epochs=epochs,
verbose=verbose, validation_data = val_generator, validation_steps = val_steps, callbacks=callbacks)
if self.plots_path:
Expand Down Expand Up @@ -228,7 +226,9 @@ def generate_encodings(self, save_file_name='encodings.pkl', max_num_samples_of_
self.encoded_training_data['labels'] = data_labels
self.encoded_training_data['encodings'] = np.squeeze(
np.array(data_encodings))

self.encoded_training_data['knn_classifier'] = KNeighborsClassifier(n_neighbors=1)
self.encoded_training_data['knn_classifier'].fit(self.encoded_training_data['encodings'],
self.encoded_training_data['labels'])
f = open(os.path.join(self.encodings_path, save_file_name), "wb")
pickle.dump(self.encoded_training_data, f)
f.close()
Expand All @@ -241,7 +241,7 @@ def load_model(self,file_path):
self.model = load_model(file_path,
custom_objects={'contrastive_loss': lac.contrastive_loss,
'accuracy': lac.accuracy,
'triplet_loss': lac.triplet_loss,
'loss_function': lac.triplet_loss(self.margin),
'RAdam': RAdam})
self.input_shape = list(self.model.inputs[0].shape[1:])
self.base_model = Model(inputs=[self.model.layers[3].get_input_at(0)],
Expand All @@ -265,12 +265,22 @@ def predict(self, image):
predicted_label = self.encoded_training_data['labels'][max_element]
return predicted_label

def predict_knn(self, image):
if type(image) is str:
img = cv2.imread(image)
else:
img = image
img = cv2.resize(img, (self.input_shape[0], self.input_shape[1]))
encoding = self.base_model.predict(np.expand_dims(img, axis=0))
predicted_label = self.encoded_training_data['knn_classifier'].predict(encoding)
return predicted_label

def calculate_prediction_accuracy(self):
correct = 0
total_n_of_images = len(self.data_loader.images_paths['val'])
for img_path, img_label in zip(self.data_loader.images_paths['val'],
self.data_loader.images_labels['val']):
prediction = self.predict(img_path)
prediction = self.predict_knn(img_path)[0]
if prediction == img_label:
correct+=1
return correct/total_n_of_images
Expand Down
1 change: 1 addition & 0 deletions siamese_net/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def parse_net_params(filename='configs/road_signs.yml'):
params['model_save_name'] = cfg['model_save_name']
if 'dataset_path' in cfg:
params['loader'] = SiameseImageLoader(cfg['dataset_path'],
margin = cfg['margin'],
input_shape=cfg['input_shape'],
augmentations=augmentations)

Expand Down
10,877 changes: 10,767 additions & 110 deletions test_network.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
batch_size = 8
val_steps = 100

config_name = 'resnext50_merged_dataset'
model = SiameseNet('configs/road_signs_{}.yml'.format(config_name))
config_name = 'road_signs_simple2_merged_dataset'
model = SiameseNet('configs/{}.yml'.format(config_name))

initial_lr = 1e-4
decay_factor = 0.95
Expand All @@ -34,7 +34,7 @@
epochs=n_epochs,
callbacks = callbacks,
val_steps=100,
n_classes=4,
n_classes=10,
n_samples=4,
negative_selection_mode='semihard')

Expand Down

0 comments on commit 5b0f3c5

Please sign in to comment.