Skip to content

Commit

Permalink
add triplet network
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Aug 11, 2019
1 parent 683d234 commit d31e51b
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 46 deletions.
6 changes: 3 additions & 3 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def get_batch_triplets(self, batch_size, s='train'):

return triplets, targets

def generate(self, batch_size, mode='pair', s="train"):
def generate(self, batch_size, mode='siamese', s="train"):
while True:
if mode == 'pair':
if mode == 'siamese':
data, targets = self.get_batch_pairs(batch_size, s)
if mode == 'triplet':
elif mode == 'triplet':
data, targets = self.get_batch_triplets(batch_size, s)
yield (data, targets)

Expand Down
115 changes: 85 additions & 30 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,38 @@
import numpy as np
import keras.backend as K
import tensorflow as tf
import pickle
import cv2
import random
from keras.models import Model, load_model
from keras import optimizers
from keras.regularizers import l2
from keras.utils import plot_model
from keras.layers import Dense, Input, Lambda, Dropout, Flatten
from keras.layers import Conv2D, MaxPool2D, BatchNormalization
from keras.layers import Conv2D, MaxPool2D, BatchNormalization, concatenate
from classification_models import Classifiers

import utils
import pickle


class SiameseNet:
"""
SiameseNet for image classification
mode = 'l1' -> l1_loss
mode = 'l2' -> l2_loss
distance_type = 'l1' -> l1_loss
distance_type = 'l2' -> l2_loss
mode = 'siamese' -> Siamese network
mode = 'triplet' -> Triplen network
"""

