From 1181cc867084b9a0917ed7f2b25ba805c011b913 Mon Sep 17 00:00:00 2001 From: rauf Date: Fri, 14 Feb 2020 18:32:14 +0300 Subject: [PATCH] add dataloder --- embedding_net/datagenerators.py | 55 +++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/embedding_net/datagenerators.py b/embedding_net/datagenerators.py index 5cb2797..2a032ce 100644 --- a/embedding_net/datagenerators.py +++ b/embedding_net/datagenerators.py @@ -6,22 +6,15 @@ from itertools import combinations from sklearn.metrics import pairwise_distances from tensorflow.keras.utils import Sequence +from sklearn.model_selection import train_test_split -class ENDataGenerator(Sequence): +class ENDataLoader(): def __init__(self, dataset_path, - input_shape=None, - batch_size = 32, - n_batches = 10, csv_file=None, image_id_column = 'image_id', - label_column = 'label', - augmentations=None): - + label_column = 'label'): + self.dataset_path = dataset_path - self.input_shape = input_shape - self.augmentations = augmentations - self.batch_size = batch_size - self.n_batches = n_batches self.class_files_paths = {} self.class_names = [] @@ -33,11 +26,14 @@ def __init__(self, dataset_path, self.n_classes = len(self.class_names) self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()} - def __len__(self): - return self.n_batches - - def __getitem__(self, index): - pass + def split_train_val(self, val_ratio): + train_data = {} + val_data = {} + for k, v in self.class_files_paths.items(): + train_d, val_d = train_test_split(v, test_size=val_ratio, random_state=42) + train_data[k] = train_d + val_data[k] = val_d + return train_data, val_data, self.class_names def _load_from_dataframe(self, csv_file, image_id_column, label_column): dataframe = pd.read_csv(csv_file) @@ -56,7 +52,32 @@ def _load_from_directory(self): (f.name.endswith('.jpg') or f.name.endswith('.png') and not f.name.startswith('._'))] - self.class_files_paths[class_name] = class_image_paths + self.class_files_paths[class_name] = class_image_paths + + +class ENDataGenerator(Sequence): + def __init__(self, class_files_paths, + class_names, + input_shape=None, + batch_size = 32, + n_batches = 10, + augmentations=None): + + self.input_shape = input_shape + self.augmentations = augmentations + self.batch_size = batch_size + self.n_batches = n_batches + self.class_files_paths = class_files_paths + self.class_names = class_names + + self.n_classes = len(self.class_names) + self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()} + + def __len__(self): + return self.n_batches + + def __getitem__(self, index): + pass def get_image(self, img_path): img = cv2.imread(img_path)