Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::enqueueV3::2666, condition: mContext.profileObliviousBindings.at(profileObliviousIndex) || getPtrOrNull(mOutputAllocators, profileObliviousIndex)) #4224

Open
fgias opened this issue Oct 25, 2024 · 16 comments
Assignees
Labels
Engine Build Issues with engine build internal-bug-tracked Tracked internally, will be fixed in a future release. triaged Issue has been triaged by maintainers

Comments

@fgias
Copy link

fgias commented Oct 25, 2024

Description

We have a pytorch GNN model that we run on an Nvidia GPU with TensorRT (TRT). For the scatter_add operation we are using the scatter elements plugin for TRT. We are now trying to quantize it.

We are following the same procedure that worked for the quantization of a simple multilayer perceptron. After quantizing to INT8 with pytorch-quantization and exporting with ONNX, I pass the model to TRT with precision=INT8 without errors. However, during runtime I get the error:

3: [executionContext.cpp::enqueueV3::2666] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::enqueueV3::2666, condition: mContext.profileObliviousBindings.at(profileObliviousIndex) || getPtrOrNull(mOutputAllocators, profileObliviousIndex)
)

The plugin states that it does not support INT8, but I do not see why it cannot be left to FP32 precision while the rest of the model be quantized. Any ideas of what is causing the problem?

@lix19937
Copy link

you can try to use trtexec --int8 --fp16 --onnx=spec --verbose --plugins=spec and upload the build log here ?

@fgias
Copy link
Author

fgias commented Oct 28, 2024

Okay, so running the command, at the end I get the error below, even though running the model normally only gives the previous error I mentioned and not a segfault.

[fgiasemi@n4050501 build]$ /usr/src/tensorrt/bin/trtexec --int8 --fp16 --onnx=../etx4velo/onnx_export/gnn/gnn_q.onnx --verbose --plugins=device/tensorrt_scatter/libtensorrt_scatter.so
&&&& RUNNING TensorRT.trtexec [TensorRT v8601] # /usr/src/tensorrt/bin/trtexec --int8 --fp16 --onnx=../etx4velo/onnx_export/gnn/gnn_q.onnx --verbose --plugins=device/tensorrt_scatter/libtensorrt_scatter.so
[10/28/2024-16:21:35] [W] --plugins flag has been deprecated, use --staticPlugins flag instead.
[10/28/2024-16:21:35] [I] === Model Options ===
[10/28/2024-16:21:35] [I] Format: ONNX
[10/28/2024-16:21:35] [I] Model: ../etx4velo/onnx_export/gnn/gnn_q.onnx
[10/28/2024-16:21:35] [I] Output:
[10/28/2024-16:21:35] [I] === Build Options ===
[10/28/2024-16:21:35] [I] Max batch: explicit batch
[10/28/2024-16:21:35] [I] Memory Pools: workspace: default, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
[10/28/2024-16:21:35] [I] minTiming: 1
[10/28/2024-16:21:35] [I] avgTiming: 8
[10/28/2024-16:21:35] [I] Precision: FP32+FP16+INT8
[10/28/2024-16:21:35] [I] LayerPrecisions: 
[10/28/2024-16:21:35] [I] Layer Device Types: 
[10/28/2024-16:21:35] [I] Calibration: Dynamic
[10/28/2024-16:21:35] [I] Refit: Disabled
[10/28/2024-16:21:35] [I] Version Compatible: Disabled

...
...
...


 000000000003e8, Reformatted Output Tensor 2 to {ForeignNode[model.output_edge_classifier.9.weight.../Squeeze]} (Half[1]) -> 740 (Float[1])
