Skip to content

Commit

Permalink
add dataloder
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketFlash committed Feb 14, 2020
1 parent 2c1f026 commit 1181cc8
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions embedding_net/datagenerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,15 @@
from itertools import combinations
from sklearn.metrics import pairwise_distances
from tensorflow.keras.utils import Sequence
from sklearn.model_selection import train_test_split

class ENDataGenerator(Sequence):
class ENDataLoader():
def __init__(self, dataset_path,
input_shape=None,
batch_size = 32,
n_batches = 10,
csv_file=None,
image_id_column = 'image_id',
label_column = 'label',
augmentations=None):

label_column = 'label'):

self.dataset_path = dataset_path
self.input_shape = input_shape
self.augmentations = augmentations
self.batch_size = batch_size
self.n_batches = n_batches
self.class_files_paths = {}
self.class_names = []

Expand All @@ -33,11 +26,14 @@ def __init__(self, dataset_path,
self.n_classes = len(self.class_names)
self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()}

def __len__(self):
return self.n_batches

def __getitem__(self, index):
pass
def split_train_val(self, val_ratio):
train_data = {}
val_data = {}
for k, v in self.class_files_paths.items():
train_d, val_d = train_test_split(v, test_size=val_ratio, random_state=42)
train_data[k] = train_d
val_data[k] = val_d
return train_data, val_data, self.class_names

def _load_from_dataframe(self, csv_file, image_id_column, label_column):
dataframe = pd.read_csv(csv_file)
Expand All @@ -56,7 +52,32 @@ def _load_from_directory(self):
(f.name.endswith('.jpg') or
f.name.endswith('.png') and
not f.name.startswith('._'))]
self.class_files_paths[class_name] = class_image_paths
self.class_files_paths[class_name] = class_image_paths


class ENDataGenerator(Sequence):
def __init__(self, class_files_paths,
class_names,
input_shape=None,
batch_size = 32,
n_batches = 10,
augmentations=None):

self.input_shape = input_shape
self.augmentations = augmentations
self.batch_size = batch_size
self.n_batches = n_batches
self.class_files_paths = class_files_paths
self.class_names = class_names

self.n_classes = len(self.class_names)
self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()}

def __len__(self):
return self.n_batches

def __getitem__(self, index):
pass

def get_image(self, img_path):
img = cv2.imread(img_path)
Expand Down

0 comments on commit 1181cc8

Please sign in to comment.