-
Notifications
You must be signed in to change notification settings - Fork 0
/
inpaint_ops.py
executable file
·553 lines (471 loc) · 17.9 KB
/
inpaint_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
import logging
import math
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import add_arg_scope
from PIL import Image, ImageDraw
from neuralgym.ops.layers import resize
from neuralgym.ops.layers import *
from neuralgym.ops.loss_ops import *
from neuralgym.ops.gan_ops import *
from neuralgym.ops.summary_ops import *
logger = logging.getLogger()
np.random.seed(2018)
@add_arg_scope
def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv',
padding='SAME', activation=tf.nn.elu, training=True):
"""Define conv for generator.
Args:
x: Input.
cnum: Channel number.
ksize: Kernel size.
Stride: Convolution stride.
Rate: Rate for or dilated conv.
name: Name of layers.
padding: Default to SYMMETRIC.
activation: Activation function after convolution.
training: If current graph is for training or inference, used for bn.
Returns:
tf.Tensor: output
"""
assert padding in ['SYMMETRIC', 'SAME', 'REFELECT']
if padding == 'SYMMETRIC' or padding == 'REFELECT':
p = int(rate*(ksize-1)/2)
x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
padding = 'VALID'
x = tf.layers.conv2d(
x, cnum, ksize, stride, dilation_rate=rate,
activation=None, padding=padding, name=name)
if cnum == 3 or activation is None:
# conv for output
return x
x, y = tf.split(x, 2, 3)
x = activation(x)
y = tf.nn.sigmoid(y)
x = x * y
return x
@add_arg_scope
def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True):
"""Define deconv for generator.
The deconv is defined to be a x2 resize_nearest_neighbor operation with
additional gen_conv operation.
Args:
x: Input.
cnum: Channel number.
name: Name of layers.
training: If current graph is for training or inference, used for bn.
Returns:
tf.Tensor: output
"""
with tf.variable_scope(name):
x = resize(x, func=tf.image.resize_nearest_neighbor)
x = gen_conv(
x, cnum, 3, 1, name=name+'_conv', padding=padding,
training=training)
return x
@add_arg_scope
def dis_conv(x, cnum, ksize=5, stride=2, name='conv', training=True):
"""Define conv for discriminator.
Activation is set to leaky_relu.
Args:
x: Input.
cnum: Channel number.
ksize: Kernel size.
Stride: Convolution stride.
name: Name of layers.
training: If current graph is for training or inference, used for bn.
Returns:
tf.Tensor: output
"""
x = conv2d_spectral_norm(x, cnum, ksize, stride, 'SAME', name=name)
x = tf.nn.leaky_relu(x)
return x
def random_bbox(FLAGS):
"""Generate a random tlhw.
Returns:
tuple: (top, left, height, width)
"""
img_shape = FLAGS.img_shapes
img_height = img_shape[0]
img_width = img_shape[1]
maxt = img_height - FLAGS.vertical_margin - FLAGS.height
maxl = img_width - FLAGS.horizontal_margin - FLAGS.width
t = tf.random_uniform(
[], minval=FLAGS.vertical_margin, maxval=maxt, dtype=tf.int32)
l = tf.random_uniform(
[], minval=FLAGS.horizontal_margin, maxval=maxl, dtype=tf.int32)
h = tf.constant(FLAGS.height)
w = tf.constant(FLAGS.width)
return (t, l, h, w)
def bbox2mask(FLAGS, bbox, name='mask'):
"""Generate mask tensor from bbox.
Args:
bbox: tuple, (top, left, height, width)
Returns:
tf.Tensor: output with shape [1, H, W, 1]
"""
def npmask(bbox, height, width, delta_h, delta_w):
mask = np.zeros((1, height, width, 1), np.float32)
h = np.random.randint(delta_h//2+1)
w = np.random.randint(delta_w//2+1)
mask[:, bbox[0]+h:bbox[0]+bbox[2]-h,
bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1.
return mask
with tf.variable_scope(name), tf.device('/cpu:0'):
img_shape = FLAGS.img_shapes
height = img_shape[0]
width = img_shape[1]
mask = tf.py_func(
npmask,
[bbox, height, width,
FLAGS.max_delta_height, FLAGS.max_delta_width],
tf.float32, stateful=False)
mask.set_shape([1] + [height, width] + [1])
return mask
def brush_stroke_mask(FLAGS, name='mask'):
"""Generate mask tensor from bbox.
Returns:
tf.Tensor: output with shape [1, H, W, 1]
"""
min_num_vertex = 4
max_num_vertex = 12
mean_angle = 2*math.pi / 5
angle_range = 2*math.pi / 15
min_width = 12
max_width = 40
def generate_mask(H, W):
average_radius = math.sqrt(H*H+W*W) / 8
mask = Image.new('L', (W, H), 0)
for _ in range(np.random.randint(1, 4)):
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
angle_min = mean_angle - np.random.uniform(0, angle_range)
angle_max = mean_angle + np.random.uniform(0, angle_range)
angles = []
vertex = []
for i in range(num_vertex):
if i % 2 == 0:
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
else:
angles.append(np.random.uniform(angle_min, angle_max))
h, w = mask.size
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
for i in range(num_vertex):
r = np.clip(
np.random.normal(loc=average_radius, scale=average_radius//2),
0, 2*average_radius)
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
vertex.append((int(new_x), int(new_y)))
draw = ImageDraw.Draw(mask)
width = int(np.random.uniform(min_width, max_width))
draw.line(vertex, fill=1, width=width)
for v in vertex:
draw.ellipse((v[0] - width//2,
v[1] - width//2,
v[0] + width//2,
v[1] + width//2),
fill=1)
if np.random.normal() > 0:
mask.transpose(Image.FLIP_LEFT_RIGHT)
if np.random.normal() > 0:
mask.transpose(Image.FLIP_TOP_BOTTOM)
mask = np.asarray(mask, np.float32)
mask = np.reshape(mask, (1, H, W, 1))
return mask
with tf.variable_scope(name), tf.device('/cpu:0'):
img_shape = FLAGS.img_shapes
height = img_shape[0]
width = img_shape[1]
mask = tf.py_func(
generate_mask,
[height, width],
tf.float32, stateful=True)
mask.set_shape([1] + [height, width] + [1])
return mask
def local_patch(x, bbox):
"""Crop local patch according to bbox.
Args:
x: input
bbox: (top, left, height, width)
Returns:
tf.Tensor: local patch
"""
x = tf.image.crop_to_bounding_box(x, bbox[0], bbox[1], bbox[2], bbox[3])
return x
def resize_mask_like(mask, x):
"""Resize mask like shape of x.
Args:
mask: Original mask.
x: To shape of x.
Returns:
tf.Tensor: resized mask
"""
mask_resize = resize(
mask, to_shape=x.get_shape().as_list()[1:3],
func=tf.image.resize_nearest_neighbor)
return mask_resize
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
fuse_k=3, softmax_scale=10., training=True, fuse=True):
""" Contextual attention layer implementation.
Contextual attention is first introduced in publication:
Generative Image Inpainting with Contextual Attention, Yu et al.
Args:
x: Input feature to match (foreground).
t: Input feature for match (background).
mask: Input mask for t, indicating patches not available.
ksize: Kernel size for contextual attention.
stride: Stride for extracting patches from t.
rate: Dilation for matching.
softmax_scale: Scaled softmax for attention.
training: Indicating if current graph is training or inference.
Returns:
tf.Tensor: output
"""
# get shapes
raw_fs = tf.shape(f)
raw_int_fs = f.get_shape().as_list()
raw_int_bs = b.get_shape().as_list()
# extract patches from background with stride and rate
kernel = 2*rate
raw_w = tf.extract_image_patches(
b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')
raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
# downscaling foreground option: downscaling both foreground and
# background for matching and use original background for reconstruction.
f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor) # https://github.com/tensorflow/tensorflow/issues/11651
if mask is not None:
mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
fs = tf.shape(f)
int_fs = f.get_shape().as_list()
f_groups = tf.split(f, int_fs[0], axis=0)
# from t(H*W*C) to w(b*k*k*c*h*w)
bs = tf.shape(b)
int_bs = b.get_shape().as_list()
w = tf.extract_image_patches(
b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
# process mask
if mask is None:
mask = tf.zeros([1, bs[1], bs[2], 1])
m = tf.extract_image_patches(
mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
m = tf.reshape(m, [1, -1, ksize, ksize, 1])
m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
m = m[0]
mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32)
w_groups = tf.split(w, int_bs[0], axis=0)
raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
y = []
offsets = []
k = fuse_k
scale = softmax_scale
fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
# conv for compare
wi = wi[0]
wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)
yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")
# conv implementation for fuse scores to encourage large patches
if fuse:
yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
yi = tf.transpose(yi, [0, 2, 1, 4, 3])
yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
yi = tf.transpose(yi, [0, 2, 1, 4, 3])
yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])
# softmax to match
yi *= mm # mask
yi = tf.nn.softmax(yi*scale, 3)
yi *= mm # mask
offset = tf.argmax(yi, axis=3, output_type=tf.int32)
offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
# deconv for patch pasting
# 3.1 paste center
wi_center = raw_wi[0]
yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
y.append(yi)
offsets.append(offset)
y = tf.concat(y, axis=0)
y.set_shape(raw_int_fs)
offsets = tf.concat(offsets, axis=0)
offsets.set_shape(int_bs[:3] + [2])
# case1: visualize optical flow: minus current position
h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
offsets = offsets - tf.concat([h_add, w_add], axis=3)
# to flow image
flow = flow_to_image_tf(offsets)
# # case2: visualize which pixels are attended
# flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
if rate != 1:
flow = resize(flow, scale=rate, func=tf.image.resize_bilinear)
return y, flow
def test_contextual_attention(args):
"""Test contextual attention layer with 3-channel image input
(instead of n-channel feature).
"""
import cv2
import os
# run on cpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
rate = 2
stride = 1
grid = rate*stride
b = cv2.imread(args.imageA)
b = cv2.resize(b, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)
h, w, _ = b.shape
b = b[:h//grid*grid, :w//grid*grid, :]
b = np.expand_dims(b, 0)
logger.info('Size of imageA: {}'.format(b.shape))
f = cv2.imread(args.imageB)
h, w, _ = f.shape
f = f[:h//grid*grid, :w//grid*grid, :]
f = np.expand_dims(f, 0)
logger.info('Size of imageB: {}'.format(f.shape))
with tf.Session() as sess:
bt = tf.constant(b, dtype=tf.float32)
ft = tf.constant(f, dtype=tf.float32)
yt, flow = contextual_attention(
ft, bt, stride=stride, rate=rate,
training=False, fuse=False)
y = sess.run(yt)
cv2.imwrite(args.imageOut, y[0])
def make_color_wheel():
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
colorwheel[col:col+YG, 1] = 255
col += YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
colorwheel[col:col+CB, 2] = 255
col += CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col+MR, 0] = 255
return colorwheel
COLORWHEEL = make_color_wheel()
def compute_color(u,v):
h, w = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
# colorwheel = COLORWHEEL
colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u**2+v**2)
a = np.arctan2(-v, -u) / np.pi
fk = (a+1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols+1] = 1
f = fk - k0
for i in range(np.size(colorwheel,1)):
tmp = colorwheel[:, i]
col0 = tmp[k0-1] / 255
col1 = tmp[k1-1] / 255
col = (1-f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1-rad[idx]*(1-col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
return img
def flow_to_image(flow):
"""Transfer flow map to image.
Part of code forked from flownet.
"""
out = []
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
maxrad = -1
for i in range(flow.shape[0]):
u = flow[i, :, :, 0]
v = flow[i, :, :, 1]
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
u[idxunknow] = 0
v[idxunknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(maxrad, np.max(rad))
u = u/(maxrad + np.finfo(float).eps)
v = v/(maxrad + np.finfo(float).eps)
img = compute_color(u, v)
out.append(img)
return np.float32(np.uint8(out))
def flow_to_image_tf(flow, name='flow_to_image'):
"""Tensorflow ops for computing flow to image.
"""
with tf.variable_scope(name), tf.device('/cpu:0'):
img = tf.py_func(flow_to_image, [flow], tf.float32, stateful=False)
img.set_shape(flow.get_shape().as_list()[0:-1]+[3])
img = img / 127.5 - 1.
return img
def highlight_flow(flow):
"""Convert flow into middlebury color code image.
"""
out = []
s = flow.shape
for i in range(flow.shape[0]):
img = np.ones((s[1], s[2], 3)) * 144.
u = flow[i, :, :, 0]
v = flow[i, :, :, 1]
for h in range(s[1]):
for w in range(s[1]):
ui = u[h,w]
vi = v[h,w]
img[ui, vi, :] = 255.
out.append(img)
return np.float32(np.uint8(out))
def highlight_flow_tf(flow, name='flow_to_image'):
"""Tensorflow ops for highlight flow.
"""
with tf.variable_scope(name), tf.device('/cpu:0'):
img = tf.py_func(highlight_flow, [flow], tf.float32, stateful=False)
img.set_shape(flow.get_shape().as_list()[0:-1]+[3])
img = img / 127.5 - 1.
return img
def image2edge(image):
"""Convert image to edges.
"""
out = []
for i in range(image.shape[0]):
img = cv2.Laplacian(image[i, :, :, :], cv2.CV_64F, ksize=3, scale=2)
out.append(img)
return np.float32(np.uint8(out))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--imageA', default='', type=str, help='Image A as background patches to reconstruct image B.')
parser.add_argument('--imageB', default='', type=str, help='Image B is reconstructed with image A.')
parser.add_argument('--imageOut', default='result.png', type=str, help='Image B is reconstructed with image A.')
args = parser.parse_args()
test_contextual_attention(args)