-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathautoencoder_training.py
70 lines (54 loc) · 2.84 KB
/
autoencoder_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import datetime
import models
from utils import normalize, lr_schedule
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def main():
# feel free to remove this when training
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
encoder = models.Encoder_Gene()
rms = tf.keras.optimizers.Adam(lr=0.001)
encoder.compile(optimizer=rms, loss=tf.keras.losses.Huber(delta=1.0))
decoder = models.Decoder_Gene()
rms = tf.keras.optimizers.Adam(lr=0.001)
decoder.compile(optimizer=rms, loss=tf.keras.losses.Huber(delta=1.0))
ultra = tf.keras.layers.Input((128, 128, 1), name='Original_Image')
x = encoder(ultra)
out = decoder(x)
model = Model(inputs=ultra, outputs=out, name='ConvAE')
rms = tf.keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=rms, loss=tf.keras.losses.Huber(delta=1.0), metrics=['mse'])
train_datagen = ImageDataGenerator(horizontal_flip=True,
vertical_flip=False,
fill_mode='nearest',
preprocessing_function=normalize)
train_generator = train_datagen.flow_from_directory("data/train",
target_size=(128, 128),
batch_size=1,
color_mode='grayscale',
class_mode="input")
validation_datagen = ImageDataGenerator(preprocessing_function=normalize)
validation_generator = validation_datagen.flow_from_directory("data/val",
target_size=(128, 128),
batch_size=32,
color_mode='grayscale',
class_mode="input")
log_dir = "AE_logs\\fit\\" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
model.fit(train_generator,
epochs=120,
validation_data=validation_generator,
callbacks=[tensorboard_callback, lr_callback])
path = os.path.join("AE_weights")
if not os.path.exists(path):
os.mkdir(path)
encoder.save_weights("AE_weights/encoder.h5")
decoder.save_weights("AE_weights/decoder.h5")
print("Saved model to disk")
if __name__ == '__main__':
main()