Skip to content

Commit

Permalink
add backbone model pretraining function
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Nov 28, 2019
1 parent eb4c77a commit 184e677
Show file tree
Hide file tree
Showing 14 changed files with 18,407 additions and 9,202 deletions.
20 changes: 20 additions & 0 deletions configs/road_signs_resnet18_all.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.5
mode : 'triplet'
distance_type : 'l1'
backbone : 'resnet18'
backbone_weights : 'imagenet'
optimizer : 'radam'
learning_rate : 0.0001
project_name : 'road_signs/'
freeze_backbone : False
embeddings_normalization: True

#paths
dataset_path : '/home/rauf/datasets/road_signs/road_signs_all/'
tensorboard_log_path : 'tf_log/'
weights_save_path : 'weights/'
plots_path : 'plots/'
encodings_path : 'encodings/'
model_save_name : 'best_model_resnet18_all.h5'
20 changes: 20 additions & 0 deletions configs/road_signs_resnet18_max80_min30.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.5
mode : 'triplet'
distance_type : 'l1'
backbone : 'resnet18'
backbone_weights : 'imagenet'
optimizer : 'radam'
learning_rate : 0.0001
project_name : 'road_signs/'
freeze_backbone : False
embeddings_normalization: True

#paths
dataset_path : '/home/rauf/datasets/road_signs/road_signs_max80_min30/'
tensorboard_log_path : 'tf_log/'
weights_save_path : 'weights/'
plots_path : 'plots/'
encodings_path : 'encodings/'
model_save_name : 'best_model_resnet18_max80_min30.h5'
20 changes: 20 additions & 0 deletions configs/road_signs_resnet18_mini.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.3
mode : 'triplet'
distance_type : 'l1'
backbone : 'resnet18'
backbone_weights : 'imagenet'
optimizer : 'radam'
learning_rate : 0.0001
project_name : 'road_signs/'
freeze_backbone : False
embeddings_normalization: False

#paths
dataset_path : '/home/rauf/datasets/road_signs/road_signs_mini/'
tensorboard_log_path : 'tf_log/'
weights_save_path : 'weights/'
plots_path : 'plots/'
encodings_path : 'encodings/'
model_save_name : 'best_model_resnet18_mini.h5'
20 changes: 20 additions & 0 deletions configs/road_signs_resnet18_paper.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.5
mode : 'triplet'
distance_type : 'l1'
backbone : 'resnet18'
backbone_weights : 'imagenet'
optimizer : 'radam'
learning_rate : 0.0001
project_name : 'road_signs/'
freeze_backbone : False
embeddings_normalization: True

#paths
dataset_path : '/home/rauf/datasets/road_signs/road_signs_full/'
tensorboard_log_path : 'tf_log/'
weights_save_path : 'weights/'
plots_path : 'plots/'
encodings_path : 'encodings/'
model_save_name : 'best_model_resnet18_paper.h5'
20 changes: 20 additions & 0 deletions configs/road_signs_resnet18_paper_cutted.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.4
mode : 'triplet'
distance_type : 'l1'
backbone : 'resnet18'
backbone_weights : 'imagenet'
optimizer : 'radam'
learning_rate : 0.0001
project_name : 'road_signs/'
freeze_backbone : False
embeddings_normalization: True

#paths
dataset_path : '/home/rauf/datasets/road_signs/road_signs_full_cutted/'
tensorboard_log_path : 'tf_log/'
weights_save_path : 'weights/'
plots_path : 'plots/'
encodings_path : 'encodings/'
model_save_name : 'best_model_resnet18_paper.h5'
20 changes: 20 additions & 0 deletions configs/road_signs_resnet18_paper_remaining.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.5
mode : 'triplet'
distance_type : 'l1'
backbone : 'resnet18'
backbone_weights : 'imagenet'
optimizer : 'radam'
learning_rate : 0.0001
project_name : 'road_signs/'
freeze_backbone : False
embeddings_normalization: True

#paths
dataset_path : '/home/rauf/datasets/road_signs/road_signs_full_remaining/'
tensorboard_log_path : 'tf_log/'
weights_save_path : 'weights/'
plots_path : 'plots/'
encodings_path : 'encodings/'
model_save_name : 'best_model_resnet18_paper.h5'
20 changes: 20 additions & 0 deletions configs/road_signs_simple2_mini.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
input_shape : [48, 48, 3]
encodings_len: 256
margin: 0.5
mode : 'triplet'
distance_type : 'l1'
backbone : 'simple2'
backbone_weights : 'imagenet'
optimizer : 'radam'
learning_rate : 0.0001
project_name : 'road_signs/'
freeze_backbone : True
embeddings_normalization: True