Layer(Reformat): Reformatting CopyNode for Output Tensor 1 to {ForeignNode[model.output_edge_classifier.9.weight.../Squeeze]}, Tactic: 0x00000000000003e8, Reformatted Output Tensor 1 to {ForeignNode[model.output_edge_classifier.9.weight.../Squeeze]} (Half[1,256]) -> e (Float[1,256])
Layer(Reformat): Reformatting CopyNode for Output Tensor 0 to {ForeignNode[model.output_edge_classifier.9.weight.../Squeeze]}, Tactic: 0x00000000000003e8, Reformatted Output Tensor 0 to {ForeignNode[model.output_edge_classifier.9.weight.../Squeeze]} (Half[1,256]) -> edge_score (Float[1,256])
[10/28/2024-16:21:49] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +5, now: CPU 0, GPU 5 (MiB)
[10/28/2024-16:21:49] [V] [TRT] Adding 1 engine(s) to plan file.
[10/28/2024-16:21:49] [I] Engine built in 14.364 sec.
[10/28/2024-16:21:49] [I] [TRT] Loaded engine size: 6 MiB
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Local registry did not find scatter_add creator. Will try parent registry if enabled.
[10/28/2024-16:21:49] [V] [TRT] Global registry found scatter_add creator.
[10/28/2024-16:21:49] [V] [TRT] Trying to load shared library libcublas.so.12
[10/28/2024-16:21:49] [V] [TRT] Loaded shared library libcublas.so.12
[10/28/2024-16:21:49] [V] [TRT] Using cublas as plugin tactic source
[10/28/2024-16:21:49] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 2877, GPU 765 (MiB)
[10/28/2024-16:21:49] [V] [TRT] Trying to load shared library libcudnn.so.8
[10/28/2024-16:21:49] [V] [TRT] Loaded shared library libcudnn.so.8
[10/28/2024-16:21:49] [V] [TRT] Using cuDNN as plugin tactic source
[10/28/2024-16:21:49] [I] [TRT] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2877, GPU 773 (MiB)
[10/28/2024-16:21:49] [V] [TRT] Deserialization required 9841 microseconds.
[10/28/2024-16:21:49] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +4, now: CPU 0, GPU 4 (MiB)
[10/28/2024-16:21:49] [I] Engine deserialized in 0.0112783 sec.
[10/28/2024-16:21:49] [V] [TRT] Trying to load shared library libcublas.so.12
[10/28/2024-16:21:49] [V] [TRT] Loaded shared library libcublas.so.12
[10/28/2024-16:21:49] [V] [TRT] Using cublas as plugin tactic source
[10/28/2024-16:21:49] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 2877, GPU 765 (MiB)
[10/28/2024-16:21:49] [V] [TRT] Trying to load shared library libcudnn.so.8
[10/28/2024-16:21:49] [V] [TRT] Loaded shared library libcudnn.so.8
[10/28/2024-16:21:49] [V] [TRT] Using cuDNN as plugin tactic source
[10/28/2024-16:21:49] [I] [TRT] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2877, GPU 773 (MiB)
[10/28/2024-16:21:49] [V] [TRT] Total per-runner device persistent memory is 0
[10/28/2024-16:21:49] [V] [TRT] Total per-runner host persistent memory is 2912
[10/28/2024-16:21:49] [V] [TRT] Allocated activation device memory of size 3417088
[10/28/2024-16:21:49] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +4, now: CPU 0, GPU 8 (MiB)
[10/28/2024-16:21:49] [W] [TRT] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[10/28/2024-16:21:49] [I] Setting persistentCacheLimit to 0 bytes.
[10/28/2024-16:21:49] [V] Using enqueueV3.
[10/28/2024-16:21:49] [I] Using random values for input x
[10/28/2024-16:21:49] [I] Input binding for x with dimensions 1x3 is created.
[10/28/2024-16:21:49] [I] Using random values for input start
[10/28/2024-16:21:49] [I] Input binding for start with dimensions 1 is created.
[10/28/2024-16:21:49] [I] Using random values for input end
[10/28/2024-16:21:49] [I] Input binding for end with dimensions 1 is created.
[10/28/2024-16:21:49] [I] Output binding for edge_score with dimensions 1x256 is created.
[10/28/2024-16:21:49] [I] Output binding for e with dimensions 1x256 is created.
[10/28/2024-16:21:49] [I] Output binding for 740 with dimensions 1 is created.
[10/28/2024-16:21:49] [I] Starting inference
Cuda failure: an illegal memory access was encountered
Aborted (core dumped)

@lix19937
Copy link

lix19937 commented Oct 29, 2024

It seems that libtensorrt_scatter.so has bug. You can modify the enqueue function run as dummy (once enter, then return 0;), then recompile the plugin, rerun the trtexec.

@fgias
Copy link
Author

fgias commented Oct 29, 2024

I did it and the error is exactly the same, so i guess the bug is not in the enqueue function?

@kevinch-nv
Copy link
Collaborator

The error message in the OP usually indicates that a buffer was not properly assigned to an output binding of an engine. Given that trtexec doesn't throw the same error, it may be related to data-dependent shapes, where an OuputAllocator class needs to be provided and assigned to a data-dependent output binding.

On the trtexec side, are you able to repro the same issue with the unquantized model? If the model can be shared for us to debug that would be useful to have as well.

@kevinch-nv kevinch-nv added the triaged Issue has been triaged by maintainers label Oct 29, 2024
@kevinch-nv kevinch-nv self-assigned this Oct 29, 2024
@fgias
Copy link
Author

fgias commented Oct 30, 2024

With the unquantized model, there is no issue at all. The issue in the OP comes when I try to replace the unquantized model with the INT8 version of it. I attach the onnx files for the quantized gnn_q.onnx and unquantized gnn.onnx models, as well as our scatter_add plugin shared library models.zip Thank you.

@kevinch-nv
Copy link
Collaborator

Finally got around to taking a look at the issue, using TensorRT 8.6 even with the non-quantized model I'm seeing the similar cuda error:

