-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
46 lines (34 loc) · 1.31 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import tensorflow as tf
from tensorflow.python.ops.gen_dataset_ops import MapDataset
from tensorflow.keras.preprocessing import image_dataset_from_directory
from config import DATA, IMAGES, GROUND_TRUTH, WIDTH, HEIGHT, BATCH_SIZE
def configure_for_performance(ds: MapDataset) -> MapDataset:
ds = ds.cache()
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
def get_dataset() -> MapDataset:
normalization = tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255)
images_dir = os.path.join(DATA, IMAGES)
masks_dir = os.path.join(DATA, GROUND_TRUTH)
if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
print('Dir does not exists')
return
train_images = image_dataset_from_directory(
images_dir,
label_mode=None,
image_size=(HEIGHT, WIDTH),
batch_size=BATCH_SIZE,
shuffle=False
)
train_ground_truth = image_dataset_from_directory(
masks_dir,
label_mode=None,
image_size=(HEIGHT, WIDTH),
batch_size=BATCH_SIZE,
shuffle=False
)
dataset = tf.data.Dataset.zip((train_images, train_ground_truth))
dataset = dataset.map(lambda img, mask: (normalization(img), normalization(mask)))
dataset = configure_for_performance(dataset)
return dataset