Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This version of TensorRT only supports input K as an initializer (Detectron2->ONNX->TensorRT) #2721

Closed
niqbal996 opened this issue Mar 1, 2023 · 2 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@niqbal996
Copy link

Description

I am trying to export the Detectron2 models such as FCOS, RetinaNet (Bounding box detectors only) and will also try for FasterRCNN to ONNX->TensorRT and then test them on the Jetson devices. I have successfully done that with some transformer based detectors but struggling with errors with FCOS and RetinaNet with the TopK node. I have tried the script below to use onnxgraph surgeon to simplify the model but the TopK nodes are still there and error persists when parsing the ONNX model. There is a discussion going on in this issue as well (#2678 (comment)). In the script below I have iteratively used the onnx graph surgeon and constant folding until the nodes are not reduced any more based on your examples. I have checked that the OPSET 17 does support TopK operation. How can I fix this error to be able to generate the TensorRT engine file? Thank you.

Environment

TensorRT Version: 8.4.1.5 on host system
NVIDIA GPU: Jetson ORIN
NVIDIA Driver Version:
CUDA Version: 11.4.239
CUDNN Version: 8.4.1.50
Operating System: Ubuntu 20.04 LTS
Python Version (if applicable): 3.8.10
Tensorflow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if so, version): r8.5.2.2-devel l4t docker container
Onnx version: 1.13.0
Onnx Graph surgeon: 0.3.26

Relevant Files

You can download the ONNX model from this link onnx model FCOS

Steps To Reproduce

You can run this script and provide the path to the model or you can use trtexec tool like this to get the same error:

trtexec --onnx=fcos_opset17.onnx --useCudaGraph
import os
import sys
import logging
import onnx_graphsurgeon as gs
import onnx
import tensorrt as trt
from onnx import shape_inference

logging.basicConfig(level=logging.INFO)
logging.getLogger("ModelHelper").setLevel(logging.INFO)
log = logging.getLogger("ModelHelper")

def save(graph, output_path):
        """
        Save the ONNX model to the given location.
        :param output_path: Path pointing to the location where to write out the updated ONNX model.
        """
        graph.cleanup().toposort()
        model = gs.export_onnx(graph)
        output_path = os.path.realpath(output_path)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        onnx.save(model, output_path)
        log.info("Saved ONNX model to {}".format(output_path))
class EngineBuilder:
    """
    Parses an ONNX graph and builds a TensorRT engine from it.
    """

    def __init__(self, verbose=False, workspace=8):
        """
        :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger.
        :param workspace: Max memory workspace to allow, in Gb.
        """
        self.trt_logger = trt.Logger(trt.Logger.INFO)
        if verbose:
            self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE

        trt.init_libnvinfer_plugins(self.trt_logger, namespace="")

        self.builder = trt.Builder(self.trt_logger)
        self.config = self.builder.create_builder_config()
        self.config.max_workspace_size = workspace * (2 ** 30)

        self.batch_size = None
        self.network = None
        self.parser = None

    def create_network(self, onnx_path):
        """
        Parse the ONNX graph and create the corresponding TensorRT network definition.
        :param onnx_path: The path to the ONNX graph to load.
        """
        network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

        self.network = self.builder.create_network(network_flags)
        self.parser = trt.OnnxParser(self.network, self.trt_logger)

        onnx_path = os.path.realpath(onnx_path)
        with open(onnx_path, "rb") as f:
            if not self.parser.parse(f.read()):
                log.error("Failed to load ONNX file: {}".format(onnx_path))
                for error in range(self.parser.num_errors):
                    log.error(self.parser.get_error(error))
                sys.exit(1)

        inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
        outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]

        log.info("Network Description")
        for input in inputs:
            self.batch_size = input.shape[0]
            log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
        for output in outputs:
            log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
        assert self.batch_size > 0
        self.builder.max_batch_size = self.batch_size

