-
Notifications
You must be signed in to change notification settings - Fork 14
/
test.py
107 lines (89 loc) · 3.44 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
'''
Copyright (c) 2020 NVIDIA
Author: Wentao Yuan
'''
import argparse
import numpy as np
import os
import torch
from time import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from data import TestData
from model import DeepGMR
def evaluate(model, loader, rmse_thresh, save_results=False, results_dir=None):
model.eval()
log_fmt = 'Test: inference time {:.3f}, preprocessing time {:.3f}, loss {:.4f}, ' + \
'rotation error {:.2f}, translation error {:.4f}, RMSE {:.4f}, Recall {:.3f}'
inference_time = 0
preprocess_time = 0
losses = 0
r_errs = 0
t_errs = 0
rmses = 0
n_correct = 0
N = 0
if save_results:
rotations = []
translations = []
rotations_gt = []
translations_gt = []
start = time()
for step, (pts1, pts2, T_gt) in enumerate(tqdm(loader, leave=False)):
if torch.cuda.is_available():
pts1 = pts1.cuda()
pts2 = pts2.cuda()
T_gt = T_gt.cuda()
preprocess_time += time() - start
N += pts1.shape[0]
start = time()
with torch.no_grad():
loss, r_err, t_err, rmse = model(pts1, pts2, T_gt)
inference_time += time() - start
losses += loss.item()
r_errs += r_err.sum().item()
t_errs += t_err.sum().item()
rmses += rmse.sum().item()
n_correct += (rmse < rmse_thresh).sum().item()
if save_results:
rotations.append(model.T_12[:, :3, :3].cpu().numpy())
translations.append(model.T_12[:, :3, 3].cpu().numpy())
rotations_gt.append(T_gt[:, :3, :3].cpu().numpy())
translations_gt.append(T_gt[:, :3, 3].cpu().numpy())
start = time()
log_str = log_fmt.format(
inference_time / N, preprocess_time / N, losses / len(loader),
r_errs / N, t_errs / N, rmses / N, n_correct / N
)
print(log_str)
if save_results:
os.makedirs(results_dir, exist_ok=True)
np.save(os.path.join(results_dir, 'rotations.npy'), np.concatenate(rotations, 0))
np.save(os.path.join(results_dir, 'translations.npy'), np.concatenate(translations, 0))
np.save(os.path.join(results_dir, 'rotations_gt.npy'), np.concatenate(rotations_gt, 0))
np.save(os.path.join(results_dir, 'translations_gt.npy'), np.concatenate(translations_gt, 0))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# general
parser.add_argument('--data_file')
parser.add_argument('--results_dir')
parser.add_argument('--checkpoint')
parser.add_argument('--save_results', action='store_true')
parser.add_argument('--rmse_thresh', type=int, default=0.2)
# dataset
parser.add_argument('--n_points', type=int, default=1024)
parser.add_argument('--batch_size', type=int, default=32)
# model
parser.add_argument('--d_model', type=int, default=1024)
parser.add_argument('--n_clusters', type=int, default=16)
parser.add_argument('--use_rri', action='store_true')
parser.add_argument('--use_tnet', action='store_true')
parser.add_argument('--k', type=int, default=20)
args = parser.parse_args()
model = DeepGMR(args)
if torch.cuda.is_available():
model.cuda()
test_data = TestData(args.data_file, args)
test_loader = DataLoader(test_data, args.batch_size)
model.load_state_dict(torch.load(args.checkpoint))
evaluate(model, test_loader, args.rmse_thresh, args.save_results, args.results_dir)