Skip to content

Commit

Permalink
add classes centers based encodings generation
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Dec 17, 2019
1 parent e4d6505 commit 7ad9127
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 10 deletions.
1 change: 1 addition & 0 deletions configs/plates.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ encodings_save_name: 'encodings_resnet18_plates.pkl'

# encodings parameters
save_encodings : True
centers_only: False
max_num_samples_of_each_class : 30
knn_k : 1
1 change: 1 addition & 0 deletions configs/road_signs_resnet18.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ encodings_save_name: 'encodings_resnet18.pkl'

# encodings parameters
save_encodings : True
centers_only: False
max_num_samples_of_each_class : 30
knn_k : 1
1 change: 1 addition & 0 deletions configs/road_signs_resnet34.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ encodings_save_name: 'encodings_resnet18.pkl'

# encodings parameters
save_encodings : True
centers_only: False
max_num_samples_of_each_class : 30
knn_k : 1
1 change: 1 addition & 0 deletions configs/road_signs_resnext50.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ encodings_save_name: 'encodings_resnet18.pkl'

# encodings parameters
save_encodings : True
centers_only: False
max_num_samples_of_each_class : 30
knn_k : 1
34 changes: 26 additions & 8 deletions embedding_net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,20 @@ def validate(self, number_of_comparisons=100, batch_size=4, s="val"):
val_accuracies_it) / len(val_accuracies_it)
return val_loss_epoch, val_accuracy_epoch


def _generate_encoding(self, img_path):
img = self.data_loader.get_image(img_path)
if img is None:
return None
encoding = self.base_model.predict(np.expand_dims(img, axis=0))
return encoding

def generate_encodings(self, save_file_name='encodings.pkl', max_num_samples_of_each_class=10, knn_k=1, shuffle=True):

def generate_encodings(self, save_file_name='encodings.pkl', only_centers=False, max_num_samples_of_each_class=10, knn_k=1, shuffle=True):
data_paths, data_labels, data_encodings = [], [], []
classes_counter = {}
classes_encodings = {}
k_val = 1 if only_centers else knn_k

if shuffle:
c = list(zip(
Expand All @@ -216,21 +220,35 @@ def generate_encodings(self, save_file_name='encodings.pkl', max_num_samples_of_

for img_path, img_label in zip(self.data_loader.images_paths['train'],
self.data_loader.images_labels['train']):
if img_label not in classes_counter:
classes_counter[img_label] = 0
if only_centers:
if img_label not in classes_encodings:
classes_encodings[img_label] = []
else:
if img_label not in classes_counter:
classes_counter[img_label] = 0
if classes_counter[img_label] < max_num_samples_of_each_class:
encod = self._generate_encoding(img_path)

if encod is not None:
data_paths.append(img_path)
data_labels.append(img_label)
data_encodings.append(encod)
classes_counter[img_label] += 1
if only_centers:
classes_encodings[img_label].append(encod)
else:
data_paths.append(img_path)
data_labels.append(img_label)
data_encodings.append(encod)
classes_counter[img_label] += 1
if only_centers:
for class_i, encodings_i in classes_encodings.items():
encodings_i_np = np.array(encodings_i)
class_encoding = np.mean(encodings_i_np, axis = 0)
data_encodings.append(class_encoding)
data_labels.append(class_i)
self.encoded_training_data['paths'] = data_paths
self.encoded_training_data['labels'] = data_labels
self.encoded_training_data['encodings'] = np.squeeze(
np.array(data_encodings))
self.encoded_training_data['knn_classifier'] = KNeighborsClassifier(
n_neighbors=knn_k)
n_neighbors=k_val)
self.encoded_training_data['knn_classifier'].fit(self.encoded_training_data['encodings'],
self.encoded_training_data['labels'])
f = open(save_file_name, "wb")
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
image-classifiers
keras
tensorflow-gpu
keras==2.2.5
tensorflow-gpu==1.14.0
matplotlib
albumentations
scikit-learn
Expand Down

0 comments on commit 7ad9127

Please sign in to comment.