-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo_onnx.py
84 lines (72 loc) · 2.56 KB
/
demo_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import time
import cv2
import numpy as np
import onnxruntime as ort
available_providers = ort.get_available_providers()
providers = []
if "CUDAExecutionProvider" in available_providers:
providers.append("CUDAExecutionProvider")
providers.append("CPUExecutionProvider")
def main(model_path: str,
left_image: str, right_image: str, output_path: str):
sess = ort.InferenceSession(
model_path,
providers=providers
)
input_names = [i.name for i in sess.get_inputs()]
output_names = [i.name for i in sess.get_outputs()]
input_height = sess.get_inputs()[0].shape[2]
input_width = sess.get_inputs()[0].shape[3]
print(f"{input_height=}")
print(f"{input_width=}")
left = cv2.resize(cv2.cvtColor(cv2.imread(left_image), cv2.COLOR_BGR2RGB), (input_width, input_height)).astype(np.float32) / 255.0
right = cv2.resize(cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), (input_width, input_height)).astype(np.float32) / 255.0
left = np.transpose(left, (2, 0, 1))[np.newaxis, :, :, :]
right = np.transpose(right, (2, 0, 1))[np.newaxis, :, :, :]
for _ in range(5):
t = time.time()
output = sess.run(output_names,
{
input_names[0]: left,
input_names[1]: right,
}
)
dt = time.time() - t
print(f"\033[34mElapsed: {dt:.3f} sec, {1/dt:.3f} FPS\033[0m")
disp = output[0][0]
norm = ((disp - disp.min()) / (disp.max() - disp.min()) * 255).astype(np.uint8)
colored = cv2.applyColorMap(norm, cv2.COLORMAP_PLASMA)
cv2.imwrite(output_path, colored)
print(f"\033[32moutput: {output_path}\033[0m")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"-m",
"--model_path",
type=str,
default="cgi_stereo_sceneflow_480x640/cgi_stereo_sceneflow_480x640.onnx",
help="ONNX model file path.")
parser.add_argument(
"-l",
"--left_image",
type=str,
default="data/left.png",
help="input left image.")
parser.add_argument(
"-r",
"--right_image",
type=str,
default="data/right.png",
help="input right image.")
parser.add_argument(
"-o",
"--output_path",
type=str,
default="output.png",
help="output colored disparity image paht.")
args = parser.parse_args()
main(
args.model_path,
args.left_image, args.right_image, args.output_path
)