-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSolver.py
270 lines (229 loc) · 13.3 KB
/
Solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# @author : Bingyu Xin
# @Institute : CS@Rutgers
import os
from os.path import join
import time
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.utils.data as Data
from tensorboardX import SummaryWriter
from skimage.metrics import structural_similarity as cal_ssim
from skimage.metrics import peak_signal_noise_ratio as cal_psnr
from skimage.metrics import normalized_root_mse as cal_nrmse
from loss import CompoundLoss
from utils import output2complex
from read_data import MyData
from model.DCCNN import DCCNN
from model.LPDNet import LPDNet
from model.HQSNet import HQSNet
from model.ISTANet_plus import ISTANetplus
import numpy as np
class Solver():
def __init__(self, args):
torch.autograd.set_detect_anomaly(True)
self.args = args
################ experiment settings ################
self.model_name = self.args.model
self.acc = self.args.acc
self.imageDir_train = self.args.train_path # train path
self.imageDir_val = self.args.val_path # val path while training
self.imageDir_test = self.args.test_path # test path
self.num_epoch = self.args.num_epoch # training epochs
self.batch_size = self.args.batch_size # batch size
self.val_on_epochs = self.args.val_on_epochs # validate on every val_on_epochs;
self.resume = self.args.resume # resume training
## settings for optimizer
self.lr = self.args.lr
## settings for data preprocessing
self.img_size = (192, 160)
self.saveDir = 'weight' # model save path while training
if not os.path.isdir(self.saveDir):
os.makedirs(self.saveDir)
self.task_name = self.model_name + '_acc_' + str(self.acc) + '_bs_' + str(self.batch_size) \
+ '_lr_' + str(self.lr)
print('task_name: ', self.task_name)
self.model_path = 'weight/' + self.task_name + '_best.pth' # model load path for test
############################################ Specify network ############################################
if self.model_name == 'dc-cnn':
self.net = DCCNN(n_iter=8)
elif self.model_name == 'ista-net-plus':
self.net = ISTANetplus(n_iter=8)
elif self.model_name == 'lpd-net':
self.net = LPDNet(n_iter=8)
elif self.model_name == 'hqs-net':
self.net = HQSNet(block_type='cnn', buffer_size=5, n_iter=8)
elif self.model_name == 'hqs-net-unet':
# HQS-Net-Unet is for best reconstruction quality, so we enlarge the model, it is not a fair comparison to other models
self.net = HQSNet(block_type='unet', buffer_size=5, n_iter=10)
else:
assert "wrong model name !"
print('Total # of model params: %.5fM' % (sum(p.numel() for p in self.net.parameters()) / 10. ** 6))
self.net.cuda()
def train(self):
############################################ Specify loss ############################################
## Notice:
## there is an unknown backward gradient bug when training HQS-Net-Unet, which may interupt the training,
## you can simply resume the training by setting for --resume 1 in the scripts.
self.criterion = CompoundLoss('ms-ssim')
############################################ Specify optimizer ########################################
self.optimizer_G = optim.Adam(self.net.parameters(), lr=self.lr, eps=1e-3, weight_decay=1e-10)
############################################ load data ############################################
dataset_train = MyData(self.imageDir_train, self.acc, self.img_size, is_training='train')
dataset_val = MyData(self.imageDir_val, self.acc, self.img_size, is_training='val')
num_workers = 4
use_pin_memory = True
loader_train = Data.DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=True,
num_workers=num_workers, pin_memory=use_pin_memory)
loader_val = Data.DataLoader(dataset_val, batch_size=self.batch_size, shuffle=False, drop_last=False,
num_workers=num_workers, pin_memory=use_pin_memory)
self.slices_val = len(dataset_val)
print("slices of 2d train data: ", len(dataset_train))
print("slices of 2d validation data: ", len(dataset_val))
############################################ setting for tensorboard ###################################
self.writer = SummaryWriter('log/' + self.task_name)
############################################ start to run epochs #######################################
start_epoch = 0
best_val_psnr = 0
if self.resume:
best_name = self.task_name + '_best.pth'
checkpoint = torch.load(join(self.saveDir, best_name))
self.net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint['epoch'] + 1
best_val_psnr = checkpoint['val_psnr']
print('load pretrained model---, start epoch at, ', start_epoch, ', star_psnr_val is: ', best_val_psnr)
for epoch in range(start_epoch, self.num_epoch):
####################### 1. training #######################
loss_g = self._train_cnn(loader_train)
####################### 2. validate #######################
if epoch == start_epoch:
base_psnr, base_ssim = self._validate_base(loader_val)
if epoch % self.val_on_epochs == 0:
val_psnr, val_ssim = self._validate(loader_val)
########################## 3. print and tensorboard ########################
print("Epoch {}/{}".format(epoch + 1, self.num_epoch))
print(" base PSNR:\t\t{:.6f}".format(base_psnr))
print(" test PSNR:\t\t{:.6f}".format(val_psnr))
print(" base SSIM:\t\t{:.6f}".format(base_ssim))
print(" test SSIM:\t\t{:.6f}".format(val_ssim))
## write to tensorboard
self.writer.add_scalar("loss/train_loss", loss_g, epoch)
self.writer.add_scalar("metric/base_psnr", base_psnr, epoch)
self.writer.add_scalar("metric/val_psnr", val_psnr, epoch)
self.writer.add_scalar("metric/base_ssim", base_ssim, epoch)
self.writer.add_scalar("metric/val_ssim", val_ssim, epoch)
## save the best model according to validation psnr
if best_val_psnr < val_psnr:
best_val_psnr = val_psnr
best_name = self.task_name + '_best.pth'
state = {'net': self.net.state_dict(), 'epoch': epoch, 'val_psnr': val_psnr, 'val_ssim': val_ssim}
torch.save(state, join(self.saveDir, best_name))
self.writer.close()
def test(self):
############################################ load data ################################
dataset_val = MyData(self.imageDir_test, self.acc, self.img_size, is_training='test')
loader_val = Data.DataLoader(dataset_val, batch_size=self.batch_size, shuffle=False, drop_last=False,
num_workers=2, pin_memory=False)
len_data = len(dataset_val)
print("slices of 2d test data: ", len_data)
checkpoint = torch.load(self.model_path)
print("best epoch at : {}, val_psnr: {:.6f}, val_ssim: {:.6f}" \
.format(checkpoint['epoch'], checkpoint['val_psnr'], checkpoint['val_ssim']))
self.net.load_state_dict(checkpoint['net'])
self.net.cuda()
self.net.eval()
base_psnr = []
test_psnr = []
base_ssim = []
test_ssim = []
base_nrmse = []
test_nrmse = []
with torch.no_grad():
time_0 = time.time()
for data_dict in tqdm(loader_val):
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict['im_A_und'].float().cuda(), \
data_dict['k_A_und'].float().cuda(), \
data_dict['mask_A'].float().cuda()
T1 = self.net(im_A_und, k_A_und, mask)
############## convert model ouput to complex value in original range
T1 = output2complex(T1)
im_A = output2complex(im_A)
im_A_und = output2complex(im_A_und)
########################### calulate metrics ###################################
for T1_i, im_A_i, im_A_und_i in zip(T1.cpu().numpy(), im_A.cpu().numpy(), im_A_und.cpu().numpy()):
## for skimage.metrics, input is (im_true,im_pred)
base_nrmse.append(cal_nrmse(im_A_i, im_A_und_i))
test_nrmse.append(cal_nrmse(im_A_i, T1_i))
base_ssim.append(cal_ssim(im_A_i, im_A_und_i, data_range=im_A_i.max()))
test_ssim.append(cal_ssim(im_A_i, T1_i, data_range=im_A_i.max()))
base_psnr.append(cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max()))
test_psnr.append(cal_psnr(im_A_i, T1_i, data_range=im_A_i.max()))
time_1 = time.time()
## comment metric calculation code for more precise inference speed
print('inference speed: {:.5f} ms/slice'.format(1000 * (time_1 - time_0) / len_data))
print(" base PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_psnr), np.std(base_psnr)))
print(" test PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_psnr), np.std(test_psnr)))
print(" base SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_ssim), np.std(base_ssim)))
print(" test SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_ssim), np.std(test_ssim)))
print(" base NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_nrmse), np.std(base_nrmse)))
print(" test NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_nrmse), np.std(test_nrmse)))
def _train_cnn(self, loader_train):
self.net.train()
for data_dict in tqdm(loader_train):
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[
'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict['mask_A'].float().cuda()
if self.model_name == 'ista-net-plus':
T1, loss_layers_sym = self.net(im_A_und, k_A_und, mask)
else:
T1 = self.net(im_A_und, k_A_und, mask)
T1 = output2complex(T1)
im_A = output2complex(im_A)
############################################# 1.2 update generator #############################################
loss_g = self.criterion(T1, im_A, data_range=im_A.max())
if self.model_name == 'ista-net-plus':
loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
for k in range(len(loss_layers_sym) - 1):
loss_constraint += torch.mean(torch.pow(loss_layers_sym[k + 1], 2))
loss_g = loss_g + 0.01 * loss_constraint
self.optimizer_G.zero_grad()
loss_g.backward()
self.optimizer_G.step()
return loss_g
def _validate_base(self, loader_val):
base_psnr = 0
base_ssim = 0
for data_dict in loader_val:
im_A, im_A_und, = data_dict['im_A'].float().cuda(), data_dict['im_A_und'].float().cuda()
############## convert model ouput to complex value in original range
im_A = output2complex(im_A)
im_A_und = output2complex(im_A_und)
########################### cal metrics ###################################
for im_A_i, im_A_und_i in zip(im_A.cpu().numpy(),
im_A_und.cpu().numpy()):
## for skimage.metrics, input is (im_true,im_pred)
base_ssim += cal_ssim(im_A_i, im_A_und_i, data_range=im_A_i.max())
base_psnr += cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max())
base_psnr /= self.slices_val
base_ssim /= self.slices_val
return base_psnr, base_ssim
def _validate(self, loader_val):
test_psnr = 0
test_ssim = 0
self.net.eval()
with torch.no_grad():
for data_dict in tqdm(loader_val):
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[
'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict[
'mask_A'].float().cuda()
T1 = self.net(im_A_und, k_A_und, mask)
############## convert model ouput to complex value in original range
T1 = output2complex(T1)
im_A = output2complex(im_A)
########################### cal metrics ###################################
for T1_i, im_A_i in zip(T1.cpu().numpy(), im_A.cpu().numpy()):
## for skimage.metrics, input is (im_true,im_pred)
test_ssim += cal_ssim(im_A_i, T1_i, data_range=im_A_i.max())
test_psnr += cal_psnr(im_A_i, T1_i, data_range=im_A_i.max())
test_psnr /= self.slices_val
test_ssim /= self.slices_val
return test_psnr, test_ssim