forked from senguptaumd/Background-Matting
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_background-matting_image.py
201 lines (150 loc) · 7.34 KB
/
test_background-matting_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from __future__ import print_function
import os, glob, time, argparse, pdb, cv2
#import matplotlib.pyplot as plt
import numpy as np
from skimage.measure import label
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from functions import *
from networks import ResnetConditionHR
torch.set_num_threads(1)
#os.environ["CUDA_VISIBLE_DEVICES"]="4"
#print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Background Matting.')
parser.add_argument('-m', '--trained_model', type=str, default='real-fixed-cam',choices=['real-fixed-cam', 'real-hand-held', 'syn-comp-adobe'],help='Trained background matting model')
parser.add_argument('-o', '--output_dir', type=str, required=True,help='Directory to save the output results. (required)')
parser.add_argument('-i', '--input_dir', type=str, required=True,help='Directory to load input images. (required)')
parser.add_argument('-tb', '--target_back', type=str,help='Directory to load the target background.')
parser.add_argument('-b', '--back', type=str,default=None,help='Captured background image. (only use for inference on videos with fixed camera')
args=parser.parse_args()
#input model
model_main_dir='Models/' + args.trained_model + '/';
#input data path
data_path=args.input_dir
if os.path.isdir(args.target_back):
args.video=True
print('Using video mode')
else:
args.video=False
print('Using image mode')
#target background path
back_img10=cv2.imread(args.target_back); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;
#initialize network
fo=glob.glob(model_main_dir + 'netG_epoch_*')
model_name1=fo[0]
netM=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=7,n_blocks2=3)
netM=nn.DataParallel(netM)
netM.load_state_dict(torch.load(model_name1),map_location=torch.device('cpu'))
#netM.cuda();
netM.eval()
cudnn.benchmark=True
reso=(512,512) #input reoslution to the network
#load captured background for video mode, fixed camera
if args.back is not None:
bg_im0=cv2.imread(args.back); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);
#Create a list of test images
test_imgs = [f for f in os.listdir(data_path) if
os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')]
test_imgs.sort()
#output directory
result_path=args.output_dir
if not os.path.exists(result_path):
os.makedirs(result_path)
for i in range(0,len(test_imgs)):
filename = test_imgs[i]
#original image
bgr_img = cv2.imread(os.path.join(data_path, filename)); bgr_img=cv2.cvtColor(bgr_img,cv2.COLOR_BGR2RGB);
if args.back is None:
#captured background image
bg_im0=cv2.imread(os.path.join(data_path, filename.replace('_img','_back'))); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);
#segmentation mask
rcnn = cv2.imread(os.path.join(data_path, filename.replace('_img','_masksDL')),0);
if args.video: #if video mode, load target background frames
#target background path
back_img10=cv2.imread(os.path.join(args.target_back,filename.replace('_img.png','.png'))); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;
#create multiple frames with adjoining frames
gap=20
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
idx=[i-2*gap,i-gap,i+gap,i+2*gap]
for t in range(0,4):
if idx[t]<0:
idx[t]=len(test_imgs)+idx[t]
elif idx[t]>=len(test_imgs):
idx[t]=idx[t]-len(test_imgs)
file_tmp=test_imgs[idx[t]]
bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp));
multi_fr_w[...,t]=cv2.cvtColor(bgr_img_mul,cv2.COLOR_BGR2GRAY);
else:
## create the multi-frame
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
multi_fr_w[...,0] = cv2.cvtColor(bgr_img,cv2.COLOR_BGR2GRAY);
multi_fr_w[...,1] = multi_fr_w[...,0]
multi_fr_w[...,2] = multi_fr_w[...,0]
multi_fr_w[...,3] = multi_fr_w[...,0]
#crop tightly
bgr_img0=bgr_img;
bbox=get_bbox(rcnn,R=bgr_img0.shape[0],C=bgr_img0.shape[1])
crop_list=[bgr_img,bg_im0,rcnn,back_img10,back_img20,multi_fr_w]
crop_list=crop_images(crop_list,reso,bbox)
bgr_img=crop_list[0]; bg_im=crop_list[1]; rcnn=crop_list[2]; back_img1=crop_list[3]; back_img2=crop_list[4]; multi_fr=crop_list[5]
#process segmentation mask
kernel_er = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
kernel_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
rcnn=rcnn.astype(np.float32)/255; rcnn[rcnn>0.2]=1;
K=25
zero_id=np.nonzero(np.sum(rcnn,axis=1)==0)
del_id=zero_id[0][zero_id[0]>250]
if len(del_id)>0:
del_id=[del_id[0]-2,del_id[0]-1,*del_id]
rcnn=np.delete(rcnn,del_id,0)
rcnn = cv2.copyMakeBorder( rcnn, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE)
rcnn = cv2.erode(rcnn, kernel_er, iterations=10)
rcnn = cv2.dilate(rcnn, kernel_dil, iterations=5)
rcnn=cv2.GaussianBlur(rcnn.astype(np.float32),(31,31),0)
rcnn=(255*rcnn).astype(np.uint8)
rcnn=np.delete(rcnn, range(reso[0],reso[0]+K), 0)
#convert to torch
img=torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0); img=2*img.float().div(255)-1
bg=torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0); bg=2*bg.float().div(255)-1
rcnn_al=torch.from_numpy(rcnn).unsqueeze(0).unsqueeze(0); rcnn_al=2*rcnn_al.float().div(255)-1
multi_fr=torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0); multi_fr=2*multi_fr.float().div(255)-1
with torch.no_grad():
#img,bg,rcnn_al, multi_fr =Variable(img.cuda()), Variable(bg.cuda()), Variable(rcnn_al.cuda()), Variable(multi_fr.cuda())
img,bg,rcnn_al, multi_fr =Variable(img), Variable(bg), Variable(rcnn_al), Variable(multi_fr)
input_im=torch.cat([img,bg,rcnn_al,multi_fr],dim=1)
alpha_pred,fg_pred_tmp=netM(img,bg,rcnn_al,multi_fr)
#al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor)
al_mask=(alpha_pred>0.95).type(torch.FloatTensor)
# for regions with alpha>0.95, simply use the image as fg
fg_pred=img*al_mask + fg_pred_tmp*(1-al_mask)
alpha_out=to_image(alpha_pred[0,...]);
#refine alpha with connected component
labels=label((alpha_out>0.05).astype(int))
try:
assert( labels.max() != 0 )
except:
continue
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
alpha_out=alpha_out*largestCC
alpha_out=(255*alpha_out[...,0]).astype(np.uint8)
fg_out=to_image(fg_pred[0,...]); fg_out=fg_out*np.expand_dims((alpha_out.astype(float)/255>0.01).astype(float),axis=2); fg_out=(255*fg_out).astype(np.uint8)
#Uncrop
R0=bgr_img0.shape[0];C0=bgr_img0.shape[1]
alpha_out0=uncrop(alpha_out,bbox,R0,C0)
fg_out0=uncrop(fg_out,bbox,R0,C0)
#compose
back_img10=cv2.resize(back_img10,(C0,R0)); back_img20=cv2.resize(back_img20,(C0,R0))
comp_im_tr1=composite4(fg_out0,back_img10,alpha_out0)
comp_im_tr2=composite4(fg_out0,back_img20,alpha_out0)
cv2.imwrite(result_path+'/'+filename.replace('_img','_out'), alpha_out0)
cv2.imwrite(result_path+'/'+filename.replace('_img','_fg'), cv2.cvtColor(fg_out0,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_compose'), cv2.cvtColor(comp_im_tr1,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_matte').format(i), cv2.cvtColor(comp_im_tr2,cv2.COLOR_BGR2RGB))
print('Done: ' + str(i+1) + '/' + str(len(test_imgs)))