Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorrt的支持 #12

Open
Egrt opened this issue Feb 7, 2024 · 6 comments
Open

tensorrt的支持 #12

Egrt opened this issue Feb 7, 2024 · 6 comments

Comments

@Egrt
Copy link

Egrt commented Feb 7, 2024

请问能否增加对tensorrt的支持?

@Egrt
Copy link
Author

Egrt commented Feb 7, 2024

[02/07/2024-21:51:33] [TRT] [E] ModelImporter.cpp:729: --- End node --- [02/07/2024-21:51:33] [TRT] [E] ModelImporter.cpp:732: ERROR: ModelImporter.cpp:168 In function parseGraph: [6] Invalid Node - TopK_587 This version of TensorRT only supports input K as an initializer. Try applying constant folding on the model using Polygraphy: https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/examples/cli/surgeon/02_folding_constants Traceback (most recent call last): File "d:/Notebook/rtmlib/rtm.py", line 41, in <module> build_model(rtmdet_onnx_model) File "d:/Notebook/rtmlib/rtm.py", line 16, in build_model build_engine(onnx_file_path, engine_file_path, True) File "d:/Notebook/rtmlib/rtm.py", line 28, in build_engine raise RuntimeError(f'failed to load ONNX file: {onnx_file_path}') RuntimeError: failed to load ONNX file: rtmdet_nano_8xb32-300e_hand-267f9c8f.onnx
我的环境为:
TensorRT 8.5.1.7
当我将onnx模型转换为engine模型时候出现了报错,该如何解决

@Tau-J
Copy link
Owner

Tau-J commented Feb 8, 2024

Hi @Egrt, RTMPose 的TensorRT转换流程请参考官方文档。关于TensorRT推理的支持,在计划之中,但由于版本对齐相对困难,暂时不会在短期内完成

@chenscottus
Copy link

Please update to TensorRT 8.6.x

@Egrt
Copy link
Author

Egrt commented Feb 13, 2024

Please update to TensorRT 8.6.x

Thanks, it worked.

@Egrt
Copy link
Author

Egrt commented Feb 20, 2024

实现了rtmdet的tensorrt加速,转换onnx模型时必须转换为静态模型,仅供参考 @Tau-J

from typing import List, Tuple
import os
import numpy as np
import tensorrt as trt
import cv2
import time

def build_model(onnx_file_path):
    engine_file_path = onnx_file_path.replace('.onnx', '.engine')

    if not os.path.exists(engine_file_path):
        print('模型制作中,第一次等待时间较长, 完成后会有文字提示')
        build_engine(onnx_file_path, engine_file_path, True)

def build_engine(onnx_file_path, engine_file_path, half=True):
    """Takes an ONNX file and creates a TensorRT engine to run inference with"""
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    config = builder.create_builder_config()
    config.max_workspace_size = 4 * 1 << 30
    flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flag)
    parser = trt.OnnxParser(network, logger)
    if not parser.parse_from_file(str(onnx_file_path)):
        raise RuntimeError(f'failed to load ONNX file: {onnx_file_path}')
    half &= builder.platform_has_fast_fp16
    if half:
        config.set_flag(trt.BuilderFlag.FP16)
    with builder.build_engine(network, config) as engine, open(engine_file_path, 'wb') as t:
        t.write(engine.serialize())
    return engine_file_path

def draw_bbox(img, bboxes, color=(0, 255, 0)):
    for bbox in bboxes:
        img = cv2.rectangle(img, (int(bbox[0]), int(bbox[1])),
                            (int(bbox[2]), int(bbox[3])), color, 2)
    return img

def nms(boxes, scores, nms_thr):
    """Single class NMS implemented in Numpy."""
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= nms_thr)[0]
        order = order[inds + 1]

    return keep

