forked from chequanghuy/TwinLiteNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_trt.py
76 lines (65 loc) · 2.63 KB
/
test_trt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import os
import torch
import tensorrt as trt
from collections import OrderedDict, namedtuple
import torch.nn as nn
import shutil
import cv2
def Run(model,img):
img = cv2.resize(img, (640, 360))
img_rs=img.copy()
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
img=torch.from_numpy(img)
img = torch.unsqueeze(img, 0) # add a batch dimension
img=img.cuda().float() / 255.0
img = img.cuda()
img_out = model(img)
x0=img_out[0]
x1=img_out[1]
_,da_predict=torch.max(x0, 1)
_,ll_predict=torch.max(x1, 1)
DA = da_predict.byte().cpu().data.numpy()[0]*255
LL = ll_predict.byte().cpu().data.numpy()[0]*255
img_rs[DA>100]=[255,0,0]
img_rs[LL>100]=[0,255,0]
return img_rs
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
class TRT(nn.Module):
def __init__(self,weight='model_best.engine'):
super().__init__()
device = torch.device('cuda:0')
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
with open(weight, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
self.context = model.create_execution_context()
self.bindings = OrderedDict()
self.output_names = []
for i in range(model.num_bindings):
name = model.get_binding_name(i)
print(name)
dtype = trt.nptype(model.get_binding_dtype(i))
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
self.dynamic = True
self.context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
else: # output
self.output_names.append(name)
shape = tuple(self.context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
def forward(self, im):
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
return [self.bindings[x].data for x in sorted(self.output_names)]
trt=TRT('best.engine')
image_list=os.listdir('images')
shutil.rmtree('results')
os.mkdir('results')
for i, imgName in enumerate(image_list):
img = cv2.imread(os.path.join('images',imgName))
img=Run(trt,img)
cv2.imwrite(os.path.join('results',imgName),img)