-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmetric.py
124 lines (101 loc) · 4.1 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import sys
this_dir = os.path.dirname(__file__)
current_path = os.path.join(this_dir)
sys.path.append(current_path)
import numpy as np
# from pyquaternion import Quaternion
def add_err(gt_pose, est_pose, model):
def transform_points(points_3d, mat):
rot = np.matmul(mat[:3, :3], points_3d.transpose())
return rot.transpose() + mat[:3, 3]
v_A = transform_points(model, gt_pose)
v_B = transform_points(model, est_pose)
v_A = np.array([x for x in v_A])
v_B = np.array([x for x in v_B])
return np.mean(np.linalg.norm(v_A - v_B, axis=1))
def adds_err(gt_pose, est_pose, model, sample_num=100):
error = []
def transform_points(points_3d, mat):
rot = np.matmul(mat[:3, :3], points_3d.transpose())
return rot.transpose() + mat[:3, 3]
v_A = transform_points(model, gt_pose)
v_B = transform_points(model, est_pose)
for idx_A, perv_A in enumerate(v_A):
if idx_A > sample_num: break
min_error_perv_A = 10000.0
for idx_B, perv_B in enumerate(v_B):
if idx_B > sample_num: break
if np.linalg.norm(perv_A - perv_B)<min_error_perv_A:
min_error_perv_A = np.linalg.norm(perv_A - perv_B)
error.append(min_error_perv_A)
return np.mean(error)
def rot_error(gt_pose, est_pose):
def matrix2quaternion(m):
tr = m[0, 0] + m[1, 1] + m[2, 2]
if tr > 0:
S = np.sqrt(tr + 1.0) * 2
qw = 0.25 * S
qx = (m[2, 1] - m[1, 2]) / S
qy = (m[0, 2] - m[2, 0]) / S
qz = (m[1, 0] - m[0, 1]) / S
elif (m[0, 0] > m[1, 1]) and (m[0, 0] > m[2, 2]):
S = np.sqrt(1. + m[0, 0] - m[1, 1] - m[2, 2]) * 2
qw = (m[2, 1] - m[1, 2]) / S
qx = 0.25 * S
qy = (m[0, 1] + m[1, 0]) / S
qz = (m[0, 2] + m[2, 0]) / S
elif m[1, 1] > m[2, 2]:
S = np.sqrt(1. + m[1, 1] - m[0, 0] - m[2, 2]) * 2
qw = (m[0, 2] - m[2, 0]) / S
qx = (m[0, 1] + m[1, 0]) / S
qy = 0.25 * S
qz = (m[1, 2] + m[2, 1]) / S
else:
S = np.sqrt(1. + m[2, 2] - m[0, 0] - m[1, 1]) * 2
qw = (m[1, 0] - m[0, 1]) / S
qx = (m[0, 2] + m[2, 0]) / S
qy = (m[1, 2] + m[2, 1]) / S
qz = 0.25 * S
return np.array([qw, qx, qy, qz])
gt_quat = Quaternion(matrix2quaternion(gt_pose[:3, :3]))
est_quat = Quaternion(matrix2quaternion(est_pose[:3, :3]))
return np.abs((gt_quat * est_quat.inverse).degrees)
def trans_error(gt_pose, est_pose):
trans_err_norm = np.linalg.norm(gt_pose[:3, 3] - est_pose[:3, 3])
trans_err_single = np.abs(gt_pose[:3, 3] - est_pose[:3, 3])
return trans_err_norm, trans_err_single
def iou(gt_box, est_box):
xA = max(gt_box[0], est_box[0])
yA = max(gt_box[1], est_box[1])
xB = min(gt_box[2], est_box[2])
yB = min(gt_box[3], est_box[3])
if xB <= xA or yB <= yA:
return 0.
interArea = (xB - xA) * (yB - yA)
# compute the area of both the prediction and ground-truth
# rectangles
boxAArea = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
boxBArea = (est_box[2] - est_box[0]) * (est_box[3] - est_box[1])
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the interesection area
return interArea / float(boxAArea + boxBArea - interArea)
def projection_error_2d(gt_pose, est_pose, model, cam):
"""Compute 2d projection error
Args
- gt_pose: (np.array) [4 x 4] pose matrix
- est_pose: (np.array) [4 x 4] pose matrix
- model: (np.array) [N x 3] model 3d vertices
- cam: (np.array) [3 x 3] camera matrix
"""
gt_pose = gt_pose[:3]
est_pose = est_pose[:3]
model = np.concatenate((model, np.ones((model.shape[0], 1))), axis=1)
gt_2d = np.matmul(np.matmul(cam, gt_pose), model.T)
est_2d = np.matmul(np.matmul(cam, est_pose), model.T)
gt_2d /= gt_2d[2, :]
est_2d /= est_2d[2, :]
gt_2d = gt_2d[:2, :].T
est_2d = est_2d[:2, :].T
return np.mean(np.linalg.norm(gt_2d - est_2d, axis=1))