-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
76 lines (57 loc) · 2.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from glob import glob
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam
import numpy as np
from models.unet import unet
from models.unetWavelet import unetWavelet
from utils import CustomGenerator, psnr
from numpy.random import seed
seed(1)
def train(FLAGS):
### Parameters of the training
# Can be freely changed
epochs = 100 #50
train_ratio = 0.8
batch_size = 16
#cropShape = (192,192,3)
cropShape = (64,64,3)
early_stop = 15 #10
# The learning rate is decayed from 10e-4 to 10e-5 over 200 epochs.
optimizer = Adam(lr=10e-4, beta_1=0.9, beta_2=0.999, epsilon=10e-8, decay=10e-3)
train_path = FLAGS.train_path
valid_path = FLAGS.test_path
#noise_level = FLAGS.noise
architecture = FLAGS.architecture
if architecture == "unet":
filepath = "weights/DenoisingUnet.h5"
model = unet(cropShape)
elif architecture == "wavelet":
filepath = "weights/DenoisingWavelet.h5"
model = unetWavelet(cropShape)
else :
raise RuntimeError('Unkwown architecture, please choose from "unet" or "wavelet".')
#model.summary()
model.compile(loss='mean_squared_error', optimizer=optimizer, metrics=[psnr])
### Preparation of the dataset, building the generators
train_files =[img_path for img_path in glob(train_path + '/*.png')]
test_files =[img_path for img_path in glob(valid_path + '/*.png')]
train_generator = CustomGenerator(train_files[0:int(len(train_files)*train_ratio)], batch_size, cropShape)
valid_generator = CustomGenerator(train_files[int(len(train_files)*train_ratio):], batch_size, cropShape)
test_generator = CustomGenerator(test_files, batch_size, cropShape)
### Train the model
model.fit_generator(
train_generator,
epochs=epochs,
verbose=2,
callbacks = [
ModelCheckpoint("temp.h5", monitor='val_loss', verbose=0, save_best_only=True, mode='auto'),
EarlyStopping(monitor='val_loss', patience=early_stop, verbose=0, mode='auto')
],
validation_data=valid_generator
)
model.load_weights("temp.h5")
model.save(filepath)
print("Training done")
### Evaluate the model on the test dataset
result = model.evaluate_generator(test_generator)
print(result)