#paths
dataset_path : '/home/rauf/datasets/road_signs/road_signs_mini/'
tensorboard_log_path : 'tf_log/'
weights_save_path : 'weights/'
plots_path : 'plots/'
encodings_path : 'encodings/'
model_save_name : 'best_model_simple2_mini.h5'
2 changes: 1 addition & 1 deletion embedding_net/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def get_backbone(input_shape,
base_model = Model(
inputs=[backbone_model.input], outputs=[encoded_output])

return base_model
return base_model, backbone_model
130 changes: 129 additions & 1 deletion embedding_net/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, dataset_path, input_shape=None, augmentations=None, data_subs
self.current_idx = {d: 0 for d in data_subsets}
self._load_images_paths()
self.classes = {
s: list(set(self.images_labels[s])) for s in data_subsets}
s: sorted(list(set(self.images_labels[s]))) for s in data_subsets}
self.n_classes = {s: len(self.classes[s]) for s in data_subsets}
self.n_samples = {d: len(self.images_paths[d]) for d in data_subsets}
self.indexes = {d: {cl: np.where(np.array(self.images_labels[d]) == cl)[
Expand Down Expand Up @@ -247,6 +247,9 @@ def generate_mining(self, embedding_model, n_classes, n_samples, margin = 0.5, n

def get_image(self, img_path):
img = cv2.imread(img_path)
if img is None:
print('image is not exist ' + img_path)
return None
if self.input_shape:
img = cv2.resize(
img, (self.input_shape[0], self.input_shape[1]))
Expand All @@ -270,3 +273,128 @@ def plot_batch(self, data, targets):
i += it_val

plt.show()




class SimpleNetImageLoader:
"""
Image loader for Embedding network
"""

def __init__(self, dataset_path, input_shape=None, augmentations=None, data_subsets=['train', 'val']):
self.dataset_path = dataset_path
self.data_subsets = data_subsets
self.images_paths = {}
self.images_labels = {}
self.input_shape = input_shape
self.augmentations = augmentations
self.current_idx = {d: 0 for d in data_subsets}
self._load_images_paths()
self.classes = {
s: sorted(list(set(self.images_labels[s]))) for s in data_subsets}
self.n_classes = {s: len(self.classes[s]) for s in data_subsets}
self.n_samples = {d: len(self.images_paths[d]) for d in data_subsets}
self.indexes = {d: {cl: np.where(np.array(self.images_labels[d]) == cl)[
0] for cl in self.classes[d]} for d in data_subsets}


def _load_images_paths(self):
for d in self.data_subsets:
self.images_paths[d] = []
self.images_labels[d] = []
for root, dirs, files in os.walk(self.dataset_path+d):
for f in files:
if f.endswith('.jpg') or f.endswith('.png'):
self.images_paths[d].append(root+'/'+f)
self.images_labels[d].append(root.split('/')[-1])


def _get_images_set(self, clsss, idxs, s='train', with_aug=True):
if type(clsss) is list:
indxs = [self.indexes[s][cl][idx] for cl, idx in zip(clsss, idxs)]
else:
indxs = [self.indexes[s][clsss][idx] for idx in idxs]
imgs = [cv2.imread(self.images_paths[s][idx]) for idx in indxs]

if self.input_shape:
imgs = [cv2.resize(
img, (self.input_shape[0], self.input_shape[1])) for img in imgs]

if with_aug:
imgs = [self.augmentations(image=img)['image'] for img in imgs]

return imgs


def get_batch(self, batch_size, s='train'):
images = [np.zeros((batch_size, self.input_shape[0], self.input_shape[1], 3))]
targets = np.zeros((batch_size,self.n_classes[s]))

count = 0
with_aug = s == 'train' and self.augmentations
for i in range(batch_size):
selected_class_idx = random.randrange(0, self.n_classes[s])
selected_class = self.classes[s][selected_class_idx]
selected_class_n_elements = len(self.indexes[s][selected_class])

indx = random.randrange(0, selected_class_n_elements)

img = self._get_images_set(
[selected_class], [indx], s=s, with_aug=with_aug)
images[0][count, :, :, :] = img[0]
targets[i][selected_class_idx] = 1
count+=1

return images, targets


def generate(self, batch_size, s='train'):
while True:
data, targets = self.get_batch(batch_size, s)
yield (data, targets)


def get_image(self, img_path):
img = cv2.imread(img_path)
if self.input_shape:
img = cv2.resize(
img, (self.input_shape[0], self.input_shape[1]))
return img


def plot_batch(self, data, targets):
num_imgs = data[0].shape[0]
it_val = len(data)
fig, axs = plt.subplots(num_imgs, it_val, figsize=(
30, 50), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=.5, wspace=.001)

axs = axs.ravel()
i = 0
for img_idx, targ in zip(range(num_imgs), targets):
for j in range(it_val):
img = cv2.cvtColor(data[j][img_idx].astype(
np.uint8), cv2.COLOR_BGR2RGB)
axs[i+j].imshow(img)
axs[i+j].set_title(targ)
i += it_val

plt.show()
num_imgs = data[0].shape[0]
it_val = len(data)
fig, axs = plt.subplots(num_imgs, it_val, figsize=(
30, 50), facecolor='w', edgecolor='k')
fig.subplots_adjust(hspace=.5, wspace=.001)

axs = axs.ravel()
i = 0
for img_idx, targ in zip(range(num_imgs), targets):
for j in range(it_val):
img = cv2.cvtColor(data[j][img_idx].astype(
np.uint8), cv2.COLOR_BGR2RGB)
axs[i+j].imshow(img)
axs[i+j].set_title(targ)
i += it_val

plt.show()
48 changes: 40 additions & 8 deletions embedding_net/losses_and_accuracies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import keras.backend as K
import keras


def contrastive_loss(y_true, y_pred):
Expand All @@ -10,7 +11,8 @@ def contrastive_loss(y_true, y_pred):
margin_square = K.square(K.maximum(margin - y_pred, 0))
return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)

def triplet_loss(margin = 0.5):

def triplet_loss(margin=0.5):
"""
Implementation of the triplet loss function
Arguments:
Expand All @@ -25,24 +27,54 @@ def triplet_loss(margin = 0.5):
def loss_function(y_true, y_pred):
total_lenght = y_pred.shape.as_list()[-1]

anchor = y_pred[:,0:int(total_lenght*1/3)]
positive = y_pred[:,int(total_lenght*1/3):int(total_lenght*2/3)]
negative = y_pred[:,int(total_lenght*2/3):int(total_lenght*3/3)]
anchor = y_pred[:, 0:int(total_lenght*1/3)]
positive = y_pred[:, int(total_lenght*1/3):int(total_lenght*2/3)]
negative = y_pred[:, int(total_lenght*2/3):int(total_lenght*3/3)]

# distance between the anchor and the positive
pos_dist = K.sum(K.square(anchor-positive),axis=1)
pos_dist = K.sum(K.square(anchor-positive), axis=1)

# distance between the anchor and the negative
neg_dist = K.sum(K.square(anchor-negative),axis=1)
neg_dist = K.sum(K.square(anchor-negative), axis=1)

# compute loss
basic_loss = pos_dist-neg_dist+margin
loss = K.maximum(basic_loss,0.0)
loss = K.maximum(basic_loss, 0.0)
return loss

return loss_function


def accuracy(y_true, y_pred):
'''Compute classification accuracy with a fixed threshold on distances.
'''
return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))


class tSNECallback(keras.callbacks.Callback):

def __init__(self, save_file_name='tSNE.gif'):
super(tSNECallback, self).__init__()
self.save_file_name = save_file_name

def on_train_begin(self, logs={}):
self.aucs = []
self.losses = []

def on_train_end(self, logs={}):
return

def on_epoch_begin(self, epoch, logs={}):
return

def on_epoch_end(self, epoch, logs={}):
self.losses.append(logs.get('loss'))
y_pred = self.model.predict(self.model.validation_data[0])
self.aucs.append(roc_auc_score(self.model.validation_data[1], y_pred))
return

def on_batch_begin(self, batch, logs={}):
return

def on_batch_end(self, batch, logs={}):
return
Loading

0 comments on commit 184e677

Please sign in to comment.