def __init__(self, input_shape, image_loader, mode='l1', backbone='resnet50',
def __init__(self, input_shape, image_loader, mode='siamese', distance_type ='l1', backbone='resnet50',
backbone_weights = 'imagenet',
optimizer=optimizers.Adam(lr=1e-4), tensorboard_log_path='tf_log/',
weights_save_path='weights/', plots_path='plots/', encodings_path='encodings/',
project_name='', freeze_backbone=True):
self.input_shape = input_shape
self.backbone = backbone
self.backbone_weights = backbone_weights
self.distance_type = distance_type
self.mode = mode
self.project_name = project_name
self.optimizer = optimizer
Expand All @@ -54,15 +57,14 @@ def __init__(self, input_shape, image_loader, mode='l1', backbone='resnet50',
if self.weights_save_path:
os.makedirs(self.weights_save_path, exist_ok=True)

self._create_model()
if self.mode == 'siamese':
self._create_model_siamese()
elif self.mode == 'triplet':
self._create_model_triplet()
self.data_loader = image_loader
self.encoded_training_data = {}

def _create_model(self):

input_image_1 = Input(self.input_shape)
input_image_2 = Input(self.input_shape)

def _create_base_model(self):
if self.backbone == 'simple':
input_image = Input(self.input_shape)
x = Conv2D(64, (10, 10), activation='relu',
Expand Down Expand Up @@ -134,19 +136,28 @@ def _create_model(self):

self.base_model = Model(
inputs=[backbone_model.input], outputs=[encoded_output])
pass


def _create_model_siamese(self):

input_image_1 = Input(self.input_shape)
input_image_2 = Input(self.input_shape)

self._create_base_model()

image_encoding_1 = self.base_model(input_image_1)
image_encoding_2 = self.base_model(input_image_2)

if self.mode == 'l1':
if self.distance_type == 'l1':
L1_layer = Lambda(
lambda tensors: K.abs(tensors[0] - tensors[1]))
distance = L1_layer([image_encoding_1, image_encoding_2])

prediction = Dense(units=1, activation='sigmoid')(distance)
metric = 'binary_accuracy'

elif self.mode == 'l2':
elif self.distance_type == 'l2':

L2_layer = Lambda(
lambda tensors: K.sqrt(K.maximum(K.sum(K.square(tensors[0] - tensors[1]), axis=1, keepdims=True), K.epsilon())))
Expand All @@ -169,6 +180,23 @@ def _create_model(self):
self.model.compile(loss=self.contrastive_loss, metrics=[metric],
optimizer=self.optimizer)

def _create_model_triplet(self):
input_image_a = Input(self.input_shape)
input_image_p = Input(self.input_shape)
input_image_n = Input(self.input_shape)

self._create_base_model()

image_encoding_a = self.base_model(input_image_a)
image_encoding_p = self.base_model(input_image_p)
image_encoding_n = self.base_model(input_image_n)

merged_vector = concatenate([image_encoding_a, image_encoding_p, image_encoding_n],
axis=-1, name='merged_layer')
self.model = Model(inputs=[input_image_a,input_image_p, input_image_n],
outputs=merged_vector)
self.model.compile(loss=self.triplet_loss, optimizer=self.optimizer)


def contrastive_loss(self, y_true, y_pred):
'''Contrastive loss from Hadsell-et-al.'06
Expand All @@ -179,6 +207,40 @@ def contrastive_loss(self, y_true, y_pred):
margin_square = K.square(K.maximum(margin - y_pred, 0))
return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)

def triplet_loss(self, y_true, y_pred, alpha = 0.4):
"""
Implementation of the triplet loss function
Arguments:
y_true -- true labels, required when you define a loss in Keras, you don't need it in this function.
y_pred -- python list containing three objects:
anchor -- the encodings for the anchor data
positive -- the encodings for the positive data (similar to anchor)
negative -- the encodings for the negative data (different from anchor)
Returns:
loss -- real number, value of the loss
"""
print('y_pred.shape = ',y_pred)

total_lenght = y_pred.shape.as_list()[-1]
# print('total_lenght=', total_lenght)
# total_lenght =12
print(y_pred)
anchor = y_pred[:,0:int(total_lenght*1/3)]
positive = y_pred[:,int(total_lenght*1/3):int(total_lenght*2/3)]
negative = y_pred[:,int(total_lenght*2/3):int(total_lenght*3/3)]

# distance between the anchor and the positive
pos_dist = K.sum(K.square(anchor-positive),axis=1)

# distance between the anchor and the negative
neg_dist = K.sum(K.square(anchor-negative),axis=1)

# compute loss
basic_loss = pos_dist-neg_dist+alpha
loss = K.maximum(basic_loss,0.0)

return loss

def accuracy(self, y_true, y_pred):
'''Compute classification accuracy with a fixed threshold on distances.
'''
Expand All @@ -200,8 +262,8 @@ def validate_on_batch(self, batch_size=8, s="val"):

def train_generator(self, steps_per_epoch, epochs, callbacks = [], val_steps=100, with_val=True, batch_size=8, verbose=1):

train_generator = self.data_loader.generate(batch_size, s="train")
val_generator = self.data_loader.generate(batch_size, s="val")
train_generator = self.data_loader.generate(batch_size, mode=self.mode, s="train")
val_generator = self.data_loader.generate(batch_size, mode=self.mode, s="val")

history = self.model.fit_generator(train_generator, steps_per_epoch=steps_per_epoch, epochs=epochs,
verbose=verbose, validation_data = val_generator, validation_steps = val_steps, callbacks=callbacks)
Expand All @@ -213,13 +275,9 @@ def validate(self, number_of_comparisons=100, batch_size=4, s="val"):
val_losses_it = []
for _ in range(number_of_comparisons):
pairs, targets = next(generator)
# predictions = self.model.predict(pairs)

val_loss_it, val_accuracy_it = self.model.test_on_batch(
pairs, targets)
# print(predictions)
# print(targets)
# print('================================')
val_accuracies_it.append(val_accuracy_it)
val_losses_it.append(val_loss_it)
val_loss_epoch = sum(val_losses_it) / len(val_losses_it)
Expand Down Expand Up @@ -260,18 +318,15 @@ def generate_encodings(self, save_file_name='encodings.pkl', max_num_samples_of_
f.close()

def load_encodings(self, path_to_encodings):
try:
with open(path_to_encodings, 'rb') as f:
self.encoded_training_data = pickle.load(f)
except:
print("Problem with encodings file")
utils.load_encodings(path_to_encodings)

def load_model(self,file_path):
self.model = load_model(file_path,
custom_objects={'contrastive_loss': self.contrastive_loss,
'accuracy': self.accuracy})
self.base_model = Model(inputs=[self.model.layers[2].get_input_at(0)],
outputs=[self.model.layers[2].layers[-1].output])
'accuracy': self.accuracy,
'triplet_loss': self.triplet_loss})
self.base_model = Model(inputs=[self.model.layers[3].get_input_at(0)],
outputs=[self.model.layers[3].layers[-1].output])

def calculate_distances(self, encoding):
training_encodings = self.encoded_training_data['encodings']
Expand All @@ -281,7 +336,7 @@ def calculate_distances(self, encoding):
def predict(self, image_path):
img = cv2.imread(image_path)
img = cv2.resize(img, (self.input_shape[0], self.input_shape[1]))
print(img.shape)

encoding = self.base_model.predict(np.expand_dims(img, axis=0))
distances = self.calculate_distances(encoding)
max_element = np.argmin(distances)
Expand Down
25 changes: 12 additions & 13 deletions test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ def plot_grapth(values, y_label, title, project_name):

fig.savefig("plots/{}{}.png".format(project_name, y_label))


input_shape = (48, 48, 3)
project_name = 'road_signs/'
dataset_path = '/home/rauf/plates_competition/dataset/road_signs/road_signs_separated/'

# input_shape = (256, 256, 3)
# project_name = 'plates/'
# dataset_path = '/home/rauf/plates_competition/dataset/to_train/'

n_epochs = 1000
n_steps_per_epoch = 500
batch_size = 4
batch_size = 64
val_steps = 100
input_shape = (48, 48, 3)
# input_shape = (256, 256, 3)


# augmentations = A.Compose([
# A.RandomBrightnessContrast(p=0.4),
Expand All @@ -53,11 +54,9 @@ def plot_grapth(values, y_label, title, project_name):

optimizer = optimizers.Adam(lr=1e-4)
# optimizer = optimizers.RMSprop(lr=1e-5)
# model = SiameseNet(input_shape=(256, 256, 3), backbone='resnet50', mode='l2',
# image_loader=loader, optimizer=optimizer)

model = SiameseNet(input_shape=input_shape, backbone='simple2', backbone_weights='imagenet', mode='l2',
image_loader=loader, optimizer=optimizer, project_name=project_name,
model = SiameseNet(input_shape=input_shape, backbone='simple2', backbone_weights='imagenet', mode='triplet',
image_loader=loader, optimizer=optimizer, project_name=project_name, distance_type='l2',
freeze_backbone=False)


Expand All @@ -72,7 +71,7 @@ def plot_grapth(values, y_label, title, project_name):
TensorBoard(log_dir=model.tensorboard_log_path),
# ReduceLROnPlateau(factor=0.9, patience=50,
# min_lr=1e-12, verbose=1),
ModelCheckpoint(filepath=os.path.join(model.weights_save_path, 'best_model_2.h5'), verbose=1, monitor='loss',
ModelCheckpoint(filepath=os.path.join(model.weights_save_path, 'best_model_3.h5'), verbose=1, monitor='val_loss',
save_best_only=True)
]

Expand All @@ -82,14 +81,14 @@ def plot_grapth(values, y_label, title, project_name):
H = model.train_generator(steps_per_epoch=n_steps_per_epoch, callbacks=callbacks,
val_steps=val_steps, epochs=n_epochs)
train_losses = H.history['loss']
train_accuracies = H.history['accuracy']
# train_accuracies = H.history['accuracy']
val_losses = H.history['val_loss']
val_accuracies = H.history['val_accuracy']
# val_accuracies = H.history['val_accuracy']

plot_grapth(train_losses, 'train_loss', 'Losses on train', project_name)
plot_grapth(train_accuracies, 'train_acc', 'Accuracies on train', project_name)
# plot_grapth(train_accuracies, 'train_acc', 'Accuracies on train', project_name)
plot_grapth(val_losses, 'val_loss', 'Losses on val', project_name)
plot_grapth(val_accuracies, 'val_acc', 'Accuracies on val', project_name)
# plot_grapth(val_accuracies, 'val_acc', 'Accuracies on val', project_name)


model.generate_encodings()
Expand Down
28 changes: 28 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from sklearn.manifold import TSNE
import pickle
import numpy as np
from matplotlib import pyplot as plt


def load_encodings(path_to_encodings):

with open(path_to_encodings, 'rb') as f:
encodings = pickle.load(f)
return encodings


def make_tsne(project_name, show=True):
encodings = load_encodings(
'encodings/{}encodings.pkl'.format(project_name))
labels = list(set(encodings['labels']))
tsne = TSNE()
tsne_train = tsne.fit_transform(encodings['encodings'])
fig, ax = plt.subplots(figsize=(16, 16))
for i, l in enumerate(labels):
ax.scatter(tsne_train[np.array(encodings['labels']) == l, 0],
tsne_train[np.array(encodings['labels']) == l, 1], label=l)
ax.legend()
if show:
fig.show()

fig.savefig("plots/{}{}.png".format(project_name, 'tsne.png'))

0 comments on commit d31e51b

Please sign in to comment.