onnx_model_path = '/opt/git/fcos_opset17.onnx'
onnx_simp_path = '/opt/git/fcos_opset17_simp_python.onnx'
graph = gs.import_onnx(onnx.load(onnx_model_path))
assert graph
log.info("ONNX graph loaded successfully")
graph.fold_constants()
for i in range(20):
    count_before = len(graph.nodes)
    graph.cleanup().toposort()
    try:
        for node in graph.nodes:
            for o in node.outputs:
                o.shape = None
        model = gs.export_onnx(graph)
        model = shape_inference.infer_shapes(model)
        graph = gs.import_onnx(model)
    except Exception as e:
        log.info("Shape inference could not be performed at this time:\n{}".format(e))
    try:
        graph.fold_constants(fold_shapes=True)
    except TypeError as e:
        log.error("This version of ONNX GraphSurgeon does not support folding shapes, please upgrade your "
                    "onnx_graphsurgeon module. Error:\n{}".format(e))
        raise

    count_after = len(graph.nodes)
    log.info("Reduced model nodes from {} to {} in iteration number {}".format(count_before, count_after, i))
    if count_before == count_after:
        # No new folding occurred in this iteration, so we can stop for now.
        log.info("Model has not been simplified any further. Saving model now")
        model = save(graph, onnx_simp_path)
        break
    if i==11:
        model = save(graph, onnx_simp_path)

builder = EngineBuilder(verbose=True, workspace=1)
builder.create_network(onnx_simp_path)

Error output from trtexec:

[03/01/2023-12:42:04] [E] [TRT] ModelImporter.cpp:726: While parsing node number 547 [TopK -> "/model/TopK_2_output_0"]:
[03/01/2023-12:42:04] [E] [TRT] ModelImporter.cpp:727: --- Begin node ---
[03/01/2023-12:42:04] [E] [TRT] ModelImporter.cpp:728: input: "/model/GatherND_2_output_0"
input: "/model/Reshape_50_output_0"
output: "/model/TopK_2_output_0"
output: "/model/TopK_2_output_1"
name: "/model/TopK_2"
op_type: "TopK"
attribute {
  name: "axis"
  i: -1
  type: INT
}
attribute {
  name: "largest"
  i: 1
  type: INT
}
attribute {
  name: "sorted"
  i: 1
  type: INT
}

[03/01/2023-12:42:04] [E] [TRT] ModelImporter.cpp:729: --- End node ---
[03/01/2023-12:42:04] [E] [TRT] ModelImporter.cpp:731: ERROR: ModelImporter.cpp:168 In function parseGraph:
[6] Invalid Node - /model/TopK_2
This version of TensorRT only supports input K as an initializer. Try applying constant folding on the model using Polygraphy: https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/examples/cli/surgeon/02_folding_constants
[03/01/2023-12:42:04] [E] Failed to parse onnx file
[03/01/2023-12:42:04] [I] Finish parsing network model
[03/01/2023-12:42:04] [E] Parsing model failed
[03/01/2023-12:42:04] [E] Failed to create engine from model or file.
[03/01/2023-12:42:04] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8502] # trtexec --onnx=../fcos_opset17_simp.onnx --useCudaGraph

Python ERROR:

ERROR:ModelHelper:Failed to load ONNX file: /opt/git/fcos_opset17_simp_python.onnx
ERROR:ModelHelper:In node 547 (parseGraph): INVALID_NODE: Invalid Node - /model/TopK_2
This version of TensorRT only supports input K as an initializer. Try applying constant folding on the model using Polygraphy: https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/examples/cli/surgeon/02_folding_constants

Any help is appreciated. Thank you.

@zerollzeng
Copy link
Collaborator

I've asked this internally, since this is a duplicate of #2678, how about let's close this bug and follow it there? Thanks!

@zerollzeng zerollzeng self-assigned this Mar 5, 2023
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Mar 5, 2023
@ttyio
Copy link
Collaborator

ttyio commented Mar 28, 2023

closing since no activity for more than 3 weeks, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants