-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_run.py
94 lines (88 loc) · 5.02 KB
/
test_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import matplotlib
matplotlib.use('Agg')
import torch
from scipy import ndimage
import os
import SimpleITK as sitk
from metrics import dice_eval, assd_eval
def test(config, iplc_model, valid_loader, test_loader, list_data, target, save_path):
dataset = config['train']['dataset']
if dataset == 'mms':
num_classes = config['network']['n_classes_mms']
device = torch.device('cuda:{}'.format(config['train']['gpu']))
for data_loader in [test_loader]:
all_batch_dice = []
all_batch_assd = []
with torch.no_grad():
iplc_model.eval()
for it, (xt, xt_label, xt_name, lab_Imag_dir) in enumerate(data_loader):
xt = xt.to(device)
lab_x = xt_label.to(device)
xt_label = xt_label.numpy().squeeze().astype(np.uint8)
output = iplc_model.test_with_name(xt)
output = output.squeeze(0)
output = torch.argmax(output, dim=1)
output_ = output.cpu().numpy()
xt = xt.detach().cpu().numpy().squeeze()
output = output_.squeeze()
if config['test']['save_result']:
lab_Imag = sitk.ReadImage(lab_Imag_dir[0])
lab_arr = sitk.GetArrayFromImage(lab_Imag)
output_ = np.expand_dims(output_, axis=0)
if len(lab_arr.shape) == 4:
e, a, b, c = lab_arr.shape
elif len(lab_arr.shape) == 3:
e, b, c = lab_arr.shape
ee, aa, bb, cc = output_.shape
zoom = [1, 1, b / bb, c / cc]
output_ = ndimage.zoom(output_, zoom, order=0)
output_ = output_.squeeze(0).astype(np.float64)
name = str(xt_name)[2:-3]
results = save_path + '/nii/' + str(target)
if not os.path.exists(results):
os.makedirs(results)
predict_dir = os.path.join(results, name)
out_lab_obj = sitk.GetImageFromArray(output_)
out_lab_obj.CopyInformation(lab_Imag)
sitk.WriteImage(out_lab_obj, predict_dir)
lab_Imag = sitk.ReadImage(lab_Imag_dir)
lab_arr = sitk.GetArrayFromImage(lab_Imag)
e, a, b, c = lab_arr.shape
ee, bb, cc = output.shape
zoom = [1, b / bb, c / cc]
output = ndimage.zoom(output, zoom, order=0)
xt_label = ndimage.zoom(xt_label, zoom, order=0)
one_case_dice = dice_eval(output, xt_label, num_classes) * 100
all_batch_dice += [one_case_dice]
one_case_assd = assd_eval(output, xt_label, num_classes)
all_batch_assd.append(one_case_assd)
all_batch_dice = np.array(all_batch_dice)
all_batch_assd = np.array(all_batch_assd)
mean_dice = np.mean(all_batch_dice, axis=0)
std_dice = np.std(all_batch_dice, axis=0)
mean_assd = np.mean(all_batch_assd, axis=0)
std_assd = np.std(all_batch_assd, axis=0)
print(mean_dice, std_dice, mean_assd, std_assd)
if dataset == 'mms':
print('{}±{} {}±{} {}±{}'.format(np.round(mean_dice[0], 2), np.round(std_dice[0], 2),
np.round(mean_dice[1], 2), np.round(std_dice[1], 2),
np.round(mean_dice[2], 2), np.round(std_dice[2], 2)))
print('{}±{}'.format(np.round(np.mean(mean_dice, axis=0), 2), np.round(np.mean(std_dice, axis=0), 2)))
list_data.append('{}±{} {}±{} {}±{}'.format(np.round(mean_dice[0], 2), np.round(std_dice[0], 2),
np.round(mean_dice[1], 2), np.round(std_dice[1], 2),
np.round(mean_dice[2], 2), np.round(std_dice[2], 2)))
list_data.append(
'{}±{}'.format(np.round(np.mean(mean_dice, axis=0), 2), np.round(np.mean(std_dice, axis=0), 2)))
if dataset == 'mms':
print('ASSD:')
print('{}±{} {}±{} {}±{}'.format(np.round(mean_assd[0], 2), np.round(std_assd[0], 2),
np.round(mean_assd[1], 2), np.round(std_assd[1], 2),
np.round(mean_assd[2], 2), np.round(std_assd[2], 2)))
print('{}±{}'.format(np.round(np.mean(mean_assd, axis=0), 2), np.round(np.mean(std_assd, axis=0), 2)))
list_data.append('{}±{} {}±{} {}±{}'.format(np.round(mean_assd[0], 2), np.round(std_assd[0], 2),
np.round(mean_assd[1], 2), np.round(std_assd[1], 2),
np.round(mean_assd[2], 2), np.round(std_assd[2], 2)))
list_data.append(
'{}±{}'.format(np.round(np.mean(mean_assd, axis=0), 2), np.round(np.mean(std_assd, axis=0), 2)))
return list_data