def multiclass_nms(boxes, scores, nms_thr, score_thr):
    """Multiclass NMS implemented in Numpy.

    Class-aware version.
    """
    final_dets = []
    num_classes = scores.shape[1]
    for cls_ind in range(num_classes):
        cls_scores = scores[:, cls_ind]
        valid_score_mask = cls_scores > score_thr
        if valid_score_mask.sum() == 0:
            continue
        else:
            valid_scores = cls_scores[valid_score_mask]
            valid_boxes = boxes[valid_score_mask]
            keep = nms(valid_boxes, valid_scores, nms_thr)
            if len(keep) > 0:
                cls_inds = np.ones((len(keep), 1)) * cls_ind
                dets = np.concatenate(
                    [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1)
                final_dets.append(dets)
    if len(final_dets) == 0:
        return None
    return np.concatenate(final_dets, 0)

RTMLIB_SETTINGS = {
    'opencv': {
        'cpu': (cv2.dnn.DNN_BACKEND_OPENCV, cv2.dnn.DNN_TARGET_CPU),

        # You need to manually build OpenCV through cmake
        'cuda': (cv2.dnn.DNN_BACKEND_CUDA, cv2.dnn.DNN_TARGET_CUDA)
    },
    'onnxruntime': {
        'cpu': 'CPUExecutionProvider',
        'cuda': 'CUDAExecutionProvider'
    },
}


class BaseTool():

    def __init__(self,
                 onnx_model: str = None,
                 model_input_size: tuple = None,
                 mean: tuple = None,
                 std: tuple = None,
                 nms_thr: float = 0.5,
                 score_thr: float =0.3,
                 backend: str = 'tensorrt',
                 device: str = 'cuda'):

        if backend == 'opencv':
            try:
                providers = RTMLIB_SETTINGS[backend][device]

                session = cv2.dnn.readNetFromONNX(onnx_model)
                session.setPreferableBackend(providers[0])
                session.setPreferableTarget(providers[1])
                self.session = session
            except Exception:
                raise RuntimeError(
                    'This model is not supported by OpenCV'
                    ' backend, please use `pip install'
                    ' onnxruntime` or `pip install'
                    ' onnxruntime-gpu` to install onnxruntime'
                    ' backend. Then specify `backend=onnxruntime`.')  # noqa

        elif backend == 'onnxruntime':
            import onnxruntime as ort
            providers = RTMLIB_SETTINGS[backend][device]

            self.session = ort.InferenceSession(path_or_bytes=onnx_model,
                                                providers=[providers])
        elif backend == 'tensorrt':
            import tensorrt as trt
            import pycuda.driver as cuda
            import pycuda.autoinit
            
            engine_path = onnx_model.replace('.onnx', '.engine')
            logger = trt.Logger(trt.Logger.WARNING)
            logger.min_severity = trt.Logger.Severity.ERROR
            runtime = trt.Runtime(logger)
            trt.init_libnvinfer_plugins(logger,'') # initialize TensorRT plugins
            with open(engine_path, "rb") as f:
                serialized_engine = f.read()
            engine = runtime.deserialize_cuda_engine(serialized_engine)
            self.imgsz = engine.get_binding_shape(0)[2:]  # get the read shape of model, in case user input it wrong
            self.context = engine.create_execution_context()
            self.inputs, self.outputs, self.bindings = [], [], []
            self.stream = cuda.Stream()
            for binding in engine:
                size = trt.volume(engine.get_binding_shape(binding))
                dims = engine.get_binding_shape(binding)
                if dims[1] < 0:
                    size *= -1
                dtype = trt.nptype(engine.get_binding_dtype(binding))
                host_mem = cuda.pagelocked_empty(size, dtype)
                device_mem = cuda.mem_alloc(host_mem.nbytes)
                self.bindings.append(int(device_mem))
                if engine.binding_is_input(binding):
                    self.inputs.append({'host': host_mem, 'device': device_mem})
                else:
                    self.outputs.append({'host': host_mem, 'device': device_mem})
        else:
            raise NotImplementedError

        print(f'load {onnx_model} with {backend} backend')

        self.onnx_model = onnx_model
        self.model_input_size = model_input_size
        self.mean = mean
        self.std = std
        self.nms_thr = nms_thr
        self.score_thr = score_thr
        self.backend = backend
        self.device = device
    
    def inference(self, img: np.ndarray):
        """Inference model.

        Args:
            img (np.ndarray): Input image in shape.

        Returns:
            outputs (np.ndarray): Output of RTMPose model.
        """
        # build input to (1, 3, H, W)
        img = img.transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        input = img[None, :, :, :]

        # run model
        if self.backend == 'onnxruntime':
            sess_input = {self.session.get_inputs()[0].name: input}
            sess_output = []
            for out in self.session.get_outputs():
                sess_output.append(out.name)

            outputs = self.session.run(sess_output, sess_input)
        elif self.backend == 'tensorrt':
            import pycuda.driver as cuda
            self.inputs[0]['host'] = np.ravel(img)
            # transfer data to the gpu
            for inp in self.inputs:
                cuda.memcpy_htod_async(inp['device'], inp['host'], self.stream)
            # run inference
            self.context.execute_async_v2(
                bindings=self.bindings,
                stream_handle=self.stream.handle)
            # fetch outputs from gpu
            for out in self.outputs:
                cuda.memcpy_dtoh_async(out['host'], out['device'], self.stream)
            # synchronize stream
            self.stream.synchronize()

            outputs = [out['host'] for out in self.outputs]
            outputs = np.array(outputs).reshape(1, 1, 8400, 6)
        return outputs
    
class RTMDet(BaseTool):

    def __init__(self,
                 onnx_model: str,
                 model_input_size: tuple = (640, 640),
                 mean: tuple = (103.5300, 116.2800, 123.6750),
                 std: tuple = (57.3750, 57.1200, 58.3950),
                 nms_thr: float = 0.5,
                 score_thr: float =0.3,
                 backend: str = 'tensorrt',
                 device: str = 'cpu'):
        super().__init__(onnx_model,
                         model_input_size,
                         mean,
                         std,
                         nms_thr=nms_thr,
                         score_thr=score_thr,
                         backend=backend,
                         device=device)

    def __call__(self, image: np.ndarray):
        image, ratio = self.preprocess(image)
        outputs = self.inference(image)[0]
        results = self.postprocess(outputs, ratio)
        return results

    def preprocess(self, img: np.ndarray):
        """Do preprocessing for RTMPose model inference.

        Args:
            img (np.ndarray): Input image in shape.

        Returns:
            tuple:
            - resized_img (np.ndarray): Preprocessed image.
            - center (np.ndarray): Center of image.
            - scale (np.ndarray): Scale of image.
        """
        if len(img.shape) == 3:
            padded_img = np.ones(
                (self.model_input_size[0], self.model_input_size[1], 3),
                dtype=np.uint8) * 114
        else:
            padded_img = np.ones(self.model_input_size, dtype=np.uint8) * 114

        ratio = min(self.model_input_size[0] / img.shape[0],
                    self.model_input_size[1] / img.shape[1])
        resized_img = cv2.resize(
            img,
            (int(img.shape[1] * ratio), int(img.shape[0] * ratio)),
            interpolation=cv2.INTER_LINEAR,
        ).astype(np.uint8)
        padded_shape = (int(img.shape[0] * ratio), int(img.shape[1] * ratio))
        padded_img[:padded_shape[0], :padded_shape[1]] = resized_img

        # normalize image
        if self.mean is not None:
            self.mean = np.array(self.mean)
            self.std = np.array(self.std)
            padded_img = (padded_img - self.mean) / self.std

        return padded_img, ratio

    def postprocess(
        self,
        outputs: List[np.ndarray],
        ratio: float = 1.,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Do postprocessing for RTMDet model inference.

        Args:
            outputs (List[np.ndarray]): Outputs of RTMDet model.
            ratio (float): Ratio of preprocessing.

        Returns:
            tuple:
            - final_boxes (np.ndarray): Final bounding boxes.
            - final_scores (np.ndarray): Final scores.
        """

        if outputs.shape[-1] == 4:
            # onnx without nms module

            grids = []
            expanded_strides = []
            strides = [8, 16, 32]

            hsizes = [self.model_input_size[0] // stride for stride in strides]
            wsizes = [self.model_input_size[1] // stride for stride in strides]

            for hsize, wsize, stride in zip(hsizes, wsizes, strides):
                xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
                grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
                grids.append(grid)
                shape = grid.shape[:2]
                expanded_strides.append(np.full((*shape, 1), stride))

            grids = np.concatenate(grids, 1)
            expanded_strides = np.concatenate(expanded_strides, 1)
            outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
            outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides

            predictions = outputs[0]
            boxes = predictions[:, :4]
            scores = predictions[:, 4:5] * predictions[:, 5:]

            boxes_xyxy = np.ones_like(boxes)
            boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
            boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
            boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
            boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
            boxes_xyxy /= ratio
            dets = multiclass_nms(boxes_xyxy,
                                  scores,
                                  nms_thr=self.nms_thr,
                                  score_thr=self.score_thr)
            if dets is not None:
                pack_dets = (dets[:, :4], dets[:, 4], dets[:, 5])
                final_boxes, final_scores, final_cls_inds = pack_dets
                isscore = final_scores > self.score_thr
                iscat = final_cls_inds == 0
                isbbox = [i and j for (i, j) in zip(isscore, iscat)]
                final_boxes = final_boxes[isbbox]

        elif outputs.shape[-1] == 5:
            # onnx contains nms module

            pack_dets = (outputs[0, :, :4], outputs[0, :, 4])
            final_boxes, final_scores = pack_dets
            final_boxes /= ratio
            isscore = final_scores > self.score_thr
            isbbox = [i for i in isscore]
            final_boxes = final_boxes[isbbox]

        elif outputs.shape[-1] == 6:
            # onnx static
            dets = multiclass_nms(outputs[0, :, :4], outputs[0, :, 4:6],
                                  nms_thr=self.nms_thr,
                                  score_thr=self.score_thr)
            pack_dets = (dets[:, :4], dets[:, 4], dets[:, 5])
            final_boxes, final_scores, final_cls_inds = pack_dets
            final_boxes /= ratio

        return final_boxes


if __name__=='__main__':
    mode = 'image'
    image_path = 'test.jpg'
    rtmdet_onnx_model = 'rtmdet_tiny_uav.onnx'
    rtmpose_onnx_model = ''
    build_model(rtmdet_onnx_model)

    rtmdet = RTMDet(onnx_model=rtmdet_onnx_model,
                    model_input_size=(640, 640),
                    backend='tensorrt',
                    device='cuda'
                    )
    if mode == 'video':
        cap = cv2.VideoCapture(0)

        frame_idx = 0

        while cap.isOpened():
            success, frame = cap.read()
            frame_idx += 1

            if not success:
                break
            s = time.time()
            bboxes = rtmdet(frame)
            # keypoints, scores = rtmpose(frame, bboxes=bboxes)
            det_time = time.time() - s
            print('det: ', det_time)

            img_show = frame.copy()

            img_show = draw_bbox(img_show, bboxes, (0, 255, 0))
            # img_show = draw_skeleton(img_show,
            #                      keypoints,
            #                      scores,
            #                      False,
            #                      kpt_thr=0.2,
            #                      line_width=3)
            img_show = cv2.resize(img_show, (960, 640))
            cv2.imshow('img', img_show)
            key = cv2.waitKey(25)  
            if key == ord('q'): 
                cap.release()     
                break
        cv2.destroyAllWindows()

    elif mode == 'image':
        from PIL import Image
        frame = cv2.imread(image_path)
        s = time.time()
        bboxes = rtmdet(frame)
        img_show = frame.copy()
        img_show = draw_bbox(img_show, bboxes, (0, 255, 0))
        det_time = time.time() - s
        print('det: ', det_time)
        image = Image.fromarray(img_show)
        image.show()

@vieenrose
Copy link

Alternatively, you can enable TensorrtExecutionProvider on ONNXruntime to use TensorRT as inference backend. Note that you may have to perform shape inference on ONNX model first using symbolic_shape_infer.py to prepare your model. Also for TensorRT 8.2-8.4, build custom TensorRT Ops plug-in from MMdeploy (and load the plug-in to Tensorrt Execution Provider following usage) is also required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants