Skip to content

Commit

Permalink
update plot batch function
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Jan 7, 2020
1 parent f38dfa5 commit 25fe9f6
Show file tree
Hide file tree
Showing 3 changed files with 455 additions and 5,619 deletions.
25 changes: 21 additions & 4 deletions embedding_net/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _load_images_paths(self):
random.seed(4)

for d in self.data_subsets:
print('=======================================')
print('{:5} ======================================'.format(d))
self.images_paths[d] = []
self.images_labels[d] = []
for root, dirs, files in os.walk(self.dataset_path+d):
Expand All @@ -61,7 +61,6 @@ def _load_images_paths(self):
if curr_class in skip_list:
continue

print('Class {:11}: total number of files {}'.format(curr_class, n_obj))
idx_list = list(range(n_obj))
random.shuffle(idx_list)
count = 0
Expand All @@ -75,7 +74,7 @@ def _load_images_paths(self):
n_files_selected+=count
if d == 'train':
n_classes_selected += 1
print('Class {:11}: selected number of files {}'.format(curr_class, count))
print('Class {:11}: total number of files {:6}, selected {:6}'.format(curr_class, n_obj, count))
print('Total number of files in dataset: {}'.format(n_files_dataset))
print('Number of selected files: {}'.format(n_files_selected))
print('Number of selected classes: {}'.format(n_classes_selected))
Expand Down Expand Up @@ -260,7 +259,7 @@ def get_batch_triplets_mining(self,

all_embeddings_list = []
all_images_list = []
# with_aug = s == 'train' and self.augmentations

with_aug = self.augmentations
for idx, cl_img_idxs in enumerate(selected_images):
images = self._get_images_set(
Expand Down Expand Up @@ -348,6 +347,24 @@ def get_image(self, img_path):
img, (self.input_shape[0], self.input_shape[1]))
return img


def plot_batch_simple(self, data, targets):
num_imgs = data[0].shape[0]
img_h = data[0].shape[1]
img_w = data[0].shape[2]
full_img = np.zeros((img_h,num_imgs*img_w,3), dtype=np.uint8)
indxs = np.argmax(targets, axis=1)
class_names = [self.classes['train'][i] for i in indxs]

for i in range(num_imgs):
full_img[:,i*img_w:(i+1)*img_w,:] = data[0][i,:,:,:]
cv2.putText(full_img, class_names[i], (img_w*i + 10, 20), cv2.FONT_HERSHEY_SIMPLEX,
0.5, (0, 255, 0), 1, cv2.LINE_AA)
plt.figure(figsize = (20,2))
plt.imshow(full_img)
plt.show()


def plot_batch(self, data, targets):
num_imgs = data[0].shape[0]
it_val = len(data)
Expand Down
53 changes: 0 additions & 53 deletions embedding_net/pretrain_backbone_softmax.py

This file was deleted.

5,996 changes: 434 additions & 5,562 deletions test_network.ipynb

Large diffs are not rendered by default.

0 comments on commit 25fe9f6

Please sign in to comment.