/ssd/TRT_VERSIONS/TensorRT-8.6.2.3/bin/trtexec --onnx=gnn.onnx --plugins=/ssd/github_issues/scatterint8/models/libtensorrt_scatter.so
&&&& RUNNING TensorRT.trtexec [TensorRT v8602] # /ssd/TRT_VERSIONS/TensorRT-8.6.2.3/bin/trtexec --onnx=gnn.onnx --plugins=/ssd/github_issues/scatterint8/models/libtensorrt_scatter.so
...
[11/06/2024-00:19:12] [I] Setting persistentCacheLimit to 0 bytes.
[11/06/2024-00:19:12] [I] Using random values for input x
[11/06/2024-00:19:12] [I] Input binding for x with dimensions 1x3 is created.
[11/06/2024-00:19:12] [I] Using random values for input start
[11/06/2024-00:19:12] [I] Input binding for start with dimensions 1 is created.
[11/06/2024-00:19:12] [I] Using random values for input end
[11/06/2024-00:19:12] [I] Input binding for end with dimensions 1 is created.
[11/06/2024-00:19:12] [I] Output binding for edge_score with dimensions 1 is created.
[11/06/2024-00:19:12] [I] Starting inference
[11/06/2024-00:19:12] [E] Error[1]: [deviceToShapeHostRunner.cpp::execute::38] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
Cuda failure: an illegal memory access was encountered
Aborted (core dumped)

I'm debugging further (trying to repro with TensorRT 10.X). Were you able to successfully run the non-quantized model with the same command?

@fgias
Copy link
Author

fgias commented Nov 7, 2024

You're right, I get the same error with trtexec with the non-quantized model

...
[11/07/2024-10:57:37] [I] Input binding for start with dimensions 1 is created.
[11/07/2024-10:57:37] [I] Using random values for input end
[11/07/2024-10:57:37] [I] Input binding for end with dimensions 1 is created.
[11/07/2024-10:57:37] [I] Output binding for edge_score with dimensions 1 is created.
[11/07/2024-10:57:37] [I] Starting inference
Cuda failure: an illegal memory access was encountered
Aborted (core dumped)

but when I run the model normally, I don't get any error, it runs properly.

@kevinch-nv
Copy link
Collaborator

What do you mean when you say you run the model "normally", is it in your inference script?

I've modified the model to use the native ScatterElements to work with TensorRT 10.X whose plugin library includes the ScatterElements plugin.

After doing so, I'm seeing this following error:

[11/12/2024-15:48:53] [E] Error[1]: IBuilder::buildSerializedNetwork: Error Code 1: Internal Error (Constant does not support output type Bool.)

This looks to be an internal builder issue. I've filed an internal bug to track this.

@kevinch-nv kevinch-nv added Engine Build Issues with engine build internal-bug-tracked Tracked internally, will be fixed in a future release. labels Nov 12, 2024
@fgias
Copy link
Author

fgias commented Nov 14, 2024

Yes, indeed, normally I mean for inferring it inside our pipeline. There is no error and the results are the expected results so, with the unquantized GNN, the plugin seems to be running as expected. The problem (the OP) starts when I try passing the quantized GNN onnx file, with INT8 precision.

Regarding the error, what does that mean? That it's of the internal builder? Does it mean that there's nothing we can do at this point?

@fgias
Copy link
Author

fgias commented Dec 2, 2024

Hi, do we have any news on this?

@kevinch-nv
Copy link
Collaborator

kevinch-nv commented Dec 2, 2024

I'm wondering what the difference is between the code in your pipeline versus trtexec. We use trtexec as our baseline, and I'm unable to run inference through trtexec with your attached model.

We've made some progress internally and believe that the root cause of the illegal cuda memory access stems from the scatter nodes.

I've modified the custom scatter nodes back to regular ScatterElements and ran shape analysis on it. The rank for the updates tensor != the rank of the indices tensor in the model, which violates the operator specification. What are the expected shapes in the original model?

Image

@fgias
Copy link
Author

fgias commented Dec 9, 2024

In our case, we use the scatter add operation for the message passing of the GNN, and the data are the node encoding features, which in our case is 32-dimensional. The aggregation is over neighbors, so a specific set of nodes given by the indices.

@kevinch-nv
Copy link
Collaborator

The ScatterAdd operation doesn't work if the ranks of the inputs do not match, and in ONNX, the shapes between indices and updates must match as well. This looks to be an export issue where the shape of indices is not exported properly.

@fgias
Copy link
Author

fgias commented Dec 10, 2024

Is this due to something on our side?

@kevinch-nv
Copy link
Collaborator

Yes, looks to be a model definition issue. Can you double check the scatter module in your pytorch model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Engine Build Issues with engine build internal-bug-tracked Tracked internally, will be fixed in a future release. triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants