Skip to content

Commit

Permalink
create new dataloader and generators
Browse files Browse the repository at this point in the history
  • Loading branch information
Rauf Yagfarov authored and Rauf Yagfarov committed Feb 17, 2020
1 parent 1181cc8 commit 7b6b20b
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 296 deletions.
46 changes: 27 additions & 19 deletions configs/bengali.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,24 @@ MODEL:
freeze_backbone : False
embeddings_normalization: True

TRAIN:
DATALOADER:
dataset_path : '/home/rauf/datasets/bengali/pngs/train/'
csv_file : '/home/rauf/datasets/bengali/train_new.csv'
image_id_column : 'image_id'
label_column : 'label'
validate : True
val_ratio : 0.2

GENERATOR:
negatives_selection_mode : 'semihard'
mining_n_classes: 5
mining_n_samples: 3
margin: 0.5
batch_size : 8
n_batches : 200
augmentation_type : 'default'

TRAIN:
# optimizer parameters
optimizer : 'radam'
learning_rate : 0.0001
Expand All @@ -21,29 +34,24 @@ TRAIN:

# embeddings learning training parameters
n_epochs : 1000
n_batches : 200
val_batch_size : 8
val_steps : 200
negatives_selection_mode : 'semihard'
mining_n_classes: 5
mining_n_samples: 3

# plot training history
plot_history : True

SOFTMAX_PRETRAINING:
# softmax pretraining parameters
softmax_pretraining : True
softmax_batch_size : 8
softmax_val_steps : 200
softmax_steps_per_epoch : 500
softmax_epochs : 20
optimizer : 'radam'
learning_rate : 0.0001
decay_factor : 0.99
step_size : 1

# plot training history
plot_history : True
batch_size : 8
val_steps : 200
steps_per_epoch : 500
n_epochs : 20

PATHS:
SAVE_PATHS:
work_dir : 'work_dirs/road_signs_resnet18/'
dataset_path : '/home/rauf/datasets/bengali/pngs/train/'
csv_file : '/home/rauf/datasets/bengali/train.csv'
image_id_column : 'image_id'
label_column : 'label'
encodings_path : 'encodings/'
model_save_name : 'best_model_resnet18.h5'
encodings_save_name: 'encodings_resnet18.pkl'
Expand Down
103 changes: 80 additions & 23 deletions embedding_net/datagenerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class ENDataLoader():
def __init__(self, dataset_path,
csv_file=None,
image_id_column = 'image_id',
label_column = 'label'):
label_column = 'label',
validate = True,
val_ratio = 0.1):

self.dataset_path = dataset_path
self.class_files_paths = {}
Expand All @@ -26,14 +28,23 @@ def __init__(self, dataset_path,
self.n_classes = len(self.class_names)
self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()}

self.validate = validate
self.val_ratio = val_ratio

if self.validate:
self.train_data, self.val_data = self.split_train_val(self.val_ratio)
else:
self.train_data = self.class_files_paths
self.val_data = {}

def split_train_val(self, val_ratio):
train_data = {}
val_data = {}
for k, v in self.class_files_paths.items():
train_d, val_d = train_test_split(v, test_size=val_ratio, random_state=42)
train_data[k] = train_d
val_data[k] = val_d
return train_data, val_data, self.class_names
return train_data, val_data

def _load_from_dataframe(self, csv_file, image_id_column, label_column):
dataframe = pd.read_csv(csv_file)
Expand Down Expand Up @@ -110,20 +121,22 @@ def _get_images_set(self, clsss, idxs, with_aug=True):
class TripletsDataGenerator(ENDataGenerator):

def __init__(self, embedding_model,
dataset_path,
class_files_paths,
class_names,
n_batches = 10,
input_shape=None,
batch_size = 32,
csv_file=None,
image_id_column = 'image_id',
label_column = 'label',
augmentations=None,
k_classes=5,
k_samples=5,
margin=0.5,
negative_selection_mode='semihard'):
super().__init__(dataset_path, input_shape, batch_size, n_batches, csv_file,
image_id_column,label_column, augmentations)
super().__init__(class_files_paths=class_files_paths,
clas_names=class_names,
input_shape=input_shape,
batch_size=batch_size,
n_batches=n_batches,
augmentations=augmentations)
modes = {'semihard' : self.semihard_negative,
'hardest': self.hardest_negative,
'random_hard': self.random_hard_negative}
Expand Down Expand Up @@ -210,16 +223,18 @@ def __getitem__(self, index):


class SimpleTripletsDataGenerator(ENDataGenerator):
def __init__(self, dataset_path,
def __init__(self, class_files_paths,
class_names,
input_shape=None,
batch_size = 32,
n_batches = 10,
csv_file=None,
image_id_column = 'image_id',
label_column = 'label',
n_batches = 10,
augmentations=None):
super().__init__(dataset_path, input_shape, batch_size, n_batches, csv_file,
image_id_column,label_column, augmentations)
super().__init__(class_files_paths=class_files_paths,
clas_names=class_names,
input_shape=input_shape,
batch_size=batch_size,
n_batches=n_batches,
augmentations=augmentations)

def get_batch_triplets(self):
triplets = [np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], 3)),
Expand Down Expand Up @@ -261,17 +276,19 @@ def __getitem__(self, index):

class SiameseDataGenerator(ENDataGenerator):

def __init__(self, dataset_path,
def __init__(self, class_files_paths,
class_names,
input_shape=None,
batch_size = 32,
n_batches = 10,
csv_file=None,
image_id_column = 'image_id',
label_column = 'label',
n_batches = 10,
augmentations=None):

super().__init__(dataset_path, input_shape, batch_size, n_batches, dataframe,
image_id_column,label_column, augmentations)
super().__init__(class_files_paths=class_files_paths,
clas_names=class_names,
input_shape=input_shape,
batch_size=batch_size,
n_batches=n_batches,
augmentations=augmentations)

def get_batch_pairs(self):
pairs = [np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], 3)), np.zeros(
Expand Down Expand Up @@ -317,4 +334,44 @@ def get_batch_pairs(self):
return pairs, targets

def __getitem__(self, index):
return self.get_batch_pairs()
return self.get_batch_pairs()


class SimpleDataGenerator(ENDataGenerator):
def __init__(self, class_files_paths,
class_names,
input_shape=None,
batch_size = 32,
n_batches = 10,
augmentations=None):

super().__init__(class_files_paths=class_files_paths,
clas_names=class_names,
input_shape=input_shape,
batch_size=batch_size,
n_batches=n_batches,
augmentations=augmentations)

def get_batch(self):
images = [
np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], 3))]
targets = np.zeros((self.batch_size, self.n_classes))

count = 0
with_aug = self.augmentations
for i in range(self.batch_size):
selected_class_idx = random.randrange(0, self.n_classes)
selected_class = self.class_names[selected_class_idx]
selected_class_n_elements = len(self.class_files_paths[selected_class])

indx = random.randrange(0, selected_class_n_elements)

img = self._get_images_set([selected_class], [indx], with_aug=with_aug)
images[0][count, :, :, :] = img[0]
targets[i][selected_class_idx] = 1
count += 1

return images, targets

def __getitem__(self, index):
return self.get_batch()
Loading

0 comments on commit 7b6b20b

Please sign in to comment.