forked from mrharicot/monodepth
-
Notifications
You must be signed in to change notification settings - Fork 1
/
monodepth_model.py
390 lines (317 loc) · 17 KB
/
monodepth_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
# Copyright UCL Business plc 2017. Patent Pending. All rights reserved.
#
# The MonoDepth Software is licensed under the terms of the UCLB ACP-A licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
#
# For any other use of the software not covered by the UCLB ACP-A Licence,
# please contact info@uclb.com
"""Fully convolutional model for monocular depth estimation
by Clement Godard, Oisin Mac Aodha and Gabriel J. Brostow
http://visual.cs.ucl.ac.uk/pubs/monoDepth/
"""
from __future__ import absolute_import, division, print_function
from collections import namedtuple
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from bilinear_sampler import *
monodepth_parameters = namedtuple('parameters',
'encoder, '
'height, width, '
'batch_size, '
'num_threads, '
'num_epochs, '
'do_stereo, '
'wrap_mode, '
'use_deconv, '
'alpha_image_loss, '
'disp_gradient_loss_weight, '
'lr_loss_weight, '
'full_summary')
class MonodepthModel(object):
"""monodepth model"""
def __init__(self, params, mode, left, right, reuse_variables=None, model_index=0):
self.params = params
self.mode = mode
self.left = left
self.right = right
self.model_collection = ['model_' + str(model_index)]
self.reuse_variables = reuse_variables
self.build_model()
self.build_outputs()
if self.mode == 'test':
return
self.build_losses()
self.build_summaries()
def gradient_x(self, img):
gx = img[:,:,:-1,:] - img[:,:,1:,:]
return gx
def gradient_y(self, img):
gy = img[:,:-1,:,:] - img[:,1:,:,:]
return gy
def upsample_nn(self, x, ratio):
s = tf.shape(x)
h = s[1]
w = s[2]
return tf.image.resize_nearest_neighbor(x, [h * ratio, w * ratio])
def scale_pyramid(self, img, num_scales):
scaled_imgs = [img]
s = tf.shape(img)
h = s[1]
w = s[2]
for i in range(num_scales - 1):
ratio = 2 ** (i + 1)
nh = h // ratio
nw = w // ratio
scaled_imgs.append(tf.image.resize_area(img, [nh, nw]))
return scaled_imgs
def generate_image_left(self, img, disp):
return bilinear_sampler_1d_h(img, -disp)
def generate_image_right(self, img, disp):
return bilinear_sampler_1d_h(img, disp)
def SSIM(self, x, y):
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu_x = slim.avg_pool2d(x, 3, 1, 'VALID')
mu_y = slim.avg_pool2d(y, 3, 1, 'VALID')
sigma_x = slim.avg_pool2d(x ** 2, 3, 1, 'VALID') - mu_x ** 2
sigma_y = slim.avg_pool2d(y ** 2, 3, 1, 'VALID') - mu_y ** 2
sigma_xy = slim.avg_pool2d(x * y , 3, 1, 'VALID') - mu_x * mu_y
SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2)
SSIM = SSIM_n / SSIM_d
return tf.clip_by_value((1 - SSIM) / 2, 0, 1)
def get_disparity_smoothness(self, disp, pyramid):
disp_gradients_x = [self.gradient_x(d) for d in disp]
disp_gradients_y = [self.gradient_y(d) for d in disp]
image_gradients_x = [self.gradient_x(img) for img in pyramid]
image_gradients_y = [self.gradient_y(img) for img in pyramid]
weights_x = [tf.exp(-tf.reduce_mean(tf.abs(g), 3, keep_dims=True)) for g in image_gradients_x]
weights_y = [tf.exp(-tf.reduce_mean(tf.abs(g), 3, keep_dims=True)) for g in image_gradients_y]
smoothness_x = [disp_gradients_x[i] * weights_x[i] for i in range(4)]
smoothness_y = [disp_gradients_y[i] * weights_y[i] for i in range(4)]
return smoothness_x + smoothness_y
def get_disp(self, x):
disp = 0.3 * self.conv(x, 2, 3, 1, tf.nn.sigmoid)
return disp
def conv(self, x, num_out_layers, kernel_size, stride, activation_fn=tf.nn.elu):
p = np.floor((kernel_size - 1) / 2).astype(np.int32)
p_x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]])
return slim.conv2d(p_x, num_out_layers, kernel_size, stride, 'VALID', activation_fn=activation_fn)
def conv_block(self, x, num_out_layers, kernel_size):
conv1 = self.conv(x, num_out_layers, kernel_size, 1)
conv2 = self.conv(conv1, num_out_layers, kernel_size, 2)
return conv2
def maxpool(self, x, kernel_size):
p = np.floor((kernel_size - 1) / 2).astype(np.int32)
p_x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]])
return slim.max_pool2d(p_x, kernel_size)
def resconv(self, x, num_layers, stride):
do_proj = tf.shape(x)[3] != num_layers or stride == 2
shortcut = []
conv1 = self.conv(x, num_layers, 1, 1)
conv2 = self.conv(conv1, num_layers, 3, stride)
conv3 = self.conv(conv2, 4 * num_layers, 1, 1, None)
if do_proj:
shortcut = self.conv(x, 4 * num_layers, 1, stride, None)
else:
shortcut = x
return tf.nn.elu(conv3 + shortcut)
def resblock(self, x, num_layers, num_blocks):
out = x
for i in range(num_blocks - 1):
out = self.resconv(out, num_layers, 1)
out = self.resconv(out, num_layers, 2)
return out
def upconv(self, x, num_out_layers, kernel_size, scale):
upsample = self.upsample_nn(x, scale)
conv = self.conv(upsample, num_out_layers, kernel_size, 1)
return conv
def deconv(self, x, num_out_layers, kernel_size, scale):
p_x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]])
conv = slim.conv2d_transpose(p_x, num_out_layers, kernel_size, scale, 'SAME')
return conv[:,3:-1,3:-1,:]
def build_vgg(self):
#set convenience functions
conv = self.conv
if self.params.use_deconv:
upconv = self.deconv
else:
upconv = self.upconv
with tf.variable_scope('encoder'):
conv1 = self.conv_block(self.model_input, 32, 7) # H/2
conv2 = self.conv_block(conv1, 64, 5) # H/4
conv3 = self.conv_block(conv2, 128, 3) # H/8
conv4 = self.conv_block(conv3, 256, 3) # H/16
conv5 = self.conv_block(conv4, 512, 3) # H/32
conv6 = self.conv_block(conv5, 512, 3) # H/64
conv7 = self.conv_block(conv6, 512, 3) # H/128
with tf.variable_scope('skips'):
skip1 = conv1
skip2 = conv2
skip3 = conv3
skip4 = conv4
skip5 = conv5
skip6 = conv6
with tf.variable_scope('decoder'):
upconv7 = upconv(conv7, 512, 3, 2) #H/64
concat7 = tf.concat([upconv7, skip6], 3)
iconv7 = conv(concat7, 512, 3, 1)
upconv6 = upconv(iconv7, 512, 3, 2) #H/32
concat6 = tf.concat([upconv6, skip5], 3)
iconv6 = conv(concat6, 512, 3, 1)
upconv5 = upconv(iconv6, 256, 3, 2) #H/16
concat5 = tf.concat([upconv5, skip4], 3)
iconv5 = conv(concat5, 256, 3, 1)
upconv4 = upconv(iconv5, 128, 3, 2) #H/8
concat4 = tf.concat([upconv4, skip3], 3)
iconv4 = conv(concat4, 128, 3, 1)
self.disp4 = self.get_disp(iconv4)
udisp4 = self.upsample_nn(self.disp4, 2)
upconv3 = upconv(iconv4, 64, 3, 2) #H/4
concat3 = tf.concat([upconv3, skip2, udisp4], 3)
iconv3 = conv(concat3, 64, 3, 1)
self.disp3 = self.get_disp(iconv3)
udisp3 = self.upsample_nn(self.disp3, 2)
upconv2 = upconv(iconv3, 32, 3, 2) #H/2
concat2 = tf.concat([upconv2, skip1, udisp3], 3)
iconv2 = conv(concat2, 32, 3, 1)
self.disp2 = self.get_disp(iconv2)
udisp2 = self.upsample_nn(self.disp2, 2)
upconv1 = upconv(iconv2, 16, 3, 2) #H
concat1 = tf.concat([upconv1, udisp2], 3)
iconv1 = conv(concat1, 16, 3, 1)
self.disp1 = self.get_disp(iconv1)
def build_resnet50(self):
#set convenience functions
conv = self.conv
if self.params.use_deconv:
upconv = self.deconv
else:
upconv = self.upconv
with tf.variable_scope('encoder'):
conv1 = conv(self.model_input, 64, 7, 2) # H/2 - 64D
pool1 = self.maxpool(conv1, 3) # H/4 - 64D
conv2 = self.resblock(pool1, 64, 3) # H/8 - 256D
conv3 = self.resblock(conv2, 128, 4) # H/16 - 512D
conv4 = self.resblock(conv3, 256, 6) # H/32 - 1024D
conv5 = self.resblock(conv4, 512, 3) # H/64 - 2048D
with tf.variable_scope('skips'):
skip1 = conv1
skip2 = pool1
skip3 = conv2
skip4 = conv3
skip5 = conv4
# DECODING
with tf.variable_scope('decoder'):
upconv6 = upconv(conv5, 512, 3, 2) #H/32
concat6 = tf.concat([upconv6, skip5], 3)
iconv6 = conv(concat6, 512, 3, 1)
upconv5 = upconv(iconv6, 256, 3, 2) #H/16
concat5 = tf.concat([upconv5, skip4], 3)
iconv5 = conv(concat5, 256, 3, 1)
upconv4 = upconv(iconv5, 128, 3, 2) #H/8
concat4 = tf.concat([upconv4, skip3], 3)
iconv4 = conv(concat4, 128, 3, 1)
self.disp4 = self.get_disp(iconv4)
udisp4 = self.upsample_nn(self.disp4, 2)
upconv3 = upconv(iconv4, 64, 3, 2) #H/4
concat3 = tf.concat([upconv3, skip2, udisp4], 3)
iconv3 = conv(concat3, 64, 3, 1)
self.disp3 = self.get_disp(iconv3)
udisp3 = self.upsample_nn(self.disp3, 2)
upconv2 = upconv(iconv3, 32, 3, 2) #H/2
concat2 = tf.concat([upconv2, skip1, udisp3], 3)
iconv2 = conv(concat2, 32, 3, 1)
self.disp2 = self.get_disp(iconv2)
udisp2 = self.upsample_nn(self.disp2, 2)
upconv1 = upconv(iconv2, 16, 3, 2) #H
concat1 = tf.concat([upconv1, udisp2], 3)
iconv1 = conv(concat1, 16, 3, 1)
self.disp1 = self.get_disp(iconv1)
def build_model(self):
with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], activation_fn=tf.nn.elu):
with tf.variable_scope('model', reuse=self.reuse_variables):
self.left_pyramid = self.scale_pyramid(self.left, 4)
if self.mode == 'train':
self.right_pyramid = self.scale_pyramid(self.right, 4)
if self.params.do_stereo:
self.model_input = tf.concat([self.left, self.right], 3)
else:
self.model_input = self.left
#build model
if self.params.encoder == 'vgg':
self.build_vgg()
elif self.params.encoder == 'resnet50':
self.build_resnet50()
else:
return None
def build_outputs(self):
# STORE DISPARITIES
with tf.variable_scope('disparities'):
self.disp_est = [self.disp1, self.disp2, self.disp3, self.disp4]
self.disp_left_est = [tf.expand_dims(d[:,:,:,0], 3) for d in self.disp_est]
self.disp_right_est = [tf.expand_dims(d[:,:,:,1], 3) for d in self.disp_est]
if self.mode == 'test':
return
# GENERATE IMAGES
with tf.variable_scope('images'):
self.left_est = [self.generate_image_left(self.right_pyramid[i], self.disp_left_est[i]) for i in range(4)]
self.right_est = [self.generate_image_right(self.left_pyramid[i], self.disp_right_est[i]) for i in range(4)]
# LR CONSISTENCY
with tf.variable_scope('left-right'):
self.right_to_left_disp = [self.generate_image_left(self.disp_right_est[i], self.disp_left_est[i]) for i in range(4)]
self.left_to_right_disp = [self.generate_image_right(self.disp_left_est[i], self.disp_right_est[i]) for i in range(4)]
# DISPARITY SMOOTHNESS
with tf.variable_scope('smoothness'):
self.disp_left_smoothness = self.get_disparity_smoothness(self.disp_left_est, self.left_pyramid)
self.disp_right_smoothness = self.get_disparity_smoothness(self.disp_right_est, self.right_pyramid)
def build_losses(self):
with tf.variable_scope('losses', reuse=self.reuse_variables):
# IMAGE RECONSTRUCTION
# L1
self.l1_left = [tf.abs( self.left_est[i] - self.left_pyramid[i]) for i in range(4)]
self.l1_reconstruction_loss_left = [tf.reduce_mean(l) for l in self.l1_left]
self.l1_right = [tf.abs(self.right_est[i] - self.right_pyramid[i]) for i in range(4)]
self.l1_reconstruction_loss_right = [tf.reduce_mean(l) for l in self.l1_right]
# SSIM
self.ssim_left = [self.SSIM( self.left_est[i], self.left_pyramid[i]) for i in range(4)]
self.ssim_loss_left = [tf.reduce_mean(s) for s in self.ssim_left]
self.ssim_right = [self.SSIM(self.right_est[i], self.right_pyramid[i]) for i in range(4)]
self.ssim_loss_right = [tf.reduce_mean(s) for s in self.ssim_right]
# WEIGTHED SUM
self.image_loss_right = [self.params.alpha_image_loss * self.ssim_loss_right[i] + (1 - self.params.alpha_image_loss) * self.l1_reconstruction_loss_right[i] for i in range(4)]
self.image_loss_left = [self.params.alpha_image_loss * self.ssim_loss_left[i] + (1 - self.params.alpha_image_loss) * self.l1_reconstruction_loss_left[i] for i in range(4)]
self.image_loss = tf.add_n(self.image_loss_left + self.image_loss_right)
# DISPARITY SMOOTHNESS
self.disp_left_loss = [tf.reduce_mean(tf.abs(self.disp_left_smoothness[i])) / 2 ** i for i in range(4)]
self.disp_right_loss = [tf.reduce_mean(tf.abs(self.disp_right_smoothness[i])) / 2 ** i for i in range(4)]
self.disp_gradient_loss = tf.add_n(self.disp_left_loss + self.disp_right_loss)
# LR CONSISTENCY
self.lr_left_loss = [tf.reduce_mean(tf.abs(self.right_to_left_disp[i] - self.disp_left_est[i])) for i in range(4)]
self.lr_right_loss = [tf.reduce_mean(tf.abs(self.left_to_right_disp[i] - self.disp_right_est[i])) for i in range(4)]
self.lr_loss = tf.add_n(self.lr_left_loss + self.lr_right_loss)
# TOTAL LOSS
self.total_loss = self.image_loss + self.params.disp_gradient_loss_weight * self.disp_gradient_loss + self.params.lr_loss_weight * self.lr_loss
def build_summaries(self):
# SUMMARIES
with tf.device('/cpu:0'):
for i in range(4):
tf.summary.scalar('ssim_loss_' + str(i), self.ssim_loss_left[i] + self.ssim_loss_right[i], collections=self.model_collection)
tf.summary.scalar('l1_loss_' + str(i), self.l1_reconstruction_loss_left[i] + self.l1_reconstruction_loss_right[i], collections=self.model_collection)
tf.summary.scalar('image_loss_' + str(i), self.image_loss_left[i] + self.image_loss_right[i], collections=self.model_collection)
tf.summary.scalar('disp_gradient_loss_' + str(i), self.disp_left_loss[i] + self.disp_right_loss[i], collections=self.model_collection)
tf.summary.scalar('lr_loss_' + str(i), self.lr_left_loss[i] + self.lr_right_loss[i], collections=self.model_collection)
tf.summary.image('disp_left_est_' + str(i), self.disp_left_est[i], max_outputs=4, collections=self.model_collection)
tf.summary.image('disp_right_est_' + str(i), self.disp_right_est[i], max_outputs=4, collections=self.model_collection)
if self.params.full_summary:
tf.summary.image('left_est_' + str(i), self.left_est[i], max_outputs=4, collections=self.model_collection)
tf.summary.image('right_est_' + str(i), self.right_est[i], max_outputs=4, collections=self.model_collection)
tf.summary.image('ssim_left_' + str(i), self.ssim_left[i], max_outputs=4, collections=self.model_collection)
tf.summary.image('ssim_right_' + str(i), self.ssim_right[i], max_outputs=4, collections=self.model_collection)
tf.summary.image('l1_left_' + str(i), self.l1_left[i], max_outputs=4, collections=self.model_collection)
tf.summary.image('l1_right_' + str(i), self.l1_right[i], max_outputs=4, collections=self.model_collection)
if self.params.full_summary:
tf.summary.image('left', self.left, max_outputs=4, collections=self.model_collection)
tf.summary.image('right', self.right, max_outputs=4, collections=self.model_collection)