forked from baegwangbin/surface_normal_uncertainty
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
98 lines (80 loc) · 3.7 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import os
import sys
import numpy as np
from tqdm import tqdm
import torch
from data.dataloader_vcc import VCC_Loader, VCC_DatasetParams
from models.NNET import NNET
import funcs.utils as utils
import matplotlib.pyplot as plt
from PIL import Image
import matplotlib
matplotlib.use('Agg')
def test(model, test_loader, device, results_dir):
alpha_max = 90
with torch.no_grad():
for data_dict in tqdm(test_loader):
img = data_dict['img'].to(device)
norm_out_list, _, _ = model(img)
norm_out = norm_out_list[-1]
pred_norm = norm_out[:, :3, :, :]
pred_kappa = norm_out[:, 3:, :, :]
# to numpy arrays
pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
pred_kappa = pred_kappa.cpu().permute(0, 2, 3, 1).numpy()
# save results
img_name = os.path.basename(data_dict['img_name'][0]).replace('.png', '')
# 2. predicted normal
# 模型使用左x,上y,后z左手坐标系,将其转换为右x,下y,前z相机坐标系
pred_norm *= -1
pred_norm_rgb = ((pred_norm + 1) * 0.5) * 255
pred_norm_rgb = np.clip(pred_norm_rgb, a_min=0, a_max=255)
pred_norm_rgb = pred_norm_rgb.astype(np.uint8) # (B, H, W, 3)
target_path = f'{results_dir}/{img_name}_pred_norm.png'
plt.imsave(target_path, pred_norm_rgb[0, :, :, :])
# 4. predicted uncertainty
pred_alpha = utils.kappa_to_alpha(pred_kappa)
pred_alpha = np.clip(pred_alpha, 0, alpha_max)
pred_alpha_gray = (pred_alpha * (255 / 90)).astype(np.uint8)
target_path = f'{results_dir}/{img_name}_pred_alpha.png'
# plt.imsave(target_path, pred_alpha_gray[0, :, :, 0], cmap='gray')
Image.fromarray(pred_alpha_gray[0, :, :, 0]).save(target_path)
if __name__ == '__main__':
# Arguments #################################################################################
parser = argparse.ArgumentParser(fromfile_prefix_chars='@', conflict_handler='resolve')
parser.convert_arg_line_to_args = utils.convert_arg_line_to_args
parser.add_argument('--architecture', required=True, type=str, help='{BN, GN}')
parser.add_argument("--pretrained", required=True, type=str, help="{pretrained model path}")
parser.add_argument('--sampling_ratio', type=float, default=0.4)
parser.add_argument('--importance_ratio', type=float, default=0.7)
parser.add_argument('--input_height', default=480, type=int)
parser.add_argument('--input_width', default=640, type=int)
parser.add_argument('--result_dir', type=str, default='results')
# read arguments from txt file
if sys.argv.__len__() == 2 and '.txt' in sys.argv[1]:
arg_filename_with_prefix = '@' + sys.argv[1]
args = parser.parse_args([arg_filename_with_prefix])
else:
args = parser.parse_args()
device = torch.device('cuda:0')
# load checkpoint
checkpoint = args.pretrained
print(f'loading checkpoint... {checkpoint}')
model = NNET(args).to(device)
model = utils.load_checkpoint(checkpoint, model)
model.eval()
print('loading checkpoint... / done')
# test the model
results_dir = f'{args.result_dir}/results'
os.makedirs(results_dir, exist_ok=True)
params = VCC_DatasetParams()
params.mode = 'test'
params.input_height = args.input_height
params.input_width = args.input_width
params.data_record_file = f'./data_split/data.txt'
params.need_scene = True
params.need_normal = False
params.need_depth = False
test_loader = VCC_Loader(params).data
test(model, test_loader, device, results_dir)