-
Notifications
You must be signed in to change notification settings - Fork 370
/
train_model.py
240 lines (155 loc) · 9.09 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# python train_model.py model={iphone,sony,blackberry} dped_dir=dped vgg_dir=vgg_pretrained/imagenet-vgg-verydeep-19.mat
import tensorflow as tf
import imageio
import numpy as np
import sys
tf.compat.v1.disable_v2_behavior()
from load_dataset import load_test_data, load_batch
from ssim import MultiScaleSSIM
import models
import utils
import vgg
# defining size of the training image patches
PATCH_WIDTH = 100
PATCH_HEIGHT = 100
PATCH_SIZE = PATCH_WIDTH * PATCH_HEIGHT * 3
# processing command arguments
phone, batch_size, train_size, learning_rate, num_train_iters, \
w_content, w_color, w_texture, w_tv, \
dped_dir, vgg_dir, eval_step = utils.process_command_args(sys.argv)
np.random.seed(0)
# defining system architecture
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
# placeholders for training data
phone_ = tf.compat.v1.placeholder(tf.float32, [None, PATCH_SIZE])
phone_image = tf.reshape(phone_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])
dslr_ = tf.compat.v1.placeholder(tf.float32, [None, PATCH_SIZE])
dslr_image = tf.reshape(dslr_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])
adv_ = tf.compat.v1.placeholder(tf.float32, [None, 1])
# get processed enhanced image
enhanced = models.resnet(phone_image)
# transform both dslr and enhanced images to grayscale
enhanced_gray = tf.reshape(tf.image.rgb_to_grayscale(enhanced), [-1, PATCH_WIDTH * PATCH_HEIGHT])
dslr_gray = tf.reshape(tf.image.rgb_to_grayscale(dslr_image),[-1, PATCH_WIDTH * PATCH_HEIGHT])
# push randomly the enhanced or dslr image to an adversarial CNN-discriminator
adversarial_ = tf.multiply(enhanced_gray, 1 - adv_) + tf.multiply(dslr_gray, adv_)
adversarial_image = tf.reshape(adversarial_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 1])
discrim_predictions = models.adversarial(adversarial_image)
# losses
# 1) texture (adversarial) loss
discrim_target = tf.concat([adv_, 1 - adv_], 1)
loss_discrim = -tf.reduce_sum(discrim_target * tf.compat.v1.log(tf.clip_by_value(discrim_predictions, 1e-10, 1.0)))
loss_texture = -loss_discrim
correct_predictions = tf.equal(tf.argmax(discrim_predictions, 1), tf.argmax(discrim_target, 1))
discim_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
# 2) content loss
CONTENT_LAYER = 'relu5_4'
enhanced_vgg = vgg.net(vgg_dir, vgg.preprocess(enhanced * 255))
dslr_vgg = vgg.net(vgg_dir, vgg.preprocess(dslr_image * 255))
content_size = utils._tensor_size(dslr_vgg[CONTENT_LAYER]) * batch_size
loss_content = 2 * tf.nn.l2_loss(enhanced_vgg[CONTENT_LAYER] - dslr_vgg[CONTENT_LAYER]) / content_size
# 3) color loss
enhanced_blur = utils.blur(enhanced)
dslr_blur = utils.blur(dslr_image)
loss_color = tf.reduce_sum(tf.pow(dslr_blur - enhanced_blur, 2))/(2 * batch_size)
# 4) total variation loss
batch_shape = (batch_size, PATCH_WIDTH, PATCH_HEIGHT, 3)
tv_y_size = utils._tensor_size(enhanced[:,1:,:,:])
tv_x_size = utils._tensor_size(enhanced[:,:,1:,:])
y_tv = tf.nn.l2_loss(enhanced[:,1:,:,:] - enhanced[:,:batch_shape[1]-1,:,:])
x_tv = tf.nn.l2_loss(enhanced[:,:,1:,:] - enhanced[:,:,:batch_shape[2]-1,:])
loss_tv = 2 * (x_tv/tv_x_size + y_tv/tv_y_size) / batch_size
# final loss
loss_generator = w_content * loss_content + w_texture * loss_texture + w_color * loss_color + w_tv * loss_tv
# psnr loss
enhanced_flat = tf.reshape(enhanced, [-1, PATCH_SIZE])
loss_mse = tf.reduce_sum(tf.pow(dslr_ - enhanced_flat, 2))/(PATCH_SIZE * batch_size)
loss_psnr = 20 * utils.log10(1.0 / tf.sqrt(loss_mse))
# optimize parameters of image enhancement (generator) and discriminator networks
generator_vars = [v for v in tf.compat.v1.global_variables() if v.name.startswith("generator")]
discriminator_vars = [v for v in tf.compat.v1.global_variables() if v.name.startswith("discriminator")]
train_step_gen = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_generator, var_list=generator_vars)
train_step_disc = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_discrim, var_list=discriminator_vars)
saver = tf.compat.v1.train.Saver(var_list=generator_vars, max_to_keep=100)
print('Initializing variables')
sess.run(tf.compat.v1.global_variables_initializer())
# loading training and test data
print("Loading test data...")
test_data, test_answ = load_test_data(phone, dped_dir, PATCH_SIZE)
print("Test data was loaded\n")
print("Loading training data...")
train_data, train_answ = load_batch(phone, dped_dir, train_size, PATCH_SIZE)
print("Training data was loaded\n")
TEST_SIZE = test_data.shape[0]
num_test_batches = int(test_data.shape[0] / batch_size)
print('Training network')
train_loss_gen = 0.0
train_acc_discrim = 0.0
all_zeros = np.reshape(np.zeros((batch_size, 1)), [batch_size, 1])
test_crops = test_data[np.random.randint(0, TEST_SIZE, 5), :]
logs = open('models/' + phone + '.txt', "w+")
logs.close()
for i in range(num_train_iters):
# train generator
idx_train = np.random.randint(0, train_size, batch_size)
phone_images = train_data[idx_train]
dslr_images = train_answ[idx_train]
[loss_temp, temp] = sess.run([loss_generator, train_step_gen],
feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: all_zeros})
train_loss_gen += loss_temp / eval_step
# train discriminator
idx_train = np.random.randint(0, train_size, batch_size)
# generate image swaps (dslr or enhanced) for discriminator
swaps = np.reshape(np.random.randint(0, 2, batch_size), [batch_size, 1])
phone_images = train_data[idx_train]
dslr_images = train_answ[idx_train]
[accuracy_temp, temp] = sess.run([discim_accuracy, train_step_disc],
feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
train_acc_discrim += accuracy_temp / eval_step
if i % eval_step == 0:
# test generator and discriminator CNNs
test_losses_gen = np.zeros((1, 6))
test_accuracy_disc = 0.0
loss_ssim = 0.0
for j in range(num_test_batches):
be = j * batch_size
en = (j+1) * batch_size
swaps = np.reshape(np.random.randint(0, 2, batch_size), [batch_size, 1])
phone_images = test_data[be:en]
dslr_images = test_answ[be:en]
[enhanced_crops, accuracy_disc, losses] = sess.run([enhanced, discim_accuracy, \
[loss_generator, loss_content, loss_color, loss_texture, loss_tv, loss_psnr]], \
feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
test_losses_gen += np.asarray(losses) / num_test_batches
test_accuracy_disc += accuracy_disc / num_test_batches
loss_ssim += MultiScaleSSIM(np.reshape(dslr_images * 255, [batch_size, PATCH_HEIGHT, PATCH_WIDTH, 3]),
enhanced_crops * 255) / num_test_batches
logs_disc = "step %d, %s | discriminator accuracy | train: %.4g, test: %.4g" % \
(i, phone, train_acc_discrim, test_accuracy_disc)
logs_gen = "generator losses | train: %.4g, test: %.4g | content: %.4g, color: %.4g, texture: %.4g, tv: %.4g | psnr: %.4g, ms-ssim: %.4g\n" % \
(train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2],
test_losses_gen[0][3], test_losses_gen[0][4], test_losses_gen[0][5], loss_ssim)
print(logs_disc)
print(logs_gen)
# save the results to log file
logs = open('models/' + phone + '.txt', "a")
logs.write(logs_disc)
logs.write('\n')
logs.write(logs_gen)
logs.write('\n')
logs.close()
# save visual results for several test image crops
enhanced_crops = sess.run(enhanced, feed_dict={phone_: test_crops, dslr_: dslr_images, adv_: all_zeros})
idx = 0
for crop in enhanced_crops:
before_after = np.hstack((np.reshape(test_crops[idx], [PATCH_HEIGHT, PATCH_WIDTH, 3]), crop))
imageio.imwrite('results/' + str(phone)+ "_" + str(idx) + '_iteration_' + str(i) + '.jpg', before_after)
idx += 1
train_loss_gen = 0.0
train_acc_discrim = 0.0
# save the model that corresponds to the current iteration
saver.save(sess, 'models/' + str(phone) + '_iteration_' + str(i) + '.ckpt', write_meta_graph=False)
# reload a different batch of training data
del train_data
del train_answ
train_data, train_answ = load_batch(phone, dped_dir, train_size, PATCH_SIZE)