diff --git a/CHANGELOG.md b/CHANGELOG.md index c81d2225..f1d4f86a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,30 @@ # TensorRT OSS Release Changelog -## 10.4.0 GA - 2024-09-11 + +## 10.5.0 GA - 2024-09-30 +Key Features and Updates: + +- Demo changes + - Added [Flux.1-dev](demo/Diffusion) pipeline +- Sample changes + - None +- Plugin changes + - Migrated `IPluginV2`-descendent versions of `bertQKVToContextPlugin` (1, 2, 3) to newer versions (4, 5, 6 respectively) which implement `IPluginV3`. + - Note: + - The newer versions preserve the attributes and I/O of the corresponding older plugin version + - The older plugin versions are deprecated and will be removed in a future release +- Quickstart guide + - None +- Parser changes + - Added support for real-valued `STFT` operations + - Improved error handling in `IParser` + +Known issues: + +- Demos: + - TensorRT engine might not be build successfully when using `--fp8` flag on H100 GPUs. + +## 10.4.0 GA - 2024-09-18 Key Features and Updates: - Demo changes diff --git a/CMakeLists.txt b/CMakeLists.txt index 2928f5ef..c0a0cf60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -158,7 +158,11 @@ endif() find_library_create_target(nvinfer ${nvinfer_lib_name} SHARED ${TRT_LIB_DIR}) -find_library(CUDART_LIB cudart_static HINTS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES lib lib/x64 lib64) +if (DEFINED USE_CUGFX) + find_library(CUDART_LIB cugfx_dll HINTS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES lib lib/x64 lib64) +else() + find_library(CUDART_LIB cudart_static HINTS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES lib lib/x64 lib64) +endif() if (NOT MSVC) find_library(RT_LIB rt) @@ -174,7 +178,6 @@ if (DEFINED GPU_ARCHS) separate_arguments(GPU_ARCHS) else() list(APPEND GPU_ARCHS - 70 75 ) @@ -202,7 +205,7 @@ endif() set(BERT_GENCODES) # Generate SASS for each architecture foreach(arch ${GPU_ARCHS}) - if (${arch} GREATER_EQUAL 70) + if (${arch} GREATER_EQUAL 75) set(BERT_GENCODES "${BERT_GENCODES} -gencode arch=compute_${arch},code=sm_${arch}") endif() set(GENCODES "${GENCODES} -gencode arch=compute_${arch},code=sm_${arch}") @@ -211,7 +214,7 @@ endforeach() # Generate PTX for the last architecture in the list. list(GET GPU_ARCHS -1 LATEST_SM) set(GENCODES "${GENCODES} -gencode arch=compute_${LATEST_SM},code=compute_${LATEST_SM}") -if (${LATEST_SM} GREATER_EQUAL 70) +if (${LATEST_SM} GREATER_EQUAL 75) set(BERT_GENCODES "${BERT_GENCODES} -gencode arch=compute_${LATEST_SM},code=compute_${LATEST_SM}") endif() diff --git a/README.md b/README.md index cdd6bdbb..bb2f2e41 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ You can skip the **Build** section to enjoy TensorRT with Python. To build the TensorRT-OSS components, you will first need the following software packages. **TensorRT GA build** -* TensorRT v10.4.0.26 +* TensorRT v10.5.0.18 * Available from direct download links listed below **System Packages** @@ -73,25 +73,25 @@ To build the TensorRT-OSS components, you will first need the following software If using the TensorRT OSS build container, TensorRT libraries are preinstalled under `/usr/lib/x86_64-linux-gnu` and you may skip this step. Else download and extract the TensorRT GA build from [NVIDIA Developer Zone](https://developer.nvidia.com) with the direct links below: - - [TensorRT 10.4.0.26 for CUDA 11.8, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz) - - [TensorRT 10.4.0.26 for CUDA 12.6, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz) - - [TensorRT 10.4.0.26 for CUDA 11.8, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-11.8.zip) - - [TensorRT 10.4.0.26 for CUDA 12.6, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip) + - [TensorRT 10.5.0.18 for CUDA 11.8, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz) + - [TensorRT 10.5.0.18 for CUDA 12.6, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz) + - [TensorRT 10.5.0.18 for CUDA 11.8, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/zip/TensorRT-10.5.0.18.Windows.win10.cuda-11.8.zip) + - [TensorRT 10.5.0.18 for CUDA 12.6, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/zip/TensorRT-10.5.0.18.Windows.win10.cuda-12.6.zip) **Example: Ubuntu 20.04 on x86-64 with cuda-12.6** ```bash cd ~/Downloads - tar -xvzf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz - export TRT_LIBPATH=`pwd`/TensorRT-10.4.0.26 + tar -xvzf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz + export TRT_LIBPATH=`pwd`/TensorRT-10.5.0.18 ``` **Example: Windows on x86-64 with cuda-12.6** ```powershell - Expand-Archive -Path TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip - $env:TRT_LIBPATH="$pwd\TensorRT-10.4.0.26\lib" + Expand-Archive -Path TensorRT-10.5.0.18.Windows.win10.cuda-12.6.zip + $env:TRT_LIBPATH="$pwd\TensorRT-10.5.0.18\lib" ``` ## Setting Up The Build Environment diff --git a/VERSION b/VERSION index 9de5818c..5099d096 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -10.4.0.26 +10.5.0.18 diff --git a/demo/BERT/README.md b/demo/BERT/README.md index 2aa91952..074848c3 100755 --- a/demo/BERT/README.md +++ b/demo/BERT/README.md @@ -75,7 +75,7 @@ The following software version configuration has been tested: |Software|Version| |--------|-------| |Python|>=3.8| -|TensorRT|10.4.0.26| +|TensorRT|10.5.0.18| |CUDA|12.6| ## Setup @@ -430,244 +430,245 @@ The following sections provide details on how we achieved our performance and in Results were obtained by running `scripts/inference_benchmark.sh --gpu Ampere` on NVIDIA A100 (40G). -##### BERT base +##### BERT Base | Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 0.69 | 0.69 | 0.55 | 0.79 | 0.79 | 0.63 | -| 128 | 2 | 0.60 | 0.76 | 0.60 | 0.72 | 0.91 | 0.72 | -| 128 | 4 | 0.73 | 0.93 | 0.73 | 1.09 | 1.09 | 0.94 | -| 128 | 8 | 1.21 | 1.21 | 0.95 | 1.31 | 1.31 | 1.30 | -| 128 | 12 | 1.40 | 1.40 | 1.21 | 1.72 | 1.72 | 1.72 | -| 128 | 16 | 1.34 | 1.71 | 1.34 | 2.08 | 2.08 | 2.06 | -| 128 | 24 | 1.82 | 1.83 | 1.82 | 3.05 | 3.06 | 3.03 | -| 128 | 32 | 2.23 | 2.24 | 2.23 | 3.95 | 3.99 | 3.91 | -| 128 | 64 | 4.19 | 4.20 | 4.14 | 7.82 | 7.83 | 7.69 | -| 128 | 128 | 8.14 | 8.19 | 8.09 | 15.37 | 15.42 | 15.32 | -| 384 | 1 | 1.13 | 1.45 | 1.14 | 1.25 | 1.60 | 1.26 | -| 384 | 2 | 1.32 | 1.69 | 1.32 | 1.55 | 1.98 | 1.55 | -| 384 | 4 | 1.66 | 2.12 | 1.66 | 2.12 | 2.13 | 2.12 | -| 384 | 8 | 2.21 | 2.21 | 2.20 | 3.37 | 3.40 | 3.33 | -| 384 | 12 | 3.31 | 3.31 | 3.31 | 4.82 | 4.83 | 4.78 | -| 384 | 16 | 4.00 | 4.00 | 4.00 | 6.38 | 6.43 | 6.37 | -| 384 | 24 | 5.70 | 5.75 | 5.70 | 9.44 | 9.49 | 9.35 | -| 384 | 32 | 7.72 | 7.74 | 7.66 | 13.02 | 13.02 | 12.91 | -| 384 | 64 | 14.88 | 14.90 | 14.84 | 25.17 | 25.25 | 24.88 | -| 384 | 128 | 29.00 | 29.01 | 28.83 | 49.03 | 49.22 | 48.77 | - -##### BERT large +| 128 | 1 | 0.53 | 0.68 | 0.54 | 0.79 | 0.79 | 0.64 | +| 128 | 2 | 0.76 | 0.76 | 0.60 | 0.72 | 0.91 | 0.72 | +| 128 | 4 | 0.73 | 0.92 | 0.73 | 1.03 | 1.04 | 0.93 | +| 128 | 8 | 0.94 | 1.20 | 0.95 | 1.31 | 1.31 | 1.31 | +| 128 | 12 | 1.19 | 1.20 | 1.19 | 1.72 | 1.73 | 1.72 | +| 128 | 16 | 1.33 | 1.71 | 1.34 | 2.07 | 2.08 | 2.05 | +| 128 | 24 | 1.82 | 1.82 | 1.81 | 3.04 | 3.07 | 3.01 | +| 128 | 32 | 2.23 | 2.24 | 2.23 | 3.90 | 3.93 | 3.86 | +| 128 | 64 | 4.15 | 4.17 | 4.12 | 7.62 | 7.70 | 7.57 | +| 128 | 128 | 8.11 | 8.12 | 8.03 | 15.34 | 15.35 | 15.20 | +| 384 | 1 | 1.13 | 1.45 | 1.13 | 1.24 | 1.25 | 1.24 | +| 384 | 2 | 1.31 | 1.31 | 1.31 | 1.54 | 1.98 | 1.55 | +| 384 | 4 | 1.66 | 1.66 | 1.66 | 2.12 | 2.12 | 2.12 | +| 384 | 8 | 2.21 | 2.21 | 2.20 | 3.34 | 3.36 | 3.32 | +| 384 | 12 | 3.32 | 3.32 | 3.31 | 4.78 | 4.82 | 4.77 | +| 384 | 16 | 4.01 | 4.01 | 4.00 | 6.37 | 6.44 | 6.35 | +| 384 | 24 | 5.71 | 5.71 | 5.70 | 9.47 | 9.49 | 9.39 | +| 384 | 32 | 7.64 | 7.64 | 7.63 | 13.00 | 13.04 | 12.85 | +| 384 | 64 | 14.87 | 14.88 | 14.73 | 25.12 | 25.14 | 24.78 | +| 384 | 128 | 28.96 | 28.97 | 28.70 | 48.93 | 49.13 | 48.57 | + +##### BERT Large | Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 1.24 | 1.24 | 1.24 | 1.54 | 1.55 | 1.54 | -| 128 | 2 | 1.42 | 1.79 | 1.42 | 1.82 | 1.82 | 1.82 | -| 128 | 4 | 1.78 | 1.79 | 1.78 | 2.53 | 2.53 | 2.52 | -| 128 | 8 | 2.64 | 2.64 | 2.64 | 4.07 | 4.10 | 4.06 | -| 128 | 12 | 3.11 | 3.12 | 3.11 | 5.08 | 5.10 | 5.03 | -| 128 | 16 | 4.03 | 4.03 | 4.03 | 6.95 | 6.95 | 6.90 | -| 128 | 24 | 5.32 | 5.34 | 5.30 | 9.80 | 9.90 | 9.72 | -| 128 | 32 | 7.07 | 7.07 | 7.00 | 13.08 | 13.08 | 12.93 | -| 128 | 64 | 12.94 | 13.01 | 12.82 | 24.83 | 24.99 | 24.69 | -| 128 | 128 | 25.29 | 25.29 | 25.09 | 49.70 | 49.72 | 49.06 | -| 384 | 1 | 2.55 | 2.56 | 2.55 | 2.96 | 2.96 | 2.96 | +| 128 | 1 | 1.22 | 1.23 | 1.22 | 1.54 | 1.91 | 1.55 | +| 128 | 2 | 1.42 | 1.42 | 1.41 | 1.82 | 1.82 | 1.82 | +| 128 | 4 | 1.78 | 2.06 | 1.79 | 2.50 | 2.50 | 2.50 | +| 128 | 8 | 2.64 | 2.64 | 2.64 | 3.98 | 3.98 | 3.98 | +| 128 | 12 | 3.09 | 3.09 | 3.08 | 5.02 | 5.07 | 4.99 | +| 128 | 16 | 4.09 | 4.09 | 4.08 | 6.93 | 6.94 | 6.86 | +| 128 | 24 | 5.28 | 5.28 | 5.27 | 9.64 | 9.68 | 9.56 | +| 128 | 32 | 7.01 | 7.01 | 6.95 | 12.92 | 13.07 | 12.85 | +| 128 | 64 | 12.86 | 12.86 | 12.73 | 24.79 | 25.07 | 24.59 | +| 128 | 128 | 25.03 | 25.26 | 24.99 | 49.12 | 49.28 | 48.83 | +| 384 | 1 | 2.55 | 2.55 | 2.55 | 2.96 | 2.96 | 2.95 | | 384 | 2 | 3.04 | 3.04 | 3.03 | 3.90 | 3.90 | 3.90 | -| 384 | 4 | 4.01 | 4.01 | 4.01 | 5.74 | 5.79 | 5.71 | -| 384 | 8 | 7.16 | 7.16 | 7.15 | 11.15 | 11.24 | 11.09 | -| 384 | 12 | 9.15 | 9.23 | 9.14 | 15.46 | 15.47 | 15.40 | -| 384 | 16 | 12.40 | 12.40 | 12.29 | 21.17 | 21.18 | 21.05 | -| 384 | 24 | 17.72 | 17.85 | 17.64 | 31.09 | 31.36 | 30.81 | -| 384 | 32 | 23.29 | 23.31 | 23.15 | 41.32 | 41.34 | 40.86 | -| 384 | 64 | 45.38 | 45.40 | 45.02 | 79.95 | 80.27 | 79.31 | -| 384 | 128 | 87.97 | 87.99 | 87.89 | 156.97 | 157.56 | 155.84 | +| 384 | 4 | 4.01 | 4.02 | 4.01 | 5.68 | 5.74 | 5.67 | +| 384 | 8 | 7.18 | 7.18 | 7.17 | 11.13 | 11.13 | 11.01 | +| 384 | 12 | 9.14 | 9.15 | 9.13 | 15.43 | 15.44 | 15.32 | +| 384 | 16 | 12.28 | 12.28 | 12.27 | 21.14 | 21.15 | 20.90 | +| 384 | 24 | 17.68 | 17.68 | 17.54 | 30.98 | 31.02 | 30.68 | +| 384 | 32 | 23.24 | 23.24 | 23.02 | 41.11 | 41.20 | 40.58 | +| 384 | 64 | 44.86 | 45.13 | 44.78 | 79.25 | 79.68 | 79.10 | +| 384 | 128 | 87.82 | 87.84 | 87.69 | 156.70 | 157.02 | 155.61 | ##### Megatron Large with Sparsity | Sequence Length | Batch Size | INT8 QAT Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 1.24 | 1.56 | 1.24 | -| 128 | 2 | 1.42 | 1.42 | 1.42 | -| 128 | 4 | 1.78 | 1.79 | 1.78 | -| 128 | 8 | 2.64 | 2.65 | 2.64 | -| 128 | 12 | 3.11 | 3.12 | 3.11 | -| 128 | 16 | 4.03 | 4.03 | 4.02 | -| 128 | 24 | 5.32 | 5.34 | 5.31 | -| 128 | 32 | 7.07 | 7.09 | 7.02 | -| 128 | 64 | 12.98 | 13.01 | 12.86 | -| 128 | 128 | 25.40 | 25.55 | 25.17 | -| 384 | 1 | 2.55 | 2.55 | 2.55 | -| 384 | 2 | 3.03 | 3.04 | 3.03 | -| 384 | 4 | 4.01 | 4.01 | 4.01 | -| 384 | 8 | 7.16 | 7.16 | 7.16 | -| 384 | 12 | 9.14 | 9.23 | 9.14 | -| 384 | 16 | 12.31 | 12.41 | 12.29 | -| 384 | 24 | 17.85 | 17.90 | 17.68 | -| 384 | 32 | 23.41 | 23.51 | 23.23 | -| 384 | 64 | 45.39 | 45.40 | 45.09 | -| 384 | 128 | 88.73 | 88.79 | 88.11 | +| 128 | 1 | 1.11 | 1.40 | 1.11 | +| 128 | 2 | 1.33 | 1.33 | 1.33 | +| 128 | 4 | 1.78 | 1.78 | 1.78 | +| 128 | 8 | 2.54 | 2.54 | 2.53 | +| 128 | 12 | 2.97 | 2.97 | 2.97 | +| 128 | 16 | 3.99 | 3.99 | 3.98 | +| 128 | 24 | 4.91 | 4.91 | 4.90 | +| 128 | 32 | 7.13 | 7.13 | 7.12 | +| 128 | 64 | 11.61 | 11.62 | 11.60 | +| 128 | 128 | 21.22 | 21.32 | 21.09 | +| 384 | 1 | 1.71 | 2.15 | 1.71 | +| 384 | 2 | 2.21 | 2.21 | 2.21 | +| 384 | 4 | 3.47 | 3.48 | 3.47 | +| 384 | 8 | 5.74 | 5.74 | 5.74 | +| 384 | 12 | 8.21 | 8.21 | 8.20 | +| 384 | 16 | 10.33 | 10.34 | 10.32 | +| 384 | 24 | 14.68 | 14.69 | 14.67 | +| 384 | 32 | 18.73 | 18.74 | 18.72 | +| 384 | 64 | 35.77 | 35.78 | 35.49 | +| 384 | 128 | 67.78 | 67.95 | 67.63 | ### Inference Performance NVIDIA L4 Results were obtained by running `scripts/inference_benchmark.sh --gpu Ampere` on NVIDIA L4. -##### BERT base +##### BERT Base | Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 0.62 | 0.62 | 0.60 | 1.03 | 1.03 | 1.01 | -| 128 | 2 | 0.79 | 0.80 | 0.77 | 1.31 | 1.35 | 1.30 | -| 128 | 4 | 1.14 | 1.15 | 1.12 | 2.23 | 2.23 | 2.15 | -| 128 | 8 | 1.97 | 1.97 | 1.92 | 3.68 | 3.69 | 3.63 | -| 128 | 12 | 2.66 | 2.67 | 2.61 | 5.34 | 5.35 | 5.27 | -| 128 | 16 | 3.39 | 3.39 | 3.34 | 6.62 | 6.69 | 6.58 | -| 128 | 24 | 4.84 | 4.85 | 4.76 | 10.49 | 10.55 | 10.32 | -| 128 | 32 | 6.20 | 6.29 | 6.14 | 13.92 | 13.92 | 13.75 | -| 128 | 64 | 13.42 | 13.42 | 13.26 | 31.28 | 31.48 | 31.07 | -| 128 | 128 | 28.48 | 28.64 | 28.19 | 66.10 | 66.23 | 65.36 | -| 384 | 1 | 1.29 | 1.30 | 1.29 | 2.08 | 2.09 | 2.08 | -| 384 | 2 | 1.83 | 1.84 | 1.82 | 3.15 | 3.19 | 3.11 | -| 384 | 4 | 2.99 | 2.99 | 2.92 | 5.75 | 5.81 | 5.68 | -| 384 | 8 | 5.53 | 5.54 | 5.42 | 11.28 | 11.33 | 11.08 | -| 384 | 12 | 8.26 | 8.29 | 8.09 | 17.19 | 17.22 | 16.99 | -| 384 | 16 | 11.00 | 11.08 | 10.85 | 23.38 | 23.38 | 22.90 | -| 384 | 24 | 16.79 | 16.89 | 16.60 | 37.90 | 38.29 | 37.18 | -| 384 | 32 | 23.08 | 23.31 | 22.74 | 50.70 | 50.94 | 50.27 | -| 384 | 64 | 49.43 | 49.86 | 48.56 | 103.88 | 104.19 | 102.89 | -| 384 | 128 | 104.55 | 104.97 | 103.74 | 211.09 | 211.67 | 209.85 | - -##### BERT large +| 128 | 1 | 0.61 | 0.61 | 0.60 | 1.01 | 1.01 | 1.00 | +| 128 | 2 | 0.79 | 0.80 | 0.77 | 1.32 | 1.35 | 1.31 | +| 128 | 4 | 1.14 | 1.15 | 1.12 | 2.22 | 2.23 | 2.14 | +| 128 | 8 | 1.94 | 1.96 | 1.90 | 3.66 | 3.67 | 3.63 | +| 128 | 12 | 2.67 | 2.67 | 2.61 | 5.34 | 5.34 | 5.26 | +| 128 | 16 | 3.37 | 3.38 | 3.32 | 6.69 | 6.69 | 6.64 | +| 128 | 24 | 4.84 | 4.84 | 4.75 | 10.53 | 10.64 | 10.50 | +| 128 | 32 | 6.21 | 6.28 | 6.13 | 13.91 | 13.91 | 13.72 | +| 128 | 64 | 13.40 | 13.60 | 13.20 | 31.48 | 31.53 | 31.01 | +| 128 | 128 | 28.42 | 28.68 | 27.84 | 70.60 | 71.10 | 69.25 | +| 384 | 1 | 1.27 | 1.27 | 1.27 | 2.08 | 2.09 | 2.07 | +| 384 | 2 | 1.84 | 1.84 | 1.82 | 3.15 | 3.19 | 3.11 | +| 384 | 4 | 2.94 | 2.94 | 2.91 | 5.68 | 5.75 | 5.63 | +| 384 | 8 | 5.53 | 5.55 | 5.42 | 11.45 | 11.59 | 11.32 | +| 384 | 12 | 8.21 | 8.31 | 8.07 | 17.16 | 17.36 | 17.00 | +| 384 | 16 | 10.96 | 11.07 | 10.80 | 23.20 | 23.50 | 22.81 | +| 384 | 24 | 16.71 | 16.74 | 16.55 | 39.82 | 40.46 | 38.15 | +| 384 | 32 | 22.82 | 23.00 | 22.63 | 50.56 | 50.89 | 50.14 | +| 384 | 64 | 49.66 | 50.18 | 48.40 | 104.90 | 105.55 | 103.81 | +| 384 | 128 | 104.78 | 105.09 | 103.96 | 208.20 | 208.70 | 206.93 | + +##### BERT Large | Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 1.78 | 1.79 | 1.76 | 3.11 | 3.11 | 3.10 | -| 128 | 2 | 2.50 | 2.51 | 2.44 | 4.35 | 4.45 | 4.31 | -| 128 | 4 | 3.60 | 3.63 | 3.54 | 6.83 | 6.86 | 6.77 | -| 128 | 8 | 6.27 | 6.31 | 6.25 | 12.98 | 13.01 | 12.80 | -| 128 | 12 | 8.40 | 8.41 | 8.27 | 18.45 | 18.66 | 18.22 | -| 128 | 16 | 11.22 | 11.23 | 11.12 | 25.18 | 25.19 | 25.14 | -| 128 | 24 | 15.95 | 16.10 | 15.82 | 35.67 | 35.68 | 35.59 | -| 128 | 32 | 21.30 | 21.35 | 20.90 | 49.02 | 49.26 | 48.33 | -| 128 | 64 | 44.08 | 44.32 | 43.93 | 107.89 | 108.30 | 107.11 | -| 128 | 128 | 93.69 | 94.36 | 92.69 | 215.00 | 215.46 | 213.84 | -| 384 | 1 | 3.43 | 3.44 | 3.41 | 6.58 | 6.66 | 6.40 | -| 384 | 2 | 5.55 | 5.55 | 5.49 | 10.56 | 10.59 | 10.44 | -| 384 | 4 | 9.80 | 9.88 | 9.58 | 20.55 | 20.94 | 19.93 | -| 384 | 8 | 18.04 | 18.11 | 17.86 | 38.87 | 39.47 | 37.69 | -| 384 | 12 | 26.44 | 26.61 | 26.14 | 59.28 | 59.85 | 56.90 | -| 384 | 16 | 36.37 | 36.48 | 36.04 | 82.93 | 83.33 | 81.95 | -| 384 | 24 | 53.60 | 53.73 | 53.15 | 122.78 | 123.06 | 122.05 | -| 384 | 32 | 75.52 | 75.84 | 74.45 | 164.55 | 164.98 | 163.68 | -| 384 | 64 | 157.71 | 158.27 | 155.68 | 345.90 | 346.53 | 344.57 | -| 384 | 128 | 331.37 | 332.44 | 329.06 | 663.75 | 664.69 | 661.89 | +| 128 | 1 | 1.79 | 1.80 | 1.77 | 3.11 | 3.11 | 3.09 | +| 128 | 2 | 2.49 | 2.49 | 2.43 | 4.35 | 4.37 | 4.33 | +| 128 | 4 | 3.62 | 3.70 | 3.60 | 6.86 | 6.89 | 6.78 | +| 128 | 8 | 6.26 | 6.31 | 6.24 | 12.85 | 12.91 | 12.73 | +| 128 | 12 | 8.40 | 8.41 | 8.28 | 18.42 | 18.43 | 18.33 | +| 128 | 16 | 11.23 | 11.24 | 11.12 | 25.18 | 25.19 | 25.10 | +| 128 | 24 | 15.95 | 16.09 | 15.90 | 35.67 | 35.67 | 35.47 | +| 128 | 32 | 21.26 | 21.31 | 20.91 | 48.92 | 49.21 | 48.26 | +| 128 | 64 | 44.10 | 44.11 | 43.92 | 108.81 | 109.12 | 107.18 | +| 128 | 128 | 94.22 | 95.02 | 92.65 | 217.32 | 219.58 | 212.68 | +| 384 | 1 | 3.41 | 3.43 | 3.39 | 6.55 | 6.57 | 6.36 | +| 384 | 2 | 5.55 | 5.56 | 5.46 | 10.34 | 10.35 | 10.18 | +| 384 | 4 | 9.69 | 9.79 | 9.53 | 20.66 | 20.95 | 19.94 | +| 384 | 8 | 18.08 | 18.19 | 17.92 | 38.41 | 39.30 | 37.62 | +| 384 | 12 | 26.20 | 26.44 | 26.11 | 60.38 | 60.91 | 58.67 | +| 384 | 16 | 36.33 | 36.41 | 36.02 | 81.66 | 82.16 | 80.52 | +| 384 | 24 | 53.54 | 53.61 | 53.08 | 123.01 | 123.34 | 122.10 | +| 384 | 32 | 75.01 | 75.43 | 74.40 | 170.40 | 171.03 | 169.12 | +| 384 | 64 | 157.97 | 158.62 | 155.87 | 349.25 | 351.53 | 344.76 | +| 384 | 128 | 330.88 | 331.87 | 328.27 | 632.85 | 633.88 | 629.74 | ##### Megatron Large with Sparsity | Sequence Length | Batch Size | INT8 QAT Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 1.78 | 1.79 | 1.76 | -| 128 | 2 | 2.50 | 2.51 | 2.44 | -| 128 | 4 | 3.56 | 3.57 | 3.54 | -| 128 | 8 | 6.27 | 6.31 | 6.26 | -| 128 | 12 | 8.40 | 8.41 | 8.29 | -| 128 | 16 | 11.23 | 11.23 | 11.16 | -| 128 | 24 | 16.06 | 16.12 | 15.90 | -| 128 | 32 | 21.31 | 21.34 | 20.98 | -| 128 | 64 | 44.15 | 44.66 | 43.88 | -| 128 | 128 | 94.19 | 94.93 | 92.81 | -| 384 | 1 | 3.39 | 3.43 | 3.37 | -| 384 | 2 | 5.56 | 5.56 | 5.48 | -| 384 | 4 | 9.81 | 9.90 | 9.61 | -| 384 | 8 | 18.07 | 18.25 | 17.94 | -| 384 | 12 | 26.47 | 26.57 | 26.27 | -| 384 | 16 | 36.78 | 37.14 | 36.37 | -| 384 | 24 | 54.16 | 54.53 | 53.65 | -| 384 | 32 | 75.33 | 75.62 | 74.69 | -| 384 | 64 | 158.72 | 159.55 | 156.72 | -| 384 | 128 | 333.24 | 334.26 | 330.67 | +| 128 | 1 | 1.49 | 1.49 | 1.48 | +| 128 | 2 | 2.03 | 2.03 | 1.99 | +| 128 | 4 | 2.99 | 3.00 | 2.93 | +| 128 | 8 | 5.00 | 5.07 | 4.99 | +| 128 | 12 | 6.69 | 6.72 | 6.58 | +| 128 | 16 | 8.77 | 8.84 | 8.66 | +| 128 | 24 | 13.28 | 13.30 | 13.14 | +| 128 | 32 | 17.41 | 17.44 | 17.26 | +| 128 | 64 | 35.73 | 36.07 | 35.49 | +| 128 | 128 | 79.03 | 79.15 | 78.47 | +| 384 | 1 | 2.78 | 2.79 | 2.72 | +| 384 | 2 | 4.10 | 4.12 | 4.06 | +| 384 | 4 | 7.57 | 7.58 | 7.45 | +| 384 | 8 | 15.03 | 15.10 | 14.86 | +| 384 | 12 | 21.52 | 21.69 | 21.31 | +| 384 | 16 | 28.29 | 28.33 | 28.10 | +| 384 | 24 | 46.83 | 47.09 | 46.29 | +| 384 | 32 | 60.29 | 60.47 | 59.37 | +| 384 | 64 | 125.58 | 125.64 | 125.24 | +| 384 | 128 | 253.46 | 253.90 | 252.28 | ### Inference Performance NVIDIA L40S Results were obtained by running `scripts/inference_benchmark.sh --gpu Ampere` on NVIDIA L40S. -##### BERT base +##### BERT Base | Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 0.34 | 0.34 | 0.34 | 0.48 | 0.48 | 0.48 | -| 128 | 2 | 0.41 | 0.41 | 0.41 | 0.56 | 0.56 | 0.55 | -| 128 | 4 | 0.50 | 0.50 | 0.50 | 0.77 | 0.77 | 0.77 | -| 128 | 8 | 0.67 | 0.67 | 0.67 | 1.30 | 1.30 | 1.29 | -| 128 | 12 | 0.91 | 0.91 | 0.91 | 1.68 | 1.68 | 1.67 | -| 128 | 16 | 1.09 | 1.10 | 1.09 | 2.22 | 2.23 | 2.22 | -| 128 | 24 | 1.50 | 1.50 | 1.48 | 3.23 | 3.24 | 3.20 | -| 128 | 32 | 1.82 | 1.83 | 1.82 | 3.94 | 3.94 | 3.93 | -| 128 | 64 | 3.47 | 3.47 | 3.45 | 8.24 | 8.26 | 8.14 | -| 128 | 128 | 7.74 | 7.91 | 7.66 | 17.73 | 17.86 | 17.56 | -| 384 | 1 | 0.73 | 0.73 | 0.73 | 1.02 | 1.02 | 1.02 | -| 384 | 2 | 0.88 | 0.89 | 0.88 | 1.38 | 1.38 | 1.37 | -| 384 | 4 | 1.17 | 1.17 | 1.16 | 2.16 | 2.17 | 2.15 | -| 384 | 8 | 1.72 | 1.73 | 1.72 | 3.45 | 3.46 | 3.45 | -| 384 | 12 | 2.73 | 2.73 | 2.72 | 5.07 | 5.07 | 5.05 | -| 384 | 16 | 3.28 | 3.28 | 3.27 | 7.41 | 7.44 | 7.37 | -| 384 | 24 | 4.93 | 4.94 | 4.90 | 10.16 | 10.19 | 10.09 | -| 384 | 32 | 6.33 | 6.34 | 6.29 | 14.07 | 14.11 | 13.96 | -| 384 | 64 | 13.74 | 13.76 | 13.57 | 30.65 | 30.82 | 30.14 | -| 384 | 128 | 28.25 | 28.41 | 27.87 | 62.48 | 62.67 | 61.70 | - -##### BERT large +| 128 | 1 | 0.33 | 0.33 | 0.33 | 0.48 | 0.48 | 0.48 | +| 128 | 2 | 0.41 | 0.41 | 0.41 | 0.57 | 0.57 | 0.57 | +| 128 | 4 | 0.50 | 0.51 | 0.50 | 0.78 | 0.78 | 0.78 | +| 128 | 8 | 0.67 | 0.67 | 0.67 | 1.33 | 1.33 | 1.32 | +| 128 | 12 | 0.91 | 0.91 | 0.91 | 1.75 | 1.76 | 1.73 | +| 128 | 16 | 1.10 | 1.10 | 1.09 | 2.29 | 2.29 | 2.28 | +| 128 | 24 | 1.48 | 1.49 | 1.47 | 3.30 | 3.31 | 3.27 | +| 128 | 32 | 1.84 | 1.84 | 1.83 | 3.98 | 3.99 | 3.97 | +| 128 | 64 | 3.61 | 3.66 | 3.56 | 8.64 | 8.70 | 8.51 | +| 128 | 128 | 7.92 | 7.99 | 7.82 | 18.78 | 18.82 | 18.45 | +| 384 | 1 | 0.73 | 0.73 | 0.73 | 1.11 | 1.12 | 1.10 | +| 384 | 2 | 0.88 | 0.88 | 0.88 | 1.39 | 1.39 | 1.38 | +| 384 | 4 | 1.17 | 1.17 | 1.17 | 2.19 | 2.20 | 2.19 | +| 384 | 8 | 1.74 | 1.74 | 1.73 | 3.53 | 3.53 | 3.50 | +| 384 | 12 | 2.75 | 2.75 | 2.73 | 5.32 | 5.33 | 5.29 | +| 384 | 16 | 3.33 | 3.33 | 3.31 | 7.62 | 7.64 | 7.57 | +| 384 | 24 | 4.97 | 4.98 | 4.95 | 10.53 | 10.57 | 10.40 | +| 384 | 32 | 6.55 | 6.57 | 6.48 | 14.36 | 14.47 | 14.20 | +| 384 | 64 | 14.27 | 14.37 | 14.07 | 33.31 | 33.51 | 32.65 | +| 384 | 128 | 30.38 | 30.52 | 29.73 | 67.34 | 68.04 | 66.06 | + +##### BERT Large | Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 0.89 | 0.89 | 0.88 | 1.30 | 1.30 | 1.30 | -| 128 | 2 | 0.98 | 0.98 | 0.98 | 1.47 | 1.47 | 1.47 | -| 128 | 4 | 1.34 | 1.35 | 1.33 | 2.29 | 2.29 | 2.28 | -| 128 | 8 | 1.92 | 1.92 | 1.91 | 3.82 | 3.83 | 3.79 | -| 128 | 12 | 2.77 | 2.78 | 2.75 | 5.73 | 5.73 | 5.71 | -| 128 | 16 | 3.22 | 3.22 | 3.19 | 6.72 | 6.73 | 6.67 | -| 128 | 24 | 4.54 | 4.55 | 4.52 | 10.38 | 10.39 | 10.31 | -| 128 | 32 | 5.67 | 5.68 | 5.63 | 12.87 | 12.90 | 12.74 | -| 128 | 64 | 11.92 | 11.96 | 11.77 | 28.21 | 28.40 | 27.89 | -| 128 | 128 | 25.44 | 25.49 | 25.12 | 61.85 | 62.01 | 61.20 | -| 384 | 1 | 1.68 | 1.68 | 1.68 | 2.74 | 2.75 | 2.74 | -| 384 | 2 | 2.32 | 2.32 | 2.30 | 3.87 | 3.87 | 3.85 | -| 384 | 4 | 3.27 | 3.28 | 3.27 | 6.32 | 6.35 | 6.29 | -| 384 | 8 | 5.09 | 5.09 | 5.06 | 10.76 | 10.77 | 10.60 | -| 384 | 12 | 8.06 | 8.07 | 8.02 | 18.73 | 18.77 | 18.64 | -| 384 | 16 | 9.70 | 9.75 | 9.61 | 22.15 | 22.26 | 21.95 | -| 384 | 24 | 15.02 | 15.04 | 14.88 | 35.43 | 35.48 | 35.15 | -| 384 | 32 | 20.37 | 20.49 | 20.00 | 46.36 | 46.37 | 45.86 | -| 384 | 64 | 43.50 | 43.65 | 43.02 | 105.84 | 106.08 | 104.92 | -| 384 | 128 | 86.30 | 86.48 | 85.58 | 195.18 | 195.98 | 192.90 | +| 128 | 1 | 0.89 | 0.89 | 0.88 | 1.30 | 1.30 | 1.29 | +| 128 | 2 | 0.97 | 0.98 | 0.97 | 1.45 | 1.45 | 1.44 | +| 128 | 4 | 1.36 | 1.36 | 1.35 | 2.30 | 2.30 | 2.29 | +| 128 | 8 | 1.94 | 1.96 | 1.93 | 3.89 | 3.90 | 3.88 | +| 128 | 12 | 2.82 | 2.82 | 2.80 | 5.89 | 5.90 | 5.85 | +| 128 | 16 | 3.26 | 3.27 | 3.24 | 6.85 | 6.86 | 6.80 | +| 128 | 24 | 4.62 | 4.63 | 4.59 | 10.72 | 10.73 | 10.64 | +| 128 | 32 | 5.74 | 5.76 | 5.70 | 13.22 | 13.23 | 13.04 | +| 128 | 64 | 12.18 | 12.20 | 11.97 | 29.42 | 29.59 | 28.89 | +| 128 | 128 | 26.68 | 26.86 | 26.23 | 68.72 | 69.05 | 67.12 | +| 384 | 1 | 1.68 | 1.68 | 1.68 | 2.78 | 2.78 | 2.77 | +| 384 | 2 | 2.31 | 2.31 | 2.30 | 3.95 | 3.95 | 3.94 | +| 384 | 4 | 3.29 | 3.30 | 3.29 | 6.57 | 6.58 | 6.50 | +| 384 | 8 | 5.16 | 5.17 | 5.13 | 10.89 | 10.90 | 10.79 | +| 384 | 12 | 8.16 | 8.17 | 8.10 | 19.81 | 19.91 | 19.31 | +| 384 | 16 | 9.90 | 9.93 | 9.80 | 23.34 | 23.51 | 23.10 | +| 384 | 24 | 15.60 | 15.62 | 15.39 | 37.37 | 37.48 | 36.93 | +| 384 | 32 | 20.66 | 20.73 | 20.33 | 50.13 | 50.34 | 49.52 | +| 384 | 64 | 46.31 | 46.53 | 45.39 | 111.74 | 111.98 | 110.14 | +| 384 | 128 | 93.80 | 94.04 | 92.33 | 213.05 | 214.15 | 210.25 | ##### Megatron Large with Sparsity | Sequence Length | Batch Size | INT8 QAT Latency (ms) | | | |-----------------|------------|-----------------|-----------------|---------| | | | 95th Percentile | 99th Percentile | Average | -| 128 | 1 | 0.89 | 0.89 | 0.88 | -| 128 | 2 | 0.98 | 0.98 | 0.98 | -| 128 | 4 | 1.34 | 1.36 | 1.33 | -| 128 | 8 | 1.93 | 1.95 | 1.91 | -| 128 | 12 | 2.79 | 2.82 | 2.77 | -| 128 | 16 | 3.24 | 3.24 | 3.22 | -| 128 | 24 | 4.59 | 4.59 | 4.57 | -| 128 | 32 | 5.68 | 5.68 | 5.65 | -| 128 | 64 | 11.81 | 11.87 | 11.71 | -| 128 | 128 | 26.21 | 26.24 | 25.86 | -| 384 | 1 | 1.68 | 1.68 | 1.68 | -| 384 | 2 | 2.31 | 2.32 | 2.31 | -| 384 | 4 | 3.29 | 3.29 | 3.28 | -| 384 | 8 | 5.14 | 5.15 | 5.10 | -| 384 | 12 | 8.05 | 8.06 | 8.01 | -| 384 | 16 | 9.78 | 9.80 | 9.66 | -| 384 | 24 | 15.14 | 15.15 | 15.01 | -| 384 | 32 | 20.34 | 20.42 | 19.99 | -| 384 | 64 | 43.81 | 43.97 | 43.39 | -| 384 | 128 | 88.37 | 88.64 | 87.38 | \ No newline at end of file +| 128 | 1 | 0.76 | 0.76 | 0.76 | +| 128 | 2 | 0.91 | 0.91 | 0.91 | +| 128 | 4 | 1.13 | 1.13 | 1.13 | +| 128 | 8 | 1.70 | 1.70 | 1.70 | +| 128 | 12 | 2.26 | 2.26 | 2.25 | +| 128 | 16 | 2.72 | 2.72 | 2.71 | +| 128 | 24 | 4.54 | 4.55 | 4.52 | +| 128 | 32 | 5.14 | 5.16 | 5.10 | +| 128 | 64 | 10.07 | 10.08 | 10.01 | +| 128 | 128 | 21.57 | 21.67 | 21.21 | +| 384 | 1 | 1.13 | 1.13 | 1.13 | +| 384 | 2 | 1.64 | 1.65 | 1.62 | +| 384 | 4 | 2.51 | 2.51 | 2.50 | +| 384 | 8 | 5.02 | 5.03 | 4.99 | +| 384 | 12 | 6.43 | 6.43 | 6.41 | +| 384 | 16 | 8.47 | 8.49 | 8.41 | +| 384 | 24 | 12.62 | 12.65 | 12.54 | +| 384 | 32 | 16.88 | 16.91 | 16.74 | +| 384 | 64 | 36.62 | 36.71 | 36.12 | +| 384 | 128 | 79.88 | 80.18 | 77.33 | + diff --git a/demo/Diffusion/README.md b/demo/Diffusion/README.md index dd9aa50d..778767a5 100755 --- a/demo/Diffusion/README.md +++ b/demo/Diffusion/README.md @@ -7,7 +7,7 @@ This demo application ("demoDiffusion") showcases the acceleration of Stable Dif ### Clone the TensorRT OSS repository ```bash -git clone git@github.com:NVIDIA/TensorRT.git -b release/10.4 --single-branch +git clone git@github.com:NVIDIA/TensorRT.git -b release/10.5 --single-branch cd TensorRT ``` @@ -16,7 +16,7 @@ cd TensorRT Install nvidia-docker using [these intructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker). ```bash -docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:24.05-py3 /bin/bash +docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:24.07-py3 /bin/bash ``` NOTE: The demo supports CUDA>=11.8 @@ -43,19 +43,18 @@ pip3 install -r requirements.txt > NOTE: demoDiffusion has been tested on systems with NVIDIA H100, A100, L40, T4, and RTX4090 GPUs, and the following software configuration. ``` -diffusers 0.29.2 +diffusers 0.30.2 onnx 1.15.0 onnx-graphsurgeon 0.5.2 onnxruntime 1.16.3 polygraphy 0.49.9 -tensorrt 10.4.0.26 +tensorrt 10.5.0.18 tokenizers 0.13.3 torch 2.2.0 -transformers 4.33.1 +transformers 4.42.2 controlnet-aux 0.0.6 nvidia-modelopt 0.15.1 ``` -> NOTE: optionally install HuggingFace [accelerate](https://pypi.org/project/accelerate/) package for faster and less memory-intense model loading. Note that installing accelerate is known to cause failures while running certain pipelines in Torch Compile mode ([known issue](https://github.com/huggingface/diffusers/issues/9091)) # Running demoDiffusion @@ -72,7 +71,6 @@ python3 demo_txt2img_xl.py --help ### HuggingFace user access token To download model checkpoints for the Stable Diffusion pipelines, obtain a `read` access token to HuggingFace Hub. See [instructions](https://huggingface.co/docs/hub/security-tokens). -> NOTE: This step isn't required for many models now. ```bash export HF_TOKEN= @@ -158,7 +156,7 @@ python3 demo_txt2img_xl.py "Picture of a rustic Italian village with Olive trees Run the below command to generate an image with Stable Diffusion XL in INT8 ```bash -python3 demo_txt2img_xl.py "a photo of an astronaut riding a horse on mars" --version xl-1.0 --onnx-dir onnx-sdxl --engine-dir engine-sdxl --int8 +python3 demo_txt2img_xl.py "a photo of an astronaut riding a horse on mars" --version xl-1.0 --onnx-dir onnx-sdxl --engine-dir engine-sdxl --int8 ``` Run the below command to generate an image with Stable Diffusion XL in FP8. (FP8 is only supppoted on Hopper.) @@ -202,7 +200,7 @@ python3 demo_txt2img_sd3.py "dog wearing a sweater and a blue collar" --version Note that a denosing-percentage is applied to the number of denoising-steps when an input image conditioning is provided. Its default value is set to 0.6. This parameter can be updated using `--denoising-percentage` -### Image-to-video using SVD (Stable Video Diffusion) +### Generate a video guided by an initial image using Stable Video Diffusion Download the pre-exported ONNX model @@ -224,7 +222,7 @@ python3 demo_img2vid.py --version svd-xt-1.1 --onnx-dir onnx-svd-xt-1-1 --engine NOTE: The min and max guidance scales are configured using --min-guidance-scale and --max-guidance-scale respectively. -### Generate an image using Stable Cascade guided by a text prompt +### Generate an image guided by a text prompt using Stable Cascade Run the below command to generate an image using Stable Cascade ```bash @@ -242,11 +240,22 @@ python3 demo_stable_cascade.py --onnx-opset=16 "Anthropomorphic cat dressed as a > NOTE: The denoising steps and guidance scale for the Prior and Decoder models are configured using --prior-denoising-steps, --prior-guidance-scale, --decoder-denoising-steps, and --decoder-guidance-scale respectively. +### Generate an image guided by a text prompt using Flux + +```bash +python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN +``` + +NOTE: Running the Flux pipeline requires 80GB of GPU memory or higher + ## Configuration options - Noise scheduler can be set using `--scheduler `. Note: not all schedulers are available for every version. - To accelerate engine building time use `--timing-cache `. The cache file will be created if it does not already exist. Note that performance may degrade if cache files are used across multiple GPU targets. It is recommended to use timing caches only during development. To achieve the best perfromance in deployment, please build engines without timing cache. - Specify new directories for storing onnx and engine files when switching between versions, LoRAs, ControlNets, etc. This can be done using `--onnx-dir ` and `--engine-dir `. - Inference performance can be improved by enabling [CUDA graphs](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs) using `--use-cuda-graph`. Enabling CUDA graphs requires fixed input shapes, so this flag must be combined with `--build-static-batch` and cannot be combined with `--build-dynamic-shape`. +## Known Issues +- LoRA adapter functionality is compatible with diffusers version 0.26.3. To run the LoRA pipeline, we recommend installing this specific version. However, the Stable Cascade pipeline requires diffusers version 0.29.2 or higher and will not be compatible if diffusers is downgraded. + diff --git a/demo/Diffusion/demo_img2img.py b/demo/Diffusion/demo_img2img.py old mode 100755 new mode 100644 diff --git a/demo/Diffusion/demo_inpaint.py b/demo/Diffusion/demo_inpaint.py old mode 100755 new mode 100644 diff --git a/demo/Diffusion/demo_txt2img_flux.py b/demo/Diffusion/demo_txt2img_flux.py new file mode 100644 index 00000000..067c2c8f --- /dev/null +++ b/demo/Diffusion/demo_txt2img_flux.py @@ -0,0 +1,137 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from cuda import cudart + +from flux_pipeline import FluxPipeline +from utilities import PIPELINE_TYPE, add_arguments, process_pipeline_args + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Options for Flux Txt2Img Demo", conflict_handler="resolve" + ) + parser = add_arguments(parser) + parser.add_argument( + "--version", + type=str, + default="flux.1-dev", + choices=["flux.1-dev"], + help="Version of Flux", + ) + parser.add_argument( + "--prompt2", + default=None, + nargs="*", + help="Text prompt(s) to be sent to the T5 tokenizer and text encoder. If not defined, prompt will be used instead", + ) + parser.add_argument( + "--height", + type=int, + default=1024, + help="Height of image to generate (must be multiple of 8)", + ) + parser.add_argument( + "--width", + type=int, + default=1024, + help="Width of image to generate (must be multiple of 8)", + ) + parser.add_argument( + "--denoising-steps", type=int, default=50, help="Number of denoising steps" + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=3.5, + help="Value of classifier-free guidance scale (must be greater than 1)", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with the prompt", + ) + return parser.parse_args() + + +def process_demo_args(args): + batch_size = args.batch_size + prompt = args.prompt + # If prompt2 is not defined, use prompt instead + prompt2 = args.prompt2 or prompt + + # Process input args + if not isinstance(prompt, list): + raise ValueError(f"`prompt` must be of type `str` list, but is {type(prompt)}") + prompt = prompt * batch_size + + if not isinstance(prompt2, list): + raise ValueError( + f"`prompt2` must be of type `str` list, but is {type(prompt2)}" + ) + if len(prompt2) == 1: + prompt2 = prompt2 * batch_size + + if args.max_sequence_length is not None and args.max_sequence_length > 512: + raise ValueError( + f"`max_sequence_length` cannot be greater than 512 but is {args.max_sequence_length}" + ) + + args_run_demo = ( + prompt, + prompt2, + args.height, + args.width, + args.batch_count, + args.num_warmup_runs, + args.use_cuda_graph, + ) + + return args_run_demo + + +if __name__ == "__main__": + print("[I] Initializing Flux txt2img demo using TensorRT") + args = parse_args() + + kwargs_init_pipeline, kwargs_load_engine, _ = process_pipeline_args(args) + args_run_demo = process_demo_args(args) + + # Initialize demo + demo = FluxPipeline( + pipeline_type=PIPELINE_TYPE.TXT2IMG, + max_sequence_length=args.max_sequence_length, + **kwargs_init_pipeline, + ) + + # Load TensorRT engines and pytorch modules + demo.load_engines( + args.engine_dir, args.framework_model_dir, args.onnx_dir, **kwargs_load_engine + ) + + # Load resources + _, shared_device_memory = cudart.cudaMalloc(demo.calculate_max_device_memory()) + demo.activate_engines(shared_device_memory) + demo.load_resources(args.height, args.width, args.batch_size, args.seed) + + # Run inference + demo.run(*args_run_demo) + + demo.teardown() diff --git a/demo/Diffusion/diffusion_pipeline.py b/demo/Diffusion/diffusion_pipeline.py new file mode 100644 index 00000000..02c7a631 --- /dev/null +++ b/demo/Diffusion/diffusion_pipeline.py @@ -0,0 +1,648 @@ + +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from cuda import cudart +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + LCMScheduler, LMSDiscreteScheduler, + PNDMScheduler, + UniPCMultistepScheduler, + DDPMWuerstchenScheduler, + FlowMatchEulerDiscreteScheduler +) +from hashlib import md5 +from models import make_scheduler, LoraLoader +import nvtx +import json +import os +import pathlib +import tensorrt as trt +import torch +from utilities import ( + PIPELINE_TYPE, + Engine, + get_refit_weights, + load_calib_prompts, + merge_loras, + unload_model, + save_image +) +from typing import Optional, List +from utils_modelopt import ( + filter_func, + quantize_lvl, + get_int8_config, + check_lora, + set_fmha, + generate_fp8_scales, + SD_FP8_FP16_DEFAULT_CONFIG, + SD_FP8_FP32_DEFAULT_CONFIG, +) + +class DiffusionPipeline(ABC): + """ + Application showcasing the acceleration of Stable Diffusion pipelines using NVidia TensorRT. + """ + VALID_DIFFUSION_PIPELINES = ( + "1.4", + "1.5", + "dreamshaper-7", + "2.0-base", + "2.0", + "2.1-base", + "2.1", + "xl-1.0", + "xl-turbo", + "svd-xt-1.1", + "sd3", + "cascade", + "flux.1-dev" + ) + SCHEDULER_DEFAULTS = { + "1.4": "PNDM", + "1.5": "PNDM", + "dreamshaper-7": "PNDM", + "2.0-base": "DDIM", + "2.0": "DDIM", + "2.1-base": "PNDM", + "2.1": "DDIM", + "xl-1.0" : "Euler", + "xl-turbo": "EulerA", + "svd-xt-1.1": "Euler", + "cascade": "DDPMWuerstchen", + "flux.1-dev": "FlowMatchEuler" + } + + def __init__( + self, + version='1.5', + pipeline_type=PIPELINE_TYPE.TXT2IMG, + max_batch_size=16, + denoising_steps=30, + scheduler=None, + lora_scale: Optional[List[int]] = None, + lora_path: Optional[List[str]] = None, + device='cuda', + output_dir='.', + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + framework_model_dir='pytorch_model', + return_latents=False, + torch_inference='', + ): + """ + Initializes the Diffusion pipeline. + + Args: + version (str): + The version of the pipeline. Should be one of the values listed in DiffusionPipeline.VALID_DIFFUSION_PIPELINES. + pipeline_type (PIPELINE_TYPE): + Task performed by the current pipeline. Should be one of PIPELINE_TYPE.__members__. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + denoising_steps (int): + The number of denoising steps. + More denoising steps usually lead to a higher quality image at the expense of slower inference. + scheduler (str): + The scheduler to guide the denoising process. Must be one of the values listed in DiffusionPipeline.SCHEDULER_DEFAULTS.values(). + lora_scale (float): + Scale of LoRA weights, default 1 (must between 0 and 1). + lora_path (str): + Path to LoRA adaptor. Ex: 'latent-consistency/lcm-lora-sdv1-5'. + device (str): + PyTorch device to run inference. Default: 'cuda'. + output_dir (str): + Output directory for log files and image artifacts. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference. + framework_model_dir (str): + cache directory for framework checkpoints. + return_latents (bool): + Skip decoding the image and return latents instead. + torch_inference (str): + Run inference with PyTorch (using specified compilation mode) instead of TensorRT. + """ + self.denoising_steps = denoising_steps + self.max_batch_size = max_batch_size + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + self.version = version + self.pipeline_type = pipeline_type + self.return_latents = return_latents + + if not scheduler: + scheduler = 'UniPC' if self.pipeline_type.is_controlnet() else self.SCHEDULER_DEFAULTS.get(version, 'DDIM') + print(f"[I] Autoselected scheduler: {scheduler}") + + scheduler_class_map = { + "DDIM" : DDIMScheduler, + "DDPM" : DDPMScheduler, + "EulerA" : EulerAncestralDiscreteScheduler, + "Euler" : EulerDiscreteScheduler, + "LCM" : LCMScheduler, + "LMSD" : LMSDiscreteScheduler, + "PNDM" : PNDMScheduler, + "UniPC" : UniPCMultistepScheduler, + "DDPMWuerstchen" : DDPMWuerstchenScheduler, + "FlowMatchEuler": FlowMatchEulerDiscreteScheduler + } + try: + scheduler_class = scheduler_class_map[scheduler] + except KeyError: + raise ValueError(f"Unsupported scheduler {scheduler}. Should be one of {list(scheduler_class.keys())}.") + self.scheduler = make_scheduler(scheduler_class, version, pipeline_type, hf_token, framework_model_dir) + + self.torch_inference = torch_inference + if self.torch_inference: + torch._inductor.config.conv_1x1_as_mm = True + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.coordinate_descent_check_all_directions = True + self.use_cuda_graph = use_cuda_graph + + # initialized in load_engines() + self.models = {} + self.torch_models = {} + self.engine = {} + self.shared_device_memory = None + + # initialize lora loader and scales + self.lora_loader = None + self.lora_scales = dict() + if lora_path: + self.lora_loader = LoraLoader(lora_path) + assert len(lora_path) == len(lora_scale) + for i, path in enumerate(lora_path): + self.lora_scales[path] = lora_scale[i] + + # initialized in load_resources() + self.events = {} + self.generator = None + self.markers = {} + self.seed = None + self.stream = None + self.tokenizer = None + + # config to store additional info + self.config = {} + + def profile_start(self, name, color='blue'): + if self.nvtx_profile: + self.markers[name] = nvtx.start_range(message=name, color=color) + if name in self.events: + cudart.cudaEventRecord(self.events[name][0], 0) + + def profile_stop(self, name): + if name in self.events: + cudart.cudaEventRecord(self.events[name][1], 0) + if self.nvtx_profile: + nvtx.end_range(self.markers[name]) + + def load_resources(self, image_height, image_width, batch_size, seed): + # Initialize noise generator + if seed is not None: + self.seed = seed + self.generator = torch.Generator(device="cuda").manual_seed(seed) + + # Create CUDA events and stream + for stage in self.stages: + self.events[stage] = [cudart.cudaEventCreate()[1], cudart.cudaEventCreate()[1]] + self.stream = cudart.cudaStreamCreate()[1] + + # Allocate TensorRT I/O buffers + if not self.torch_inference: + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + self.engine[model_name].allocate_buffers(shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.device) + + def _create_directories(self, engine_dir, onnx_dir): + # Create directories if missing + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + def _cached_model_name(self, model_name): + if self.pipeline_type.is_inpaint(): + model_name += '_inpaint' + return model_name + + def _get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=''): + onnx_model_dir = os.path.join(onnx_dir, self._cached_model_name(model_name)+suffix+('.opt' if opt else '')) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, 'model.onnx') + + def _get_engine_path(self, model_name, engine_dir, enable_refit=False, suffix=''): + return os.path.join(engine_dir, self._cached_model_name(model_name)+suffix+('.refit' if enable_refit else '')+'.trt'+trt.__version__+'.plan') + + def _get_weights_map_path(self, model_name, onnx_dir): + onnx_model_dir = os.path.join(onnx_dir, self._cached_model_name(model_name)+'.opt') + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, 'weights_map.json') + + def _get_refit_nodes_path(self, model_name, onnx_dir, suffix=''): + onnx_model_dir = os.path.join(onnx_dir, self._cached_model_name(model_name)+'.opt') + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, 'refit'+suffix+'.json') + + def _get_state_dict_path(self, model_name, onnx_dir, suffix=''): + onnx_model_dir = os.path.join(onnx_dir, self._cached_model_name(model_name)+suffix) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, 'state_dict.pt') + + @abstractmethod + def _initialize_models(self): + raise NotImplementedError("Please Implement the _initialize_models method") + + def _get_lora_suffix(self): + if self.lora_loader: + return '-' + '-'.join([str(md5(path.encode('utf-8')).hexdigest()) + '-' + ('%.2f' % self.lora_scales[path]) for path in sorted(self.lora_loader.paths)]) + return '' + + def _prepare_model_configs(self, onnx_dir, engine_dir, enable_refit, int8, fp8, quantization_level, quantization_percentile, quantization_alpha, calibration_size): + model_names = self.models.keys() + lora_suffix = self._get_lora_suffix() + self.torch_fallback = dict(zip(model_names, [self.torch_inference or self.config.get(model_name.replace('-','_')+'_torch_fallback', False) for model_name in model_names])) + + configs = {} + for model_name in model_names: + config = { + 'do_engine_refit': not self.pipeline_type.is_sd_xl_refiner() and enable_refit and model_name.startswith('unet'), + 'do_lora_merge': not enable_refit and self.lora_loader and model_name.startswith('unet'), + 'use_int8': False, + 'use_fp8': False, + } + config['model_suffix'] = lora_suffix if config['do_lora_merge'] else '' + + if int8: + assert self.pipeline_type.is_sd_xl_base() or self.version in ["1.5", "2.1", "2.1-base"], "int8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline" + if model_name == ('unetxl' if self.pipeline_type.is_sd_xl() else 'unet'): + config['use_int8'] = True + config['model_suffix'] += f"-int8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" + elif fp8: + assert self.pipeline_type.is_sd_xl() or self.version in ["1.5", "2.1", "2.1-base"], "fp8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline" + if model_name == ('unetxl' if self.pipeline_type.is_sd_xl() else 'unet'): + config['use_fp8'] = True + config['model_suffix'] += f"-fp8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" + + config['onnx_path'] = self._get_onnx_path(model_name, onnx_dir, opt=False, suffix=config['model_suffix']) + config['onnx_opt_path'] = self._get_onnx_path(model_name, onnx_dir, suffix=config['model_suffix']) + config['engine_path'] = self._get_engine_path(model_name, engine_dir, config['do_engine_refit'], suffix=config['model_suffix']) + config['weights_map_path'] = self._get_weights_map_path(model_name, onnx_dir) if config['do_engine_refit'] else None + config['state_dict_path'] = self._get_state_dict_path(model_name, onnx_dir, suffix=config['model_suffix']) + config['refit_weights_path'] = self._get_refit_nodes_path(model_name, onnx_dir, suffix=lora_suffix) + + configs[model_name] = config + + return configs + + def _calibrate_and_save_model(self, pipeline, model, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size): + print(f"[I] Calibrated weights not found, generating {model_config['state_dict_path']}") + calibration_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'calibration-prompts.txt') + calibration_prompts = load_calib_prompts(calib_batch_size, calibration_file) + + # TODO check size > calibration_size + def do_calibrate(pipeline, calibration_prompts, **kwargs): + for i_th, prompts in enumerate(calibration_prompts): + if i_th >= kwargs["calib_size"]: + return + pipeline( + prompt=prompts, + num_inference_steps=kwargs["n_steps"], + negative_prompt=[ + "normal quality, low quality, worst quality, low res, blurry, nsfw, nude" + ] + * len(prompts), + ).images + + def forward_loop(model): + pipeline.unet = model + do_calibrate( + pipeline=pipeline, + calibration_prompts=calibration_prompts, + calib_size=calibration_size // calib_batch_size, + n_steps=self.denoising_steps, + ) + + print(f"[I] Performing calibration for {calibration_size} steps.") + if model_config['use_int8']: + quant_config = get_int8_config( + model, + quantization_level, + quantization_alpha, + quantization_percentile, + self.denoising_steps + ) + elif model_config['use_fp8']: + quant_config = SD_FP8_FP32_DEFAULT_CONFIG if self.version == "2.1" else SD_FP8_FP16_DEFAULT_CONFIG + check_lora(model) + mtq.quantize(model, quant_config, forward_loop) + mto.save(model, model_config['state_dict_path']) + + def _get_quantized_model(self, obj, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size): + pipeline = obj.get_pipeline() + model = pipeline.unet + if model_config['use_fp8'] and quantization_level == 4.0: + set_fmha(model) + + if not os.path.exists(model_config['state_dict_path']): + self._calibrate_and_save_model(pipeline, model, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size) + else: + mto.restore(model, model_config['state_dict_path']) + + if not os.path.exists(model_config['onnx_path']): + quantize_lvl(model, quantization_level) + mtq.disable_quantizer(model, filter_func) + if model_config['use_fp8']: + generate_fp8_scales(model) + else: + model = None + + return model + + def _export_onnx(self, obj, model_config, opt_image_height, opt_image_width, static_shape, onnx_opset, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size): + do_export_onnx = not os.path.exists(model_config['engine_path']) and not os.path.exists(model_config['onnx_opt_path']) + do_export_weights_map = model_config['weights_map_path'] and not os.path.exists(model_config['weights_map_path']) + + if do_export_onnx or do_export_weights_map: + if not model_config['use_int8'] and not model_config['use_fp8']: + obj.export_onnx(model_config['onnx_path'], model_config['onnx_opt_path'], onnx_opset, opt_image_height, opt_image_width, enable_lora_merge=model_config['do_lora_merge'], static_shape=static_shape) + else: + print(f"[I] Generating quantized ONNX model: {model_config['onnx_path']}") + quantized_model = self._get_quantized_model(obj, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size) + obj.export_onnx(model_config['onnx_path'], model_config['onnx_opt_path'], onnx_opset, opt_image_height, opt_image_width, custom_model=quantized_model, static_shape=static_shape) + + # FIXME do_export_weights_map needs ONNX graph + if do_export_weights_map: + print(f"[I] Saving weights map: {model_config['weights_map_path']}") + obj.export_weights_map(model_config['onnx_opt_path'], model_config['weights_map_path']) + + + def _build_engine(self, obj, engine, model_config, opt_batch_size, opt_image_height, opt_image_width, optimization_level, static_batch, static_shape, enable_all_tactics, timing_cache): + update_output_names = obj.get_output_names() + obj.extra_output_names if obj.extra_output_names else None + fp16amp = False if (model_config['use_fp8'] or getattr(obj, 'build_strongly_typed', False)) else obj.fp16 + tf32amp = obj.tf32 + bf16amp = False if (model_config['use_fp8'] or getattr(obj, 'build_strongly_typed', False)) else obj.bf16 + strongly_typed = True if (model_config['use_fp8'] or getattr(obj, 'build_strongly_typed', False)) else False + extra_build_args = {'verbose': self.verbose} + extra_build_args['builder_optimization_level'] = optimization_level + if model_config['use_int8']: + extra_build_args['int8'] = True + extra_build_args['precision_constraints'] = 'prefer' + engine.build(model_config['onnx_opt_path'], + strongly_typed=strongly_typed, + fp16=fp16amp, + tf32=tf32amp, + bf16=bf16amp, + input_profile=obj.get_input_profile( + opt_batch_size, opt_image_height, opt_image_width, + static_batch=static_batch, static_shape=static_shape + ), + enable_refit=model_config['do_engine_refit'], + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=update_output_names, + **extra_build_args) + + def _refit_engine(self, obj, model_name, model_config): + assert model_config['weights_map_path'] + with open(model_config['weights_map_path'], 'r') as fp_wts: + print(f"[I] Loading weights map: {model_config['weights_map_path']} ") + [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts) + + if not os.path.exists(model_config['refit_weights_path']): + print(f"[I] Saving refit weights: {model_config['refit_weights_path']}") + model = merge_loras(obj.get_model(), obj.lora_dict, obj.lora_alphas, obj.lora_scales) + refit_weights = get_refit_weights(model.state_dict(), model_config['onnx_opt_path'], weights_name_mapping, weights_shape_mapping) + torch.save(refit_weights, model_config['refit_weights_path']) + unload_model(model) + else: + print(f"[I] Loading refit weights: {model_config['refit_weights_path']}") + refit_weights = torch.load(model_config['refit_weights_path']) + self.engine[model_name].refit(refit_weights, obj.fp16) + + def _load_torch_models(self): + # Load torch models + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + self.torch_models[model_name] = obj.get_model(torch_inference=self.torch_inference) + + def load_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + optimization_level=3, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_all_tactics=False, + timing_cache=None, + int8=False, + fp8=False, + quantization_level=2.5, + quantization_percentile=1.0, + quantization_alpha=0.8, + calibration_size=32, + calib_batch_size=2, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to store the TensorRT engines. + framework_model_dir (str): + Directory to store the framework model ckpt. + onnx_dir (str): + Directory to store the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + optimization_level (int): + Optimization level to build the TensorRT engine with. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to speed up TensorRT build. + int8 (bool): + Whether to quantize to int8 format or not (SDXL, SD15 and SD21 only). + fp8 (bool): + Whether to quantize to fp8 format or not (SDXL, SD15 and SD21 only). + quantization_level (float): + Controls which layers to quantize. 1: CNN, 2: CNN+FFN, 2.5: CNN+FFN+QKV, 3: CNN+FC + quantization_percentile (float): + Control quantization scaling factors (amax) collecting range, where the minimum amax in + range(n_steps * percentile) will be collected. Recommendation: 1.0 + quantization_alpha (float): + The alpha parameter for SmoothQuant quantization used for linear layers. + Recommendation: 0.8 for SDXL + calibration_size (int): + The number of steps to use for calibrating the model for quantization. + Recommendation: 32, 64, 128 for SDXL + calib_batch_size (int): + The batch size to use for calibration. Defaults to 2. + """ + self._create_directories(engine_dir, onnx_dir) + self._initialize_models(framework_model_dir, int8, fp8) + + model_configs = self._prepare_model_configs(onnx_dir, engine_dir, enable_refit, int8, fp8, quantization_level, quantization_percentile, quantization_alpha, calibration_size) + + # Export models to ONNX + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + self._export_onnx(obj, model_configs[model_name], opt_image_height, opt_image_width, static_shape, onnx_opset, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size) + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + + model_config = model_configs[model_name] + engine = Engine(model_config['engine_path']) + if not os.path.exists(model_config['engine_path']): + self._build_engine(obj, engine, model_config, opt_batch_size, opt_image_height, opt_image_width, optimization_level, static_batch, static_shape, enable_all_tactics, timing_cache) + self.engine[model_name] = engine + + # Load and refit TensorRT engines + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + self.engine[model_name].load() + + model_config = model_configs[model_name] + if model_config['do_engine_refit'] and obj.lora_dict: + self._refit_engine(obj, model_name, model_config) + + # Load PyTorch models if torch-inference mode is enabled + self._load_torch_models() + + # Reclaim GPU memory from torch cache + torch.cuda.empty_cache() + + def calculate_max_device_memory(self): + max_device_memory = 0 + for model_name, engine in self.engine.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + def activate_engines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.calculate_max_device_memory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + for engine in self.engine.values(): + engine.activate(device_memory=self.shared_device_memory) + + def run_engine(self, model_name, feed_dict): + engine = self.engine[model_name] + return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e[0]) + cudart.cudaEventDestroy(e[1]) + + for engine in self.engine.values(): + del engine + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + for torch_model in self.torch_models.values(): + del torch_model + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width, latents_dtype=torch.float32): + latents_dtype = latents_dtype # text_embeddings.dtype + latents_shape = (batch_size, unet_channels, latent_height, latent_width) + latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator) + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def save_image(self, images, pipeline, prompt, seed): + # Save image + prompt_prefix = ''.join(set([prompt[i].replace(' ','_')[:10] for i in range(len(prompt))])) + image_name_prefix = '-'.join([pipeline, prompt_prefix, str(seed)]) + image_name_suffix = 'torch' if self.torch_inference else 'trt' + save_image(images, self.output_dir, image_name_prefix, image_name_suffix) + + @abstractmethod + def print_summary(self): + """Print a summary of the pipeline's configuration.""" + raise NotImplementedError("Please Implement the print_summary method") + + + @abstractmethod + def infer(self): + """Perform inference using the pipeline.""" + raise NotImplementedError("Please Implement the infer method") + + + @abstractmethod + def run(self): + """Run the pipeline.""" + raise NotImplementedError("Please Implement the run method") + diff --git a/demo/Diffusion/flux_pipeline.py b/demo/Diffusion/flux_pipeline.py new file mode 100644 index 00000000..4de8e0f7 --- /dev/null +++ b/demo/Diffusion/flux_pipeline.py @@ -0,0 +1,554 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +import numpy as np +from cuda import cudart +import inspect +from models import ( + get_clip_embedding_dim, + make_tokenizer, + CLIPModel, + T5Model, + FluxTransformerModel, + VAEModel, +) +import tensorrt as trt +import time +import torch +from utilities import ( + PIPELINE_TYPE, + TRT_LOGGER, +) +from diffusion_pipeline import DiffusionPipeline + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class FluxPipeline(DiffusionPipeline): + """ + Application showcasing the acceleration of Flux pipelines using Nvidia TensorRT. + """ + + def __init__( + self, + version="flux.1-dev", + pipeline_type=PIPELINE_TYPE.TXT2IMG, + guidance_scale=3.5, + max_sequence_length=512, + **kwargs, + ): + """ + Initializes the Flux pipeline. + + Args: + guidance_scale (`float`, defaults to 3.5): + Guidance scale is enabled by setting as > 1. + Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length to use with the `prompt`. + """ + super().__init__(version=version, pipeline_type=pipeline_type, **kwargs) + self.guidance_scale = guidance_scale + self.max_sequence_length = max_sequence_length + + # Pipeline type + self.stages = ["clip", "t5", "transformer", "vae"] + + def _initialize_models(self, framework_model_dir, int8, fp8): + # Load text tokenizer(s) + self.tokenizer = make_tokenizer( + self.version, self.pipeline_type, self.hf_token, framework_model_dir + ) + self.tokenizer2 = make_tokenizer( + self.version, + self.pipeline_type, + self.hf_token, + framework_model_dir, + subfolder="tokenizer_2", + tokenizer_type="t5", + ) + + # Load pipeline models + models_args = { + "version": self.version, + "pipeline": self.pipeline_type, + "device": self.device, + "hf_token": self.hf_token, + "verbose": self.verbose, + "framework_model_dir": framework_model_dir, + "max_batch_size": self.max_batch_size, + } + + self.fp16 = True + self.tf32 = True + if "clip" in self.stages: + self.models["clip"] = CLIPModel( + **models_args, + fp16=self.fp16, + tf32=self.tf32, + embedding_dim=get_clip_embedding_dim(self.version, self.pipeline_type), + keep_pooled_output=True, + subfolder="text_encoder", + ) + + if "t5" in self.stages: + # Known accuracy issues with FP16 + self.models["t5"] = T5Model( + **models_args, + fp16=False, + tf32=self.tf32, + subfolder="text_encoder_2", + text_maxlen=self.max_sequence_length, + ) + + if "transformer" in self.stages: + self.models["transformer"] = FluxTransformerModel( + **models_args, + fp16=self.fp16, + tf32=self.tf32, + text_maxlen=self.max_sequence_length, + build_strongly_typed=False, + ) + + if "vae" in self.stages: + # Accuracy issues with FP16 + self.models["vae"] = VAEModel(**models_args, fp16=False, tf32=self.tf32) + + self.vae_scale_factor = ( + 2 ** (len(self.models["vae"].config["block_out_channels"])) + if "vae" in self.stages and self.models["vae"] is not None + else 16 + ) + + # Copied from https://github.com/huggingface/diffusers/blob/v0.30.1/src/diffusers/pipelines/flux/pipeline_flux.py#L436 + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + """ + Reshapes latents from (B, C, H, W) to (B, H/2, W/2, C*4) as expected by the denoiser + """ + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + return latents + + # Copied from https://github.com/huggingface/diffusers/blob/v0.30.1/src/diffusers/pipelines/flux/pipeline_flux.py#L444 + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + """ + Reshapes denoised latents to the format (B, C, H, W) + """ + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape( + batch_size, channels // (2 * 2), height * 2, width * 2 + ) + + return latents + + # Copied from https://github.com/huggingface/diffusers/blob/v0.30.1/src/diffusers/pipelines/flux/pipeline_flux.py#L421 + @staticmethod + def _prepare_latent_image_ids(height, width, dtype, device): + """ + Prepares latent image indices + """ + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + ) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = ( + latent_image_ids.shape + ) + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + def initialize_latents( + self, + batch_size, + num_channels_latents, + latent_height, + latent_width, + latents_dtype=torch.float32, + ): + latents_dtype = latents_dtype # text_embeddings.dtype + latents_shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = torch.randn( + latents_shape, + device=self.device, + dtype=latents_dtype, + generator=self.generator, + ) + + latents = self._pack_latents( + latents, batch_size, num_channels_latents, latent_height, latent_width + ) + + latent_image_ids = self._prepare_latent_image_ids( + latent_height, latent_width, latents_dtype, self.device + ) + + return latents, latent_image_ids + + def encode_prompt( + self, prompt, encoder="clip", max_sequence_length=None, pooled_output=False + ): + self.profile_start(encoder, color="green") + + tokenizer = self.tokenizer2 if encoder == "t5" else self.tokenizer + max_sequence_length = ( + tokenizer.model_max_length + if max_sequence_length is None + else max_sequence_length + ) + + def tokenize(prompt, max_sequence_length): + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids.type(torch.int32).to(self.device) + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, max_sequence_length - 1 : -1] + ) + warnings.warn( + "The following part of your input was truncated because `max_sequence_length` is set to " + f"{max_sequence_length} tokens: {removed_text}" + ) + + if self.torch_inference or self.torch_fallback[encoder]: + outputs = self.torch_models[encoder]( + text_input_ids, output_hidden_states=False + ) + text_encoder_output = ( + outputs[0].clone() + if pooled_output == False + else outputs.pooler_output.clone() + ) + else: + # NOTE: output tensor for the encoder must be cloned because it will be overwritten when called again for prompt2 + outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) + output_name = ( + "text_embeddings" if not pooled_output else "pooled_embeddings" + ) + text_encoder_output = outputs[output_name].clone() + + return text_encoder_output + + # Tokenize prompt + text_encoder_output = tokenize(prompt, max_sequence_length) + + self.profile_stop(encoder) + return ( + text_encoder_output.to(torch.float16) if self.fp16 else text_encoder_output + ) + + def denoise_latent( + self, + latents, + timesteps, + text_embeddings, + pooled_embeddings, + text_ids, + latent_image_ids, + denoiser="transformer", + guidance=None, + ): + do_autocast = self.torch_inference != "" and self.models[denoiser].fp16 + with torch.autocast("cuda", enabled=do_autocast): + self.profile_start(denoiser, color="blue") + # handle guidance + if self.models[denoiser].config["guidance_embeds"] and guidance is None: + guidance = torch.full( + [1], self.guidance_scale, device=self.device, dtype=torch.float32 + ) + guidance = guidance.expand(latents.shape[0]) + + for step_index, timestep in enumerate(timesteps): + # prepare inputs + timestep_inp = timestep.expand(latents.shape[0]).to(latents.dtype) + params = { + "hidden_states": latents, + "timestep": timestep_inp / 1000, + "pooled_projections": pooled_embeddings, + "encoder_hidden_states": text_embeddings, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + } + if guidance is not None: + params.update({"guidance": guidance}) + + # Predict the noise residual + if self.torch_inference or self.torch_fallback[denoiser]: + noise_pred = self.torch_models[denoiser](**params)["sample"] + else: + noise_pred = self.run_engine(denoiser, params)["latent"] + + latents = self.scheduler.step( + noise_pred, timestep, latents, return_dict=False + )[0] + + self.profile_stop(denoiser) + return latents.to(dtype=torch.float32) + + def decode_latent(self, latents, decoder="vae"): + self.profile_start(decoder, color="red") + if self.torch_inference or self.torch_fallback[decoder]: + images = self.torch_models[decoder](latents, return_dict=False)[0] + else: + images = self.run_engine(decoder, {"latent": latents})["images"] + self.profile_stop(decoder) + return images + + def print_summary(self, denoising_steps, walltime_ms, batch_size): + print("|-----------------|--------------|") + print("| {:^15} | {:^12} |".format("Module", "Latency")) + print("|-----------------|--------------|") + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP", + cudart.cudaEventElapsedTime( + self.events["clip"][0], self.events["clip"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "T5", + cudart.cudaEventElapsedTime(self.events["t5"][0], self.events["t5"][1])[ + 1 + ], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "Transformer x " + str(denoising_steps), + cudart.cudaEventElapsedTime( + self.events["transformer"][0], self.events["transformer"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE-Dec", + cudart.cudaEventElapsedTime( + self.events["vae"][0], self.events["vae"][1] + )[1], + ) + ) + print("|-----------------|--------------|") + print("| {:^15} | {:>9.2f} ms |".format("Pipeline", walltime_ms)) + print("|-----------------|--------------|") + print("Throughput: {:.2f} image/s".format(batch_size * 1000.0 / walltime_ms)) + + def infer( + self, + prompt, + prompt2, + image_height, + image_width, + warmup=False, + save_image=True, + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + prompt2 (str): + The prompt to be sent to the T5 tokenizer and text encoder + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + warmup (bool): + Indicate if this is a warmup run. + save_image (bool): + Save the generated image (if applicable) + """ + assert len(prompt) == len(prompt2) + batch_size = len(prompt) + + # Spatial dimensions of latent tensor + latent_height = 2 * (int(image_height) // self.vae_scale_factor) + latent_width = 2 * (int(image_width) // self.vae_scale_factor) + + num_inference_steps = self.denoising_steps + + with torch.inference_mode(), trt.Runtime(TRT_LOGGER): + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # Initialize latents + latents, latent_image_ids = self.initialize_latents( + batch_size=batch_size, + num_channels_latents=self.models["transformer"].config["in_channels"] + // 4, + latent_height=latent_height, + latent_width=latent_width, + latents_dtype=torch.float16 if self.fp16 else torch.float32, + ) + + # CLIP and T5 text encoder(s) + pooled_embeddings = self.encode_prompt(prompt, pooled_output=True) + text_embeddings = self.encode_prompt( + prompt2, encoder="t5", max_sequence_length=self.max_sequence_length + ) + text_ids = torch.zeros(text_embeddings.shape[1], 3).to( + device=self.device, dtype=text_embeddings.dtype + ) + + # Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps = None + # TODO: support custom timesteps + if timesteps is not None: + if ( + "timesteps" + not in inspect.signature(self.scheduler.set_timesteps).parameters + ): + raise ValueError( + f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + self.scheduler.set_timesteps(timesteps=timesteps, device=self.device) + assert self.denoising_steps == len(self.scheduler.timesteps) + else: + self.scheduler.set_timesteps(sigmas=sigmas, mu=mu, device=self.device) + timesteps = self.scheduler.timesteps.to(self.device) + num_inference_steps = len(timesteps) + + # DiT denoiser + latents = self.denoise_latent( + latents, + timesteps, + text_embeddings, + pooled_embeddings, + text_ids, + latent_image_ids, + ) + + # VAE decode latent + latents = self._unpack_latents( + latents, image_height, image_width, self.vae_scale_factor + ) + latents = ( + latents / self.models["vae"].config["scaling_factor"] + ) + self.models["vae"].config["shift_factor"] + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + walltime_ms = (e2e_toc - e2e_tic) * 1000.0 + if not warmup: + self.print_summary(num_inference_steps, walltime_ms, batch_size) + if not self.return_latents and save_image: + # post-process images + images = ( + ((images + 1) * 255 / 2) + .clamp(0, 255) + .detach() + .permute(0, 2, 3, 1) + .round() + .type(torch.uint8) + .cpu() + .numpy() + ) + self.save_image( + images, self.pipeline_type.name.lower(), prompt, self.seed + ) + + return (latents, walltime_ms) if self.return_latents else (images, walltime_ms) + + def run( + self, + prompt, + prompt2, + height, + width, + batch_count, + num_warmup_runs, + use_cuda_graph, + **kwargs, + ): + num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs + if num_warmup_runs > 0: + print("[I] Warming up ..") + for _ in range(num_warmup_runs): + self.infer(prompt, prompt2, height, width, warmup=True, **kwargs) + + for _ in range(batch_count): + print("[I] Running Flux pipeline") + if self.nvtx_profile: + cudart.cudaProfilerStart() + self.infer(prompt, prompt2, height, width, warmup=False, **kwargs) + if self.nvtx_profile: + cudart.cudaProfilerStop() diff --git a/demo/Diffusion/models.py b/demo/Diffusion/models.py old mode 100644 new mode 100755 index cdd2300c..241ca821 --- a/demo/Diffusion/models.py +++ b/demo/Diffusion/models.py @@ -17,14 +17,6 @@ from diffusers import DiffusionPipeline from diffusers.loaders import LoraLoaderMixin -from diffusers.models import ( - AutoencoderKL, - AutoencoderKLTemporalDecoder, - ControlNetModel, - UNet2DConditionModel, - UNetSpatioTemporalConditionModel, - StableCascadeUNet -) from diffusers.pipelines.wuerstchen import PaellaVQModel import json import numpy as np @@ -39,14 +31,17 @@ import torch import torch.nn.functional as F from transformers import ( + AutoConfig, CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast ) from huggingface_hub import hf_hub_download -from utilities import merge_loras +from utilities import merge_loras, import_from_diffusers from utils_sd3.sd3_impls import BaseModel as BaseModelSD3 from utils_sd3.sd3_impls import SDVAE from utils_sd3.other_impls import load_into, SDClipModel, SDXLClipG, T5XXLModel @@ -58,6 +53,19 @@ ) from onnxmltools.utils.float16_converter import convert_float_to_float16 +# List of models to import from diffusers.models +models_to_import = [ + 'AutoencoderKL', + 'AutoencoderKLTemporalDecoder', + 'ControlNetModel', + 'UNet2DConditionModel', + 'UNetSpatioTemporalConditionModel', + 'StableCascadeUNet', + 'FluxTransformer2DModel' +] +for model in models_to_import: + globals()[model] = import_from_diffusers(model, 'diffusers.models') + class Optimizer(): def __init__( self, @@ -201,7 +209,7 @@ def get_path(version, pipeline, controlnets=None): elif version == "1.4": return "CompVis/stable-diffusion-v1-4" elif version == "1.5": - return "benjamin-paine/stable-diffusion-v1-5" + return "KiwiXR/stable-diffusion-v1-5" elif version == 'dreamshaper-7': return 'Lykon/dreamshaper-7' elif version in ("2.0-base", "2.0") and pipeline.is_inpaint(): @@ -230,11 +238,13 @@ def get_path(version, pipeline, controlnets=None): return "stabilityai/stable-cascade" else: return "stabilityai/stable-cascade-prior" + elif version == 'flux.1-dev': + return "black-forest-labs/FLUX.1-dev" else: raise ValueError(f"Unsupported version {version} + pipeline {pipeline.name}") def get_clip_embedding_dim(version, pipeline): - if version in ("1.4", "1.5", "dreamshaper-7"): + if version in ("1.4", "1.5", "dreamshaper-7", "flux.1-dev"): return 768 elif version in ("2.0", "2.0-base", "2.1", "2.1-base"): return 1024 @@ -335,6 +345,7 @@ def __init__(self, verbose=True, framework_model_dir='pytorch_model', fp16=False, + tf32=False, bf16=False, int8=False, fp8=False, @@ -355,6 +366,7 @@ def __init__(self, self.framework_model_dir = framework_model_dir self.fp16 = fp16 + self.tf32 = tf32 self.bf16 = bf16 self.int8 = int8 self.fp8 = fp8 @@ -363,7 +375,7 @@ def __init__(self, self.min_batch = 1 self.max_batch = max_batch_size self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.max_image_shape = 1344 # max image resolution: 1344x1344 self.min_latent_shape = self.min_image_shape // self.compression_factor self.max_latent_shape = self.max_image_shape // self.compression_factor @@ -380,7 +392,7 @@ def get_pipeline(self): return DiffusionPipeline.from_pretrained( self.path, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts, ).to(self.device) @@ -534,7 +546,6 @@ def optimize(self, onnx_graph, return_onnx=True, **kwargs): def check_dims(self, batch_size, image_height, image_width): assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 latent_height = image_height // self.compression_factor latent_width = image_width // self.compression_factor assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape @@ -568,31 +579,36 @@ def __init__(self, max_batch_size, embedding_dim, fp16=False, + tf32=False, bf16=False, output_hidden_states=False, + keep_pooled_output=False, subfolder="text_encoder", lora_dict=None, lora_alphas=None, ): - super(CLIPModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + super(CLIPModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=embedding_dim) self.subfolder = subfolder self.hidden_layer_offset = 0 if pipeline.is_cascade() else -1 + self.keep_pooled_output = keep_pooled_output # Output the final hidden state if output_hidden_states: self.extra_output_names = ['hidden_states'] def get_model(self, torch_inference=''): + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} clip_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) if not os.path.exists(clip_model_dir): model = CLIPTextModel.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token).to(self.device) - model.save_pretrained(clip_model_dir) + token=self.hf_token, + **model_opts).to(self.device) + model.save_pretrained(clip_model_dir, **model_opts) else: print(f"[I] Load CLIPTextModel model from: {clip_model_dir}") - model = CLIPTextModel.from_pretrained(clip_model_dir).to(self.device) + model = CLIPTextModel.from_pretrained(clip_model_dir, **model_opts).to(self.device) model = optimize_checkpoint(model, torch_inference) return model @@ -600,13 +616,19 @@ def get_input_names(self): return ['input_ids'] def get_output_names(self): - return ['text_embeddings'] + output_names = ['text_embeddings'] + if self.keep_pooled_output: + output_names += ['pooled_embeddings'] + return output_names def get_dynamic_axes(self): - return { + dynamic_axes = { 'input_ids': {0: 'B'}, - 'text_embeddings': {0: 'B'} + 'text_embeddings': {0: 'B'}, } + if self.keep_pooled_output: + dynamic_axes['pooled_embeddings'] = {0: 'B'} + return dynamic_axes def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): self.check_dims(batch_size, image_height, image_width) @@ -619,10 +641,12 @@ def get_shape_dict(self, batch_size, image_height, image_width): self.check_dims(batch_size, image_height, image_width) output = { 'input_ids': (batch_size, self.text_maxlen), - 'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim) + 'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim), } + if self.keep_pooled_output: + output['pooled_embeddings'] = (batch_size, self.embedding_dim) if 'hidden_states' in self.extra_output_names: - output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + output['hidden_states'] = (batch_size, self.text_maxlen, self.embedding_dim) return output def get_sample_input(self, batch_size, image_height, image_width, static_shape): @@ -632,15 +656,15 @@ def get_sample_input(self, batch_size, image_height, image_width, static_shape): def optimize(self, onnx_graph): opt = Optimizer(onnx_graph, verbose=self.verbose) opt.info(self.name + ': original') - opt.select_outputs([0]) # delete graph output#1 + keep_outputs = [0, 1] if self.keep_pooled_output else [0] + opt.select_outputs(keep_outputs) opt.cleanup() - opt.info(self.name + ': remove output[1]') opt.fold_constants() opt.info(self.name + ': fold constants') opt.infer_shapes() opt.info(self.name + ': shape inference') - opt.select_outputs([0], names=['text_embeddings']) # rename network output - opt.info(self.name + ': remove output[0]') + opt.select_outputs(keep_outputs, names=self.get_output_names()) # rename network outputs + opt.info(self.name + ': rename network output(s)') opt_onnx_graph = opt.cleanup(return_onnx=True) if 'hidden_states' in self.extra_output_names: opt_onnx_graph = opt.clip_add_hidden_states(self.hidden_layer_offset, return_onnx=True) @@ -677,7 +701,7 @@ def get_model(self, torch_inference=''): model = CLIPTextModelWithProjection.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts).to(self.device) model.save_pretrained(clip_model_dir, **model_opts) else: @@ -725,6 +749,72 @@ def get_sample_input(self, batch_size, image_height, image_width, static_shape): torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) ) +class T5Model(BaseModel): + def __init__(self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + fp16=False, + tf32=False, + bf16=False, + subfolder="text_encoder", + text_maxlen=512, + ): + super(T5Model, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, bf16=bf16, max_batch_size=max_batch_size, text_maxlen=text_maxlen) + self.subfolder = subfolder + self.config = AutoConfig.from_pretrained(self.path, subfolder=self.subfolder, token=self.hf_token) + + def get_model(self, torch_inference=''): + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} + t5_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) + if not os.path.exists(t5_model_dir): + model = T5EncoderModel.from_pretrained(self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts).to(self.device) + model.save_pretrained(t5_model_dir, **model_opts) + else: + print(f"[I] Load T5EncoderModel model from: {t5_model_dir}") + model = T5EncoderModel.from_pretrained(t5_model_dir, **model_opts).to(self.device) + model = optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ['input_ids'] + + def get_output_names(self): + return ['text_embeddings'] + + def get_dynamic_axes(self): + return { + 'input_ids': {0: 'B'}, + 'text_embeddings': {0: 'B'} + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + 'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + 'input_ids': (batch_size, self.text_maxlen), + 'text_embeddings': (batch_size, self.text_maxlen, self.config.d_model) + } + return output + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + class SD3_CLIPGModel(CLIPModel): def __init__(self, version, @@ -897,7 +987,7 @@ def get_model(self, torch_inference=''): model = CLIPVisionModelWithProjection.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token).to(self.device) + token=self.hf_token).to(self.device) model.save_pretrained(clip_model_dir) else: print(f"[I] Load CLIPVisionModelWithProjection model from: {clip_model_dir}") @@ -928,7 +1018,7 @@ def get_model(self, torch_inference=''): model = CLIPImageProcessor.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token) + token=self.hf_token) model.save_pretrained(clip_model_dir) else: print(f"[I] Load CLIPImageProcessor model from: {clip_model_dir}") @@ -1014,7 +1104,7 @@ def get_model(self, torch_inference=''): unet_model = UNet2DConditionModel.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts).to(self.device) cnet_model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} controlnets = torch.nn.ModuleList([ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) for path in self.controlnets]) @@ -1027,7 +1117,7 @@ def get_model(self, torch_inference=''): model = UNet2DConditionModel.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts).to(self.device) model.save_pretrained(unet_model_dir, **model_opts) else: @@ -1173,7 +1263,7 @@ def get_model(self, torch_inference=''): model = UNet2DConditionModel.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts).to(self.device) # Use default attention processor for ONNX export if not torch_inference: @@ -1358,14 +1448,14 @@ def __init__(self, self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier def get_model(self, torch_inference=''): - model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {} + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} unet_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) unet_path = self.get_model_path(unet_model_dir, model_opts) if not os.path.exists(unet_path): model = UNetSpatioTemporalConditionModel.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts).to(self.device) model.save_pretrained(unet_model_dir, **model_opts) else: @@ -1463,7 +1553,7 @@ def get_model(self, torch_inference=''): model = StableCascadeUNet.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts).to(self.device) model.save_pretrained(unet_model_dir, **model_opts) else: @@ -1566,6 +1656,106 @@ def get_sample_input(self, batch_size, image_height, image_width, static_shape): } ) +class FluxTransformerModel(BaseModel): + def __init__(self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16 = False, + tf32=False, + int8 = False, + fp8 = False, + max_batch_size = 16, + text_maxlen = 77, + build_strongly_typed=False + ): + + super(FluxTransformerModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, int8=int8, fp8=fp8, max_batch_size=max_batch_size, text_maxlen=text_maxlen) + self.subfolder = 'transformer' + self.config = FluxTransformer2DModel.load_config(self.path, subfolder=self.subfolder, token=self.hf_token) + self.build_strongly_typed = build_strongly_typed + + def get_model(self, torch_inference=''): + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} + transformer_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) + transformer_path = self.get_model_path(transformer_model_dir, model_opts) + if not os.path.exists(transformer_path): + model = FluxTransformer2DModel.from_pretrained(self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts).to(self.device) + model.save_pretrained(transformer_model_dir, **model_opts) + else: + print(f"[I] Load FluxTransformer2DModel model from: {transformer_path}") + model = FluxTransformer2DModel.from_pretrained(transformer_model_dir, **model_opts).to(self.device) + if torch_inference: + model.to(memory_format=torch.channels_last) + model = optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ['hidden_states', 'encoder_hidden_states', 'pooled_projections', 'timestep', 'img_ids', 'txt_ids', 'guidance'] + + def get_output_names(self): + return ['latent'] + + def get_dynamic_axes(self): + return { + 'hidden_states': {0: 'B', 1: 'latent_dim'}, + 'encoder_hidden_states': {0: 'B'}, + 'pooled_projections': {0: 'B'}, + 'timestep': {0: 'B'}, + 'img_ids': {0: 'latent_dim'}, + 'guidance': {0: 'B'}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \ + self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + 'hidden_states': [(min_batch, (min_latent_height // 2) * (min_latent_width // 2), self.config['in_channels']), (batch_size, (latent_height // 2) * (latent_width // 2), self.config['in_channels']), (max_batch, (max_latent_height // 2) * (max_latent_width // 2), self.config['in_channels'])], + 'encoder_hidden_states': [(min_batch, self.text_maxlen, self.config['joint_attention_dim']), (batch_size, self.text_maxlen, self.config['joint_attention_dim']), (max_batch, self.text_maxlen, self.config['joint_attention_dim'])], + 'pooled_projections': [(min_batch, self.config['pooled_projection_dim']), (batch_size, self.config['pooled_projection_dim']), (max_batch, self.config['pooled_projection_dim'])], + 'timestep': [(min_batch,), (batch_size,), (max_batch,)], + 'img_ids': [((min_latent_height // 2) * (min_latent_width // 2), 3), ((latent_height // 2) * (latent_width // 2), 3), ((max_latent_height // 2) * (max_latent_width // 2), 3)], + 'txt_ids': [(self.text_maxlen, 3), (self.text_maxlen, 3), (self.text_maxlen, 3)], + 'guidance': [(min_batch,), (batch_size,), (max_batch,)], + } + + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + 'hidden_states': (batch_size, (latent_height // 2) * (latent_width // 2), self.config['in_channels']), + 'encoder_hidden_states': (batch_size, self.text_maxlen, self.config['joint_attention_dim']), + 'pooled_projections': (batch_size, self.config['pooled_projection_dim']), + 'timestep': (batch_size,), + 'img_ids': ((latent_height // 2) * (latent_width // 2), 3), + 'txt_ids': (self.text_maxlen, 3), + 'latent': (batch_size, (latent_height // 2) * (latent_width // 2), self.config['in_channels']), + 'guidance': (batch_size,), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn(batch_size, (latent_height // 2) * (latent_width // 2), self.config['in_channels'], dtype=dtype, device=self.device), + torch.randn(batch_size, self.text_maxlen, self.config['joint_attention_dim'], dtype=dtype, device=self.device), + torch.randn(batch_size, self.config['pooled_projection_dim'], dtype=dtype, device=self.device), + torch.tensor([1.]*batch_size, dtype=dtype, device=self.device), + torch.randn((latent_height // 2) * (latent_width // 2), 3, dtype=dtype, device=self.device), + torch.randn(self.text_maxlen, 3, dtype=dtype, device=self.device), + { + 'guidance': torch.tensor([1.]*batch_size, dtype=dtype, device=self.device), + } + ) + class VAEModel(BaseModel): def __init__(self, version, @@ -1575,10 +1765,12 @@ def __init__(self, verbose, framework_model_dir, fp16=False, + tf32=False, max_batch_size=16, ): - super(VAEModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size) + super(VAEModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, max_batch_size=max_batch_size) self.subfolder = 'vae' + self.config = AutoencoderKL.load_config(self.path, subfolder=self.subfolder, token=self.hf_token) def get_model(self, torch_inference=''): vae_decoder_model_path = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) @@ -1586,7 +1778,7 @@ def get_model(self, torch_inference=''): model = AutoencoderKL.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token).to(self.device) + token=self.hf_token).to(self.device) model.save_pretrained(vae_decoder_model_path) else: print(f"[I] Load AutoencoderKL (decoder) model from: {vae_decoder_model_path}") @@ -1612,21 +1804,24 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \ self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) return { - 'latent': [(min_batch, 4, min_latent_height, min_latent_width), (batch_size, 4, latent_height, latent_width), (max_batch, 4, max_latent_height, max_latent_width)] + 'latent': [(min_batch, self.config['latent_channels'], min_latent_height, min_latent_width), + (batch_size, self.config['latent_channels'], latent_height, latent_width), + (max_batch, self.config['latent_channels'], max_latent_height, max_latent_width) + ] } def get_shape_dict(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) return { - 'latent': (batch_size, 4, latent_height, latent_width), + 'latent': (batch_size, self.config['latent_channels'], latent_height, latent_width), 'images': (batch_size, 3, image_height, image_width) } def get_sample_input(self, batch_size, image_height, image_width, static_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + return torch.randn(batch_size, self.config['latent_channels'], latent_height, latent_width, dtype=torch.float32, device=self.device) -class SD3_VAEDecoderModel(VAEModel): +class SD3_VAEDecoderModel(BaseModel): def __init__(self, version, pipeline, @@ -1657,6 +1852,18 @@ def get_model(self, torch_inference=''): model = optimize_checkpoint(model, torch_inference) return model + def get_input_names(self): + return ['latent'] + + def get_output_names(self): + return ['images'] + + def get_dynamic_axes(self): + return { + 'latent': {0: 'B', 2: 'H', 3: 'W'}, + 'images': {0: 'B', 2: '8H', 3: '8W'} + } + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \ @@ -1698,7 +1905,7 @@ def get_model(self, torch_inference=''): model = AutoencoderKLTemporalDecoder.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token).to(self.device) + token=self.hf_token).to(self.device) model.save_pretrained(vae_decoder_model_path) else: print(f"[I] Load AutoencoderKLTemporalDecoder model from: {vae_decoder_model_path}") @@ -1755,7 +1962,7 @@ def __init__(self, version, pipeline, hf_token, device, path, framework_model_di self.vae_encoder = AutoencoderKL.from_pretrained(path, subfolder='vae', use_safetensors=hf_safetensor, - use_auth_token=hf_token).to(device) + token=hf_token).to(device) self.vae_encoder.save_pretrained(vae_encoder_model_dir) else: print(f"[I] Load AutoencoderKL (encoder) model from: {vae_encoder_model_dir}") @@ -1894,7 +2101,7 @@ def get_model(self, torch_inference=''): model = PaellaVQModel.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - use_auth_token=self.hf_token, + token=self.hf_token, **model_opts).to(self.device) model.save_pretrained(vqgan_model_dir, **model_opts) else: @@ -1950,15 +2157,31 @@ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, s max_latent_width = int(max_latent_width * self.latent_dim_scale) return (min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width) -def make_tokenizer(version, pipeline, hf_token, framework_model_dir, subfolder="tokenizer", **kwargs): +def make_tokenizer(version, pipeline, hf_token, framework_model_dir, subfolder="tokenizer", tokenizer_type="clip"): + if tokenizer_type == "clip": + tokenizer_class = CLIPTokenizer + elif tokenizer_type == "t5": + tokenizer_class = T5TokenizerFast + else: + raise ValueError(f"Unsupported tokenizer_type {tokenizer_type}. Only tokenizer_type clip and t5 are currently supported") tokenizer_model_dir = get_checkpoint_dir(framework_model_dir, version, pipeline.name, subfolder) if not os.path.exists(tokenizer_model_dir): - model = CLIPTokenizer.from_pretrained(get_path(version, pipeline), + model = tokenizer_class.from_pretrained(get_path(version, pipeline), subfolder=subfolder, use_safetensors=pipeline.is_sd_xl(), - use_auth_token=hf_token) + token=hf_token) model.save_pretrained(tokenizer_model_dir) else: - print(f"[I] Load CLIPTokenizer model from: {tokenizer_model_dir}") - model = CLIPTokenizer.from_pretrained(tokenizer_model_dir) + print(f"[I] Load {tokenizer_class.__name__} model from: {tokenizer_model_dir}") + model = tokenizer_class.from_pretrained(tokenizer_model_dir) return model + +def make_scheduler(cls, version, pipeline, hf_token, framework_model_dir, subfolder="scheduler"): + scheduler_dir = os.path.join(framework_model_dir, version, pipeline.name, next(iter({cls.__name__})).lower(), subfolder) + if not os.path.exists(scheduler_dir): + scheduler = cls.from_pretrained(get_path(version, pipeline), subfolder=subfolder, token=hf_token) + scheduler.save_pretrained(scheduler_dir) + else: + print(f"[I] Load Scheduler {cls.__name__} from: {scheduler_dir}") + scheduler = cls.from_pretrained(scheduler_dir) + return scheduler diff --git a/demo/Diffusion/requirements.txt b/demo/Diffusion/requirements.txt index 280fe886..4e494185 100755 --- a/demo/Diffusion/requirements.txt +++ b/demo/Diffusion/requirements.txt @@ -1,7 +1,9 @@ +accelerate colored controlnet_aux==0.0.6 cuda-python -diffusers==0.29.2 +# TODO: Pin Diffusers version after the next release +git+https://github.com/huggingface/diffusers.git # Install from source for the latest changes in main ftfy matplotlib nvtx @@ -9,7 +11,7 @@ onnx==1.15.0 onnxruntime==1.17.3 opencv-python==4.8.0.74 scipy -transformers==4.36.2 +transformers==4.42.2 --extra-index-url https://pypi.nvidia.com nvidia-modelopt[torch,onnx]==0.15.1 onnx-graphsurgeon diff --git a/demo/Diffusion/stable_diffusion_pipeline.py b/demo/Diffusion/stable_diffusion_pipeline.py index a9500f9d..332b5572 100644 --- a/demo/Diffusion/stable_diffusion_pipeline.py +++ b/demo/Diffusion/stable_diffusion_pipeline.py @@ -41,6 +41,7 @@ UNetXLModel, VAEModel, VAEEncoderModel, + make_scheduler ) import numpy as np import nvtx @@ -74,6 +75,19 @@ ) class StableDiffusionPipeline: + SCHEDULER_DEFAULTS = { + "1.4": "PNDM", + "1.5": "PNDM", + "dreamshaper-7": "PNDM", + "2.0-base": "DDIM", + "2.0": "DDIM", + "2.1-base": "PNDM", + "2.1": "DDIM", + "xl-1.0" : "Euler", + "xl-turbo": "EulerA", + "svd-xt-1.1": "Euler", + "cascade": "DDPMWuerstchen" + } """ Application showcasing the acceleration of Stable Diffusion pipelines using NVidia TensorRT. """ @@ -185,48 +199,26 @@ def __init__( raise ValueError(f"Unsupported pipeline {self.pipeline_type.name}.") self.return_latents = return_latents - # Schedulers - map_version_scheduler = { - '1.4': 'PNDM', - '1.5': 'PNDM', - 'dreamshaper-7': 'PNDM', - '2.0-base': 'DDIM', - '2.0': 'DDIM', - '2.1-base': 'PNDM', - '2.1': 'DDIM', - 'xl-1.0' : 'Euler', - 'xl-turbo': 'EulerA', - 'svd-xt-1.1': 'Euler', - 'cascade': 'DDPMWuerstchen' - } - if not scheduler: - scheduler = 'UniPC' if self.pipeline_type.is_controlnet() else map_version_scheduler.get(version, 'DDIM') + scheduler = 'UniPC' if self.pipeline_type.is_controlnet() else self.SCHEDULER_DEFAULTS.get(version, 'DDIM') print(f"[I] Autoselected scheduler: {scheduler}") - def makeScheduler(cls, subfolder="scheduler", **kwargs): - return cls.from_pretrained(get_path(self.version, self.pipeline_type), subfolder=subfolder) - - if scheduler == "DDIM": - self.scheduler = makeScheduler(DDIMScheduler) - elif scheduler == "DDPM": - self.scheduler = makeScheduler(DDPMScheduler) - elif scheduler == "EulerA": - self.scheduler = makeScheduler(EulerAncestralDiscreteScheduler) - elif scheduler == "Euler": - self.scheduler = makeScheduler(EulerDiscreteScheduler) - elif scheduler == "LCM": - self.scheduler = makeScheduler(LCMScheduler) - elif scheduler == "LMSD": - self.scheduler = makeScheduler(LMSDiscreteScheduler) - elif scheduler == "PNDM": - self.scheduler = makeScheduler(PNDMScheduler) - elif scheduler == "UniPC": - self.scheduler = makeScheduler(UniPCMultistepScheduler) - elif scheduler == "DDPMWuerstchen": - self.scheduler = makeScheduler(DDPMWuerstchenScheduler) - else: - raise ValueError(f"Unsupported scheduler {scheduler}. Should be either DDIM, DDPM, EulerA, Euler, LCM, LMSD, PNDM, or UniPC.") + scheduler_class_map = { + "DDIM" : DDIMScheduler, + "DDPM" : DDPMScheduler, + "EulerA" : EulerAncestralDiscreteScheduler, + "Euler" : EulerDiscreteScheduler, + "LCM" : LCMScheduler, + "LMSD" : LMSDiscreteScheduler, + "PNDM" : PNDMScheduler, + "UniPC" : UniPCMultistepScheduler, + "DDPMWuerstchen" : DDPMWuerstchenScheduler, + } + try: + scheduler_class = scheduler_class_map[scheduler] + except KeyError: + raise ValueError(f"Unsupported scheduler {scheduler}. Should be one of {list(scheduler_class.keys())}.") + self.scheduler = make_scheduler(scheduler_class, version, pipeline_type, hf_token, framework_model_dir) self.config = {} if self.pipeline_type.is_sd_xl(): @@ -375,6 +367,7 @@ def loadEngines( opt_batch_size, opt_image_height, opt_image_width, + optimization_level=3, static_batch=False, static_shape=True, enable_refit=False, @@ -407,6 +400,8 @@ def loadEngines( Image height to optimize for during engine building. Must be a multiple of 8. opt_image_width (int): Image width to optimize for during engine building. Must be a multiple of 8. + optimization_level (int): + Optimization level to build the TensorRT engine with. static_batch (bool): Build engine only for specified opt_batch_size. static_shape (bool): @@ -424,7 +419,7 @@ def loadEngines( quantization_level (float): Controls which layers to quantize. 1: CNN, 2: CNN+FFN, 2.5: CNN+FFN+QKV, 3: CNN+FC quantization_percentile (float): - Control quantization scaling factors (amax) collecting range, where the minimum amax in + Control quantization scaling factors (amax) collecting range, where the minimum amax in range(n_steps * percentile) will be collected. Recommendation: 1.0 quantization_alpha (float): The alpha parameter for SmoothQuant quantization used for linear layers. @@ -456,12 +451,12 @@ def loadEngines( use_int8 = dict.fromkeys(model_names, False) use_fp8 = dict.fromkeys(model_names, False) if int8: - assert self.pipeline_type.is_sd_xl_base() or self.version in ["1.5", "2.1", "2.1-base"], "int8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline" + assert self.pipeline_type.is_sd_xl_base() or self.version in ["1.4", "1.5", "2.1", "2.1-base"], "int8 quantization only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipeline" model_name = 'unetxl' if self.pipeline_type.is_sd_xl() else 'unet' use_int8[model_name] = True - model_suffix[model_name] += f"-int8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" + model_suffix[model_name] += f"-int8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" elif fp8: - assert self.pipeline_type.is_sd_xl() or self.version in ["1.5", "2.1", "2.1-base"], "fp8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline" + assert self.pipeline_type.is_sd_xl() or self.version in ["1.4", "1.5", "2.1", "2.1-base"], "fp8 quantization only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipeline" model_name = 'unetxl' if self.pipeline_type.is_sd_xl() else 'unet' use_fp8[model_name] = True model_suffix[model_name] += f"-fp8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" @@ -513,7 +508,7 @@ def forward_loop(model): calib_size=calibration_size // calib_batch_size, n_steps=self.denoising_steps, ) - + print(f"[I] Performing calibration for {calibration_size} steps.") if use_int8[model_name]: quant_config = get_int8_config( @@ -524,8 +519,8 @@ def forward_loop(model): self.denoising_steps ) elif use_fp8[model_name]: - check_lora(model) quant_config = SD_FP8_FP32_DEFAULT_CONFIG if self.version == "2.1" else SD_FP8_FP16_DEFAULT_CONFIG + check_lora(model) mtq.quantize(model, quant_config, forward_loop) mto.save(model, state_dict_path) else: @@ -538,7 +533,7 @@ def forward_loop(model): if use_fp8[model_name]: generate_fp8_scales(model) else: - model = None + model = None obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width, custom_model=model, static_shape=static_shape) # FIXME do_export_weights_map needs ONNX graph @@ -557,12 +552,10 @@ def forward_loop(model): bf16amp = obj.bf16 if not use_fp8[model_name] else False strongly_typed = False if not use_fp8[model_name] else True extra_build_args = {'verbose': self.verbose} + extra_build_args['builder_optimization_level'] = optimization_level if use_int8[model_name]: extra_build_args['int8'] = True extra_build_args['precision_constraints'] = 'prefer' - extra_build_args['builder_optimization_level'] = 4 - elif use_fp8[model_name]: - extra_build_args['builder_optimization_level'] = 4 engine.build(onnx_opt_path[model_name], strongly_typed=strongly_typed, fp16=fp16amp, @@ -845,7 +838,7 @@ def encode_image(self, input_image): def decode_latent(self, latents): self.profile_start('vae', color='red') if self.torch_inference: - images = self.torch_models['vae'](latents)['sample'] + images = self.torch_models['vae'](latents, return_dict=False)[0] else: images = self.runEngine('vae', {'latent': latents})['images'] self.profile_stop('vae') diff --git a/demo/Diffusion/utilities.py b/demo/Diffusion/utilities.py old mode 100644 new mode 100755 index 911139b6..92c9c089 --- a/demo/Diffusion/utilities.py +++ b/demo/Diffusion/utilities.py @@ -16,6 +16,8 @@ # limitations under the License. # +import warnings +from importlib import import_module from collections import OrderedDict from cuda import cudart from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear @@ -61,6 +63,19 @@ def GiB(val): trt.DataType.BF16 : torch.bfloat16 } +# Define valid optimization levels for TensorRT engine build +VALID_OPTIMIZATION_LEVELS = list(range(6)) + +def import_from_diffusers(model_name, module_name): + try: + module = import_module(module_name) + return getattr(module, model_name) + except ImportError: + warnings.warn(f"Failed to import {module_name}. The {model_name} model will not be available.", ImportWarning) + except AttributeError: + warnings.warn(f"The {model_name} model is not available in the installed version of diffusers.", ImportWarning) + return None + def unload_model(model): if model: del model @@ -266,8 +281,8 @@ def build(self, if native_instancenorm: flags.append(trt.OnnxParserFlag.NATIVE_INSTANCENORM) network = network_from_onnx_path( - onnx_path, - flags=flags, + onnx_path, + flags=flags, strongly_typed=strongly_typed ) if update_output_names: @@ -584,7 +599,7 @@ def append(self, item): def add_arguments(parser): # Stable Diffusion configuration - parser.add_argument('--version', type=str, default="1.5", choices=["1.4", "1.5", "dreamshaper-7", "2.0-base", "2.0", "2.1-base", "2.1", "xl-1.0", "xl-turbo", "svd-xt-1.1", "sd3", "cascade"], help="Version of Stable Diffusion") + parser.add_argument('--version', type=str, default="1.5", choices=("1.4", "1.5", "dreamshaper-7", "2.0-base", "2.0", "2.1-base", "2.1", "xl-1.0", "xl-turbo", "svd-xt-1.1", "sd3", "cascade", "flux.1-dev"), help="Version of Stable Diffusion") parser.add_argument('prompt', nargs = '*', help="Text prompt(s) to guide image generation") parser.add_argument('--negative-prompt', nargs = '*', default=[''], help="The negative prompt(s) to guide the image generation.") parser.add_argument('--batch-size', type=int, default=1, choices=[1, 2, 4], help="Batch size (repeat prompt)") @@ -592,7 +607,7 @@ def add_arguments(parser): parser.add_argument('--height', type=int, default=512, help="Height of image to generate (must be multiple of 8)") parser.add_argument('--width', type=int, default=512, help="Height of image to generate (must be multiple of 8)") parser.add_argument('--denoising-steps', type=int, default=30, help="Number of denoising steps") - parser.add_argument('--scheduler', type=str, default=None, choices=["DDIM", "DDPM", "EulerA", "Euler", "LCM", "LMSD", "PNDM", "UniPC"], help="Scheduler for diffusion process") + parser.add_argument('--scheduler', type=str, default=None, choices=("DDIM", "DDPM", "EulerA", "Euler", "LCM", "LMSD", "PNDM", "UniPC", "DDPMWuerstchen", "FlowMatchEuler"), help="Scheduler for diffusion process") parser.add_argument('--guidance-scale', type=float, default=7.5, help="Value of classifier-free guidance scale (must be greater than 1)") parser.add_argument('--lora-scale', type=float, nargs='+', default=None, help="Scale of LoRA weights, default 1 (must between 0 and 1)") parser.add_argument('--lora-path', type=str, nargs='+', default=None, help="Path to LoRA adaptor. Ex: 'latent-consistency/lcm-lora-sdv1-5'") @@ -609,6 +624,7 @@ def add_arguments(parser): parser.add_argument('--int8', action='store_true', help="Apply int8 quantization.") parser.add_argument('--fp8', action='store_true', help="Apply fp8 quantization.") parser.add_argument('--quantization-level', type=float, default=0.0, choices=[0.0, 1.0, 2.0, 2.5, 3.0, 4.0], help="int8/fp8 quantization level, 1: CNN, 2: CNN + FFN, 2.5: CNN + FFN + QKV, 3: CNN + Almost all Linear (Including FFN, QKV, Proj and others), 4: CNN + Almost all Linear + fMHA, 0: Default to 2.5 for int8 and 4.0 for fp8.") + parser.add_argument('--optimization-level', type=int, default=None, help=f"Set the builder optimization level to build the engine with. A higher level allows TensorRT to spend more building time for more optimization options. Must be one of {VALID_OPTIMIZATION_LEVELS}.") parser.add_argument('--build-static-batch', action='store_true', help="Build TensorRT engines with fixed batch size.") parser.add_argument('--build-dynamic-shape', action='store_true', help="Build TensorRT engines with dynamic image shapes.") parser.add_argument('--build-enable-refit', action='store_true', help="Enable Refit option in TensorRT engines during build.") @@ -638,28 +654,42 @@ def process_pipeline_args(args): if args.use_cuda_graph and (not args.build_static_batch or args.build_dynamic_shape): raise ValueError(f"Using CUDA graph requires static dimensions. Enable `--build-static-batch` and do not specify `--build-dynamic-shape`") - if args.int8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.5', '2.1']): - raise ValueError(f"int8 quantization is only supported for SDXL, SD1.5 and SD2.1 pipelines.") + if args.optimization_level is None: + if args.int8 or args.fp8: + args.optimization_level = 4 + else: + args.optimization_level = 3 + + if args.optimization_level not in VALID_OPTIMIZATION_LEVELS: + raise ValueError(f"Optimization level {args.optimization_level} not valid. Valid values are: {VALID_OPTIMIZATION_LEVELS}") + + if args.int8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.4', '1.5', '2.1']): + raise ValueError(f"int8 quantization is only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipelines.") + + if args.fp8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.4', '1.5', '2.1']): + raise ValueError(f"fp8 quantization is only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipelines.") - if args.fp8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.5', '2.1']): - raise ValueError(f"fp8 quantization is only supported for SDXL, SD1.5 and SD2.1 pipelines.") - if args.fp8 and args.int8: raise ValueError(f"Cannot apply both int8 and fp8 quantization, please choose only one.") - + if args.fp8: device_info = torch.cuda.get_device_properties(0) version = device_info.major * 10 + device_info.minor if version < 90: # if Ada or older raise ValueError(f"Cannot apply FP8 quantization for GPU with compute capability {version / 10.0}. Only Hopper is supported.") - + if args.quantization_level == 0.0: + def override_quant_level(level : float, dtype_str : str): + args.quantization_level = level + print(f"The default quantization level has been set to {level} for {dtype_str}.") + if args.fp8: - args.quantization_level = 4.0 - print("The default quantization level has been set to 4.0 for FP8.") + override_quant_level(3.0 if args.version in ("1.4", "1.5") else 4.0, "FP8") elif args.int8: - args.quantization_level = 2.5 - print("The default quantization level has been set to 2.5 for INT8.") + override_quant_level(3.0, "INT8") + + if args.lora_path and not any(args.version.startswith(prefix) for prefix in ('1.5', '2.1', 'xl')): + raise ValueError(f"LoRA adapter support is only supported for SD1.5, SD2.1 and SDXL pipelines") if args.lora_scale: for lora_scale in (lora_scale for lora_scale in args.lora_scale if not 0 <= lora_scale <= 1): @@ -687,6 +717,7 @@ def process_pipeline_args(args): 'opt_batch_size': args.batch_size, 'opt_image_height': args.height, 'opt_image_width': args.width, + 'optimization_level': args.optimization_level, 'static_batch': args.build_static_batch, 'static_shape': not args.build_dynamic_shape, 'enable_all_tactics': args.build_all_tactics, diff --git a/demo/Diffusion/utils_modelopt.py b/demo/Diffusion/utils_modelopt.py index fbaaed97..9e8755c5 100644 --- a/demo/Diffusion/utils_modelopt.py +++ b/demo/Diffusion/utils_modelopt.py @@ -133,6 +133,7 @@ def quantize_lvl(unet, quant_level=2.5, linear_only=False): module.input_quantizer.disable() module.weight_quantizer.disable() elif isinstance(module, Attention): + # TRT only supports FP8 MHA with head_size % 16 == 0. head_size = int(module.inner_dim / module.heads) if quant_level >= 4 and head_size % 16 == 0: module.q_bmm_quantizer.enable() diff --git a/docker/rockylinux8.Dockerfile b/docker/rockylinux8.Dockerfile index 24cd4bce..ffe26034 100644 --- a/docker/rockylinux8.Dockerfile +++ b/docker/rockylinux8.Dockerfile @@ -25,7 +25,7 @@ ENV NV_CUDNN_VERSION 8.9.6.50-1 ENV NV_CUDNN_PACKAGE libcudnn8-${NV_CUDNN_VERSION}.cuda12.2 ENV NV_CUDNN_PACKAGE_DEV libcudnn8-devel-${NV_CUDNN_VERSION}.cuda12.2 -ENV TRT_VERSION 10.4.0.26 +ENV TRT_VERSION 10.5.0.18 SHELL ["/bin/bash", "-c"] RUN dnf install -y \ @@ -62,15 +62,15 @@ RUN dnf install -y python38 python38-devel &&\ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/rockylinux9.Dockerfile b/docker/rockylinux9.Dockerfile index 95a87cce..bed779a7 100644 --- a/docker/rockylinux9.Dockerfile +++ b/docker/rockylinux9.Dockerfile @@ -25,7 +25,7 @@ ENV NV_CUDNN_VERSION 8.9.6.50-1 ENV NV_CUDNN_PACKAGE libcudnn8-${NV_CUDNN_VERSION}.cuda12.2 ENV NV_CUDNN_PACKAGE_DEV libcudnn8-devel-${NV_CUDNN_VERSION}.cuda12.2 -ENV TRT_VERSION 10.4.0.26 +ENV TRT_VERSION 10.5.0.18 SHELL ["/bin/bash", "-c"] RUN dnf install -y \ @@ -67,15 +67,15 @@ RUN dnf -y install \ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp39-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp39-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp39-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp39-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/ubuntu-20.04.Dockerfile b/docker/ubuntu-20.04.Dockerfile index 88e504f4..a2ebb605 100644 --- a/docker/ubuntu-20.04.Dockerfile +++ b/docker/ubuntu-20.04.Dockerfile @@ -28,7 +28,7 @@ ENV CUDA_VERSION_MAJOR_MINOR=12.2 ENV NV_CUDNN_PACKAGE "libcudnn8=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" ENV NV_CUDNN_PACKAGE_DEV "libcudnn8-dev=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" -ENV TRT_VERSION 10.4.0.26 +ENV TRT_VERSION 10.5.0.18 SHELL ["/bin/bash", "-c"] RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -84,15 +84,15 @@ RUN apt-get install -y --no-install-recommends \ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/ubuntu-22.04-aarch64.Dockerfile b/docker/ubuntu-22.04-aarch64.Dockerfile index 47836ddf..e28f058c 100644 --- a/docker/ubuntu-22.04-aarch64.Dockerfile +++ b/docker/ubuntu-22.04-aarch64.Dockerfile @@ -20,7 +20,7 @@ ARG CUDA_VERSION=12.6.0 # Multi-arch container support available in non-cudnn containers. FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 -ENV TRT_VERSION 10.4.0.26 +ENV TRT_VERSION 10.5.0.18 SHELL ["/bin/bash", "-c"] # Setup user account diff --git a/docker/ubuntu-22.04.Dockerfile b/docker/ubuntu-22.04.Dockerfile index c218693d..0bb09b2d 100644 --- a/docker/ubuntu-22.04.Dockerfile +++ b/docker/ubuntu-22.04.Dockerfile @@ -28,7 +28,7 @@ ENV CUDA_VERSION_MAJOR_MINOR=12.2 ENV NV_CUDNN_PACKAGE "libcudnn8=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" ENV NV_CUDNN_PACKAGE_DEV "libcudnn8-dev=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" -ENV TRT_VERSION 10.4.0.26 +ENV TRT_VERSION 10.5.0.18 SHELL ["/bin/bash", "-c"] RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -84,15 +84,15 @@ RUN apt-get install -y --no-install-recommends \ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp310-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp310-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp310-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp310-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/ubuntu-cross-aarch64.Dockerfile b/docker/ubuntu-cross-aarch64.Dockerfile index 7a105f69..3243b4a7 100644 --- a/docker/ubuntu-cross-aarch64.Dockerfile +++ b/docker/ubuntu-cross-aarch64.Dockerfile @@ -21,7 +21,7 @@ ARG OS_VERSION=22.04 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${OS_VERSION} LABEL maintainer="NVIDIA CORPORATION" -ENV TRT_VERSION 10.4.0.26 +ENV TRT_VERSION 10.5.0.18 ENV DEBIAN_FRONTEND=noninteractive ARG uid=1000 diff --git a/include/NvInfer.h b/include/NvInfer.h index 18e10be4..e0231d4d 100644 --- a/include/NvInfer.h +++ b/include/NvInfer.h @@ -1011,8 +1011,8 @@ struct EnumMaxImpl //! //! \brief A convolution layer in a network definition. //! -//! This layer performs a correlation operation between 3-dimensional filter with a 4-dimensional tensor to produce -//! another 4-dimensional tensor. +//! This layer performs a correlation operation between 3 or 4 dimensional filter with a 4 or 5 dimensional tensor to +//! produce another 4 or 5 dimensional tensor. //! //! An optional bias argument is supported, which adds a per-channel constant to each value in the output. //! @@ -9675,7 +9675,9 @@ class IBuilder : public INoCopy //! //! \brief Determine whether the platform has fast native fp16. //! - bool platformHasFastFp16() const noexcept + //! \deprecated Deprecated in TensorRT 10.5. Please query data type support from CUDA directly. + //! + TRT_DEPRECATED bool platformHasFastFp16() const noexcept { return mImpl->platformHasFastFp16(); } @@ -9683,7 +9685,9 @@ class IBuilder : public INoCopy //! //! \brief Determine whether the platform has fast native int8. //! - bool platformHasFastInt8() const noexcept + //! \deprecated Deprecated in TensorRT 10.5. Please query data type support from CUDA directly. + //! + TRT_DEPRECATED bool platformHasFastInt8() const noexcept { return mImpl->platformHasFastInt8(); } @@ -9817,7 +9821,9 @@ class IBuilder : public INoCopy //! //! \brief Determine whether the platform has TF32 support. //! - bool platformHasTf32() const noexcept + //! \deprecated Deprecated in TensorRT 10.5. Please query data type support from CUDA directly. + //! + TRT_DEPRECATED bool platformHasTf32() const noexcept { return mImpl->platformHasTf32(); } diff --git a/include/NvInferConsistency.h b/include/NvInferConsistency.h deleted file mode 100644 index 0d1b8b40..00000000 --- a/include/NvInferConsistency.h +++ /dev/null @@ -1,157 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef NV_INFER_CONSISTENCY_H -#define NV_INFER_CONSISTENCY_H - -#include "NvInferConsistencyImpl.h" -#define NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE 1 -#include "NvInferRuntimeBase.h" -#undef NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE -#include "NvInferRuntimePlugin.h" - -//! -//! \file NvInferConsistency.h -//! - -namespace nvinfer1 -{ - -namespace consistency -{ - -//! -//! \class IConsistencyChecker -//! -//! \brief Validates a serialized engine blob. -//! -//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. -//! -//! \deprecated Deprecated in TensorRT 10.4. -//! -class TRT_DEPRECATED IConsistencyChecker -{ -public: - //! - //! \brief Check that a blob that was input to createConsistencyChecker method represents a valid engine. - //! - //! \return true if the original blob encoded an engine that belongs to valid engine domain with - //! target capability EngineCapability::kSAFETY, false otherwise. - //! - bool validate() const noexcept - { - return mImpl->validate(); - } - - //! - //! \brief De-allocates any internally allocated memory. - //! - virtual ~IConsistencyChecker() = default; - -protected: - apiv::VConsistencyChecker* mImpl; - IConsistencyChecker() = default; - IConsistencyChecker(IConsistencyChecker const& other) = delete; - IConsistencyChecker& operator=(IConsistencyChecker const& other) = delete; - IConsistencyChecker(IConsistencyChecker&& other) = delete; - IConsistencyChecker& operator=(IConsistencyChecker&& other) = delete; -}; - -//! -//! \class IPluginChecker -//! -//! \brief Consistency Checker plugin class for user implemented Plugins. -//! -//! Plugins are a mechanism for applications to implement custom layers. It provides a -//! mechanism to register Consistency plugins and -//! look up the Plugin Registry during validate. -//! -//! Supported IPlugin inferfaces are limited to IPluginV2IOExt only. -//! -//! \deprecated Deprecated in TensorRT 10.4. -//! -class TRT_DEPRECATED IPluginChecker : public IPluginCreator -{ -public: - //! - //! \brief Called during IConsistencyChecker::validate. Allows users to provide - //! custom validation of serialized Plugin data. Returns boolean that indicates - //! whether or not the Plugin passed validation. - //! - //! \param name The plugin name - //! \param serialData The memory that holds the plugin serialized data. - //! \param serialLength The size of the plugin serialized data. - //! \param in The input tensors attributes. - //! \param nbInputs The number of input tensors. - //! \param out The output tensors attributes. - //! \param nbOutputs The number of output tensors. - //! \param workspaceSize The size of workspace provided during enqueue. - //! - virtual bool validate(char const* name, void const* serialData, size_t serialLength, PluginTensorDesc const* in, - size_t nbInputs, PluginTensorDesc const* out, size_t nbOutputs, int64_t workspaceSize) const noexcept = 0; - - IPluginChecker() = default; - virtual ~IPluginChecker() override = default; - -protected: - IPluginChecker(IPluginChecker const&) = default; - IPluginChecker(IPluginChecker&&) = default; - IPluginChecker& operator=(IPluginChecker const&) & = default; - IPluginChecker& operator=(IPluginChecker&&) & = default; -}; - -} // namespace consistency - -} // namespace nvinfer1 - -//! -//! \deprecated Deprecated in TensorRT 10.4. -//! -extern "C" TENSORRTAPI void* createConsistencyChecker_INTERNAL(void* logger, void const* blob, size_t size, - int32_t version); //!< Internal C entry point for creating IConsistencyChecker. - -namespace nvinfer1 -{ - -namespace consistency -{ -//! -//! \brief Create an instance of an IConsistencyChecker class. -//! -//! ILogger is the logging class for the consistency checker. -//! -//! anonymous namespace avoids linkage surprises when linking objects built with different versions of this header. -//! -namespace // anonymous -{ - -//! -//! \deprecated Deprecated in TensorRT 10.4. -//! -TRT_DEPRECATED inline IConsistencyChecker* createConsistencyChecker(ILogger& logger, void const* blob, size_t size) -{ - return static_cast( - createConsistencyChecker_INTERNAL(&logger, blob, size, NV_TENSORRT_VERSION)); -} - -} // namespace - -} // namespace consistency - -} // namespace nvinfer1 - -#endif // NV_INFER_CONSISTENCY_H diff --git a/include/NvInferConsistencyImpl.h b/include/NvInferConsistencyImpl.h deleted file mode 100644 index 0b4e8dd3..00000000 --- a/include/NvInferConsistencyImpl.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef NV_INFER_CONSISTENCY_IMPL_H -#define NV_INFER_CONSISTENCY_IMPL_H - -namespace nvinfer1 -{ - -//! -//! \file NvInferConsistencyImpl.h -//! -//! This file contains definitions for API methods that cross the shared library boundary. These -//! methods must not be called directly by applications; they should only be called through the -//! API classes. -//! - -namespace apiv -{ - -//! -//! \deprecated Deprecated in TensorRT 10.4. -//! -class TRT_DEPRECATED VConsistencyChecker -{ -public: - virtual ~VConsistencyChecker() noexcept = default; - virtual bool validate() const noexcept = 0; -}; - -} // namespace apiv -} // namespace nvinfer1 - -#endif // NV_INFER_CONSISTENCY_IMPL_H diff --git a/include/NvInferRuntimeBase.h b/include/NvInferRuntimeBase.h index 984ce869..b6652c07 100644 --- a/include/NvInferRuntimeBase.h +++ b/include/NvInferRuntimeBase.h @@ -258,7 +258,7 @@ enum class TensorFormat : int32_t //! Vector-major format with two scalars per vector. //! Vector dimension is third to last. //! - //! This format requires FP16 or BF16 and at least three dimensions. + //! This format requires FP16 and at least three dimensions. kCHW2 = 1, //! Vector-minor format with eight scalars per vector. @@ -284,7 +284,7 @@ enum class TensorFormat : int32_t //! Vector-major format with 16 scalars per vector. //! Vector dimension is third to last. //! - //! This format requires INT8 or FP16 and at least three dimensions. + //! This format requires FP16 and at least three dimensions. //! //! For DLA usage, this format maps to the native feature format for FP16, //! and the tensor sizes are limited to C,H,W in the range [1,8192]. @@ -313,7 +313,7 @@ enum class TensorFormat : int32_t //! Vector-minor format where channel dimension is third to last and unpadded. //! - //! This format requires either FP32 or UINT8 and at least three dimensions. + //! This format requires either FP32, FP16, UINT8, INT64 or BF16 and at least three dimensions. kHWC = 8, //! DLA planar format. For a tensor with dimension {N, C, H, W}, the W axis @@ -344,7 +344,7 @@ enum class TensorFormat : int32_t //! Vector-minor format with 16 scalars per vector. //! Vector dimension is third to last. //! - //! This requires FP16 and at least three dimensions. + //! This requires FP16 or INT8 and at least three dimensions. kHWC16 = 11, //! Vector-minor format with one scalar per vector. diff --git a/include/NvInferSafeRuntime.h b/include/NvInferSafeRuntime.h deleted file mode 100644 index 6dc503e0..00000000 --- a/include/NvInferSafeRuntime.h +++ /dev/null @@ -1,1021 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef NV_INFER_SAFE_RUNTIME_H -#define NV_INFER_SAFE_RUNTIME_H - -#define NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE 1 -#include "NvInferRuntimeBase.h" -#undef NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE -#include "NvInferRuntimePlugin.h" -#include -#include - -//! -//! \file NvInferSafeRuntime.h -//! -//! The functionality in this file is only supported in NVIDIA Drive(R) products. - -//! -//! \namespace nvinfer1 -//! -//! \brief The TensorRT API version 1 namespace. -//! -namespace nvinfer1 -{ -//! -//! \namespace nvinfer1::safe -//! -//! \brief The safety subset of TensorRT's API version 1 namespace. -//! -namespace safe -{ -//! Forward declare safe::ICudaEngine for use in other interfaces. -class ICudaEngine; -//! Forward declare safe::IExecutionContext for use in other interfaces. -class IExecutionContext; - -//! -//! \class IRuntime -//! -//! \brief Allows a serialized functionally safe engine to be deserialized. -//! -//! \warning In the safety runtime the application is required to set the error reporter for correct error handling. -//! \see setErrorRecorder() -//! -//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. -//! -class IRuntime -{ -public: - //! - //! \brief Deserialize an engine from a byte array. - //! - //! If the serialized engine requires plugins the plugin creator must be registered by calling - //! IPluginRegistry::registerCreator() before calling deserializeCudaEngine(). - //! - //! \param blob The memory that holds the serialized engine. The content must be a copy of - //! the result of calling IHostMemory::data() on a serialized plan that was created via calling - //! IBuilder::buildSerializedNetwork() on a network within the supported safety scope. - //! Additionally, it must have been validated via IConsistencyChecker::validate(). - //! - //! \param size The size of the memory in bytes. This must be the result of calling IHostMemory::size() - //! on the same IHostMemory object that is associated with the blob parameter. - //! - //! \return The engine, or nullptr if it could not be deserialized. - //! - //! \see IPluginRegistry::registerCreator() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes, if called from different instances of safe::IRuntime. Calling deserializeCudaEngine - //! of the same safety runtime from multiple threads is not guaranteed to be thread safe. - //! - virtual ICudaEngine* deserializeCudaEngine(void const* const blob, std::size_t const size) noexcept = 0; - - //! - //! \brief Set the GPU allocator. - //! - //! \param allocator The GPU allocator to be used by the runtime. All GPU memory acquired will use this - //! allocator. If nullptr is passed, the default allocator will be used, which calls cudaMalloc and cudaFree. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setGpuAllocator(IGpuAllocator* const allocator) noexcept = 0; - - //! - //! \brief Set the ErrorRecorder for this interface. - //! - //! Assigns the ErrorRecorder to this interface. The ErrorRecorder will track all errors during execution. - //! This function will call incRefCount of the registered ErrorRecorder at least once. If the recorder is set to - //! nullptr, an error code of ErrorCode::kINVALID_ARGUMENT will be emitted if the recorder has already been - //! registered, or ILogger::Severity::kERROR will be logged if the recorder has not yet been registered. - //! - //! \param recorder The error recorder to register with this interface, or nullptr to deregister the current - //! error recorder. - //! - //! \see getErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setErrorRecorder(IErrorRecorder* const recorder) noexcept = 0; - - //! - //! \brief Get the ErrorRecorder assigned to this interface. - //! - //! Retrieves the assigned error recorder object for the given class. A default error recorder does not exist, - //! so a nullptr will be returned if setErrorRecorder has not been called or a previously assigned error recorder - //! has been deregistered. - //! - //! \return A pointer to the IErrorRecorder object that has been registered, or nullptr if no error recorder is set. - //! - //! \see setErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; - - IRuntime() = default; - virtual ~IRuntime() noexcept = default; - IRuntime(IRuntime const&) = delete; - IRuntime(IRuntime&&) = delete; - IRuntime& operator=(IRuntime const&) & = delete; - IRuntime& operator=(IRuntime&&) & = delete; -}; - -//! -//! \class ICudaEngine -//! -//! \brief A functionally safe engine for executing inference on a built network. -//! -//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. -//! -class ICudaEngine -{ -public: - //! - //! \brief Create an execution context. - //! - //! \see safe::IExecutionContext. - //! - //! \return An execution context object if it can be constructed, or nullptr if the construction fails. - //! - //! \details Reasons for failure may include but not be limited to: - //! - Heap memory exhaustion - //! - Device memory exhaustion - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes; if createExecutionContext fails, users must treat this as a critical - //! error and not perform any subsequent TensorRT operations apart from outputting - //! the error logs. - //! - virtual IExecutionContext* createExecutionContext() noexcept = 0; - - //! - //! \brief Create an execution context without any device memory allocated. - //! - //! The memory for execution of this device context must be supplied by the application by calling - //! safe::IExecutionContext::setDeviceMemory(). - //! - //! \see getDeviceMemorySize() safe::IExecutionContext::setDeviceMemory() - //! - //! \return An execution context object if it can be constructed, or nullptr if the construction fails. - //! - //! \details Reasons for failure may include but not be limited to heap memory exhaustion. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes; if createExecutionContext fails, users must treat this as a critical - //! error and not perform any subsequent TensorRT operations apart from outputting - //! the error logs. - //! - virtual IExecutionContext* createExecutionContextWithoutDeviceMemory() noexcept = 0; - - //! - //! \brief Return the amount of device memory required by an execution context. - //! - //! \see safe::IExecutionContext::setDeviceMemory() - //! - //! \return Size of a contiguous memory buffer (in bytes) that users need to provide to - //! safe::IExecutionContext::setDeviceMemory() if the execution context has been created by calling - //! createExecutionContextWithoutDeviceMemory(). - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual size_t getDeviceMemorySize() const noexcept = 0; - - //! - //! \brief Returns the name of the network associated with the engine. - //! - //! The name is set during network creation and is retrieved after - //! building or deserialization. - //! - //! \see INetworkDefinition::setName(), INetworkDefinition::getName() - //! - //! \return A NULL-terminated C-style string representing the name of the network, which will have a length of - //! 1024 bytes or less including the NULL terminator. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual AsciiChar const* getName() const noexcept = 0; - - //! - //! \brief Set the ErrorRecorder for this interface. - //! - //! Assigns the ErrorRecorder to this interface. The ErrorRecorder will track all errors during execution. - //! This function will call incRefCount of the registered ErrorRecorder at least once. If the recorder is set to - //! nullptr, the error code ErrorCode::kINVALID_ARGUMENT will be emitted if the recorder has been registered. - //! - //! \param recorder The error recorder to register with this interface, or nullptr to deregister the current. - //! error recorder. - //! - //! \see getErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setErrorRecorder(IErrorRecorder* const recorder) noexcept = 0; - - //! - //! \brief Get the ErrorRecorder assigned to this interface. - //! - //! Retrieves the assigned error recorder object for the given class. A - //! nullptr will be returned if an error reporter has not been inherited - //! from the IRuntime, and setErrorReporter() has not been called. - //! - //! \return A pointer to the IErrorRecorder object that has been registered, or nullptr if none has been - //! registered. - //! - //! \see setErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; - - ICudaEngine() = default; - virtual ~ICudaEngine() noexcept = default; - ICudaEngine(ICudaEngine const&) = delete; - ICudaEngine(ICudaEngine&&) = delete; - ICudaEngine& operator=(ICudaEngine const&) & = delete; - ICudaEngine& operator=(ICudaEngine&&) & = delete; - - //! - //! \brief Get the extent of an input or output tensor. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less including the - //! NULL terminator. - //! - //! \return Extent of the tensor. The invalid value Dims{-1, {}} will be returned if - //! - name is not the name of an input or output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual Dims getTensorShape(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Determine the required data type for a buffer from its tensor name. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less including the - //! NULL terminator. - //! - //! \return The type of the data in the buffer. The default value DataType::kFLOAT will be returned if - //! - name is not the name of an input or output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual DataType getTensorDataType(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Determine whether a tensor is an input or output tensor. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less including the - //! NULL terminator. - //! - //! \return kINPUT if tensorName is the name of an input tensor, kOUTPUT if tensorName is the name of an output - //! tensor. The invalid value kNONE is returned if - //! - tensorName exceeds the string length limit, or - //! - tensorName is nullptr, or - //! - tensorName does not correspond to any input or output tensor. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual TensorIOMode getTensorIOMode(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Return the size of the tensor data type in bytes for a vectorized tensor. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less including the - //! NULL terminator. - //! - //! \return The size of the tensor data type in bytes if the tensor is vectorized (4 for float and int32, - //! 2 for half, 1 for int8). 0 will be returned if - //! - name is not the name of an input or output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit, or - //! - the tensor of the given name is not vectorized. - //! - //! \see safe::ICudaEngine::getTensorVectorizedDim() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual std::int32_t getTensorBytesPerComponent(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Return the number of components included in one element for a vectorized tensor. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less including the - //! NULL terminator. - //! - //! \return The vector length (in scalars) for a vectorized tensor, or 1 for a scalar tensor. - //! The invalid value -1 will be returned if - //! - name is not the name of an input or output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit. - //! - //! \see safe::ICudaEngine::getTensorVectorizedDim() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual std::int32_t getTensorComponentsPerElement(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Return the tensor format. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less including the - //! NULL terminator. - //! - //! \return The tensor format. TensorFormat::kLINEAR will be returned if - //! - name is not the name of an input or output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual TensorFormat getTensorFormat(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Return the dimension index along which the buffer is vectorized. - //! - //! Specifically, -1 is returned if the tensor is scalar. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less including the - //! NULL terminator. - //! - //! \return The dimension index along which the buffer is vectorized. -1 will be returned if - //! - name is not the name of an input or output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit (1024 bytes or less including the NULL terminator), or - //! - the tensor of given name is not vectorized. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual std::int32_t getTensorVectorizedDim(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Return the number of input and output tensors for the network from which the engine was built. - //! - //! \return The number of IO tensors. - //! - //! \see getIOTensorName() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual std::int32_t getNbIOTensors() const noexcept = 0; - - //! - //! \brief Return the name of an IO tensor. - //! - //! If the index does not fall between 0 and getNbIOTensors()-1, the function will fail with an error code - //! of ErrorCode::kINVALID_ARGUMENT(3) that is emitted to the registered IErrorRecorder. - //! - //! \param index The IO tensor index. - //! - //! \return The name of an IO tensor, which will be a NULL-terminated string of 1024 bytes or less (including the - //! NULL terminator) if the index is in the range (between 0 and getNbIOTensors()-1). nullptr will be returned if - //! the index is not in range. - //! - //! \see getNbIOTensors() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual AsciiChar const* getIOTensorName(std::int32_t const index) const noexcept = 0; -}; - -//! -//! \brief Space to record information about runtime errors. -//! -//! kNAN_CONSUMED errors occur when NAN values are stored in an INT8 quantized datatype. -//! kINF_CONSUMED errors occur when +-INF values are stored in an INT8 quantized datatype. -//! kGATHER_OOB errors occur when a gather index tensor contains a value that is outside of the data tensor. -//! kSCATTER_OOB and kSCATTER_RACE are reserved for future use. -//! -//! Mark the RuntimeErrorType that occurs during asynchronous kernel execution. -struct RuntimeErrorInformation -{ - //! Each bit represents a RuntimeErrorType that has occurred during kernel execution. - uint64_t bitMask; -}; - -//! -//! \brief Enum to represent runtime error types. -//! -enum class RuntimeErrorType : uint64_t -{ - //! NaN floating-point value was silently consumed - kNAN_CONSUMED = 1ULL << 0, - //! Inf floating-point value was silently consumed - kINF_CONSUMED = 1ULL << 1, - //! Out-of-bounds access in gather operation - kGATHER_OOB = 1ULL << 2, - //! Out-of-bounds access in scatter operation - kSCATTER_OOB = 1ULL << 3, - //! Race condition in scatter operation - kSCATTER_RACE = 1ULL << 4, -}; - -//! -//! \class IExecutionContext -//! -//! \brief Functionally safe context for executing inference using an engine. -//! -//! Multiple safe execution contexts may exist for one safe::ICudaEngine instance, allowing the same -//! engine to be used for the execution of multiple inputs simultaneously. -//! -//! \warning Do not call the APIs of the same IExecutionContext from multiple threads at any given time. -//! Each concurrent execution must have its own instance of an IExecutionContext. -//! -//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. -//! -class IExecutionContext -{ -public: - //! - //! \brief Get the associated engine. - //! - //! \see safe::ICudaEngine - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual ICudaEngine const& getEngine() const noexcept = 0; - - //! - //! \brief Set the name of the execution context. - //! - //! This method copies the name string. - //! - //! \warning Strings passed to the runtime must be NULL terminated and have a length of 1024 bytes or less - //! including the NULL terminator. Otherwise, the operation will not change the execution context name, and - //! an error message will be recorded via the error recorder. - //! - //! \see getName() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setName(AsciiChar const* const name) noexcept = 0; - - //! - //! \brief Return the name of the execution context. - //! - //! \return The name that was passed to setName(), as a NULL-terminated string of 1024 bytes or less including - //! the NULL terminator. An empty string will be returned as the default value. - //! - //! \see setName() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual AsciiChar const* getName() const noexcept = 0; - - //! - //! \brief Set the device memory for use by this execution context. - //! - //! \param memory The start address of a device memory buffer whose size in bytes must be at least the value - //! returned by getEngine().getDeviceMemorySize(). - //! - //! If using enqueueV2() to run the network, The memory is in use - //! from the invocation of enqueueV2() until network execution is complete. - //! Releasing or otherwise using the memory for other purposes during this time will result in undefined behavior. - //! - //! \warning Do not release or use for other purposes the memory set here during network execution. - //! - //! \warning If the execution context has been created by calling createExecutionContext(), this - //! function must not be used and will fail with an error message if called. - //! - //! \see safe::ICudaEngine::getDeviceMemorySize() safe::ICudaEngine::createExecutionContextWithoutDeviceMemory() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setDeviceMemory(void* const memory) noexcept = 0; - - //! - //! \brief Set the ErrorRecorder for this interface. - //! - //! Assigns the ErrorRecorder to this interface. The ErrorRecorder will track all errors during execution. - //! This function will call incRefCount of the registered ErrorRecorder at least once. If the recorder is set to - //! nullptr, the error code ErrorCode::kINVALID_ARGUMENT will be emitted if the recorder has been registered. The - //! lifetime of the error recorder object must exceed the lifetime of the execution context. - //! - //! \param recorder Either a pointer to a valid error recorder object to register with this interface, - //! or nullptr to deregister the current recorder. - //! - //! \see getErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setErrorRecorder(IErrorRecorder* const recorder) noexcept = 0; - - //! - //! \brief get the ErrorRecorder assigned to this interface. - //! - //! Retrieves the assigned error recorder object for the given class. A default error recorder does not exist, - //! so a nullptr will be returned if setErrorRecorder has not been called. - //! - //! \return A pointer to the IErrorRecorder object that has been registered, or nullptr if the error recorder - //! has been deregistered or not set. - //! - //! \see setErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; - - IExecutionContext() = default; - virtual ~IExecutionContext() noexcept = default; - IExecutionContext(IExecutionContext const&) = delete; - IExecutionContext(IExecutionContext&&) = delete; - IExecutionContext& operator=(IExecutionContext const&) & = delete; - IExecutionContext& operator=(IExecutionContext&&) & = delete; - - //! - //! \brief Set error buffer output for runtime errors. - //! - //! The error buffer output must be allocated in device memory and will be used for subsequent - //! calls to enqueueV2() or enqueueV3(). Checking the contents of the error buffer after inference is the - //! responsibility of the application. The pointer passed here must have alignment adequate for the - //! RuntimeErrorInformation struct. - //! - //! \warning The buffer is written if reportable errors are encountered during network execution. Releasing the - //! buffer before network execution is complete will result in undefined behavior. Accessing the memory before - //! network execution is complete may not correctly capture the error state. - //! - //! \param buffer The device memory address of the runtime error information buffer. - //! - //! \see getErrorBuffer() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setErrorBuffer(RuntimeErrorInformation* const buffer) noexcept = 0; - - //! - //! \brief Get error buffer output for runtime errors. - //! - //! \return Pointer to device memory to use as runtime error buffer or nullptr if not set. - //! - //! \see setErrorBuffer() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual RuntimeErrorInformation* getErrorBuffer() const noexcept = 0; - - //! - //! \brief Return the strides of the buffer for the given tensor name. - //! - //! The strides are in units of elements, not components or bytes. - //! Elements are vectors (for a vectorized format) or scalars (for a scalar format). - //! For example, for TensorFormat::kHWC8, a stride of one spans 8 scalars. - //! - //! \param tensorName The name of an input or output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less - //! including the NULL terminator. - //! - //! \return The strides of the buffer for the given tensor name. Dims{-1, {}} will be returned if - //! - name is not the name of an input or output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual Dims getTensorStrides(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Set memory address for given input tensor. - //! - //! An address defaults to nullptr. - //! - //! Before calling enqueueV3(), each input must have a non-null address. - //! - //! \param tensorName The name of an input tensor. - //! \param data The pointer (void const*) to the input tensor data, which is device memory owned by the user. - //! Users are responsible for ensuring that the buffer size has at least the expected length, which is - //! the product of the tensor dimensions (with the vectorized dimension padded to a multiple of the vector length) - //! times the data type size. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less - //! including the NULL terminator. - //! - //! \warning The data pointer must have 256-byte alignment. - //! - //! \return True on success, false if - //! - name is not the name of an input tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit, or - //! - pointer to the const data is nullptr or not correctly aligned. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual bool setInputTensorAddress(AsciiChar const* const tensorName, void const* const data) noexcept = 0; - - //! - //! \brief Set memory address for given output tensor. - //! - //! An address defaults to nullptr. - //! - //! Before calling enqueueV3(), each output must have a non-null address. - //! - //! \param tensorName The name of an output tensor. - //! \param data The pointer (void*) to the output tensor data, which is device memory owned by the user. - //! Users are responsible for ensuring that the buffer size has at least the expected length, which is - //! the product of the tensor dimensions (with the vectorized dimension padded to a multiple of the vector length) - //! times the data type size. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less - //! including the NULL terminator. - //! \warning The data pointer must have 256-byte alignment. - //! - //! \return True on success. Return false if - //! - name is not the name of an output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit, or - //! - pointer to data is nullptr or not aligned. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual bool setOutputTensorAddress(AsciiChar const* const tensorName, void* const data) noexcept = 0; - - //! - //! \brief Set the event to mark inputs as consumed. - //! - //! Passing event==nullptr removes whatever event was set, if any. - //! - //! \param event The CUDA event that is signaled after all input tensors have been consumed, or nullptr to remove - //! an event that was previously set. - //! - //! \return True on success, false if an error occurred. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual bool setInputConsumedEvent(cudaEvent_t const event) noexcept = 0; - - //! - //! \brief Return the event associated with consuming the input. - //! - //! \return The CUDA event that was passed to setInputConsumedEvent(). nullptr will be returned if the event is - //! not set. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual cudaEvent_t getInputConsumedEvent() const noexcept = 0; - - //! - //! \brief Get memory address for given input tensor. - //! - //! \param tensorName The name of an input tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less - //! including the NULL terminator. - //! - //! \return The device memory address for the given input tensor. nullptr will be returned if - //! - name is not the name of an input tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit, or - //! - the memory address for the given input tensor is not set. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual void const* getInputTensorAddress(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Get memory address for given output tensor. - //! - //! \param tensorName The name of an output tensor. - //! - //! \warning The string tensorName must be NULL terminated and have a length of 1024 bytes or less - //! including the NULL terminator. - //! - //! \return The device memory address for the given output tensor. Return nullptr if - //! - name is not the name of an output tensor, or - //! - name is nullptr, or - //! - name exceeds the string length limit, or - //! - the memory address for the given output tensor is not set. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual void* getOutputTensorAddress(AsciiChar const* const tensorName) const noexcept = 0; - - //! - //! \brief Enqueue inference on a stream. - //! - //! Modifying or releasing memory that has been registered for the tensors before stream - //! synchronization or the event passed to setInputConsumedEvent has been signaled results in undefined - //! behavior. - //! - //! \param stream A CUDA stream on which the inference kernels will be enqueued. - //! - //! \return True on success, false if any execution error occurred. - //! Errors may include but not be limited to: - //! - Internal errors during executing one engine layer - //! - CUDA errors - //! - Some input or output tensor addresses have not been set. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual bool enqueueV3(cudaStream_t const stream) noexcept = 0; -}; - -//! -//! \class IPluginRegistry -//! -//! \brief Single registration point for all plugins in an application. It is -//! used to find plugin implementations during engine deserialization. -//! Internally, the plugin registry is considered to be a singleton so all -//! plugins in an application are part of the same global registry. -//! Note that the plugin registry is only supported for plugins of type -//! IPluginV2 and must also have a corresponding IPluginCreator implementation. -//! -//! \see IPluginV2 and IPluginCreator -//! -//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. -//! -//! \warning IPluginRegistry::setErrorRecorder() must be called to register -//! an error recorder with the registry before using other methods in the registry. -//! - -class IPluginRegistry -{ -public: - //! - //! \brief Register a plugin creator. - //! - //! \param creator The plugin creator to be registered. - //! - //! \param pluginNamespace A NULL-terminated namespace string, which must be 1024 bytes or less including the NULL - //! terminator. It must be identical with the result of calling - //! IPluginCreator::getPluginNamespace() on the creator object. - //! - //! \return True if the registration succeeded, else false. - //! - //! \details Registration may fail for any of the following reasons: - //! - The pluginNamespace string is a nullptr. - //! - The pluginNamespace string exceeds the maximum length. - //! - The pluginNamespace string does not match the result of creator.getPluginNamespace(). - //! - There have already been 100 plugin creators registered (maximum number of plugins exceeded). - //! - Another plugin creator with the same combination of plugin name, version and namespace has already been - //! registered. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes; calls to this method will be synchronized by a mutex. - //! - virtual bool registerCreator(IPluginCreator& creator, AsciiChar const* const pluginNamespace) noexcept = 0; - - //! - //! \brief Return all the registered plugin creators and the number of - //! registered plugin creators. Returns nullptr if none is found. - //! - //! \param[out] numCreators If the call completes successfully, the number of registered plugin creators (which - //! will be an integer between 0 and 100 inclusive) - //! \return The start address of an IPluginCreator* array of length numCreators if at least one plugin creator - //! has been registered, or nullptr if there are no registered plugin creators. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual IPluginCreator* const* getPluginCreatorList(int32_t* const numCreators) const noexcept = 0; - - //! - //! \brief Return plugin creator based on plugin name, version, and - //! namespace associated with plugin during network creation. - //! - //! \warning The strings pluginName, pluginVersion, and pluginNamespace must be NULL terminated and have a length - //! of 1024 bytes or less including the NULL terminator. - //! - //! \param pluginName The plugin name string - //! \param pluginVersion The plugin version string - //! \param pluginNamespace The plugin namespace (by default empty string) - //! - //! \return If a plugin creator corresponding to the passed name, version and namespace can be found in the - //! registry, it is returned. nullptr is returned in the following situations: - //! - Any of the input arguments is nullptr. - //! - Any of the input arguments exceeds the string length limit. - //! - No plugin creator corresponding to the input arguments can be found in the registry. - //! - A plugin creator can be found, but its stored namespace attribute does not match the pluginNamespace. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual IPluginCreator* getPluginCreator(AsciiChar const* const pluginName, AsciiChar const* const pluginVersion, - AsciiChar const* const pluginNamespace = "") noexcept = 0; - - //! - //! \brief Set the ErrorRecorder for this interface - //! - //! Assigns the ErrorRecorder to this interface. The ErrorRecorder will track all errors during execution. - //! This function will call incRefCount of the registered ErrorRecorder at least once. If the recorder is set to - //! nullptr, the error code ErrorCode::kINVALID_ARGUMENT will be emitted if the recorder has been registered. - //! - //! \param recorder The error recorder to register with this interface, or nullptr to deregister the current - //! recorder. - //! - //! \see getErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No - //! - virtual void setErrorRecorder(IErrorRecorder* const recorder) noexcept = 0; - - //! - //! \brief Get the ErrorRecorder assigned to this interface. - //! - //! Retrieves the assigned error recorder object for the given class. A default error recorder does not exist, - //! so a nullptr will be returned if setErrorRecorder has not been called, or an ErrorRecorder has not been - //! inherited. - //! - //! \return A pointer to the IErrorRecorder object that has been registered, or nullptr if: - //! - no error recorder has been set, or - //! - the last error recorder has been deregistered via setErrorRecorder(nullptr). - //! - //! \see setErrorRecorder() - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; - - //! - //! \brief Deregister a previously registered plugin creator. - //! - //! Since there may be a desire to limit the number of plugins, - //! this function provides a mechanism for removing plugin creators registered in TensorRT. - //! The plugin creator that is specified by \p creator is removed from TensorRT and no longer tracked. - //! - //! \param creator The plugin creator to deregister. - //! - //! \return True if the plugin creator was deregistered, false if it was not found in the registry or otherwise - //! could - //! not be deregistered. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes - //! - virtual bool deregisterCreator(IPluginCreator const& creator) noexcept = 0; - - // @cond SuppressDoxyWarnings - IPluginRegistry() = default; - IPluginRegistry(IPluginRegistry const&) = delete; - IPluginRegistry(IPluginRegistry&&) = delete; - IPluginRegistry& operator=(IPluginRegistry const&) & = delete; - IPluginRegistry& operator=(IPluginRegistry&&) & = delete; - // @endcond - -protected: - virtual ~IPluginRegistry() noexcept = default; -}; - -//! -//! \brief Create an instance of a safe::IRuntime class. -//! -//! \param logger A logger object whose lifetime must exceed that of the returned runtime. -//! Loggers must be thread-safe. -//! -//! \return A safe runtime object that can be used for safe plan file deserialization. -//! -//! This class is the logging class for the runtime. -//! -//! \usage -//! - Allowed context for the API call -//! - Thread-safe: Yes -//! -IRuntime* createInferRuntime(ILogger& logger) noexcept; - -//! -//! \brief Return the safe plugin registry -//! -//! \usage -//! - Allowed context for the API call -//! - Thread-safe: Yes -//! -extern "C" TENSORRTAPI IPluginRegistry* getSafePluginRegistry() noexcept; - -//! -//! \brief Register the plugin creator to the registry -//! The static registry object will be instantiated when the plugin library is -//! loaded. This static object will register all creators available in the -//! library to the registry. -//! -//! \warning Statically registering plugins must be avoided in the automotive -//! safety context as the application developer must first register an error recorder -//! with the plugin registry via IPluginRegistry::setErrorRecorder() before using -//! IPluginRegistry::registerCreator() or other methods. -//! -template -class PluginRegistrar -{ -public: - PluginRegistrar() - { - getSafePluginRegistry()->registerCreator(instance, ""); - } - -private: - //! Plugin instance. - T instance{}; -}; - -} // namespace safe - -} // namespace nvinfer1 - -#define REGISTER_SAFE_TENSORRT_PLUGIN(name) \ - static nvinfer1::safe::PluginRegistrar pluginRegistrar##name {} -#endif // NV_INFER_SAFE_RUNTIME_H diff --git a/include/NvInferVersion.h b/include/NvInferVersion.h index 477c5719..084b4ee1 100644 --- a/include/NvInferVersion.h +++ b/include/NvInferVersion.h @@ -24,9 +24,9 @@ #define NV_INFER_VERSION_H #define NV_TENSORRT_MAJOR 10 //!< TensorRT major version. -#define NV_TENSORRT_MINOR 4 //!< TensorRT minor version. +#define NV_TENSORRT_MINOR 5 //!< TensorRT minor version. #define NV_TENSORRT_PATCH 0 //!< TensorRT patch version. -#define NV_TENSORRT_BUILD 26 //!< TensorRT build number. +#define NV_TENSORRT_BUILD 18 //!< TensorRT build number. #define NV_TENSORRT_LWS_MAJOR 0 //!< TensorRT LWS major version. #define NV_TENSORRT_LWS_MINOR 0 //!< TensorRT LWS minor version. diff --git a/parsers/CMakeLists.txt b/parsers/CMakeLists.txt index 6b4858ba..75023fe7 100644 --- a/parsers/CMakeLists.txt +++ b/parsers/CMakeLists.txt @@ -16,8 +16,7 @@ # ############################# GENERATE C++ PROTO FILES ################################### -add_custom_target(parsers DEPENDS - nvonnxparser) +add_custom_target(parsers DEPENDS nvonnxparser) add_definitions("-D_PROTOBUF_INSTALL_DIR=${Protobuf_INSTALL_DIR}") add_compile_options("-Dgoogle=google_private") diff --git a/parsers/onnx b/parsers/onnx index 3775e499..886aff91 160000 --- a/parsers/onnx +++ b/parsers/onnx @@ -1 +1 @@ -Subproject commit 3775e499322eee17c837e27bff6d07af4261767a +Subproject commit 886aff917b63f10a81c5f31e89752a3b46169623 diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index 32ec4ed5..55825fab 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -97,6 +97,7 @@ endforeach(PLUGIN_ITER) # Add common add_subdirectory(common) +add_subdirectory(vc) # Set gencodes set_source_files_properties(${PLUGIN_CU_SOURCES} PROPERTIES COMPILE_FLAGS "${GENCODES} ${ENABLED_SMS}") diff --git a/plugin/README.md b/plugin/README.md index 1a826743..8f024104 100644 --- a/plugin/README.md +++ b/plugin/README.md @@ -7,7 +7,8 @@ | [batchTilePlugin](batchTilePlugin) [DEPRECATED] | BatchTilePlugin_TRT | 1 | | [batchedNMSPlugin](batchedNMSPlugin) [DEPRECATED] | BatchedNMS_TRT | 1 | | [batchedNMSDynamicPlugin](batchedNMSPlugin) [DEPRECATED] | BatchedNMSDynamic_TRT | 1 | -| [bertQKVToContextPlugin](bertQKVToContextPlugin) | CustomQKVToContextPluginDynamic | 1, 2, 3 | +| [bertQKVToContextPlugin](bertQKVToContextPlugin) [DEPRECATED] | CustomQKVToContextPluginDynamic | 1, 2, 3 | +| [bertQKVToContextPlugin](bertQKVToContextPlugin) | CustomQKVToContextPluginDynamic | 4, 5, 6 | | [clipPlugin](clipPlugin) [DEPRECATED] | Clip_TRT | 1 | | [coordConvACPlugin](coordConvACPlugin) [DEPRECATED] | CoordConvAC | 1 | | [cropAndResizePlugin](cropAndResizePlugin) [DEPRECATED] | CropAndResize | 1 | diff --git a/plugin/bertQKVToContextPlugin/CustomQKVToContextPluginDynamic_PluginConfig.yaml b/plugin/bertQKVToContextPlugin/CustomQKVToContextPluginDynamic_PluginConfig.yaml index 781d8b10..e14df037 100644 --- a/plugin/bertQKVToContextPlugin/CustomQKVToContextPluginDynamic_PluginConfig.yaml +++ b/plugin/bertQKVToContextPlugin/CustomQKVToContextPluginDynamic_PluginConfig.yaml @@ -16,9 +16,9 @@ # --- name: CustomQKVToContextPluginDynamic -interface: "IPluginV2DynamicExt" +interface: "IPluginV3" versions: - "1": + "4": inputs: - input outputs: @@ -148,7 +148,7 @@ versions: shape: "1" output_types: output: float32 - "2": + "5": inputs: - input - input_mask @@ -205,7 +205,7 @@ versions: - 1 var_seqlen: - 0 - # - 1 Disabled since SM specific tests are not supported yet + # - 1 # Disabled since SM specific tests are not supported yet attributes_required: - type_id - hidden_size diff --git a/plugin/bertQKVToContextPlugin/README.md b/plugin/bertQKVToContextPlugin/README.md index 81f6f2f3..45f07a45 100644 --- a/plugin/bertQKVToContextPlugin/README.md +++ b/plugin/bertQKVToContextPlugin/README.md @@ -46,16 +46,18 @@ The parameters are defined below and consists of the following attributes: | Type | Parameter | Version | Description |----------|-----------------------------------------|-----------------------------------|------------------------------------------------------------------- -|`int` |`type_id` | 1, 2 |Integer encoding the DataType (0: FP32, 1: FP16, 2: INT8) -|`int` |`hidden_size` | 1, 2, 3 |The hidden size, denoted by `E` above. -|`int` |`num_heads` | 1, 2, 3 |The number of self-attention heads. -|`bool` |`has_mask` | 1, 2 |Whether to use the input_mask input. -|`float` |`dq_probs` | 1, 2, 3 |inner layer scale factor when run in int8 precision, default 1.f/127.f. -|`int` |`var_seqlen` | 2 |Whether to use variable sequence length (0: disable, 1: enable), default 0. -|`int` |`use_int8_scale_max` | 2, 3 |Whether to use INT8 scale factors to optimize softmax MAX reduction. Only active when `type_id==2`. (0: disable, 1: enable), default 1. -|`int` |`use_explicit_int8` | 3 |Whether to use explicit INT8, (0: disable, 1: enable), default 0. -|`float` |`input_qkv_scale` | 3 |The int8 scale for the input qkv tensor when explicit precision is used, default 1.f. -|`float` |`output_ctx_scale` | 3 |The int8 scale for the output context tensor when explicit precision is used, default 1.f. +|`int` |`type_id` | 1, 2, 4, 5 |Integer encoding the DataType (0: FP32, 1: FP16, 2: INT8) +|`int` |`hidden_size` | 1, 2, 3, 4, 5, 6 |The hidden size, denoted by `E` above. +|`int` |`num_heads` | 1, 2, 3, 4, 5, 6 |The number of self-attention heads. +|`bool` |`has_mask` | 1, 2, 4, 5 |Whether to use the input_mask input. +|`float` |`dq_probs` | 1, 2, 3, 4, 5, 6 |inner layer scale factor when run in int8 precision, default 1.f/127.f. +|`int` |`var_seqlen` | 2, 5 |Whether to use variable sequence length (0: disable, 1: enable), default 0. +|`int` |`use_int8_scale_max` | 2, 3, 5, 6 |Whether to use INT8 scale factors to optimize softmax MAX reduction. Only active when `type_id==2`. (0: disable, 1: enable), default 1. +|`int` |`use_explicit_int8` | 3, 6 |Whether to use explicit INT8, (0: disable, 1: enable), default 0. +|`float` |`input_qkv_scale` | 3, 6 |The int8 scale for the input qkv tensor when explicit precision is used, default 1.f. +|`float` |`output_ctx_scale` | 3, 6 |The int8 scale for the output context tensor when explicit precision is used, default 1.f. + +Note: version 1, 2, 3 are deprecated and will be removed in a future release; please use their corresponding updated versions: 4, 5, 6 respectively. ## Additional resources @@ -70,6 +72,11 @@ documentation. ## Changelog +Aug 2024 +Ported v3 variant (IPluginV2 design) to v6 (IpluginV3 design), but preserves identical attributes and I/O to v3. +Ported v2 variant (IPluginV2 design) to v5 (IpluginV3 design), but preserves identical attributes and I/O to v2. +Ported v1 variant (IPluginV2 design) to v4 (IpluginV3 design), but preserves identical attributes and I/O to v1. + Feb 2024 The issue of the V2 plugin not supporting head sizes of 32 or less and variable sequences of 64, 96, and 384 has been resolved. diff --git a/plugin/bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h b/plugin/bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h index ecc3684d..bdc143b3 100644 --- a/plugin/bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h +++ b/plugin/bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h @@ -329,33 +329,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 uint32_t mUnrollStep; bool mInterleaved; } sMhaKernelMetaInfosV2[] = { -#if defined(ENABLE_SM72) - // Xavier - {DATA_TYPE_INT8, 128, 64, kSM_72, fused_multihead_attention_v2_int8_128_64_kernel_cubin, - fused_multihead_attention_v2_int8_128_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_128_64_kernel_sm72_interleaved", 24576, 128, 0, true}, - {DATA_TYPE_INT8, 128, 64, kSM_72, fused_multihead_attention_v2_int8_128_64_kernel_cubin, - fused_multihead_attention_v2_int8_128_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_128_64_kernel_sm72", 32768, 128, 0, false}, - {DATA_TYPE_INT8, 192, 64, kSM_72, fused_multihead_attention_v2_int8_192_64_kernel_cubin, - fused_multihead_attention_v2_int8_192_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_192_64_kernel_sm72_interleaved", 28672, 128, 0, true}, - {DATA_TYPE_INT8, 192, 64, kSM_72, fused_multihead_attention_v2_int8_192_64_kernel_cubin, - fused_multihead_attention_v2_int8_192_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_192_64_kernel_sm72", 45056, 128, 0, false}, - {DATA_TYPE_INT8, 256, 64, kSM_72, fused_multihead_attention_v2_int8_256_64_kernel_cubin, - fused_multihead_attention_v2_int8_256_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_256_64_kernel_sm72_interleaved", 36864, 128, 0, true}, - {DATA_TYPE_INT8, 256, 64, kSM_72, fused_multihead_attention_v2_int8_256_64_kernel_cubin, - fused_multihead_attention_v2_int8_256_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_256_64_kernel_sm72", 57344, 128, 0, false}, - {DATA_TYPE_INT8, 384, 64, kSM_72, fused_multihead_attention_v2_int8_384_64_kernel_cubin, - fused_multihead_attention_v2_int8_384_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_384_64_kernel_sm72_interleaved", 51200, 128, 0, true}, - {DATA_TYPE_INT8, 384, 64, kSM_72, fused_multihead_attention_v2_int8_384_64_kernel_cubin, - fused_multihead_attention_v2_int8_384_64_kernel_cubin_len, - "fused_multihead_attention_v2_int8_384_64_kernel_sm72", 77824, 128, 0, false}, -#endif // defined(ENABLE_SM72) #if defined(ENABLE_SM75) // Turing {DATA_TYPE_FP16, 64, 64, kSM_75, fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin, diff --git a/plugin/bertQKVToContextPlugin/qkvToContext.cu b/plugin/bertQKVToContextPlugin/mhaRunner.cu similarity index 98% rename from plugin/bertQKVToContextPlugin/qkvToContext.cu rename to plugin/bertQKVToContextPlugin/mhaRunner.cu index 2b89d428..1e10fe6c 100644 --- a/plugin/bertQKVToContextPlugin/qkvToContext.cu +++ b/plugin/bertQKVToContextPlugin/mhaRunner.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +19,6 @@ #include "common/bertCommon.h" #include "common/common.cuh" #include "common/serialize.hpp" -#include "qkvToContextPlugin.h" #include #include @@ -28,6 +27,7 @@ #include #include "bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h" +#include "mhaRunner.h" using namespace nvinfer1; using namespace nvinfer1::pluginInternal; @@ -391,7 +391,7 @@ std::pair tuneBatchedGemm( float ms1 = 1000000; float ms2 = 1000000; - PLUGIN_ASSERT(smVersion >= kSM_53); + PLUGIN_ASSERT(smVersion >= kSM_75); for (int a = startAlgo; a <= endAlgo; a++) { cublasGemmAlgo_t algo = static_cast(a); @@ -925,7 +925,7 @@ public: , sm(mhaInterface->mSm) , xmmaKernel(getXMMAKernelsV2(DATA_TYPE_FP16, sm)) { - assert((sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_87 || sm == kSM_89 || sm == kSM_90) + assert((sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_87 || sm == kSM_89 || sm == kSM_90) && "Unsupported architecture"); params.clear(); } @@ -1096,7 +1096,7 @@ public: , xmmas_n(0U) , threads_per_cta(1U) { - assert((sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_87 || sm == kSM_89 || sm == kSM_90) + assert((sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_87 || sm == kSM_89 || sm == kSM_90) && "Unsupported architecture"); params.clear(); } diff --git a/plugin/bertQKVToContextPlugin/mhaRunner.h b/plugin/bertQKVToContextPlugin/mhaRunner.h new file mode 100644 index 00000000..1cbcb4dc --- /dev/null +++ b/plugin/bertQKVToContextPlugin/mhaRunner.h @@ -0,0 +1,268 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_MHA_RUNNER_H +#define TRT_MHA_RUNNER_H + +// Need 10.1 for cublasGemmStridedBatchedEx +#include +#if CUDA_VERSION >= 10010 + +#include "NvInferPlugin.h" +#include "common/cublasWrapper.h" +#include "zeroPadding2d.h" +#include +#include +#include + +using namespace nvinfer1::pluginInternal; + +namespace nvinfer1 +{ +namespace plugin +{ +namespace bert +{ + +// Multi Head Attention runner +class MHARunner +{ +public: + MHARunner(nvinfer1::DataType const type, int32_t const numHeads) + : mType(type) + , mS(0) + , mB(0) + , mOmatSize(0) + , mNumMats(0) + , mNumHeads(numHeads) + , mHeadSize(0) + , mWordSize(getElementSize(type)) + , mLdQKV(0) + , mStrideQKV(0) + , mLdOut(0) + , mStrideOut(0) + , mRsqrtHeadSize(0) + { + } + + virtual ~MHARunner() = default; + + virtual void setup(int32_t S, int32_t B, int32_t headSize) + { + PLUGIN_ASSERT(S); + PLUGIN_ASSERT(B); + mB = B; + mS = S; + mHeadSize = headSize; + mRsqrtHeadSize = 1.F / std::sqrt(headSize); + + mLdQKV = 3 * B * mNumHeads * mHeadSize; + mStrideQKV = 3 * mHeadSize; + + mLdOut = B * mNumHeads * mHeadSize; + mStrideOut = mHeadSize; + mOmatSize = S * S; + mNumMats = B * mNumHeads; + } + + virtual void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, + void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) + = 0; + + virtual void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) + = 0; + + virtual size_t getSerializationSize() const noexcept; + virtual void serialize(void* buffer) const noexcept; + virtual void deserialize(void const* data, size_t length); + + virtual size_t getWorkspaceSize() const = 0; + + virtual bool isValid(int32_t headSize, int32_t s) const = 0; + +protected: + nvinfer1::DataType mType; + + int32_t mS; + int32_t mB; + int32_t mOmatSize; + int32_t mNumMats; + int32_t mNumHeads; + int32_t mHeadSize; + int32_t mWordSize; + int32_t mLdQKV; + int32_t mStrideQKV; + int32_t mLdOut; + int32_t mStrideOut; + + float mRsqrtHeadSize; +}; + +class UnfusedMHARunner : public MHARunner +{ +public: + UnfusedMHARunner(nvinfer1::DataType const type, int32_t const numHeads, int32_t const smVersion); + virtual ~UnfusedMHARunner(); + + virtual void setup(int32_t S, int32_t B, int32_t headSize) override; + + void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, + void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + size_t getWorkspaceSize() const override; + + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void deserialize(void const* data, size_t length) override; + bool isValid(int32_t headSize, int32_t s) const override; + +private: + bool mIsBestAlgoFound{}; + int32_t mAlgoBatchedEx1{}; + int32_t mAlgoBatchedEx2{}; + int32_t mSm{}; +}; + +class FusedMHARunnerFP16 : public MHARunner +{ +public: + FusedMHARunnerFP16(int32_t const numHeads, int32_t const sm); + ~FusedMHARunnerFP16() = default; // for pimpl + + virtual void setup(int32_t S, int32_t B, int32_t headSize) override; + + void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, + void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + size_t getWorkspaceSize() const override; + + void deserialize(void const* data, size_t length) override; + + bool isValid(int32_t headSize, int32_t s) const override; + +private: + int32_t mSm; + class mhaImpl; + std::unique_ptr pimpl; +}; + +class FusedMHARunnerInt8 : public MHARunner +{ +public: + FusedMHARunnerInt8(int32_t const numHeads, int32_t const sm, float const dqProbs); + ~FusedMHARunnerInt8() = default; // for pimpl + + virtual void setup(int32_t S, int32_t B, int32_t headSize) override; + + void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, + void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + size_t getWorkspaceSize() const override; + + void deserialize(void const* data, size_t length) override; + + bool isValid(int32_t headSize, int32_t s) const override; + +private: + float mDqProbs; + int32_t mSm; + class mhaImpl; + std::unique_ptr pimpl; +}; + +class FusedMHARunnerFP16v2 : public MHARunner +{ +public: + FusedMHARunnerFP16v2(int32_t const numHeads, int32_t const sm); + ~FusedMHARunnerFP16v2() = default; // for pimpl + + virtual void setup(int32_t S, int32_t B, int32_t headSize) override; + + void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, + void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + size_t getWorkspaceSize() const override; + + void deserialize(void const* data, size_t length) override; + + bool isValid(int32_t headSize, int32_t s) const override; + +private: + int32_t mSm; + class mhaImpl; + std::unique_ptr pimpl; +}; + +class FusedMHARunnerInt8v2 : public MHARunner +{ +public: + FusedMHARunnerInt8v2(int32_t const numHeads, int32_t const sm, float const dqProbs, bool const useInt8ScaleMax); + ~FusedMHARunnerInt8v2() = default; // for pimpl + + virtual void setup(int32_t S, int32_t B, int32_t headSize) override; + + void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, + void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, + nvinfer1::pluginInternal::cublasHandle_t cublas) override; + + size_t getWorkspaceSize() const override; + + void deserialize(void const* data, size_t length) override; + + bool isValid(int32_t headSize, int32_t s) const override; + +private: + float mDqProbs; + int32_t mSm; + class mhaImpl; + std::unique_ptr pimpl; + bool mUseInt8ScaleMax{true}; +}; + +} // namespace bert +} // namespace plugin +} // namespace nvinfer1 +#endif // TRT_MHA_RUNNER_H + +#endif // CUDA_VERSION >= 10010 diff --git a/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.cpp b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.cpp index 40a42af0..45d589bd 100644 --- a/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.cpp +++ b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -35,7 +36,7 @@ using namespace nvinfer1::plugin::bert; namespace { -char const* const kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_VERSION{"3"}; +char const* const kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_VERSION{"6"}; char const* const kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_NAME{"CustomQKVToContextPluginDynamic"}; } // namespace @@ -45,23 +46,21 @@ std::vector QKVToContextInterleavedPluginCreator::mPluginAttributes REGISTER_TENSORRT_PLUGIN(QKVToContextInterleavedPluginCreator); -constexpr uint32_t IIDX = 0; // index of the input tensor +constexpr uint32_t kIIDX = 0; // index of the input tensor QKVToContextInterleavedPlugin::QKVToContextInterleavedPlugin(std::string const& name, int32_t hiddenSize, int32_t numHeads, float dqProbs, bool useInt8ScaleMax, bool useExplicitInt8, float qkvScale, float ctxScale) : mLayerName(name) - , mS(0) - , mB(0) , mHeadSize(hiddenSize / numHeads) , mHiddenSize(hiddenSize) , mNumHeads(numHeads) , mDqProbs(dqProbs) - , mUseInt8ScaleMax(useInt8ScaleMax) - , mUseExplicitInt8(useExplicitInt8) , mQkvScale(qkvScale) , mCtxScale(ctxScale) { mSM = getSMVersion(); + mUseInt8ScaleMax = static_cast(useInt8ScaleMax); + mUseExplicitInt8 = static_cast(useExplicitInt8); // variable sequence length is only supported with the fused MHA kernels // we should not override mS! PLUGIN_VALIDATE((mSM == kSM_AMPERE_100 || mSM == kSM_AMPERE_10X || mSM == kSM_AMPERE_10B || mSM == kSM_TURING @@ -71,33 +70,36 @@ QKVToContextInterleavedPlugin::QKVToContextInterleavedPlugin(std::string const& mXmmaKernel = getXMMAKernelsV2(DATA_TYPE_INT8, mSM); } -QKVToContextInterleavedPlugin::QKVToContextInterleavedPlugin(std::string const& name, void const* data, size_t length) - : mLayerName(name) +QKVToContextInterleavedPlugin::~QKVToContextInterleavedPlugin() {} + +IPluginV3* QKVToContextInterleavedPlugin::attachToContext(IPluginResourceContext* context) noexcept { - deserialize_value(&data, &length, &mNumHeads); - deserialize_value(&data, &length, &mHeadSize); - deserialize_value(&data, &length, &mHiddenSize); - deserialize_value(&data, &length, &mSM); - deserialize_value(&data, &length, &mS); - deserialize_value(&data, &length, &mB); - deserialize_value(&data, &length, &mDqProbs); - deserialize_value(&data, &length, &mUseInt8ScaleMax); - deserialize_value(&data, &length, &mUseExplicitInt8); - deserialize_value(&data, &length, &mQkvScale); - deserialize_value(&data, &length, &mCtxScale); + return clone(); } -int32_t QKVToContextInterleavedPlugin::getSMVersion() const noexcept +IPluginCapability* QKVToContextInterleavedPlugin::getCapabilityInterface(PluginCapabilityType type) noexcept { - int32_t device{-1}; - PLUGIN_CHECK(cudaGetDevice(&device)); - cudaDeviceProp props; - PLUGIN_CHECK(cudaGetDeviceProperties(&props, device)); - return getTrtSMVersionDec(props.major, props.minor); + try + { + if (type == PluginCapabilityType::kBUILD) + { + return static_cast(this); + } + if (type == PluginCapabilityType::kRUNTIME) + { + return static_cast(this); + } + PLUGIN_ASSERT(type == PluginCapabilityType::kCORE); + return static_cast(this); + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; } -// IPluginV2DynamicExt Methods -nvinfer1::IPluginV2DynamicExt* QKVToContextInterleavedPlugin::clone() const noexcept +IPluginV3* QKVToContextInterleavedPlugin::clone() noexcept { try { @@ -114,27 +116,41 @@ nvinfer1::IPluginV2DynamicExt* QKVToContextInterleavedPlugin::clone() const noex return nullptr; } -DimsExprs QKVToContextInterleavedPlugin::getOutputDimensions( - int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept +int32_t QKVToContextInterleavedPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, + DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, + IExprBuilder& exprBuilder) noexcept { - // Input SHAPE is 1x(3*N*H)xTotalx1 (NCHW) - // Output SHAPE is 1x(N*H)xTotalx1 - // In SupportsFormatCombination, we force the layout to be CHW, i.e. - // Input: 3xNx(H/32)xsumSx32, Output: 1xNx(H/32)xsumSx32 - PLUGIN_ASSERT(outputIndex == 0); - // Copy over everything - DimsExprs output(inputs[IIDX]); - // output.d[0] = exprBuilder.constant(1); - // Divide last dim by three - auto const* three = exprBuilder.constant(3); - output.d[1] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[IIDX].d[1], *three); - return output; + try + { + // Input SHAPE is 1x(3*N*H)xTotalx1 (NCHW) + // Output SHAPE is 1x(N*H)xTotalx1 + // In SupportsFormatCombination, we force the layout to be CHW, i.e. + // Input: 3xNx(H/32)xsumSx32, Output: 1xNx(H/32)xsumSx32 + PLUGIN_ASSERT(inputs != nullptr); + PLUGIN_ASSERT(nbInputs == 3); + PLUGIN_ASSERT(nbShapeInputs == 0); + PLUGIN_ASSERT(outputs != nullptr); + PLUGIN_ASSERT(nbOutputs == 1); + outputs[kIIDX] = inputs[kIIDX]; + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + outputs[kIIDX].d[1] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[kIIDX].d[1], *three); + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; } + bool QKVToContextInterleavedPlugin::supportsFormatCombination( - int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t /*nbOutputs*/) noexcept { + PLUGIN_ASSERT(pos >= 0); PLUGIN_ASSERT(nbInputs == 3); - PLUGIN_ASSERT(nbOutputs == 1); + PLUGIN_ASSERT(pos <= nbInputs); + PLUGIN_ASSERT(inOut != nullptr); // 3 inputs: // 0: qkv // 1: cu_seqlens @@ -142,46 +158,55 @@ bool QKVToContextInterleavedPlugin::supportsFormatCombination( // 1 output if (pos == 0 || pos == nbInputs) { - return (inOut[pos].type == DataType::kINT8) && (inOut[pos].format == TensorFormat::kCHW32); + return (inOut[pos].desc.type == DataType::kINT8) && (inOut[pos].desc.format == TensorFormat::kCHW32); } if (pos == 1) { // cuSeqlens is a int32_t array of size B+1 - auto const* seqlens = &inOut[pos]; + auto const* seqlens = &inOut[pos].desc; return (seqlens->type == DataType::kINT32) && (seqlens->format == TensorFormat::kLINEAR); } if (pos == 2) { // this is the dummy input - return inOut[pos].dims.nbDims == 1; + return inOut[pos].desc.dims.nbDims == 1; } return false; } -void QKVToContextInterleavedPlugin::configurePlugin( - DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +int32_t QKVToContextInterleavedPlugin::onShapeChange( + PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept { + return pluginStatus_t::STATUS_SUCCESS; } -size_t QKVToContextInterleavedPlugin::getWorkspaceSize( - PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept +int32_t QKVToContextInterleavedPlugin::configurePlugin( + DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept { - return 0; + return pluginStatus_t::STATUS_SUCCESS; } -// IPluginV2Ext Methods -DataType QKVToContextInterleavedPlugin::getOutputDataType( - int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept +size_t QKVToContextInterleavedPlugin::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { - PLUGIN_ASSERT(index == 0); - return DataType::kINT8; + return 0; } -// IPluginV2 Methods -char const* QKVToContextInterleavedPlugin::getPluginType() const noexcept +int32_t QKVToContextInterleavedPlugin::getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept { - return kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_NAME; + try + { + PLUGIN_ASSERT(nbOutputs == 1); + outputTypes[0] = DataType::kINT8; + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; } char const* QKVToContextInterleavedPlugin::getPluginVersion() const noexcept @@ -194,40 +219,6 @@ int32_t QKVToContextInterleavedPlugin::getNbOutputs() const noexcept return 1; } -int32_t QKVToContextInterleavedPlugin::initialize() noexcept -{ - return 0; -} - -void QKVToContextInterleavedPlugin::terminate() noexcept {} - -size_t QKVToContextInterleavedPlugin::getSerializationSize() const noexcept -{ - return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(mHiddenSize) + sizeof(mSM) + sizeof(mS) + sizeof(mB) - + sizeof(mDqProbs) + sizeof(mUseInt8ScaleMax) + sizeof(mUseExplicitInt8) + sizeof(mQkvScale) - + sizeof(mCtxScale); -} - -void QKVToContextInterleavedPlugin::serialize(void* buffer) const noexcept -{ - serialize_value(&buffer, mNumHeads); - serialize_value(&buffer, mHeadSize); - serialize_value(&buffer, mHiddenSize); - serialize_value(&buffer, mSM); - serialize_value(&buffer, mS); - serialize_value(&buffer, mB); - serialize_value(&buffer, mDqProbs); - serialize_value(&buffer, mUseInt8ScaleMax); - serialize_value(&buffer, mUseExplicitInt8); - serialize_value(&buffer, mQkvScale); - serialize_value(&buffer, mCtxScale); -} - -void QKVToContextInterleavedPlugin::destroy() noexcept -{ - delete this; -} - void QKVToContextInterleavedPlugin::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; @@ -238,6 +229,11 @@ char const* QKVToContextInterleavedPlugin::getPluginNamespace() const noexcept return mNamespace.c_str(); } +char const* QKVToContextInterleavedPlugin::getPluginName() const noexcept +{ + return kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_NAME; +} + int32_t QKVToContextInterleavedPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* /* workspace */, cudaStream_t stream) noexcept { @@ -301,8 +297,34 @@ int32_t QKVToContextInterleavedPlugin::enqueue(PluginTensorDesc const* inputDesc } } +PluginFieldCollection const* QKVToContextInterleavedPlugin::getFieldsToSerialize() noexcept +{ + mDataToSerialize.clear(); + + mDataToSerialize.emplace_back("hidden_size", &mHiddenSize, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("num_heads", &mNumHeads, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("use_int8_scale_max", &mUseInt8ScaleMax, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("use_explicit_int8", &mUseExplicitInt8, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("input_qkv_scale", &mQkvScale, PluginFieldType::kFLOAT32, 1); + mDataToSerialize.emplace_back("output_ctx_scale", &mCtxScale, PluginFieldType::kFLOAT32, 1); + + if (mDqProbs >= 0) + { + mDataToSerialize.emplace_back("dq_probs", &mDqProbs, PluginFieldType::kFLOAT32, 1); + } + + mFCToSerialize.nbFields = mDataToSerialize.size(); + mFCToSerialize.fields = mDataToSerialize.data(); + + return &mFCToSerialize; +} + +///////////////////////// Creator methods //////////////////////// + QKVToContextInterleavedPluginCreator::QKVToContextInterleavedPluginCreator() { + static std::mutex sMutex; + std::lock_guard lock(sMutex); mPluginAttributes.clear(); mPluginAttributes.emplace_back(PluginField("hidden_size", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, 1)); @@ -331,25 +353,34 @@ PluginFieldCollection const* QKVToContextInterleavedPluginCreator::getFieldNames return &mFC; } -IPluginV2* QKVToContextInterleavedPluginCreator::createPlugin( - char const* name, PluginFieldCollection const* fc) noexcept +IPluginV3* QKVToContextInterleavedPluginCreator::createPlugin( + char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept { try { - int32_t hiddenSize = 0; // Since numHeads must always exist or validateRequiredAttributes will fail, // we can set numHeads to -1 so that static analysis tools don't warn about // a division by zero in QKVToContextInterleavedPlugin constructor. int32_t numHeads{-1}; + int32_t hiddenSize{0}; + std::optional useInt8ScaleMax; + std::optional useExplicitInt8; + std::optional qkvScale; + std::optional ctxScale; + std::optional dqProbs; + + if (phase == TensorRTPhase::kBUILD) + { - float dqProbs = -1; - int32_t useInt8ScaleMax{-1}; - - int32_t useExplicitInt8{}; - float qkvScale{1.F}; - float ctxScale{1.F}; - - plugin::validateRequiredAttributesExist({"hidden_size", "num_heads"}, fc); + plugin::validateRequiredAttributesExist({"hidden_size", "num_heads"}, fc); + } + else + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + plugin::validateRequiredAttributesExist({"hidden_size", "num_heads", "use_int8_scale_max", + "use_explicit_int8", "input_qkv_scale", "output_ctx_scale"}, + fc); + } for (int32_t i = 0; i < fc->nbFields; i++) { @@ -370,68 +401,68 @@ IPluginV2* QKVToContextInterleavedPluginCreator::createPlugin( else if (field_name.compare("dq_probs") == 0) { dqProbs = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(dqProbs > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs)).c_str()); - BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs); + PLUGIN_VALIDATE( + dqProbs.value() > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs.value())).c_str()); + BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs.value()); } else if (field_name.compare("use_int8_scale_max") == 0) { useInt8ScaleMax = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(useInt8ScaleMax == 0 || useInt8ScaleMax == 1, - ("QKV: Invalid useInt8ScaleMax " + std::to_string(useInt8ScaleMax)).c_str()); - BERT_DEBUG_VALUE("Building useInt8ScaleMax: ", useInt8ScaleMax); + PLUGIN_VALIDATE(useInt8ScaleMax.value() == 0 || useInt8ScaleMax.value() == 1, + ("QKV: Invalid useInt8ScaleMax " + std::to_string(useInt8ScaleMax.value())).c_str()); + BERT_DEBUG_VALUE("Building useInt8ScaleMax: ", useInt8ScaleMax.value()); } else if (field_name.compare("use_explicit_int8") == 0) { useExplicitInt8 = *static_cast(fc->fields[i].data); - BERT_DEBUG_VALUE("Building use_explicit_int8: ", useExplicitInt8); + PLUGIN_VALIDATE(useExplicitInt8.value() == 0 || useExplicitInt8.value() == 1, + ("QKV: Invalid useExplicitInt8 " + std::to_string(useExplicitInt8.value())).c_str()); + BERT_DEBUG_VALUE("Building use_explicit_int8: ", useExplicitInt8.value()); } else if (field_name.compare("input_qkv_scale") == 0) { qkvScale = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(qkvScale > 0, ("QKV: Invalid input_qkv_scale" + std::to_string(qkvScale)).c_str()); - BERT_DEBUG_VALUE("Building input_qkv_scale: ", qkvScale); + PLUGIN_VALIDATE( + qkvScale.value() > 0, ("QKV: Invalid input_qkv_scale" + std::to_string(qkvScale.value())).c_str()); + BERT_DEBUG_VALUE("Building input_qkv_scale: ", qkvScale.value()); } else if (field_name.compare("output_ctx_scale") == 0) { ctxScale = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(ctxScale > 0, ("QKV: Invalid output_ctx_scale " + std::to_string(ctxScale)).c_str()); - BERT_DEBUG_VALUE("Building output_ctx_scale: ", ctxScale); + PLUGIN_VALIDATE(ctxScale.value() > 0, + ("QKV: Invalid output_ctx_scale " + std::to_string(ctxScale.value())).c_str()); + BERT_DEBUG_VALUE("Building output_ctx_scale: ", ctxScale.value()); } } - if (dqProbs < 0) + if (!dqProbs.has_value()) { - gLogInfo << "Using default scale factor\n"; + gLogInfo << "Using default scale factor: 1.F/127.F" << std::endl; dqProbs = 1.F / 127.F; } - - if (useInt8ScaleMax < 0) + if (!useInt8ScaleMax.has_value()) { - gLogInfo << "Using default for use_int8_scale_max: true" << std::endl; + gLogInfo << "Using default for use_int8_scale_max: 1" << std::endl; useInt8ScaleMax = 1; } + if (!useExplicitInt8.has_value()) + { + gLogInfo << "Using default for use_explicit_int8: 0" << std::endl; + useExplicitInt8 = 0; + } + if (!qkvScale.has_value()) + { + gLogInfo << "Using default for qkvScale: 1.F" << std::endl; + qkvScale = 1.F; + } + if (!ctxScale.has_value()) + { + gLogInfo << "Using default for ctxScale: 1.F" << std::endl; + ctxScale = 1.F; + } - auto const useInt8ScaleMaxFlag = static_cast(useInt8ScaleMax); - - QKVToContextInterleavedPlugin* p = new QKVToContextInterleavedPlugin( - name, hiddenSize, numHeads, dqProbs, useInt8ScaleMaxFlag, useExplicitInt8 != 0, qkvScale, ctxScale); - return p; - } - catch (std::exception const& e) - { - caughtError(e); - } - return nullptr; -} - -IPluginV2* QKVToContextInterleavedPluginCreator::deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept -{ - try - { - // This object will be deleted when the network is destroyed, which will - // call QKVToContextInterleavedPlugin::destroy() noexcept - return new QKVToContextInterleavedPlugin(name, serialData, serialLength); + return new QKVToContextInterleavedPlugin(name, hiddenSize, numHeads, dqProbs.value(), + useInt8ScaleMax.value() != 0, useExplicitInt8.value() != 0, qkvScale.value(), ctxScale.value()); } catch (std::exception const& e) { diff --git a/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.h b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.h index 33f9de6e..7aab506d 100644 --- a/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.h +++ b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.h @@ -38,89 +38,106 @@ static constexpr int32_t kSM_AMPERE_10B = 87; static constexpr int32_t kSM_ADA_10X = 89; static constexpr int32_t kSM_HOPPER_100 = 90; -class QKVToContextInterleavedPlugin : public nvinfer1::IPluginV2DynamicExt +class QKVToContextInterleavedPlugin : public IPluginV3, + public IPluginV3OneCore, + public IPluginV3OneBuild, + public IPluginV3OneRuntime { public: QKVToContextInterleavedPlugin(std::string const& name, int32_t hiddenSize, int32_t numHeads, float dqProbs, bool useInt8ScaleMax, bool useExplicitInt8, float qkvScale, float ctxScale); - QKVToContextInterleavedPlugin(std::string const& name, void const* data, size_t length); - // It doesn't make sense to make QKVToContextInterleavedPlugin without arguments, so we // delete default constructor. QKVToContextInterleavedPlugin() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; + ~QKVToContextInterleavedPlugin() override; + + // IPluginV3 Methods + IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override; + // end of IPluginV3 Methods + + // IPluginV3OneCore Methods + char const* getPluginName() const noexcept override; + + char const* getPluginNamespace() const noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept; + + char const* getPluginVersion() const noexcept override; + // end of IPluginV3OneCore Methods + + // IPluginV3Build Methods bool supportsFormatCombination( - int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; - void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, - nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; - size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, - nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + + int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + + size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + + int32_t getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override; + + int32_t getNbOutputs() const noexcept override; + + int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, + int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override; + // end IPluginV3Build Methods + + // IPluginV3Runtime Methods + IPluginV3* clone() noexcept; + + int32_t onShapeChange( + PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType( - int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; + IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override; - // IPluginV2 Methods - char const* getPluginType() const noexcept override; - char const* getPluginVersion() const noexcept override; - int32_t getNbOutputs() const noexcept override; - int32_t initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; - char const* getPluginNamespace() const noexcept override; + PluginFieldCollection const* getFieldsToSerialize() noexcept override; + // end IPluginV3Runtime Methods protected: - void createMHARunner() noexcept; - int32_t getSMVersion() const noexcept; - private: std::string const& mLayerName; std::string mNamespace; - int32_t mS{}; - int32_t mB{}; int32_t mSM{}; int32_t mHeadSize{}; int32_t mHiddenSize{}; int32_t mNumHeads{}; + int32_t mUseInt8ScaleMax{1}; + int32_t mUseExplicitInt8{}; FusedMultiHeadAttentionXMMAKernelV2 const* mXmmaKernel; float mDqProbs{}; - bool mUseInt8ScaleMax{true}; - - bool mUseExplicitInt8{}; float mQkvScale{}; float mCtxScale{}; + + // IPluginV3 serialization related + std::vector mDataToSerialize; + nvinfer1::PluginFieldCollection mFCToSerialize; }; -class QKVToContextInterleavedPluginCreator : public nvinfer1::IPluginCreator +class QKVToContextInterleavedPluginCreator : public nvinfer1::IPluginCreatorV3One { public: QKVToContextInterleavedPluginCreator(); + ~QKVToContextInterleavedPluginCreator() override = default; char const* getPluginName() const noexcept override; char const* getPluginVersion() const noexcept override; - nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - - nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept override; + IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept; char const* getPluginNamespace() const noexcept override; diff --git a/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPluginLegacy.cpp b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPluginLegacy.cpp new file mode 100644 index 00000000..748d4b9c --- /dev/null +++ b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPluginLegacy.cpp @@ -0,0 +1,444 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "qkvToContextInt8InterleavedPluginLegacy.h" +#include "NvInfer.h" +#include "common/bertCommon.h" +#include "common/plugin.h" +#include "common/serialize.hpp" + +#include +#include +#include +#include +#include + +#include "bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +namespace +{ +char const* const kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_VERSION{"3"}; +char const* const kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_NAME{"CustomQKVToContextPluginDynamic"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection QKVToContextInterleavedPluginLegacyCreator::mFC{}; +std::vector QKVToContextInterleavedPluginLegacyCreator::mPluginAttributes; + +REGISTER_TENSORRT_PLUGIN(QKVToContextInterleavedPluginLegacyCreator); + +constexpr uint32_t kIIDX = 0; // index of the input tensor + +QKVToContextInterleavedPluginLegacy::QKVToContextInterleavedPluginLegacy(std::string const& name, int32_t hiddenSize, + int32_t numHeads, float dqProbs, bool useInt8ScaleMax, bool useExplicitInt8, float qkvScale, float ctxScale) + : mLayerName(name) + , mS(0) + , mB(0) + , mHeadSize(hiddenSize / numHeads) + , mHiddenSize(hiddenSize) + , mNumHeads(numHeads) + , mDqProbs(dqProbs) + , mUseInt8ScaleMax(useInt8ScaleMax) + , mUseExplicitInt8(useExplicitInt8) + , mQkvScale(qkvScale) + , mCtxScale(ctxScale) +{ + mSM = getSMVersion(); + // variable sequence length is only supported with the fused MHA kernels + // we should not override mS! + PLUGIN_VALIDATE((mSM == kSM_AMPERE_100 || mSM == kSM_AMPERE_10X || mSM == kSM_AMPERE_10B || mSM == kSM_TURING + || mSM == kSM_XAVIER || mSM == kSM_ADA_10X || mSM == kSM_HOPPER_100) + && "requesting maxSeqlen not compatible with GPU arch"); + // the layout changes: SxB will be a combined \sum_i s_i and hdim will be the 2nd dimension instead of the third + mXmmaKernel = getXMMAKernelsV2(DATA_TYPE_INT8, mSM); +} + +QKVToContextInterleavedPluginLegacy::QKVToContextInterleavedPluginLegacy( + std::string const& name, void const* data, size_t length) + : mLayerName(name) +{ + deserialize_value(&data, &length, &mNumHeads); + deserialize_value(&data, &length, &mHeadSize); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mSM); + deserialize_value(&data, &length, &mS); + deserialize_value(&data, &length, &mB); + deserialize_value(&data, &length, &mDqProbs); + deserialize_value(&data, &length, &mUseInt8ScaleMax); + deserialize_value(&data, &length, &mUseExplicitInt8); + deserialize_value(&data, &length, &mQkvScale); + deserialize_value(&data, &length, &mCtxScale); +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* QKVToContextInterleavedPluginLegacy::clone() const noexcept +{ + try + { + QKVToContextInterleavedPluginLegacy* ret = new QKVToContextInterleavedPluginLegacy( + mLayerName, mHiddenSize, mNumHeads, mDqProbs, mUseInt8ScaleMax, mUseExplicitInt8, mQkvScale, mCtxScale); + + ret->setPluginNamespace(mNamespace.c_str()); + return ret; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +DimsExprs QKVToContextInterleavedPluginLegacy::getOutputDimensions( + int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept +{ + // Input SHAPE is 1x(3*N*H)xTotalx1 (NCHW) + // Output SHAPE is 1x(N*H)xTotalx1 + // In SupportsFormatCombination, we force the layout to be CHW, i.e. + // Input: 3xNx(H/32)xsumSx32, Output: 1xNx(H/32)xsumSx32 + PLUGIN_ASSERT(outputIndex == 0); + // Copy over everything + DimsExprs output(inputs[kIIDX]); + // output.d[0] = exprBuilder.constant(1); + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + output.d[1] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[kIIDX].d[1], *three); + return output; +} +bool QKVToContextInterleavedPluginLegacy::supportsFormatCombination( + int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept +{ + PLUGIN_ASSERT(nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 1); + // 3 inputs: + // 0: qkv + // 1: cu_seqlens + // 2: dummy + // 1 output + if (pos == 0 || pos == nbInputs) + { + return (inOut[pos].type == DataType::kINT8) && (inOut[pos].format == TensorFormat::kCHW32); + } + + if (pos == 1) + { + // cuSeqlens is a int32_t array of size B+1 + auto const* seqlens = &inOut[pos]; + return (seqlens->type == DataType::kINT32) && (seqlens->format == TensorFormat::kLINEAR); + } + if (pos == 2) + { + // this is the dummy input + return inOut[pos].dims.nbDims == 1; + } + return false; +} + +void QKVToContextInterleavedPluginLegacy::configurePlugin( + DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +{ +} + +size_t QKVToContextInterleavedPluginLegacy::getWorkspaceSize( + PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept +{ + return 0; +} + +// IPluginV2Ext Methods +DataType QKVToContextInterleavedPluginLegacy::getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept +{ + PLUGIN_ASSERT(index == 0); + return DataType::kINT8; +} + +// IPluginV2 Methods +char const* QKVToContextInterleavedPluginLegacy::getPluginType() const noexcept +{ + return kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_NAME; +} + +char const* QKVToContextInterleavedPluginLegacy::getPluginVersion() const noexcept +{ + return kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_VERSION; +} + +int32_t QKVToContextInterleavedPluginLegacy::getNbOutputs() const noexcept +{ + return 1; +} + +int32_t QKVToContextInterleavedPluginLegacy::initialize() noexcept +{ + return 0; +} + +void QKVToContextInterleavedPluginLegacy::terminate() noexcept {} + +size_t QKVToContextInterleavedPluginLegacy::getSerializationSize() const noexcept +{ + return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(mHiddenSize) + sizeof(mSM) + sizeof(mS) + sizeof(mB) + + sizeof(mDqProbs) + sizeof(mUseInt8ScaleMax) + sizeof(mUseExplicitInt8) + sizeof(mQkvScale) + + sizeof(mCtxScale); +} + +void QKVToContextInterleavedPluginLegacy::serialize(void* buffer) const noexcept +{ + serialize_value(&buffer, mNumHeads); + serialize_value(&buffer, mHeadSize); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mSM); + serialize_value(&buffer, mS); + serialize_value(&buffer, mB); + serialize_value(&buffer, mDqProbs); + serialize_value(&buffer, mUseInt8ScaleMax); + serialize_value(&buffer, mUseExplicitInt8); + serialize_value(&buffer, mQkvScale); + serialize_value(&buffer, mCtxScale); +} + +void QKVToContextInterleavedPluginLegacy::destroy() noexcept +{ + delete this; +} + +void QKVToContextInterleavedPluginLegacy::setPluginNamespace(char const* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +char const* QKVToContextInterleavedPluginLegacy::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +int32_t QKVToContextInterleavedPluginLegacy::enqueue(PluginTensorDesc const* inputDesc, + PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* /* workspace */, + cudaStream_t stream) noexcept +{ + PLUGIN_VALIDATE(inputDesc != nullptr && outputDesc != nullptr && inputs != nullptr && outputs != nullptr); + + int32_t const total = inputDesc[0].dims.d[2]; + int32_t const B = inputDesc[1].dims.d[0] - 1; + int32_t const maxS = inputDesc[2].dims.d[0]; + int32_t S = 384; + if (maxS <= 128) + { + S = 128; + } + else if (maxS <= 192) + { + S = 192; + } + else if (maxS <= 256) + { + S = 256; + } + Fused_multihead_attention_params_v2 params{}; + params.b = B; + params.s = S; + params.h = mNumHeads; + params.d = mHeadSize; + + params.interleaved = true; + + params.o_ptr = outputs[0]; + params.qkv_ptr = const_cast(inputs[0]); + params.cu_seqlens = static_cast(const_cast(inputs[1])); + + float scaleQkv = mUseExplicitInt8 ? mQkvScale : inputDesc[0].scale; + float scaleCtx = mUseExplicitInt8 ? mCtxScale : outputDesc[0].scale; + + float scaleBmm1 = scaleQkv * scaleQkv * 0.125; // 1 / sqrt(64) + float scaleBmm2 = mDqProbs * scaleQkv / scaleCtx; + float scaleSoftmax = 1.F / mDqProbs; + + params.scale_bmm1 = reinterpret_cast(scaleBmm1); + params.scale_bmm2 = reinterpret_cast(scaleBmm2); + params.scale_softmax = reinterpret_cast(scaleSoftmax); + + params.qkv_stride_in_bytes = total; + params.o_stride_in_bytes = total; + + params.use_int8_scale_max = mUseInt8ScaleMax; + params.enable_i2f_trick + = -double(1 << 22) * double(scaleBmm2) <= -128.F && double(1 << 22) * double(scaleBmm2) >= 127.F; + + try + { + mXmmaKernel->run(params, stream); + return cudaPeekAtLastError(); + } + catch (std::exception const& e) + { + caughtError(e); + return -1; + } +} + +QKVToContextInterleavedPluginLegacyCreator::QKVToContextInterleavedPluginLegacyCreator() +{ + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("hidden_size", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("dq_probs", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("use_int8_scale_max", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("use_explicit_int8", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("input_qkv_scale", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("output_ctx_scale", nullptr, PluginFieldType::kFLOAT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* QKVToContextInterleavedPluginLegacyCreator::getPluginName() const noexcept +{ + return kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_NAME; +} + +char const* QKVToContextInterleavedPluginLegacyCreator::getPluginVersion() const noexcept +{ + return kQKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_VERSION; +} + +PluginFieldCollection const* QKVToContextInterleavedPluginLegacyCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2* QKVToContextInterleavedPluginLegacyCreator::createPlugin( + char const* name, PluginFieldCollection const* fc) noexcept +{ + try + { + int32_t hiddenSize = 0; + // Since numHeads must always exist or validateRequiredAttributes will fail, + // we can set numHeads to -1 so that static analysis tools don't warn about + // a division by zero in QKVToContextInterleavedPluginLegacy constructor. + int32_t numHeads{-1}; + + float dqProbs = -1; + int32_t useInt8ScaleMax{-1}; + + int32_t useExplicitInt8{}; + float qkvScale{1.F}; + float ctxScale{1.F}; + + plugin::validateRequiredAttributesExist({"hidden_size", "num_heads"}, fc); + + for (int32_t i = 0; i < fc->nbFields; i++) + { + std::string field_name(fc->fields[i].name); + + if (field_name.compare("hidden_size") == 0) + { + hiddenSize = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(hiddenSize > 0, ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); + BERT_DEBUG_VALUE("Building hiddenSize: ", hiddenSize); + } + else if (field_name.compare("num_heads") == 0) + { + numHeads = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); + BERT_DEBUG_VALUE("Building numHeads: ", numHeads); + } + else if (field_name.compare("dq_probs") == 0) + { + dqProbs = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(dqProbs > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs)).c_str()); + BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs); + } + else if (field_name.compare("use_int8_scale_max") == 0) + { + useInt8ScaleMax = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(useInt8ScaleMax == 0 || useInt8ScaleMax == 1, + ("QKV: Invalid useInt8ScaleMax " + std::to_string(useInt8ScaleMax)).c_str()); + BERT_DEBUG_VALUE("Building useInt8ScaleMax: ", useInt8ScaleMax); + } + else if (field_name.compare("use_explicit_int8") == 0) + { + useExplicitInt8 = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building use_explicit_int8: ", useExplicitInt8); + } + else if (field_name.compare("input_qkv_scale") == 0) + { + qkvScale = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(qkvScale > 0, ("QKV: Invalid input_qkv_scale" + std::to_string(qkvScale)).c_str()); + BERT_DEBUG_VALUE("Building input_qkv_scale: ", qkvScale); + } + else if (field_name.compare("output_ctx_scale") == 0) + { + ctxScale = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(ctxScale > 0, ("QKV: Invalid output_ctx_scale " + std::to_string(ctxScale)).c_str()); + BERT_DEBUG_VALUE("Building output_ctx_scale: ", ctxScale); + } + } + + if (dqProbs < 0) + { + gLogInfo << "Using default scale factor\n"; + dqProbs = 1.F / 127.F; + } + + if (useInt8ScaleMax < 0) + { + gLogInfo << "Using default for use_int8_scale_max: true" << std::endl; + useInt8ScaleMax = 1; + } + + auto const useInt8ScaleMaxFlag = static_cast(useInt8ScaleMax); + + QKVToContextInterleavedPluginLegacy* p = new QKVToContextInterleavedPluginLegacy( + name, hiddenSize, numHeads, dqProbs, useInt8ScaleMaxFlag, useExplicitInt8 != 0, qkvScale, ctxScale); + return p; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +IPluginV2* QKVToContextInterleavedPluginLegacyCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept +{ + try + { + // This object will be deleted when the network is destroyed, which will + // call QKVToContextInterleavedPluginLegacy::destroy() noexcept + return new QKVToContextInterleavedPluginLegacy(name, serialData, serialLength); + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +void QKVToContextInterleavedPluginLegacyCreator::setPluginNamespace(char const* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +char const* QKVToContextInterleavedPluginLegacyCreator::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} diff --git a/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPluginLegacy.h b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPluginLegacy.h new file mode 100644 index 00000000..dfc7c89f --- /dev/null +++ b/plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPluginLegacy.h @@ -0,0 +1,135 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_QKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_H +#define TRT_QKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_H + +#include "NvInferPlugin.h" +#include "fused_multihead_attention_v2/include/fused_multihead_attention_v2.h" +#include +#include +#include + +namespace nvinfer1 +{ +namespace plugin +{ +namespace bert +{ +static constexpr int32_t kSM_XAVIER = 72; +static constexpr int32_t kSM_TURING = 75; +static constexpr int32_t kSM_AMPERE_100 = 80; +static constexpr int32_t kSM_AMPERE_10X = 86; +static constexpr int32_t kSM_AMPERE_10B = 87; +static constexpr int32_t kSM_ADA_10X = 89; +static constexpr int32_t kSM_HOPPER_100 = 90; + +class QKVToContextInterleavedPluginLegacy : public nvinfer1::IPluginV2DynamicExt +{ +public: + QKVToContextInterleavedPluginLegacy(std::string const& name, int32_t hiddenSize, int32_t numHeads, float dqProbs, + bool useInt8ScaleMax, bool useExplicitInt8, float qkvScale, float ctxScale); + + QKVToContextInterleavedPluginLegacy(std::string const& name, void const* data, size_t length); + + // It doesn't make sense to make QKVToContextInterleavedPluginLegacy without arguments, so we + // delete default constructor. + QKVToContextInterleavedPluginLegacy() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + +protected: + void createMHARunner() noexcept; + +private: + std::string const& mLayerName; + std::string mNamespace; + + int32_t mS{}; + int32_t mB{}; + int32_t mSM{}; + int32_t mHeadSize{}; + int32_t mHiddenSize{}; + int32_t mNumHeads{}; + + FusedMultiHeadAttentionXMMAKernelV2 const* mXmmaKernel; + + float mDqProbs{}; + bool mUseInt8ScaleMax{true}; + + bool mUseExplicitInt8{}; + float mQkvScale{}; + float mCtxScale{}; +}; + +class QKVToContextInterleavedPluginLegacyCreator : public nvinfer1::IPluginCreator +{ +public: + QKVToContextInterleavedPluginLegacyCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + +private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace plugin +} // namespace nvinfer1 +#endif // TRT_QKV_TO_CONTEXT_INTERLEAVED_PLUGIN_LEGACY_H diff --git a/plugin/bertQKVToContextPlugin/qkvToContextPlugin.cpp b/plugin/bertQKVToContextPlugin/qkvToContextPlugin.cpp index 0562e2ea..294d020c 100644 --- a/plugin/bertQKVToContextPlugin/qkvToContextPlugin.cpp +++ b/plugin/bertQKVToContextPlugin/qkvToContextPlugin.cpp @@ -24,6 +24,7 @@ #include "bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h" #include "common/bertCommon.h" #include "common/serialize.hpp" +#include "mhaRunner.h" #include "qkvToContextPlugin.h" #include @@ -39,8 +40,8 @@ using namespace nvinfer1::pluginInternal; namespace { -char const* const kQKV_TO_CONTEXT_PLUGIN_VERSION{"1"}; -char const* const kQKV_TO_CONTEXT_VAR_SEQLEN_PLUGIN_VERSION{"2"}; +char const* const kQKV_TO_CONTEXT_PLUGIN_VERSION{"4"}; +char const* const kQKV_TO_CONTEXT_VAR_SEQLEN_PLUGIN_VERSION{"5"}; char const* const kQKV_TO_CONTEXT_PLUGIN_NAME{"CustomQKVToContextPluginDynamic"}; } // namespace @@ -50,14 +51,15 @@ std::vector QKVToContextPluginDynamicCreator::mPluginAttributes; REGISTER_TENSORRT_PLUGIN(QKVToContextPluginDynamicCreator); -constexpr uint32_t IIDX = 0; // index of the input tensor -constexpr uint32_t MIDX = 1; // index of the mask +constexpr uint32_t kIIDX = 0; // index of the input tensor +constexpr uint32_t kMIDX = 1; // index of the mask // Static class fields initialization PluginFieldCollection QKVToContextVarSeqlenPluginCreator::mFC{}; std::vector QKVToContextVarSeqlenPluginCreator::mPluginAttributes; REGISTER_TENSORRT_PLUGIN(QKVToContextVarSeqlenPluginCreator); +QKVToContextPluginDynamic::~QKVToContextPluginDynamic() {} QKVToContextPluginDynamic::QKVToContextPluginDynamic(const std::string name, const DataType type, const int32_t hiddenSize, const int32_t numHeads, float const dqProbs, bool hasImask) @@ -67,40 +69,66 @@ QKVToContextPluginDynamic::QKVToContextPluginDynamic(const std::string name, con , mHeadSize(hiddenSize / numHeads) , mHiddenSize(hiddenSize) , mNumHeads(numHeads) - , mHasImask(hasImask) , mType(type) , mDqProbs(dqProbs) { + mHasImask = static_cast(hasImask); mSM = getSMVersion(); } -QKVToContextPluginDynamic::QKVToContextPluginDynamic(const std::string name, void const* data, size_t length) +QKVToContextPluginDynamic::QKVToContextPluginDynamic(const std::string name, const DataType type, const int32_t S, + const int32_t B, const int32_t SM, const int32_t hiddenSize, const int32_t numHeads, float const dqProbs, + bool hasImask, bool hasUnfusedDispatcher, void const* runnerStateBuffer) : mLayerName(name) + , mS(S) + , mB(B) + , mSM(SM) + , mHeadSize(hiddenSize / numHeads) + , mHiddenSize(hiddenSize) + , mNumHeads(numHeads) + , mType(type) + , mDqProbs(dqProbs) + { - BERT_DEBUG_MSG("QKV Deser Start"); - deserialize_value(&data, &length, &mType); - deserialize_value(&data, &length, &mNumHeads); - deserialize_value(&data, &length, &mHeadSize); - deserialize_value(&data, &length, &mHasImask); - deserialize_value(&data, &length, &mHiddenSize); - deserialize_value(&data, &length, &mSM); - deserialize_value(&data, &length, &mS); - deserialize_value(&data, &length, &mB); - - deserialize_value(&data, &length, &mDqProbs); + BERT_DEBUG_MSG("MHA Runner Deser"); + mHasImask = static_cast(hasImask); + mHasUnfusedDispatcher = static_cast(hasUnfusedDispatcher); createMHARunner(); - int32_t hasUnfusedRunner = 0; - deserialize_value(&data, &length, &hasUnfusedRunner); - if (hasUnfusedRunner) + if (hasUnfusedDispatcher) { PLUGIN_ASSERT(unfusedDispatcher.get()); - unfusedDispatcher->deserialize(data, length); + PLUGIN_ASSERT(runnerStateBuffer != nullptr); + auto length = unfusedDispatcher->getSerializationSize(); + unfusedDispatcher->deserialize(runnerStateBuffer, length); } - BERT_DEBUG_MSG("QKV Deser done"); + BERT_DEBUG_MSG("MHA Runner Deser Done"); +} + + +IPluginCapability* QKVToContextPluginDynamic::getCapabilityInterface(PluginCapabilityType type) noexcept +{ + try + { + if (type == PluginCapabilityType::kBUILD) + { + return static_cast(this); + } + if (type == PluginCapabilityType::kRUNTIME) + { + return static_cast(this); + } + PLUGIN_ASSERT(type == PluginCapabilityType::kCORE); + return static_cast(this); + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; } void QKVToContextPluginDynamic::createMHARunner() @@ -123,52 +151,64 @@ void QKVToContextPluginDynamic::createMHARunner() } } -// IPluginV2DynamicExt Methods -nvinfer1::IPluginV2DynamicExt* QKVToContextPluginDynamic::clone() const noexcept +IPluginV3* QKVToContextPluginDynamic::clone() noexcept { BERT_DEBUG_MSG("QKV Clone"); QKVToContextPluginDynamic* ret = nullptr; + mHasUnfusedDispatcher = 0; + char* bufferData = nullptr; // the workspacesize is 0 if we have not call setup the dispatcher yet. if (unfusedDispatcher.get() && unfusedDispatcher->getWorkspaceSize()) { - std::vector buff; - buff.resize(getSerializationSize()); - serialize(buff.data()); - - ret = new QKVToContextPluginDynamic(mLayerName, buff.data(), buff.size()); - } - else - { - ret = new QKVToContextPluginDynamic(mLayerName, mType, mHiddenSize, mNumHeads, mDqProbs, mHasImask); + mHasUnfusedDispatcher = 1; + mRunnerStateBuffer.resize(unfusedDispatcher->getSerializationSize()); + unfusedDispatcher->serialize(mRunnerStateBuffer.data()); + bufferData = mRunnerStateBuffer.data(); } + ret = new QKVToContextPluginDynamic(mLayerName, mType, mS, mB, mSM, mHiddenSize, mNumHeads, mDqProbs, + static_cast(mHasImask), mHasUnfusedDispatcher, static_cast(bufferData)); ret->setPluginNamespace(mNamespace.c_str()); BERT_DEBUG_MSG("QKV Clone done"); return ret; } -DimsExprs QKVToContextPluginDynamic::getOutputDimensions( - int32_t outputIndex, DimsExprs const* inputs, int32_t /*nbInputs*/, IExprBuilder& exprBuilder) noexcept +int32_t QKVToContextPluginDynamic::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, + DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, + IExprBuilder& exprBuilder) noexcept { - // Input is BxSx3*N*H, output should be BxSxN*H - PLUGIN_ASSERT(outputIndex == 0); - // Copy over everything - DimsExprs output(inputs[IIDX]); - // Divide last dim by three - auto const* three = exprBuilder.constant(3); - output.d[HDIM] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[IIDX].d[HDIM], *three); - return output; + try + { + PLUGIN_ASSERT(inputs != nullptr); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask); + PLUGIN_ASSERT(nbShapeInputs == 0); + PLUGIN_ASSERT(outputs != nullptr); + PLUGIN_ASSERT(nbOutputs == 1); + // Input is BxSx3*N*H, output should be BxSxN*H + // Copy over everything + outputs[kIIDX] = inputs[kIIDX]; + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + outputs[kIIDX].d[HDIM] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[kIIDX].d[HDIM], *three); + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; } + bool QKVToContextPluginDynamic::supportsFormatCombination( - int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t /*nbOutputs*/) noexcept + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t /*nbOutputs*/) noexcept { PLUGIN_ASSERT(pos >= 0); PLUGIN_ASSERT(pos < 2 + mHasImask); PLUGIN_ASSERT(nbInputs == 1 + mHasImask); auto const* in = inOut; auto const* out = inOut + nbInputs; - int32_t packedSize = getMHAMaskPackedSize(mSM, mType, in->dims.d[SDIM]); + int32_t packedSize = getMHAMaskPackedSize(mSM, mType, in->desc.dims.d[SDIM]); // we only support int8 IO in fused mha runner, and we only support fused mha runner on Xavier, Turing and Ampere if (mType == DataType::kINT8) @@ -179,7 +219,7 @@ bool QKVToContextPluginDynamic::supportsFormatCombination( << kQKV_TO_CONTEXT_PLUGIN_NAME << std::endl; return false; } - if (in->dims.d[SDIM] == -1) + if (in->desc.dims.d[SDIM] == -1) { gLogError << "INT8 IO not support dynamic shape in sequence dimension for plugin " << kQKV_TO_CONTEXT_PLUGIN_NAME << std::endl; @@ -192,46 +232,36 @@ bool QKVToContextPluginDynamic::supportsFormatCombination( return false; } } - if (mType == DataType::kHALF) - { - if (mSM < kSM_53) - { - gLogError - << "Half-precision floating-point is only supported on compute capability 5.3 and later for plugin " - << kQKV_TO_CONTEXT_PLUGIN_NAME << std::endl; - return false; - } - } if (pos == 0) { - bool isFormatSupported = in->format == TensorFormat::kLINEAR; + bool isFormatSupported = in->desc.format == TensorFormat::kLINEAR; if (mType == DataType::kINT8) { - if (in->dims.d[HDIM] % 32U == 0) + if (in->desc.dims.d[HDIM] % 32U == 0) { - isFormatSupported = in->format == TensorFormat::kCHW32; + isFormatSupported = in->desc.format == TensorFormat::kCHW32; } else { - isFormatSupported = in->format == TensorFormat::kCHW4; + isFormatSupported = in->desc.format == TensorFormat::kCHW4; } } // must not check descriptions > pos - return (in->type == mType) && // precision - isFormatSupported && // format - (in->dims.nbDims == 5) && // num dims - ((in->dims.d[HDIM] % 3U) == 0) && // see getOutputDimensions - ((in->dims.d[3]) == 1) && // for fc - ((in->dims.d[4]) == 1) // for fc + return (in->desc.type == mType) && // precision + isFormatSupported && // format + (in->desc.dims.nbDims == 5) && // num dims + ((in->desc.dims.d[HDIM] % 3U) == 0) && // see getOutputDimensions + ((in->desc.dims.d[3]) == 1) && // for fc + ((in->desc.dims.d[4]) == 1) // for fc ; } // pos==1 if ((mHasImask && pos == 1)) // pos 1 is the mask { - auto const* inMask = &inOut[1]; + auto const* inMask = &inOut[1].desc; if (inMask->dims.d[1] != -1 && inMask->dims.d[1] != packedSize) { gLogError << "CustomEmbLayerNormPluginDynamic returned mask with pack size " << inMask->dims.d[1] @@ -244,106 +274,167 @@ bool QKVToContextPluginDynamic::supportsFormatCombination( return (inMask->type == DataType::kINT32) && // precision (inMask->format == TensorFormat::kLINEAR) && // format (inMask->dims.nbDims == 2) && // Bx2*maskSize - (inMask->dims.d[0] == in->dims.d[BDIM]); + (inMask->dims.d[0] == in->desc.dims.d[BDIM]); } if (!mHasImask || pos == 2) // output pos { - bool isFormatSupported = out->format == TensorFormat::kLINEAR; + bool isFormatSupported = out->desc.format == TensorFormat::kLINEAR; if (mType == DataType::kINT8) { - if (out->dims.d[HDIM] % 32U == 0) + if (out->desc.dims.d[HDIM] % 32U == 0) { - isFormatSupported = out->format == TensorFormat::kCHW32; + isFormatSupported = out->desc.format == TensorFormat::kCHW32; } else { - isFormatSupported = out->format == TensorFormat::kCHW4; + isFormatSupported = out->desc.format == TensorFormat::kCHW4; } } - return (in->type == out->type) && // precision - isFormatSupported && // format - (out->dims.nbDims == 5) && // num dims - ((in->dims.d[HDIM] / 3) == (out->dims.d[HDIM])) && // div 3 - ((out->dims.d[3]) == 1) && // for fc - ((out->dims.d[4]) == 1) && // for fc - ((out->dims.d[BDIM]) == in->dims.d[BDIM]) && // check B - ((out->dims.d[SDIM]) == in->dims.d[SDIM]) // check S + return (in->desc.type == out->desc.type) && // precision + isFormatSupported && // format + (out->desc.dims.nbDims == 5) && // num dims + ((in->desc.dims.d[HDIM] / 3) == (out->desc.dims.d[HDIM])) && // div 3 + ((out->desc.dims.d[3]) == 1) && // for fc + ((out->desc.dims.d[4]) == 1) && // for fc + ((out->desc.dims.d[BDIM]) == in->desc.dims.d[BDIM]) && // check B + ((out->desc.dims.d[SDIM]) == in->desc.dims.d[SDIM]) // check S ; } return false; } -void QKVToContextPluginDynamic::configurePlugin( - DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept + +int32_t QKVToContextPluginDynamic::onShapeChange( + PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept { - PLUGIN_ASSERT(nbInputs == 1 + mHasImask); - PLUGIN_ASSERT(nbOutputs == 1); - PluginTensorDesc const& inDesc = in[IIDX].desc; - TRT_UNUSED inDesc; - PluginTensorDesc const& outDesc = out->desc; - TRT_UNUSED outDesc; - PLUGIN_ASSERT(mType == inDesc.type); - PLUGIN_ASSERT(mType == outDesc.type); - PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); - PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); - PLUGIN_ASSERT(inDesc.dims.d[HDIM] == 3 * outDesc.dims.d[HDIM]); - if (mHasImask) + try { - PluginTensorDesc const& maskDesc = in[MIDX].desc; - TRT_UNUSED maskDesc; - PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); - } + PLUGIN_ASSERT(in != nullptr); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask); + PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[kIIDX]; + TRT_UNUSED inDesc; + PLUGIN_ASSERT(out != nullptr); + PluginTensorDesc const& outDesc = out[0]; + TRT_UNUSED outDesc; + PLUGIN_ASSERT(mType == inDesc.type); + PLUGIN_ASSERT(mType == outDesc.type); + PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); + PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); + PLUGIN_ASSERT(inDesc.dims.d[HDIM] == 3 * outDesc.dims.d[HDIM]); + if (mHasImask) + { + PluginTensorDesc const& maskDesc = in[kMIDX]; + TRT_UNUSED maskDesc; + PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); + } - createMHARunner(); + createMHARunner(); + + // during build, configurePlugin() should have set mS, mB appropriately + // during inference, the engine should have mS, mB information + PLUGIN_ASSERT(mS); + PLUGIN_ASSERT(mB); + if (fusedDispatcher.get() && fusedDispatcher->isValid(mHeadSize, mS)) + { + fusedDispatcher->setup(mS, mB, mHeadSize); + } + else + { + unfusedDispatcher->setup(mS, mB, mHeadSize); + } - const int32_t S = inDesc.dims.d[SDIM]; - const int32_t B = inDesc.dims.d[BDIM] <= 0 ? in->max.d[BDIM] : inDesc.dims.d[BDIM]; - if (S <= 0) + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; +} + +int32_t QKVToContextPluginDynamic::configurePlugin( + DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +{ + try { - // in dynamic shape build stage, we setup with max sequence that cannot fused - const int32_t Smin = in->min.d[SDIM]; - const int32_t Smax = in->max.d[SDIM]; + PLUGIN_ASSERT(in != nullptr); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask); + PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[kIIDX].desc; + TRT_UNUSED inDesc; + PLUGIN_ASSERT(out != nullptr); + PluginTensorDesc const& outDesc = out->desc; + TRT_UNUSED outDesc; + PLUGIN_ASSERT(mType == inDesc.type); + PLUGIN_ASSERT(mType == outDesc.type); + PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); + PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); + PLUGIN_ASSERT(inDesc.dims.d[HDIM] == 3 * outDesc.dims.d[HDIM]); + if (mHasImask) + { + PluginTensorDesc const& maskDesc = in[kMIDX].desc; + TRT_UNUSED maskDesc; + PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); + } + + createMHARunner(); - if (fusedDispatcher.get()) + const int32_t S = inDesc.dims.d[SDIM]; + const int32_t B = inDesc.dims.d[BDIM] <= 0 ? in->max.d[BDIM] : inDesc.dims.d[BDIM]; + if (S <= 0) { - for (int32_t i = Smax; i >= Smin; --i) + // in dynamic shape build stage, we setup with max sequence that cannot fused + const int32_t Smin = in->min.d[SDIM]; + const int32_t Smax = in->max.d[SDIM]; + + if (fusedDispatcher.get()) { - if (!fusedDispatcher->isValid(mHeadSize, i)) + for (int32_t i = Smax; i >= Smin; --i) { - unfusedDispatcher->setup(i, B, mHeadSize); - mS = i; - mB = B; - break; + if (!fusedDispatcher->isValid(mHeadSize, i)) + { + unfusedDispatcher->setup(i, B, mHeadSize); + mS = i; + mB = B; + break; + } } } + else + { + unfusedDispatcher->setup(Smax, B, mHeadSize); + mS = Smax; + mB = B; + } } else { - unfusedDispatcher->setup(Smax, B, mHeadSize); - mS = Smax; + // in inference stage or in static shape build stage + if (fusedDispatcher.get() && fusedDispatcher->isValid(mHeadSize, S)) + { + fusedDispatcher->setup(S, B, mHeadSize); + } + else + { + unfusedDispatcher->setup(S, B, mHeadSize); + } + mS = S; mB = B; } + return pluginStatus_t::STATUS_SUCCESS; } - else + catch (std::exception const& e) { - // in inference stage or in static shape build stage - if (fusedDispatcher.get() && fusedDispatcher->isValid(mHeadSize, S)) - { - fusedDispatcher->setup(S, B, mHeadSize); - } - else - { - unfusedDispatcher->setup(S, B, mHeadSize); - } - mS = S; - mB = B; + caughtError(e); } + return pluginStatus_t::STATUS_FAILURE; } -size_t QKVToContextPluginDynamic::getWorkspaceSize(PluginTensorDesc const* /*inputs*/, int32_t /*nbInputs*/, - PluginTensorDesc const* /*outputs*/, int32_t /*nbOutputs*/) const noexcept +size_t QKVToContextPluginDynamic::getWorkspaceSize(DynamicPluginTensorDesc const* /*inputs*/, int32_t /*nbInputs*/, + DynamicPluginTensorDesc const* /*outputs*/, int32_t /*nbOutputs*/) const noexcept { // only unfused kernel need workspace, and we need larger workspace for larger sequence length // we have already setup unfusedDispatcher with max sequence in configurePlugin @@ -353,34 +444,44 @@ size_t QKVToContextPluginDynamic::getWorkspaceSize(PluginTensorDesc const* /*inp } // IPluginV2Ext Methods -DataType QKVToContextPluginDynamic::getOutputDataType( - int32_t index, nvinfer1::DataType const* inputTypes, int32_t /*nbInputs*/) const noexcept +int32_t QKVToContextPluginDynamic::getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept { - PLUGIN_ASSERT(index == 0); - PLUGIN_ASSERT( - inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || inputTypes[0] == DataType::kINT8); - return inputTypes[0]; + try + { + PLUGIN_ASSERT( + inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || inputTypes[0] == DataType::kINT8); + outputTypes[0] = inputTypes[0]; + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; +} + +void QKVToContextPluginDynamic::setCublasHandle(cublasContext* cublasHandle) +{ + mCublas = cublasHandle; } -void QKVToContextPluginDynamic::attachToContext( - cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept +IPluginV3* QKVToContextPluginDynamic::attachToContext(IPluginResourceContext* context) noexcept { try { - mCublasWrapper = createPluginCublasWrapper(allocator); - mCublas = mCublasWrapper->getCublasHandle(); - PLUGIN_VALIDATE(mCublas != nullptr); + auto p = static_cast(clone()); + mCublasWrapper = createPluginCublasWrapper(context); + auto cublasHandle = mCublasWrapper->getCublasHandle(); + PLUGIN_VALIDATE(cublasHandle != nullptr); + p->setCublasHandle(cublasHandle); + return p; } catch (const std::exception& e) { caughtError(e); } -} - -// IPluginV2 Methods -char const* QKVToContextPluginDynamic::getPluginType() const noexcept -{ - return kQKV_TO_CONTEXT_PLUGIN_NAME; + return nullptr; } char const* QKVToContextPluginDynamic::getPluginVersion() const noexcept @@ -393,50 +494,11 @@ int32_t QKVToContextPluginDynamic::getNbOutputs() const noexcept return 1; } -int32_t QKVToContextPluginDynamic::initialize() noexcept -{ - return 0; -} - -void QKVToContextPluginDynamic::terminate() noexcept {} - -size_t QKVToContextPluginDynamic::getSerializationSize() const noexcept -{ - PLUGIN_ASSERT(unfusedDispatcher.get()); - return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(DataType) + sizeof(mHasImask) + sizeof(mHiddenSize) - + sizeof(mSM) + sizeof(mS) + sizeof(mB) + sizeof(mDqProbs) + sizeof(int32_t) - + unfusedDispatcher->getSerializationSize(); -} - -void QKVToContextPluginDynamic::serialize(void* buffer) const noexcept +char const* QKVToContextPluginDynamic::getPluginName() const noexcept { - serialize_value(&buffer, mType); - serialize_value(&buffer, mNumHeads); - serialize_value(&buffer, mHeadSize); - serialize_value(&buffer, mHasImask); - serialize_value(&buffer, mHiddenSize); - serialize_value(&buffer, mSM); - serialize_value(&buffer, mS); - serialize_value(&buffer, mB); - - serialize_value(&buffer, mDqProbs); - if (unfusedDispatcher.get() && unfusedDispatcher->getWorkspaceSize()) - { - int32_t hasUnfusedRunner = 1; - serialize_value(&buffer, hasUnfusedRunner); - unfusedDispatcher->serialize(buffer); - } - else - { - int32_t hasUnfusedRunner = 0; - serialize_value(&buffer, hasUnfusedRunner); - } + return kQKV_TO_CONTEXT_PLUGIN_NAME; } -void QKVToContextPluginDynamic::destroy() noexcept -{ - delete this; -} void QKVToContextPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { @@ -479,6 +541,44 @@ int32_t QKVToContextPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, Pl return 0; } +PluginFieldCollection const* QKVToContextPluginDynamic::getFieldsToSerialize() noexcept +{ + mDataToSerialize.clear(); + + mDataToSerialize.emplace_back("type_id", &mType, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("hidden_size", &mHiddenSize, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("num_heads", &mNumHeads, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("has_mask", &mHasImask, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("S", &mS, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("B", &mB, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("SM", &mSM, PluginFieldType::kINT32, 1); + + if (unfusedDispatcher.get() && unfusedDispatcher->getWorkspaceSize()) + { + mHasUnfusedDispatcher = 1; + mRunnerStateBuffer.resize(unfusedDispatcher->getSerializationSize()); + unfusedDispatcher->serialize(mRunnerStateBuffer.data()); + mDataToSerialize.emplace_back("runnerStateBuffer", (void const*) mRunnerStateBuffer.data(), + PluginFieldType::kUNKNOWN, mRunnerStateBuffer.size()); + } + else + { + mHasUnfusedDispatcher = 0; + } + + mDataToSerialize.emplace_back("hasUnfusedDispatcher", &mHasUnfusedDispatcher, PluginFieldType::kINT32, 1); + + if (mDqProbs >= 0) + { + mDataToSerialize.emplace_back("dq_probs", &mDqProbs, PluginFieldType::kFLOAT32, 1); + } + + mFCToSerialize.nbFields = mDataToSerialize.size(); + mFCToSerialize.fields = mDataToSerialize.data(); + + return &mFCToSerialize; +} + QKVToContextPluginDynamicCreator::QKVToContextPluginDynamicCreator() { mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); @@ -506,7 +606,8 @@ PluginFieldCollection const* QKVToContextPluginDynamicCreator::getFieldNames() n return &mFC; } -IPluginV2* QKVToContextPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept +IPluginV3* QKVToContextPluginDynamicCreator::createPlugin( + char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept { try { @@ -519,11 +620,24 @@ IPluginV2* QKVToContextPluginDynamicCreator::createPlugin(char const* name, Plug int32_t numHeads{-1}; bool hasMask = false; int32_t typeId = -1; - + int32_t s = -1; + int32_t b = -1; + int32_t sm = -1; + bool hasUnfusedDispatcher = false; + void const* runnerStateBuffer; float dqProbs = -1; PLUGIN_VALIDATE(fc->fields != nullptr); - plugin::validateRequiredAttributesExist({"type_id", "hidden_size", "num_heads", "has_mask"}, fc); + if (phase == TensorRTPhase::kBUILD) + { + plugin::validateRequiredAttributesExist({"type_id", "hidden_size", "num_heads", "has_mask"}, fc); + } + else + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + plugin::validateRequiredAttributesExist( + {"type_id", "S", "B", "hidden_size", "num_heads", "has_mask", "SM", "hasUnfusedDispatcher"}, fc); + } for (int32_t i = 0; i < fc->nbFields; i++) { @@ -537,19 +651,19 @@ IPluginV2* QKVToContextPluginDynamicCreator::createPlugin(char const* name, Plug PLUGIN_VALIDATE(typeId >= 0 && typeId <= 2, ("QKV: Invalid TypeId " + std::to_string(typeId)).c_str()); BERT_DEBUG_VALUE("Building typeId: ", typeId); } - if (field_name.compare("hidden_size") == 0) + else if (field_name.compare("hidden_size") == 0) { hiddenSize = *static_cast(fc->fields[i].data); PLUGIN_VALIDATE(hiddenSize > 0, ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); BERT_DEBUG_VALUE("Building hiddenSize: ", hiddenSize); } - if (field_name.compare("num_heads") == 0) + else if (field_name.compare("num_heads") == 0) { numHeads = *static_cast(fc->fields[i].data); PLUGIN_VALIDATE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); BERT_DEBUG_VALUE("Building numHeads: ", numHeads); } - if (field_name.compare("has_mask") == 0) + else if (field_name.compare("has_mask") == 0) { auto hasMaskValue = *static_cast(fc->fields[i].data); PLUGIN_VALIDATE(hasMaskValue == 0 || hasMaskValue == 1, @@ -557,13 +671,44 @@ IPluginV2* QKVToContextPluginDynamicCreator::createPlugin(char const* name, Plug hasMask = static_cast(hasMaskValue); BERT_DEBUG_VALUE("Building hasMask: ", hasMask); } - - if (field_name.compare("dq_probs") == 0) + else if (field_name.compare("dq_probs") == 0) { dqProbs = *static_cast(fc->fields[i].data); PLUGIN_VALIDATE(dqProbs > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs)).c_str()); BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs); } + else if (field_name.compare("S") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + s = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building S: ", s); + } + else if (field_name.compare("B") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + b = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building B: ", b); + } + else if (field_name.compare("SM") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + sm = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building SM: ", sm); + } + else if (field_name.compare("hasUnfusedDispatcher") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + auto hasUnfusedDispatcherValue = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(hasUnfusedDispatcherValue == 0 || hasUnfusedDispatcherValue == 1, + ("QKV: Invalid hasUnfusedDispatcher " + std::to_string(hasUnfusedDispatcherValue)).c_str()); + hasUnfusedDispatcher = static_cast(hasUnfusedDispatcherValue); + BERT_DEBUG_VALUE("Building hasUnfusedDispatcher: ", hasUnfusedDispatcher); + } + else if (field_name.compare("runnerStateBuffer") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + runnerStateBuffer = static_cast(fc->fields[i].data); + } } BERT_DEBUG_MSG("Building the Plugin..."); @@ -574,8 +719,25 @@ IPluginV2* QKVToContextPluginDynamicCreator::createPlugin(char const* name, Plug dqProbs = 1.F / 127.F; } - auto* p = new QKVToContextPluginDynamic(name, type, hiddenSize, numHeads, dqProbs, hasMask); - return p; + if (phase == TensorRTPhase::kBUILD) + { + return new QKVToContextPluginDynamic(name, type, hiddenSize, numHeads, dqProbs, hasMask); + } + + PLUGIN_VALIDATE(s != -1, "invalid S during runtime plugin creation"); + PLUGIN_VALIDATE(b != -1, "invalid B during runtime plugin creation"); + PLUGIN_VALIDATE(sm != -1, "invalid SM during runtime plugin creation"); + if (hasUnfusedDispatcher == 1) + { + PLUGIN_VALIDATE(runnerStateBuffer != nullptr, "invalid runnerStateBuffer during runtime plugin creation"); + } + else + { + PLUGIN_VALIDATE(runnerStateBuffer == nullptr, "invalid runnerStateBuffer during runtime plugin creation"); + } + + return new QKVToContextPluginDynamic( + name, type, s, b, sm, hiddenSize, numHeads, dqProbs, hasMask, hasUnfusedDispatcher, runnerStateBuffer); } catch (std::exception const& e) { @@ -584,13 +746,6 @@ IPluginV2* QKVToContextPluginDynamicCreator::createPlugin(char const* name, Plug return nullptr; } -IPluginV2* QKVToContextPluginDynamicCreator::deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept -{ - // This object will be deleted when the network is destroyed, which will - // call QKVToContextPluginDynamic::destroy() - return new QKVToContextPluginDynamic(name, serialData, serialLength); -} void QKVToContextPluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { @@ -602,6 +757,11 @@ char const* QKVToContextPluginDynamicCreator::getPluginNamespace() const noexcep return mNamespace.c_str(); } + +///// QKVToContextVarSeqlenPlugin (CustomQKVToContextPluginDynamic v5) //// + +QKVToContextVarSeqlenPlugin::~QKVToContextVarSeqlenPlugin() {} + QKVToContextVarSeqlenPlugin::QKVToContextVarSeqlenPlugin(std::string const name, DataType const type, int32_t const hiddenSize, int32_t const numHeads, float const dqProbs, bool hasImask, bool varSeqlen, bool useInt8ScaleMax) @@ -611,21 +771,21 @@ QKVToContextVarSeqlenPlugin::QKVToContextVarSeqlenPlugin(std::string const name, , mHeadSize(hiddenSize / numHeads) , mHiddenSize(hiddenSize) , mNumHeads(numHeads) - , mHasImask(hasImask) , mType(type) , mDqProbs(dqProbs) , mHdim(HDIM) - , mUseVarSeqlen(varSeqlen) - , mUseInt8ScaleMax(useInt8ScaleMax) { mSM = getSMVersion(); + mUseVarSeqlen = static_cast(varSeqlen); + mUseInt8ScaleMax = static_cast(useInt8ScaleMax); + mHasImask = static_cast(hasImask); if (varSeqlen) { // variable sequence length is only supported with the fused MHA kernels // we should not override mS! - PLUGIN_ASSERT((mSM == kSM_90 || mSM == kSM_87 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_80 - || mSM == kSM_75 || mSM == kSM_72) + PLUGIN_ASSERT( + (mSM == kSM_90 || mSM == kSM_87 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_80 || mSM == kSM_75) && (type == DataType::kINT8 || type == DataType::kHALF) && "requesting maxSeqlen not compatible with GPU arch"); // the layout changes: SxB will be a combined \sum_i s_i and hdim will be the 2nd dimension instead of the third @@ -633,29 +793,64 @@ QKVToContextVarSeqlenPlugin::QKVToContextVarSeqlenPlugin(std::string const name, } } -QKVToContextVarSeqlenPlugin::QKVToContextVarSeqlenPlugin(const std::string name, void const* data, size_t length) +QKVToContextVarSeqlenPlugin::QKVToContextVarSeqlenPlugin(std::string const name, int32_t const S, int32_t const B, + DataType const type, int32_t const hiddenSize, int32_t const numHeads, float const dqProbs, bool hasImask, + bool varSeqlen, bool useInt8ScaleMax, void const* runnerStateBuffer) : mLayerName(name) + , mS(S) + , mB(B) + , mHeadSize(hiddenSize / numHeads) + , mHiddenSize(hiddenSize) + , mNumHeads(numHeads) + , mType(type) + , mDqProbs(dqProbs) + , mHdim(HDIM) { - BERT_DEBUG_MSG("QKV Deser Start"); - deserialize_value(&data, &length, &mType); - deserialize_value(&data, &length, &mNumHeads); - deserialize_value(&data, &length, &mHeadSize); - deserialize_value(&data, &length, &mHasImask); - deserialize_value(&data, &length, &mHiddenSize); - deserialize_value(&data, &length, &mSM); - deserialize_value(&data, &length, &mS); - deserialize_value(&data, &length, &mB); - - deserialize_value(&data, &length, &mDqProbs); - - deserialize_value(&data, &length, &mUseVarSeqlen); - deserialize_value(&data, &length, &mHdim); - deserialize_value(&data, &length, &mUseInt8ScaleMax); + mSM = getSMVersion(); + mUseVarSeqlen = static_cast(varSeqlen); + mUseInt8ScaleMax = static_cast(useInt8ScaleMax); + mHasImask = static_cast(hasImask); + + if (varSeqlen) + { + // variable sequence length is only supported with the fused MHA kernels + // we should not override mS! + PLUGIN_ASSERT( + (mSM == kSM_90 || mSM == kSM_87 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_80 || mSM == kSM_75) + && (type == DataType::kINT8 || type == DataType::kHALF) + && "requesting maxSeqlen not compatible with GPU arch"); + // the layout changes: SxB will be a combined \sum_i s_i and hdim will be the 2nd dimension instead of the third + mHdim = 1; + } createMHARunner(); - mDispatcher->deserialize(data, length); - BERT_DEBUG_MSG("QKV Deser done"); + PLUGIN_ASSERT(runnerStateBuffer != nullptr); + auto length = mDispatcher->getSerializationSize(); + mDispatcher->deserialize(runnerStateBuffer, length); +} + + +IPluginCapability* QKVToContextVarSeqlenPlugin::getCapabilityInterface(PluginCapabilityType type) noexcept +{ + try + { + if (type == PluginCapabilityType::kBUILD) + { + return static_cast(this); + } + if (type == PluginCapabilityType::kRUNTIME) + { + return static_cast(this); + } + PLUGIN_ASSERT(type == PluginCapabilityType::kCORE); + return static_cast(this); + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; } void QKVToContextVarSeqlenPlugin::createMHARunner() @@ -690,22 +885,27 @@ void QKVToContextVarSeqlenPlugin::createMHARunner() } } -// IPluginV2DynamicExt Methods -nvinfer1::IPluginV2DynamicExt* QKVToContextVarSeqlenPlugin::clone() const noexcept + +IPluginV3* QKVToContextVarSeqlenPlugin::clone() noexcept { BERT_DEBUG_MSG("QKV Clone"); QKVToContextVarSeqlenPlugin* ret = nullptr; + + char* bufferData = nullptr; + // the workspacesize is 0 if we have not call setup the dispatcher yet. if (mDispatcher.get()) { - std::vector buff; - buff.resize(getSerializationSize()); - serialize(buff.data()); + mRunnerStateBuffer.resize(mDispatcher->getSerializationSize()); + mDispatcher->serialize(mRunnerStateBuffer.data()); + bufferData = mRunnerStateBuffer.data(); - ret = new QKVToContextVarSeqlenPlugin(mLayerName, buff.data(), buff.size()); + ret = new QKVToContextVarSeqlenPlugin(mLayerName, mS, mB, mType, mHiddenSize, mNumHeads, mDqProbs, mHasImask, + mUseVarSeqlen, mUseInt8ScaleMax, static_cast(bufferData)); } else { + // dispatcher not setup yet, use type 1 constructor ret = new QKVToContextVarSeqlenPlugin( mLayerName, mType, mHiddenSize, mNumHeads, mDqProbs, mHasImask, mUseVarSeqlen, mUseInt8ScaleMax); } @@ -715,26 +915,41 @@ nvinfer1::IPluginV2DynamicExt* QKVToContextVarSeqlenPlugin::clone() const noexce return ret; } -DimsExprs QKVToContextVarSeqlenPlugin::getOutputDimensions( - int32_t outputIndex, DimsExprs const* inputs, int32_t /*nbInputs*/, IExprBuilder& exprBuilder) noexcept +int32_t QKVToContextVarSeqlenPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, + DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, + IExprBuilder& exprBuilder) noexcept { - // Input is BxSx3*N*H, output should be BxSxN*H - PLUGIN_ASSERT(outputIndex == 0); - // Copy over everything - DimsExprs output(inputs[IIDX]); - // Divide last dim by three - auto const* three = exprBuilder.constant(3); - output.d[mHdim] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[IIDX].d[mHdim], *three); - return output; + try + { + PLUGIN_ASSERT(inputs != nullptr); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask + 2 * mUseVarSeqlen); + PLUGIN_ASSERT(nbShapeInputs == 0); + PLUGIN_ASSERT(outputs != nullptr); + PLUGIN_ASSERT(nbOutputs == 1); + // Input is BxSx3*N*H, output should be BxSxN*H + // Copy over everything + outputs[kIIDX] = inputs[kIIDX]; + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + // mHdim is 2 for fixed seqlen and is 1 for varseqlen + outputs[kIIDX].d[mHdim] + = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[kIIDX].d[mHdim], *three); + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; } bool QKVToContextVarSeqlenPlugin::supportsFormatCombination( - int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept { // we only support variable sequence and int8 IO in fused mha runner, and we only support fused mha runner on // Turing, Ampere and Hopper - bool const hasV2Kernels = (mSM == kSM_90 || mSM == kSM_89 || mSM == kSM_87 || mSM == kSM_86 || mSM == kSM_80 - || mSM == kSM_75 || mSM == kSM_72); + bool const hasV2Kernels + = (mSM == kSM_90 || mSM == kSM_89 || mSM == kSM_87 || mSM == kSM_86 || mSM == kSM_80 || mSM == kSM_75); PLUGIN_ASSERT( (mType != DataType::kINT8 || hasV2Kernels) && "INT8 IO is only supported on Xavier, Turing, Ampere and Hopper"); PLUGIN_ASSERT( @@ -754,12 +969,8 @@ bool QKVToContextVarSeqlenPlugin::supportsFormatCombination( PLUGIN_ASSERT(nbInputs == 4 && "for varseqlen, expected 4 inputs"); } - auto const inDims = in->dims; - // const auto inType = in->type; - // const auto inFmt = in->format; - // const auto outType = out->type; - auto const outDims = out->dims; - // const auto outFmt = out->format; + auto const inDims = in->desc.dims; + auto const outDims = out->desc.dims; auto supportedFormat = TensorFormat::kLINEAR; if (mType == DataType::kINT8) @@ -777,7 +988,7 @@ bool QKVToContextVarSeqlenPlugin::supportsFormatCombination( if (pos == 0 || pos == nbInputs) { // check input and output - auto const& desc = inOut[pos]; + auto const& desc = inOut[pos].desc; return (desc.type == mType) && // check type (desc.format == supportedFormat) && // check format (desc.dims.nbDims == supportedNbDims) && // check dims: @@ -790,7 +1001,7 @@ bool QKVToContextVarSeqlenPlugin::supportsFormatCombination( PLUGIN_ASSERT(mHasImask); if (pos == 1) { // must be input mask - auto const* mask = &inOut[pos]; + auto const* mask = &inOut[pos].desc; if (mUseVarSeqlen) { // dummy input @@ -798,106 +1009,179 @@ bool QKVToContextVarSeqlenPlugin::supportsFormatCombination( } return mask->format == TensorFormat::kLINEAR && (mask->type == DataType::kINT32) && // precision - (mask->dims.nbDims == 1) // num dims - ; + (mask->dims.nbDims == 1); // num dims } PLUGIN_ASSERT(mUseVarSeqlen); if (pos == 2) { // must be cuSeqlens // cuSeqlens is a int32_t array of size B+1 - auto const* seqlens = &inOut[pos]; + auto const* seqlens = &inOut[pos].desc; return (seqlens->type == DataType::kINT32) && (seqlens->format == TensorFormat::kLINEAR); } if (pos == 3) { // this is the dummy input - return inOut[pos].dims.nbDims == 1; + return inOut[pos].desc.dims.nbDims == 1; } return false; } -void QKVToContextVarSeqlenPlugin::configurePlugin( - DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +int32_t QKVToContextVarSeqlenPlugin::onShapeChange( + PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept { - PLUGIN_ASSERT(nbInputs == 1 + mHasImask + 2 * mUseVarSeqlen); - PLUGIN_ASSERT(nbOutputs == 1); - PluginTensorDesc const& inDesc = in[IIDX].desc; - TRT_UNUSED inDesc; - PluginTensorDesc const& outDesc = out->desc; - TRT_UNUSED outDesc; - PLUGIN_ASSERT(mType == inDesc.type); - PLUGIN_ASSERT(mType == outDesc.type); - if (!mUseVarSeqlen) + try { - PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); - PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); - PLUGIN_ASSERT(inDesc.dims.d[mHdim] == 3 * outDesc.dims.d[mHdim]); - if (mHasImask) + PLUGIN_ASSERT(in != nullptr); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask + 2 * mUseVarSeqlen); + PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[kIIDX]; + TRT_UNUSED inDesc; + PluginTensorDesc const& outDesc = out[0]; + TRT_UNUSED outDesc; + PLUGIN_ASSERT(mType == inDesc.type); + PLUGIN_ASSERT(mType == outDesc.type); + if (!mUseVarSeqlen) { - PluginTensorDesc const& maskDesc = in[MIDX].desc; - TRT_UNUSED maskDesc; - PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); + PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); + PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); + PLUGIN_ASSERT(inDesc.dims.d[mHdim] == 3 * outDesc.dims.d[mHdim]); + if (mHasImask) + { + PluginTensorDesc const& maskDesc = in[kMIDX]; + TRT_UNUSED maskDesc; + PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); + } + + // during build, configurePlugin() should have set mS, mB appropriately + // during inference, the engine should have mS, mB information + PLUGIN_ASSERT(mS); + PLUGIN_ASSERT(mB); + + BERT_DEBUG_MSG("setting up MHA runner for single sequence length"); + createMHARunner(); + this->mDispatcher->setup(mS, mB, mHeadSize); + } + else + { + BERT_DEBUG_MSG("setting up MHA runner for variable sequence length"); + createMHARunner(); + // need to initialize S and B with somewhat useful values, they will be reset at enqueue for the actual + // batchsize + this->mDispatcher->setup(256, 1, mHeadSize); } - const int32_t S = inDesc.dims.d[SDIM] <= 0 ? in->max.d[SDIM] : inDesc.dims.d[SDIM]; - const int32_t B = inDesc.dims.d[BDIM] <= 0 ? in->max.d[BDIM] : inDesc.dims.d[BDIM]; + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; +} - if (S != mS || B != mB) +int32_t QKVToContextVarSeqlenPlugin::configurePlugin( + DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +{ + try + { + PLUGIN_ASSERT(in != nullptr); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask + 2 * mUseVarSeqlen); + PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[kIIDX].desc; + TRT_UNUSED inDesc; + PluginTensorDesc const& outDesc = out->desc; + TRT_UNUSED outDesc; + PLUGIN_ASSERT(mType == inDesc.type); + PLUGIN_ASSERT(mType == outDesc.type); + if (!mUseVarSeqlen) { - BERT_DEBUG_MSG("setting up MHA runner for single sequence length"); + PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); + PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); + PLUGIN_ASSERT(inDesc.dims.d[mHdim] == 3 * outDesc.dims.d[mHdim]); + if (mHasImask) + { + PluginTensorDesc const& maskDesc = in[kMIDX].desc; + TRT_UNUSED maskDesc; + PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); + } + + const int32_t S = inDesc.dims.d[SDIM] <= 0 ? in->max.d[SDIM] : inDesc.dims.d[SDIM]; + const int32_t B = inDesc.dims.d[BDIM] <= 0 ? in->max.d[BDIM] : inDesc.dims.d[BDIM]; + + if (S != mS || B != mB) + { + BERT_DEBUG_MSG("setting up MHA runner for single sequence length"); + createMHARunner(); + this->mDispatcher->setup(S, B, mHeadSize); + mS = S; + mB = B; + } + } + else + { + BERT_DEBUG_MSG("setting up MHA runner for variable sequence length"); createMHARunner(); - this->mDispatcher->setup(S, B, mHeadSize); - mS = S; - mB = B; + // need to initialize S and B with somewhat useful values, they will be reset at enqueue for the actual + // batchsize + this->mDispatcher->setup(256, 1, mHeadSize); } + + return pluginStatus_t::STATUS_SUCCESS; } - else + catch (std::exception const& e) { - BERT_DEBUG_MSG("setting up MHA runner for variable sequence length"); - createMHARunner(); - // need to initialize S and B with somewhat useful values, they will be reset at enqueue for the actual - // batchsize - this->mDispatcher->setup(256, 1, mHeadSize); + caughtError(e); } + return pluginStatus_t::STATUS_FAILURE; } -size_t QKVToContextVarSeqlenPlugin::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t /* nbInputs */, - PluginTensorDesc const* /* outputs */, int32_t /* nbOutputs */) const noexcept +size_t QKVToContextVarSeqlenPlugin::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t /* nbInputs */, + DynamicPluginTensorDesc const* /* outputs */, int32_t /* nbOutputs */) const noexcept { - size_t paddingWorkpaceSize = mPatcher ? mPatcher->getWorkspaceSize(inputs[0].dims.d[0], mNumHeads) : 0; + size_t paddingWorkpaceSize = mPatcher ? mPatcher->getWorkspaceSize(inputs[0].desc.dims.d[0], mNumHeads) : 0; return mDispatcher->getWorkspaceSize() + paddingWorkpaceSize; } -// IPluginV2Ext Methods -DataType QKVToContextVarSeqlenPlugin::getOutputDataType( - int32_t index, nvinfer1::DataType const* inputTypes, int32_t /*nbInputs*/) const noexcept +int32_t QKVToContextVarSeqlenPlugin::getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept { - PLUGIN_ASSERT(index == 0); - PLUGIN_ASSERT( - inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || inputTypes[0] == DataType::kINT8); - return inputTypes[0]; + try + { + PLUGIN_ASSERT( + inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || inputTypes[0] == DataType::kINT8); + outputTypes[0] = inputTypes[0]; + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; +} + +void QKVToContextVarSeqlenPlugin::setCublasHandle(cublasContext* cublasHandle) +{ + mCublas = cublasHandle; } -void QKVToContextVarSeqlenPlugin::attachToContext( - cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept +IPluginV3* QKVToContextVarSeqlenPlugin::attachToContext(IPluginResourceContext* context) noexcept { try { - mCublasWrapper = createPluginCublasWrapper(allocator); - mCublas = mCublasWrapper->getCublasHandle(); - PLUGIN_VALIDATE(mCublas != nullptr); + auto p = static_cast(clone()); + mCublasWrapper = createPluginCublasWrapper(context); + auto cublasHandle = mCublasWrapper->getCublasHandle(); + PLUGIN_VALIDATE(cublasHandle != nullptr); + p->setCublasHandle(cublasHandle); + return p; } catch (const std::exception& e) { caughtError(e); } + return nullptr; } -// IPluginV2 Methods -char const* QKVToContextVarSeqlenPlugin::getPluginType() const noexcept -{ - return kQKV_TO_CONTEXT_PLUGIN_NAME; -} char const* QKVToContextVarSeqlenPlugin::getPluginVersion() const noexcept { @@ -909,42 +1193,11 @@ int32_t QKVToContextVarSeqlenPlugin::getNbOutputs() const noexcept return 1; } -int32_t QKVToContextVarSeqlenPlugin::initialize() noexcept -{ - return 0; -} - -void QKVToContextVarSeqlenPlugin::terminate() noexcept {} - -size_t QKVToContextVarSeqlenPlugin::getSerializationSize() const noexcept -{ - return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(DataType) + sizeof(mHasImask) + sizeof(mHiddenSize) - + sizeof(mSM) + sizeof(mS) + sizeof(mB) + sizeof(mDqProbs) + mDispatcher->getSerializationSize() - + sizeof(mUseVarSeqlen) + sizeof(mHdim) + sizeof(mUseInt8ScaleMax); -} - -void QKVToContextVarSeqlenPlugin::serialize(void* buffer) const noexcept +char const* QKVToContextVarSeqlenPlugin::getPluginName() const noexcept { - serialize_value(&buffer, mType); - serialize_value(&buffer, mNumHeads); - serialize_value(&buffer, mHeadSize); - serialize_value(&buffer, mHasImask); - serialize_value(&buffer, mHiddenSize); - serialize_value(&buffer, mSM); - serialize_value(&buffer, mS); - serialize_value(&buffer, mB); - - serialize_value(&buffer, mDqProbs); - serialize_value(&buffer, mUseVarSeqlen); - serialize_value(&buffer, mHdim); - serialize_value(&buffer, mUseInt8ScaleMax); - mDispatcher->serialize(buffer); + return kQKV_TO_CONTEXT_PLUGIN_NAME; } -void QKVToContextVarSeqlenPlugin::destroy() noexcept -{ - delete this; -} void QKVToContextVarSeqlenPlugin::setPluginNamespace(char const* libNamespace) noexcept { @@ -1072,6 +1325,35 @@ int32_t QKVToContextVarSeqlenPlugin::enqueue(nvinfer1::PluginTensorDesc const* i return cudaGetLastError(); } +PluginFieldCollection const* QKVToContextVarSeqlenPlugin::getFieldsToSerialize() noexcept +{ + mDataToSerialize.clear(); + + mDataToSerialize.emplace_back("type_id", &mType, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("hidden_size", &mHiddenSize, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("num_heads", &mNumHeads, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("has_mask", &mHasImask, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("var_seqlen", &mUseVarSeqlen, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("use_int8_scale_max", &mUseInt8ScaleMax, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("S", &mS, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("B", &mB, PluginFieldType::kINT32, 1); + + mRunnerStateBuffer.resize(mDispatcher->getSerializationSize()); + mDispatcher->serialize(mRunnerStateBuffer.data()); + mDataToSerialize.emplace_back("runnerStateBuffer", (void const*) mRunnerStateBuffer.data(), + PluginFieldType::kUNKNOWN, mRunnerStateBuffer.size()); + + if (mDqProbs >= 0) + { + mDataToSerialize.emplace_back("dq_probs", &mDqProbs, PluginFieldType::kFLOAT32, 1); + } + + mFCToSerialize.nbFields = mDataToSerialize.size(); + mFCToSerialize.fields = mDataToSerialize.data(); + + return &mFCToSerialize; +} + QKVToContextVarSeqlenPluginCreator::QKVToContextVarSeqlenPluginCreator() { mPluginAttributes.clear(); @@ -1102,100 +1384,143 @@ PluginFieldCollection const* QKVToContextVarSeqlenPluginCreator::getFieldNames() return &mFC; } -IPluginV2* QKVToContextVarSeqlenPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept +IPluginV3* QKVToContextVarSeqlenPluginCreator::createPlugin( + char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept { - BERT_DEBUG_MSG("Creating QKV2ContextPlugin..."); - - int32_t hiddenSize = 0; - // Since numHeads must always exist or validateRequiredAttributes will fail, - // we can set numHeads to -1 so that static analysis tools don't warn about - // a division by zero in QKVToContextVarSeqelnPlugin constructor. - int32_t numHeads{-1}; - bool hasMask = false; - int32_t typeId = -1; - - int32_t varSeqlen = 0; - - float dqProbs = -1; - int32_t useInt8ScaleMax{-1}; - - plugin::validateRequiredAttributesExist({"type_id", "hidden_size", "num_heads", "has_mask"}, fc); - for (int32_t i = 0; i < fc->nbFields; i++) + try { - std::string field_name(fc->fields[i].name); + BERT_DEBUG_MSG("Creating QKV2ContextPlugin..."); + PLUGIN_VALIDATE(fc != nullptr); + int32_t hiddenSize = 0; + // Since numHeads must always exist or validateRequiredAttributes will fail, + // we can set numHeads to -1 so that static analysis tools don't warn about + // a division by zero in QKVToContextVarSeqelnPlugin constructor. + int32_t numHeads = -1; + bool hasMask = false; + int32_t typeId = -1; + int32_t s = -1; + int32_t b = -1; + void const* runnerStateBuffer; + int32_t varSeqlen = 0; + float dqProbs = -1; + int32_t useInt8ScaleMax = -1; - if (field_name.compare("type_id") == 0) - { - typeId = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(typeId >= 0 && typeId <= 2, ("QKV: Invalid TypeId " + std::to_string(typeId)).c_str()); - BERT_DEBUG_VALUE("Building typeId: ", typeId); - } - if (field_name.compare("hidden_size") == 0) + PLUGIN_VALIDATE(fc->fields != nullptr); + + if (phase == TensorRTPhase::kBUILD) { - hiddenSize = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(hiddenSize > 0, ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); - BERT_DEBUG_VALUE("Building hiddenSize: ", hiddenSize); + plugin::validateRequiredAttributesExist({"type_id", "hidden_size", "num_heads", "has_mask"}, fc); } - if (field_name.compare("num_heads") == 0) + else { - numHeads = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); - BERT_DEBUG_VALUE("Building numHeads: ", numHeads); + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + // since fc is from a deserialized engine, + // we expect all attributes (except dq_probs) to be present during runtime + plugin::validateRequiredAttributesExist({"type_id", "S", "B", "hidden_size", "num_heads", "has_mask", + "var_seqlen", "use_int8_scale_max", "runnerStateBuffer"}, + fc); } - if (field_name.compare("has_mask") == 0) + for (int32_t i = 0; i < fc->nbFields; i++) { - hasMask = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(hasMask == 0 || hasMask == 1, ("QKV: Invalid hasMask " + std::to_string(hasMask)).c_str()); - BERT_DEBUG_VALUE("Building hasMask: ", hasMask); + std::string field_name(fc->fields[i].name); + + if (field_name.compare("type_id") == 0) + { + typeId = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(typeId >= 0 && typeId <= 2, ("QKV: Invalid TypeId " + std::to_string(typeId)).c_str()); + BERT_DEBUG_VALUE("Building typeId: ", typeId); + } + else if (field_name.compare("hidden_size") == 0) + { + hiddenSize = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(hiddenSize > 0, ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); + BERT_DEBUG_VALUE("Building hiddenSize: ", hiddenSize); + } + else if (field_name.compare("num_heads") == 0) + { + numHeads = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); + BERT_DEBUG_VALUE("Building numHeads: ", numHeads); + } + else if (field_name.compare("has_mask") == 0) + { + hasMask = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE( + hasMask == 0 || hasMask == 1, ("QKV: Invalid hasMask " + std::to_string(hasMask)).c_str()); + BERT_DEBUG_VALUE("Building hasMask: ", hasMask); + } + + else if (field_name.compare("dq_probs") == 0) + { + dqProbs = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(dqProbs > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs)).c_str()); + BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs); + } + else if (field_name.compare("var_seqlen") == 0) + { + varSeqlen = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building var_seqlen: ", varSeqlen); + } + else if (field_name.compare("use_int8_scale_max") == 0) + { + useInt8ScaleMax = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(useInt8ScaleMax == 0 || useInt8ScaleMax == 1, + ("QKV: Invalid useInt8ScaleMax " + std::to_string(useInt8ScaleMax)).c_str()); + BERT_DEBUG_VALUE("Building useInt8ScaleMax: ", useInt8ScaleMax); + } + else if (field_name.compare("S") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + s = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building S: ", s); + } + else if (field_name.compare("B") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + b = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building B: ", b); + } + else if (field_name.compare("runnerStateBuffer") == 0) + { + PLUGIN_ASSERT(phase == TensorRTPhase::kRUNTIME); + runnerStateBuffer = static_cast(fc->fields[i].data); + } } - if (field_name.compare("dq_probs") == 0) + if (useInt8ScaleMax < 0) { - dqProbs = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(dqProbs > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs)).c_str()); - BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs); + gLogInfo << "Using default for use_int8_scale_max: true" << std::endl; + useInt8ScaleMax = 1; } - if (field_name.compare("var_seqlen") == 0) + + BERT_DEBUG_MSG("Building the Plugin..."); + DataType type = static_cast(typeId); + if (type == DataType::kINT8 && dqProbs < 0) { - varSeqlen = *static_cast(fc->fields[i].data); - BERT_DEBUG_VALUE("Building var_seqlen: ", varSeqlen); + gLogInfo << "Using default scale factor\n"; + dqProbs = 1.F / 127.F; } - if (field_name.compare("use_int8_scale_max") == 0) + + auto const useInt8ScaleMaxFlag = static_cast(useInt8ScaleMax); + + if (phase == TensorRTPhase::kBUILD) { - useInt8ScaleMax = *static_cast(fc->fields[i].data); - PLUGIN_VALIDATE(useInt8ScaleMax == 0 || useInt8ScaleMax == 1, - ("QKV: Invalid useInt8ScaleMax " + std::to_string(useInt8ScaleMax)).c_str()); - BERT_DEBUG_VALUE("Building useInt8ScaleMax: ", useInt8ScaleMax); + return new QKVToContextVarSeqlenPlugin( + name, type, hiddenSize, numHeads, dqProbs, hasMask, varSeqlen, useInt8ScaleMaxFlag); } - } - if (useInt8ScaleMax < 0) - { - gLogInfo << "Using default for use_int8_scale_max: true" << std::endl; - useInt8ScaleMax = 1; - } + PLUGIN_VALIDATE(s != -1, "invalid S during runtime plugin creation"); + PLUGIN_VALIDATE(b != -1, "invalid B during runtime plugin creation"); + PLUGIN_VALIDATE(runnerStateBuffer != nullptr, "invalid runnerStateBuffer during runtime plugin creation"); - BERT_DEBUG_MSG("Building the Plugin..."); - DataType type = static_cast(typeId); - if (type == DataType::kINT8 && dqProbs < 0) + return new QKVToContextVarSeqlenPlugin(name, s, b, type, hiddenSize, numHeads, dqProbs, hasMask, varSeqlen, + useInt8ScaleMaxFlag, runnerStateBuffer); + } + catch (std::exception const& e) { - gLogInfo << "Using default scale factor\n"; - dqProbs = 1.F / 127.F; + caughtError(e); } - - auto const useInt8ScaleMaxFlag = static_cast(useInt8ScaleMax); - - QKVToContextVarSeqlenPlugin* p = new QKVToContextVarSeqlenPlugin( - name, type, hiddenSize, numHeads, dqProbs, hasMask, varSeqlen, useInt8ScaleMaxFlag); - return p; -} - -IPluginV2* QKVToContextVarSeqlenPluginCreator::deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept -{ - // This object will be deleted when the network is destroyed, which will - // call QKVToContextVarSeqlenPlugin::destroy() - return new QKVToContextVarSeqlenPlugin(name, serialData, serialLength); + return nullptr; } void QKVToContextVarSeqlenPluginCreator::setPluginNamespace(char const* libNamespace) noexcept diff --git a/plugin/bertQKVToContextPlugin/qkvToContextPlugin.h b/plugin/bertQKVToContextPlugin/qkvToContextPlugin.h index 725f430d..b2aa0ab8 100644 --- a/plugin/bertQKVToContextPlugin/qkvToContextPlugin.h +++ b/plugin/bertQKVToContextPlugin/qkvToContextPlugin.h @@ -24,6 +24,7 @@ #include "NvInferPlugin.h" #include "common/cublasWrapper.h" +#include "mhaRunner.h" #include "zeroPadding2d.h" #include #include @@ -36,83 +37,6 @@ namespace plugin namespace bert { -// Multi Head Attention runner -class MHARunner -{ -public: - MHARunner(const nvinfer1::DataType type, const int32_t numHeads) - : mType(type) - , mS(0) - , mB(0) - , mOmatSize(0) - , mNumMats(0) - , mNumHeads(numHeads) - , mHeadSize(0) - , mWordSize(getElementSize(type)) - , mLdQKV(0) - , mStrideQKV(0) - , mLdOut(0) - , mStrideOut(0) - , mRsqrtHeadSize(0) - { - } - - virtual ~MHARunner() = default; - - virtual void setup(int32_t S, int32_t B, int32_t headSize) - { - PLUGIN_ASSERT(S); - PLUGIN_ASSERT(B); - mB = B; - mS = S; - mHeadSize = headSize; - mRsqrtHeadSize = 1.F / std::sqrt(headSize); - - mLdQKV = 3 * B * mNumHeads * mHeadSize; - mStrideQKV = 3 * mHeadSize; - - mLdOut = B * mNumHeads * mHeadSize; - mStrideOut = mHeadSize; - mOmatSize = S * S; - mNumMats = B * mNumHeads; - } - - virtual void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, - void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) - = 0; - - virtual void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) - = 0; - - virtual size_t getSerializationSize() const noexcept; - virtual void serialize(void* buffer) const noexcept; - virtual void deserialize(void const* data, size_t length); - - virtual size_t getWorkspaceSize() const = 0; - - virtual bool isValid(int32_t headSize, int32_t s) const = 0; - -protected: - nvinfer1::DataType mType; - - int32_t mS; - int32_t mB; - int32_t mOmatSize; - int32_t mNumMats; - int32_t mNumHeads; - int32_t mHeadSize; - int32_t mWordSize; - int32_t mLdQKV; - int32_t mStrideQKV; - int32_t mLdOut; - int32_t mStrideOut; - - float mRsqrtHeadSize; -}; - std::pair tuneBatchedGemm( const int32_t B, const int32_t S, const int32_t numHeads, const int32_t headSize); @@ -128,51 +52,78 @@ int32_t computeMaskedScaledSoftmax(cudaStream_t stream, const int32_t ld, const // our custom layer requires extending IPluginV2 and IPluginCreator classes. // For requirements for overriden functions, check TensorRT API docs. -class QKVToContextPluginDynamic : public nvinfer1::IPluginV2DynamicExt +class QKVToContextPluginDynamic : public IPluginV3, + public IPluginV3OneCore, + public IPluginV3OneBuild, + public IPluginV3OneRuntime { public: QKVToContextPluginDynamic(const std::string name, const nvinfer1::DataType type, const int32_t hiddenSize, const int32_t numHeads, float const dqProbs, bool hasImask = false); - QKVToContextPluginDynamic(const std::string name, void const* data, size_t length); + // constructor that also takes in MHARunner state + // this constructor should only be called during runtime plugin creation after engine deserialization + QKVToContextPluginDynamic(const std::string name, const DataType type, const int32_t S, const int32_t B, + const int32_t SM, const int32_t hiddenSize, const int32_t numHeads, float const dqProbs, bool hasImask, + bool hasUnfusedDispatcher, void const* runnerStateBuffer); // It doesn't make sense to make QKVToContextPluginDynamic without arguments, so we // delete default constructor. QKVToContextPluginDynamic() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; + ~QKVToContextPluginDynamic() override; + + // IPluginV3 Methods + // NOTE: since this is itself is an abstract class, the rest of virtual methods defined in its children classes + IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override; + // end of IPluginV3 Methods + + // IPluginV3OneCore Methods + char const* getPluginName() const noexcept override; + + char const* getPluginNamespace() const noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept; + + char const* getPluginVersion() const noexcept override; + // end of IPluginV3OneCore Methods + + // IPluginV3Build Methods bool supportsFormatCombination( - int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; - void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, - nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; - size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, - nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + + int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + + size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + + int32_t getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override; + + int32_t getNbOutputs() const noexcept override; + + int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, + int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override; + // end IPluginV3Build Methods + + // IPluginV3Runtime Methods + IPluginV3* clone() noexcept; + + int32_t onShapeChange( + PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType( - int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; - void attachToContext( - cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept override; + IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override; - // IPluginV2 Methods - char const* getPluginType() const noexcept override; - char const* getPluginVersion() const noexcept override; - int32_t getNbOutputs() const noexcept override; - int32_t initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; - char const* getPluginNamespace() const noexcept override; + PluginFieldCollection const* getFieldsToSerialize() noexcept override; + // end IPluginV3Runtime Methods protected: void createMHARunner(); + void setCublasHandle(cublasContext*); private: const std::string mLayerName; @@ -190,36 +141,35 @@ class QKVToContextPluginDynamic : public nvinfer1::IPluginV2DynamicExt int32_t mHeadSize{}; int32_t mHiddenSize; int32_t mNumHeads{}; - bool mHasImask{}; + int32_t mHasImask{}; + int32_t mHasUnfusedDispatcher{}; nvinfer1::DataType mType{}; float mDqProbs{}; nvinfer1::pluginInternal::cublasHandle_t mCublas{}; // the wrapper pointer is shared among all plugins attached to the same context. std::shared_ptr mCublasWrapper; - using IPluginV2::getOutputDimensions; - using IPluginV2::getWorkspaceSize; - using IPluginV2::enqueue; - using IPluginV2Ext::configurePlugin; + // IPluginV3 serialization related + std::vector mDataToSerialize; + nvinfer1::PluginFieldCollection mFCToSerialize; + std::vector mRunnerStateBuffer; // memory management of this is handled automatically by class destructor }; -class QKVToContextPluginDynamicCreator : public nvinfer1::IPluginCreator +class QKVToContextPluginDynamicCreator : public nvinfer1::IPluginCreatorV3One { public: QKVToContextPluginDynamicCreator(); + ~QKVToContextPluginDynamicCreator() override = default; char const* getPluginName() const noexcept override; char const* getPluginVersion() const noexcept override; - nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - - nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept override; + IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept; char const* getPluginNamespace() const noexcept override; @@ -229,52 +179,76 @@ class QKVToContextPluginDynamicCreator : public nvinfer1::IPluginCreator std::string mNamespace; }; -class QKVToContextVarSeqlenPlugin : public nvinfer1::IPluginV2DynamicExt +class QKVToContextVarSeqlenPlugin : public IPluginV3, + public IPluginV3OneCore, + public IPluginV3OneBuild, + public IPluginV3OneRuntime { public: QKVToContextVarSeqlenPlugin(std::string const name, nvinfer1::DataType const type, int32_t const hiddenSize, int32_t const numHeads, float const dqProbs, bool hasImask = false, bool varSeqlen = false, bool const useInt8ScaleMax = true); - QKVToContextVarSeqlenPlugin(const std::string name, void const* data, size_t length); + QKVToContextVarSeqlenPlugin(std::string const name, int32_t const s, int32_t b, nvinfer1::DataType const type, + int32_t const hiddenSize, int32_t const numHeads, float const dqProbs, bool hasImask, bool varSeqlen, + bool const useInt8ScaleMax, void const* runnerStateBuffer); // It doesn't make sense to make QKVToContextVarSeqlenPlugin without arguments, so we // delete default constructor. QKVToContextVarSeqlenPlugin() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; + ~QKVToContextVarSeqlenPlugin() override; + + // IPluginV3 Methods + // NOTE: since this is itself is an abstract class, the rest of virtual methods defined in its children classes + IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override; + // end of IPluginV3 Methods + + // IPluginV3OneCore Methods + char const* getPluginName() const noexcept override; + + char const* getPluginNamespace() const noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept; + + char const* getPluginVersion() const noexcept override; + // end of IPluginV3OneCore Methods + + // IPluginV3Build Methods bool supportsFormatCombination( - int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; - void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, - nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; - size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, - nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + + int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + + size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + + int32_t getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override; + + int32_t getNbOutputs() const noexcept override; + + int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, + int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override; + // end IPluginV3Build Methods + + // IPluginV3Runtime Methods + IPluginV3* clone() noexcept; + + int32_t onShapeChange( + PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType( - int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; - void attachToContext( - cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept override; + IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override; - // IPluginV2 Methods - char const* getPluginType() const noexcept override; - char const* getPluginVersion() const noexcept override; - int32_t getNbOutputs() const noexcept override; - int32_t initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; - char const* getPluginNamespace() const noexcept override; + PluginFieldCollection const* getFieldsToSerialize() noexcept override; protected: void createMHARunner(); + void setCublasHandle(cublasContext*); private: const std::string mLayerName; @@ -290,41 +264,39 @@ class QKVToContextVarSeqlenPlugin : public nvinfer1::IPluginV2DynamicExt int32_t mHeadSize{}; int32_t mHiddenSize{}; int32_t mNumHeads{}; - bool mHasImask{}; + int32_t mHasImask{}; nvinfer1::DataType mType{}; float mDqProbs{}; int32_t mHdim{}; - bool mUseVarSeqlen{}; - bool mUseInt8ScaleMax{true}; + int32_t mUseVarSeqlen{}; + int32_t mUseInt8ScaleMax{true}; nvinfer1::pluginInternal::cublasHandle_t mCublas{}; // the wrapper pointer is shared among all plugins attached to the same context. std::shared_ptr mCublasWrapper; - using IPluginV2::getOutputDimensions; - using IPluginV2::getWorkspaceSize; - using IPluginV2::enqueue; - using IPluginV2Ext::configurePlugin; + // serialization data structures + std::vector mRunnerStateBuffer; // memory management of this is handled automatically by class destructor + std::vector mDataToSerialize; + nvinfer1::PluginFieldCollection mFCToSerialize; }; -class QKVToContextVarSeqlenPluginCreator : public nvinfer1::IPluginCreator +class QKVToContextVarSeqlenPluginCreator : public nvinfer1::IPluginCreatorV3One { public: QKVToContextVarSeqlenPluginCreator(); + ~QKVToContextVarSeqlenPluginCreator() override = default; char const* getPluginName() const noexcept override; char const* getPluginVersion() const noexcept override; - nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override; - nvinfer1::IPluginV2* deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept override; - - void setPluginNamespace(char const* pluginNamespace) noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept; char const* getPluginNamespace() const noexcept override; @@ -334,151 +306,6 @@ class QKVToContextVarSeqlenPluginCreator : public nvinfer1::IPluginCreator std::string mNamespace; }; -class UnfusedMHARunner : public MHARunner -{ -public: - UnfusedMHARunner(const nvinfer1::DataType type, const int32_t numHeads, const int32_t smVersion); - virtual ~UnfusedMHARunner(); - - virtual void setup(int32_t S, int32_t B, int32_t headSize) override; - - void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, - void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - size_t getWorkspaceSize() const override; - - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void deserialize(void const* data, size_t length) override; - bool isValid(int32_t headSize, int32_t s) const override; - -private: - bool mIsBestAlgoFound{}; - int32_t mAlgoBatchedEx1{}; - int32_t mAlgoBatchedEx2{}; - int32_t mSm{}; -}; - -class FusedMHARunnerFP16 : public MHARunner -{ -public: - FusedMHARunnerFP16(const int32_t numHeads, const int32_t sm); - ~FusedMHARunnerFP16() = default; // for pimpl - - virtual void setup(int32_t S, int32_t B, int32_t headSize) override; - - void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, - void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - size_t getWorkspaceSize() const override; - - void deserialize(void const* data, size_t length) override; - - bool isValid(int32_t headSize, int32_t s) const override; - -private: - int32_t mSm; - class mhaImpl; - std::unique_ptr pimpl; -}; - -class FusedMHARunnerInt8 : public MHARunner -{ -public: - FusedMHARunnerInt8(const int32_t numHeads, const int32_t sm, float const dqProbs); - ~FusedMHARunnerInt8() = default; // for pimpl - - virtual void setup(int32_t S, int32_t B, int32_t headSize) override; - - void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, - void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - size_t getWorkspaceSize() const override; - - void deserialize(void const* data, size_t length) override; - - bool isValid(int32_t headSize, int32_t s) const override; - -private: - float mDqProbs; - int32_t mSm; - class mhaImpl; - std::unique_ptr pimpl; -}; - -class FusedMHARunnerFP16v2 : public MHARunner -{ -public: - FusedMHARunnerFP16v2(const int32_t numHeads, const int32_t sm); - ~FusedMHARunnerFP16v2() = default; // for pimpl - - virtual void setup(int32_t S, int32_t B, int32_t headSize) override; - - void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, - void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - size_t getWorkspaceSize() const override; - - void deserialize(void const* data, size_t length) override; - - bool isValid(int32_t headSize, int32_t s) const override; - -private: - int32_t mSm; - class mhaImpl; - std::unique_ptr pimpl; -}; - -class FusedMHARunnerInt8v2 : public MHARunner -{ -public: - FusedMHARunnerInt8v2(int32_t const numHeads, int32_t const sm, float const dqProbs, bool const useInt8ScaleMax); - ~FusedMHARunnerInt8v2() = default; // for pimpl - - virtual void setup(int32_t S, int32_t B, int32_t headSize) override; - - void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc, - void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream, - nvinfer1::pluginInternal::cublasHandle_t cublas) override; - - size_t getWorkspaceSize() const override; - - void deserialize(void const* data, size_t length) override; - - bool isValid(int32_t headSize, int32_t s) const override; - -private: - float mDqProbs; - int32_t mSm; - class mhaImpl; - std::unique_ptr pimpl; - bool mUseInt8ScaleMax{true}; -}; - } // namespace bert } // namespace plugin } // namespace nvinfer1 diff --git a/plugin/bertQKVToContextPlugin/qkvToContextPluginLegacy.cpp b/plugin/bertQKVToContextPlugin/qkvToContextPluginLegacy.cpp new file mode 100644 index 00000000..850e8217 --- /dev/null +++ b/plugin/bertQKVToContextPlugin/qkvToContextPluginLegacy.cpp @@ -0,0 +1,1202 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Need 10.1 for cublasGemmStridedBatchedEx +#include +#if CUDA_VERSION >= 10010 + +#include "NvInfer.h" +#include "bertQKVToContextPlugin/fused_multihead_attention/include/fused_multihead_attention.h" +#include "bertQKVToContextPlugin/fused_multihead_attention_v2/include/fused_multihead_attention_v2.h" +#include "common/bertCommon.h" +#include "common/serialize.hpp" +#include "mhaRunner.h" +#include "qkvToContextPluginLegacy.h" + +#include +#include +#include +#include +#include + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; +using namespace nvinfer1::pluginInternal; + +namespace +{ +char const* const kQKV_TO_CONTEXT_PLUGIN_LEGACY_VERSION{"1"}; +char const* const kQKV_TO_CONTEXT_VAR_SEQLEN_LEGACY_PLUGIN_VERSION{"2"}; +char const* const kQKV_TO_CONTEXT_PLUGIN_NAME{"CustomQKVToContextPluginDynamic"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection QKVToContextPluginDynamicLegacyCreator::mFC{}; +std::vector QKVToContextPluginDynamicLegacyCreator::mPluginAttributes; + +REGISTER_TENSORRT_PLUGIN(QKVToContextPluginDynamicLegacyCreator); + +constexpr uint32_t kIIDX = 0; // index of the input tensor +constexpr uint32_t kMIDX = 1; // index of the mask + +// Static class fields initialization +PluginFieldCollection QKVToContextVarSeqlenPluginLegacyCreator::mFC{}; +std::vector QKVToContextVarSeqlenPluginLegacyCreator::mPluginAttributes; + +REGISTER_TENSORRT_PLUGIN(QKVToContextVarSeqlenPluginLegacyCreator); + +QKVToContextPluginDynamicLegacy::QKVToContextPluginDynamicLegacy(std::string const name, DataType const type, + int32_t const hiddenSize, int32_t const numHeads, float const dqProbs, bool hasImask) + : mLayerName(name) + , mS(0) + , mB(0) + , mHeadSize(hiddenSize / numHeads) + , mHiddenSize(hiddenSize) + , mNumHeads(numHeads) + , mHasImask(hasImask) + , mType(type) + , mDqProbs(dqProbs) + +{ + mSM = getSMVersion(); +} + +QKVToContextPluginDynamicLegacy::QKVToContextPluginDynamicLegacy( + std::string const name, void const* data, size_t length) + : mLayerName(name) +{ + BERT_DEBUG_MSG("QKV Deser Start"); + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mNumHeads); + deserialize_value(&data, &length, &mHeadSize); + deserialize_value(&data, &length, &mHasImask); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mSM); + deserialize_value(&data, &length, &mS); + deserialize_value(&data, &length, &mB); + + deserialize_value(&data, &length, &mDqProbs); + + createMHARunner(); + + int32_t hasUnfusedRunner = 0; + deserialize_value(&data, &length, &hasUnfusedRunner); + if (hasUnfusedRunner) + { + PLUGIN_ASSERT(unfusedDispatcher.get()); + unfusedDispatcher->deserialize(data, length); + } + + BERT_DEBUG_MSG("QKV Deser done"); +} + +void QKVToContextPluginDynamicLegacy::createMHARunner() +{ + if (!fusedDispatcher.get()) + { + if (mType == DataType::kHALF) + { + fusedDispatcher.reset(new FusedMHARunnerFP16(mNumHeads, mSM)); + } + else if (mType == DataType::kINT8) + { + fusedDispatcher.reset(new FusedMHARunnerInt8(mNumHeads, mSM, mDqProbs)); + } + } + + if (!unfusedDispatcher.get()) + { + unfusedDispatcher.reset(new UnfusedMHARunner(mType, mNumHeads, mSM)); + } +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* QKVToContextPluginDynamicLegacy::clone() const noexcept +{ + BERT_DEBUG_MSG("QKV Clone"); + + QKVToContextPluginDynamicLegacy* ret = nullptr; + // the workspacesize is 0 if we have not call setup the dispatcher yet. + if (unfusedDispatcher.get() && unfusedDispatcher->getWorkspaceSize()) + { + std::vector buff; + buff.resize(getSerializationSize()); + serialize(buff.data()); + + ret = new QKVToContextPluginDynamicLegacy(mLayerName, buff.data(), buff.size()); + } + else + { + ret = new QKVToContextPluginDynamicLegacy(mLayerName, mType, mHiddenSize, mNumHeads, mDqProbs, mHasImask); + } + + ret->setPluginNamespace(mNamespace.c_str()); + BERT_DEBUG_MSG("QKV Clone done"); + return ret; +} + +DimsExprs QKVToContextPluginDynamicLegacy::getOutputDimensions( + int32_t outputIndex, DimsExprs const* inputs, int32_t /*nbInputs*/, IExprBuilder& exprBuilder) noexcept +{ + // Input is BxSx3*N*H, output should be BxSxN*H + PLUGIN_ASSERT(outputIndex == 0); + // Copy over everything + DimsExprs output(inputs[kIIDX]); + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + output.d[HDIM] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[kIIDX].d[HDIM], *three); + return output; +} +bool QKVToContextPluginDynamicLegacy::supportsFormatCombination( + int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t /*nbOutputs*/) noexcept +{ + PLUGIN_ASSERT(pos >= 0); + PLUGIN_ASSERT(pos < 2 + mHasImask); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask); + auto const* in = inOut; + auto const* out = inOut + nbInputs; + int32_t packedSize = getMHAMaskPackedSize(mSM, mType, in->dims.d[SDIM]); + + // we only support int8 IO in fused mha runner, and we only support fused mha runner on Xavier, Turing and Ampere + if (mType == DataType::kINT8) + { + if (mSM != kSM_75 && mSM != kSM_80 && mSM != kSM_86 && mSM != kSM_87 && mSM != kSM_89 && mSM != kSM_90) + { + gLogError << "INT8 IO is only supported on Turing, Ampere and Hopper for plugin " + << kQKV_TO_CONTEXT_PLUGIN_NAME << std::endl; + return false; + } + if (in->dims.d[SDIM] == -1) + { + gLogError << "INT8 IO not support dynamic shape in sequence dimension for plugin " + << kQKV_TO_CONTEXT_PLUGIN_NAME << std::endl; + return false; + } + if (packedSize == unfusedMaskSize) + { + gLogError << "INT8 IO only support sequence length 128,384 for plugin " << kQKV_TO_CONTEXT_PLUGIN_NAME + << std::endl; + return false; + } + } + + if (pos == 0) + { + bool isFormatSupported = in->format == TensorFormat::kLINEAR; + if (mType == DataType::kINT8) + { + if (in->dims.d[HDIM] % 32U == 0) + { + isFormatSupported = in->format == TensorFormat::kCHW32; + } + else + { + isFormatSupported = in->format == TensorFormat::kCHW4; + } + } + + // must not check descriptions > pos + return (in->type == mType) && // precision + isFormatSupported && // format + (in->dims.nbDims == 5) && // num dims + ((in->dims.d[HDIM] % 3U) == 0) && // see getOutputDimensions + ((in->dims.d[3]) == 1) && // for fc + ((in->dims.d[4]) == 1) // for fc + ; + } + + // pos==1 + if ((mHasImask && pos == 1)) // pos 1 is the mask + { + auto const* inMask = &inOut[1]; + if (inMask->dims.d[1] != -1 && inMask->dims.d[1] != packedSize) + { + gLogError << "CustomEmbLayerNormPluginDynamic returned mask with pack size " << inMask->dims.d[1] + << ", but " << kQKV_TO_CONTEXT_PLUGIN_NAME << " expects mask pack size " << packedSize + << std::endl; + return false; + } + + // detect full mask and check that it was produced + return (inMask->type == DataType::kINT32) && // precision + (inMask->format == TensorFormat::kLINEAR) && // format + (inMask->dims.nbDims == 2) && // Bx2*maskSize + (inMask->dims.d[0] == in->dims.d[BDIM]); + } + + if (!mHasImask || pos == 2) // output pos + { + bool isFormatSupported = out->format == TensorFormat::kLINEAR; + if (mType == DataType::kINT8) + { + if (out->dims.d[HDIM] % 32U == 0) + { + isFormatSupported = out->format == TensorFormat::kCHW32; + } + else + { + isFormatSupported = out->format == TensorFormat::kCHW4; + } + } + + return (in->type == out->type) && // precision + isFormatSupported && // format + (out->dims.nbDims == 5) && // num dims + ((in->dims.d[HDIM] / 3) == (out->dims.d[HDIM])) && // div 3 + ((out->dims.d[3]) == 1) && // for fc + ((out->dims.d[4]) == 1) && // for fc + ((out->dims.d[BDIM]) == in->dims.d[BDIM]) && // check B + ((out->dims.d[SDIM]) == in->dims.d[SDIM]) // check S + ; + } + + return false; +} +void QKVToContextPluginDynamicLegacy::configurePlugin( + DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +{ + PLUGIN_ASSERT(nbInputs == 1 + mHasImask); + PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[kIIDX].desc; + TRT_UNUSED inDesc; + PluginTensorDesc const& outDesc = out->desc; + TRT_UNUSED outDesc; + PLUGIN_ASSERT(mType == inDesc.type); + PLUGIN_ASSERT(mType == outDesc.type); + PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); + PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); + PLUGIN_ASSERT(inDesc.dims.d[HDIM] == 3 * outDesc.dims.d[HDIM]); + if (mHasImask) + { + PluginTensorDesc const& maskDesc = in[kMIDX].desc; + TRT_UNUSED maskDesc; + PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); + } + + createMHARunner(); + + int32_t const S = inDesc.dims.d[SDIM]; + int32_t const B = inDesc.dims.d[BDIM] <= 0 ? in->max.d[BDIM] : inDesc.dims.d[BDIM]; + if (S <= 0) + { + // in dynamic shape build stage, we setup with max sequence that cannot fused + int32_t const Smin = in->min.d[SDIM]; + int32_t const Smax = in->max.d[SDIM]; + + if (fusedDispatcher.get()) + { + for (int32_t i = Smax; i >= Smin; --i) + { + if (!fusedDispatcher->isValid(mHeadSize, i)) + { + unfusedDispatcher->setup(i, B, mHeadSize); + mS = i; + mB = B; + break; + } + } + } + else + { + unfusedDispatcher->setup(Smax, B, mHeadSize); + mS = Smax; + mB = B; + } + } + else + { + // in inference stage or in static shape build stage + if (fusedDispatcher.get() && fusedDispatcher->isValid(mHeadSize, S)) + { + fusedDispatcher->setup(S, B, mHeadSize); + } + else + { + unfusedDispatcher->setup(S, B, mHeadSize); + } + mS = S; + mB = B; + } +} + +size_t QKVToContextPluginDynamicLegacy::getWorkspaceSize(PluginTensorDesc const* /*inputs*/, int32_t /*nbInputs*/, + PluginTensorDesc const* /*outputs*/, int32_t /*nbOutputs*/) const noexcept +{ + // only unfused kernel need workspace, and we need larger workspace for larger sequence length + // we have already setup unfusedDispatcher with max sequence in configurePlugin + // if unfusedDispatcher is not initialized in configurePlugin + PLUGIN_ASSERT(unfusedDispatcher.get()); + return unfusedDispatcher->getWorkspaceSize(); +} + +// IPluginV2Ext Methods +DataType QKVToContextPluginDynamicLegacy::getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t /*nbInputs*/) const noexcept +{ + PLUGIN_ASSERT(index == 0); + PLUGIN_ASSERT( + inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || inputTypes[0] == DataType::kINT8); + return inputTypes[0]; +} + +void QKVToContextPluginDynamicLegacy::attachToContext( + cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept +{ + try + { + mCublasWrapper = createPluginCublasWrapper(allocator); + mCublas = mCublasWrapper->getCublasHandle(); + PLUGIN_VALIDATE(mCublas != nullptr); + } + catch (std::exception const& e) + { + caughtError(e); + } +} + +// IPluginV2 Methods +char const* QKVToContextPluginDynamicLegacy::getPluginType() const noexcept +{ + return kQKV_TO_CONTEXT_PLUGIN_NAME; +} + +char const* QKVToContextPluginDynamicLegacy::getPluginVersion() const noexcept +{ + return kQKV_TO_CONTEXT_PLUGIN_LEGACY_VERSION; +} + +int32_t QKVToContextPluginDynamicLegacy::getNbOutputs() const noexcept +{ + return 1; +} + +int32_t QKVToContextPluginDynamicLegacy::initialize() noexcept +{ + return 0; +} + +void QKVToContextPluginDynamicLegacy::terminate() noexcept {} + +size_t QKVToContextPluginDynamicLegacy::getSerializationSize() const noexcept +{ + PLUGIN_ASSERT(unfusedDispatcher.get()); + return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(DataType) + sizeof(mHasImask) + sizeof(mHiddenSize) + + sizeof(mSM) + sizeof(mS) + sizeof(mB) + sizeof(mDqProbs) + sizeof(int32_t) + + unfusedDispatcher->getSerializationSize(); +} + +void QKVToContextPluginDynamicLegacy::serialize(void* buffer) const noexcept +{ + serialize_value(&buffer, mType); + serialize_value(&buffer, mNumHeads); + serialize_value(&buffer, mHeadSize); + serialize_value(&buffer, mHasImask); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mSM); + serialize_value(&buffer, mS); + serialize_value(&buffer, mB); + + serialize_value(&buffer, mDqProbs); + if (unfusedDispatcher.get() && unfusedDispatcher->getWorkspaceSize()) + { + int32_t hasUnfusedRunner = 1; + serialize_value(&buffer, hasUnfusedRunner); + unfusedDispatcher->serialize(buffer); + } + else + { + int32_t hasUnfusedRunner = 0; + serialize_value(&buffer, hasUnfusedRunner); + } +} + +void QKVToContextPluginDynamicLegacy::destroy() noexcept +{ + delete this; +} + +void QKVToContextPluginDynamicLegacy::setPluginNamespace(char const* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +char const* QKVToContextPluginDynamicLegacy::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +int32_t QKVToContextPluginDynamicLegacy::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +{ + PLUGIN_VALIDATE(inputDesc != nullptr && outputDesc != nullptr && inputs != nullptr && outputs != nullptr); + PLUGIN_ASSERT(mS == inputDesc->dims.d[SDIM]); + PLUGIN_ASSERT(mB == inputDesc->dims.d[BDIM]); + + try + { + void const* const maskPtr = mHasImask ? inputs[1] : nullptr; + if (mHasImask && fusedDispatcher.get() && fusedDispatcher->isValid(mHeadSize, inputDesc->dims.d[SDIM])) + { + fusedDispatcher->run( + inputDesc[0], outputDesc[0], inputs[0], maskPtr, outputs[0], workspace, stream, mCublas); + } + else + { + PLUGIN_VALIDATE(unfusedDispatcher.get(), "The Unfused MHARunner is uninitialized, no MHARunner available!"); + PLUGIN_VALIDATE(mType != DataType::kINT8, "The Unfused MHARunner does not support INT8!"); + unfusedDispatcher->run( + inputDesc[0], outputDesc[0], inputs[0], maskPtr, outputs[0], workspace, stream, mCublas); + } + } + catch (std::exception const& e) + { + caughtError(e); + return -1; + } + return 0; +} + +QKVToContextPluginDynamicLegacyCreator::QKVToContextPluginDynamicLegacyCreator() +{ + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("hidden_size", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("has_mask", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("dq_probs", nullptr, PluginFieldType::kFLOAT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* QKVToContextPluginDynamicLegacyCreator::getPluginName() const noexcept +{ + return kQKV_TO_CONTEXT_PLUGIN_NAME; +} + +char const* QKVToContextPluginDynamicLegacyCreator::getPluginVersion() const noexcept +{ + return kQKV_TO_CONTEXT_PLUGIN_LEGACY_VERSION; +} + +PluginFieldCollection const* QKVToContextPluginDynamicLegacyCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2* QKVToContextPluginDynamicLegacyCreator::createPlugin( + char const* name, PluginFieldCollection const* fc) noexcept +{ + try + { + BERT_DEBUG_MSG("Creating QKV2ContextPlugin..."); + PLUGIN_VALIDATE(fc != nullptr); + int32_t hiddenSize = 0; + // Since numHeads must always exist or validateRequiredAttributes will fail, + // we can set numHeads to -1 so that static analysis tools don't warn about + // a division by zero in QKVToContextPluginDynamicLegacy constructor. + int32_t numHeads{-1}; + bool hasMask = false; + int32_t typeId = -1; + + float dqProbs = -1; + + PLUGIN_VALIDATE(fc->fields != nullptr); + plugin::validateRequiredAttributesExist({"type_id", "hidden_size", "num_heads", "has_mask"}, fc); + + for (int32_t i = 0; i < fc->nbFields; i++) + { + PLUGIN_VALIDATE(fc->fields[i].name != nullptr); + PLUGIN_VALIDATE(fc->fields[i].data != nullptr); + std::string field_name(fc->fields[i].name); + + if (field_name.compare("type_id") == 0) + { + typeId = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(typeId >= 0 && typeId <= 2, ("QKV: Invalid TypeId " + std::to_string(typeId)).c_str()); + BERT_DEBUG_VALUE("Building typeId: ", typeId); + } + if (field_name.compare("hidden_size") == 0) + { + hiddenSize = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(hiddenSize > 0, ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); + BERT_DEBUG_VALUE("Building hiddenSize: ", hiddenSize); + } + if (field_name.compare("num_heads") == 0) + { + numHeads = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); + BERT_DEBUG_VALUE("Building numHeads: ", numHeads); + } + if (field_name.compare("has_mask") == 0) + { + auto hasMaskValue = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(hasMaskValue == 0 || hasMaskValue == 1, + ("QKV: Invalid hasMask " + std::to_string(hasMaskValue)).c_str()); + hasMask = static_cast(hasMaskValue); + BERT_DEBUG_VALUE("Building hasMask: ", hasMask); + } + + if (field_name.compare("dq_probs") == 0) + { + dqProbs = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(dqProbs > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs)).c_str()); + BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs); + } + } + + BERT_DEBUG_MSG("Building the Plugin..."); + auto type = static_cast(typeId); + if (type == DataType::kINT8 && dqProbs < 0) + { + BERT_DEBUG_MSG("Using default scale factor"); + dqProbs = 1.F / 127.F; + } + + auto* p = new QKVToContextPluginDynamicLegacy(name, type, hiddenSize, numHeads, dqProbs, hasMask); + return p; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +IPluginV2* QKVToContextPluginDynamicLegacyCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept +{ + // This object will be deleted when the network is destroyed, which will + // call QKVToContextPluginDynamicLegacy::destroy() + return new QKVToContextPluginDynamicLegacy(name, serialData, serialLength); +} + +void QKVToContextPluginDynamicLegacyCreator::setPluginNamespace(char const* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +char const* QKVToContextPluginDynamicLegacyCreator::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +QKVToContextVarSeqlenPluginLegacy::QKVToContextVarSeqlenPluginLegacy(std::string const name, DataType const type, + int32_t const hiddenSize, int32_t const numHeads, float const dqProbs, bool hasImask, bool varSeqlen, + bool useInt8ScaleMax) + : mLayerName(name) + , mS(0) + , mB(0) + , mHeadSize(hiddenSize / numHeads) + , mHiddenSize(hiddenSize) + , mNumHeads(numHeads) + , mHasImask(hasImask) + , mType(type) + , mDqProbs(dqProbs) + , mHdim(HDIM) + , mUseVarSeqlen(varSeqlen) + , mUseInt8ScaleMax(useInt8ScaleMax) +{ + mSM = getSMVersion(); + + if (varSeqlen) + { + // variable sequence length is only supported with the fused MHA kernels + // we should not override mS! + PLUGIN_ASSERT( + (mSM == kSM_90 || mSM == kSM_87 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_80 || mSM == kSM_75) + && (type == DataType::kINT8 || type == DataType::kHALF) + && "requesting maxSeqlen not compatible with GPU arch"); + // the layout changes: SxB will be a combined \sum_i s_i and hdim will be the 2nd dimension instead of the third + mHdim = 1; + } +} + +QKVToContextVarSeqlenPluginLegacy::QKVToContextVarSeqlenPluginLegacy( + std::string const name, void const* data, size_t length) + : mLayerName(name) +{ + BERT_DEBUG_MSG("QKV Deser Start"); + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mNumHeads); + deserialize_value(&data, &length, &mHeadSize); + deserialize_value(&data, &length, &mHasImask); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mSM); + deserialize_value(&data, &length, &mS); + deserialize_value(&data, &length, &mB); + + deserialize_value(&data, &length, &mDqProbs); + + deserialize_value(&data, &length, &mUseVarSeqlen); + deserialize_value(&data, &length, &mHdim); + deserialize_value(&data, &length, &mUseInt8ScaleMax); + + createMHARunner(); + mDispatcher->deserialize(data, length); + + BERT_DEBUG_MSG("QKV Deser done"); +} + +void QKVToContextVarSeqlenPluginLegacy::createMHARunner() +{ + if (mDispatcher.get()) + { + return; + } + + if (mUseVarSeqlen) + { + PLUGIN_ASSERT(mHeadSize <= 64); + { + if (mHeadSize != 64) + { + mPatcher.reset(new QkvPaddingRunner(mType)); + } + if (mType == DataType::kHALF) + { + mDispatcher.reset(new FusedMHARunnerFP16v2(mNumHeads, mSM)); + } + else if (mType == DataType::kINT8) + { + mDispatcher.reset(new FusedMHARunnerInt8v2(mNumHeads, mSM, mDqProbs, mUseInt8ScaleMax)); + } + } + } + else + { + PLUGIN_ASSERT(mType != DataType::kINT8); + mDispatcher.reset(new UnfusedMHARunner(mType, mNumHeads, mSM)); + } +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* QKVToContextVarSeqlenPluginLegacy::clone() const noexcept +{ + BERT_DEBUG_MSG("QKV Clone"); + + QKVToContextVarSeqlenPluginLegacy* ret = nullptr; + if (mDispatcher.get()) + { + std::vector buff; + buff.resize(getSerializationSize()); + serialize(buff.data()); + + ret = new QKVToContextVarSeqlenPluginLegacy(mLayerName, buff.data(), buff.size()); + } + else + { + ret = new QKVToContextVarSeqlenPluginLegacy( + mLayerName, mType, mHiddenSize, mNumHeads, mDqProbs, mHasImask, mUseVarSeqlen, mUseInt8ScaleMax); + } + + ret->setPluginNamespace(mNamespace.c_str()); + BERT_DEBUG_MSG("QKV Clone done"); + return ret; +} + +DimsExprs QKVToContextVarSeqlenPluginLegacy::getOutputDimensions( + int32_t outputIndex, DimsExprs const* inputs, int32_t /*nbInputs*/, IExprBuilder& exprBuilder) noexcept +{ + // Input is BxSx3*N*H, output should be BxSxN*H + PLUGIN_ASSERT(outputIndex == 0); + // Copy over everything + DimsExprs output(inputs[kIIDX]); + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + output.d[mHdim] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[kIIDX].d[mHdim], *three); + return output; +} + +bool QKVToContextVarSeqlenPluginLegacy::supportsFormatCombination( + int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept +{ + // we only support variable sequence and int8 IO in fused mha runner, and we only support fused mha runner on + // Turing, Ampere and Hopper + bool const hasV2Kernels + = (mSM == kSM_90 || mSM == kSM_89 || mSM == kSM_87 || mSM == kSM_86 || mSM == kSM_80 || mSM == kSM_75); + PLUGIN_ASSERT( + (mType != DataType::kINT8 || hasV2Kernels) && "INT8 IO is only supported on Xavier, Turing, Ampere and Hopper"); + PLUGIN_ASSERT( + (!mUseVarSeqlen || hasV2Kernels) && "Variable sequence is only supported on Xavier, Turing, Ampere and Hopper"); + + PLUGIN_ASSERT(pos >= 0); + PLUGIN_ASSERT(pos < 2 + mHasImask + 2 * mUseVarSeqlen); + PLUGIN_ASSERT(nbInputs == 1 + mHasImask + 2 * mUseVarSeqlen); + PLUGIN_ASSERT(nbOutputs == 1); + auto const* in = inOut; + auto const* out = inOut + nbInputs; + if (mUseVarSeqlen) + { + PLUGIN_ASSERT((mType == DataType::kHALF || mType == DataType::kINT8) + && "Conditions for variable seqlen support not fulfilled"); + // qkv, mask, cu_seqlens, dummy + PLUGIN_ASSERT(nbInputs == 4 && "for varseqlen, expected 4 inputs"); + } + + auto const inDims = in->dims; + auto const outDims = out->dims; + + auto supportedFormat = TensorFormat::kLINEAR; + if (mType == DataType::kINT8) + { + supportedFormat = (inDims.d[mHdim] % 32U == 0) ? TensorFormat::kCHW32 : TensorFormat::kCHW4; + } + + int32_t supportedNbDims = 5; + if (mUseVarSeqlen) + { + supportedNbDims = 4; + } + + bool supportedHdim = (pos == 0) ? (inDims.d[mHdim] % 3U == 0) : (inDims.d[mHdim] / 3 == outDims.d[mHdim]); + + if (pos == 0 || pos == nbInputs) + { // check input and output + auto const& desc = inOut[pos]; + return (desc.type == mType) && // check type + (desc.format == supportedFormat) && // check format + (desc.dims.nbDims == supportedNbDims) && // check dims: + (supportedHdim) && // - hidden dims multiple of 3 for qkv + (desc.dims.d[mHdim + 1] == 1) && // - dummy 1 or h + (desc.dims.d[mHdim + 2] == 1) // - dummy 1 or w + ; + } + + PLUGIN_ASSERT(mHasImask); + if (pos == 1) + { // must be input mask + auto const* mask = &inOut[pos]; + if (mUseVarSeqlen) + { + // dummy input + return true; + } + + return mask->format == TensorFormat::kLINEAR && (mask->type == DataType::kINT32) && // precision + (mask->dims.nbDims == 1) // num dims + ; + } + PLUGIN_ASSERT(mUseVarSeqlen); + if (pos == 2) + { // must be cuSeqlens + // cuSeqlens is a int32_t array of size B+1 + auto const* seqlens = &inOut[pos]; + return (seqlens->type == DataType::kINT32) && (seqlens->format == TensorFormat::kLINEAR); + } + if (pos == 3) + { + // this is the dummy input + return inOut[pos].dims.nbDims == 1; + } + return false; +} + +void QKVToContextVarSeqlenPluginLegacy::configurePlugin( + DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept +{ + PLUGIN_ASSERT(nbInputs == 1 + mHasImask + 2 * mUseVarSeqlen); + PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[kIIDX].desc; + TRT_UNUSED inDesc; + PluginTensorDesc const& outDesc = out->desc; + TRT_UNUSED outDesc; + PLUGIN_ASSERT(mType == inDesc.type); + PLUGIN_ASSERT(mType == outDesc.type); + if (!mUseVarSeqlen) + { + PLUGIN_ASSERT(inDesc.dims.d[BDIM] == outDesc.dims.d[BDIM]); + PLUGIN_ASSERT(inDesc.dims.d[SDIM] == outDesc.dims.d[SDIM]); + PLUGIN_ASSERT(inDesc.dims.d[mHdim] == 3 * outDesc.dims.d[mHdim]); + if (mHasImask) + { + PluginTensorDesc const& maskDesc = in[kMIDX].desc; + TRT_UNUSED maskDesc; + PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[BDIM]); + } + + int32_t const S = inDesc.dims.d[SDIM] <= 0 ? in->max.d[SDIM] : inDesc.dims.d[SDIM]; + int32_t const B = inDesc.dims.d[BDIM] <= 0 ? in->max.d[BDIM] : inDesc.dims.d[BDIM]; + + if (S != mS || B != mB) + { + BERT_DEBUG_MSG("setting up MHA runner for single sequence length"); + createMHARunner(); + this->mDispatcher->setup(S, B, mHeadSize); + mS = S; + mB = B; + } + } + else + { + BERT_DEBUG_MSG("setting up MHA runner for variable sequence length"); + createMHARunner(); + // need to initialize S and B with somewhat useful values, they will be reset at enqueue for the actual + // batchsize + this->mDispatcher->setup(256, 1, mHeadSize); + } +} + +size_t QKVToContextVarSeqlenPluginLegacy::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t /* nbInputs */, + PluginTensorDesc const* /* outputs */, int32_t /* nbOutputs */) const noexcept +{ + size_t paddingWorkpaceSize = mPatcher ? mPatcher->getWorkspaceSize(inputs[0].dims.d[0], mNumHeads) : 0; + return mDispatcher->getWorkspaceSize() + paddingWorkpaceSize; +} + +// IPluginV2Ext Methods +DataType QKVToContextVarSeqlenPluginLegacy::getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t /*nbInputs*/) const noexcept +{ + PLUGIN_ASSERT(index == 0); + PLUGIN_ASSERT( + inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || inputTypes[0] == DataType::kINT8); + return inputTypes[0]; +} + +void QKVToContextVarSeqlenPluginLegacy::attachToContext( + cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept +{ + try + { + mCublasWrapper = createPluginCublasWrapper(allocator); + mCublas = mCublasWrapper->getCublasHandle(); + PLUGIN_VALIDATE(mCublas != nullptr); + } + catch (std::exception const& e) + { + caughtError(e); + } +} + +// IPluginV2 Methods +char const* QKVToContextVarSeqlenPluginLegacy::getPluginType() const noexcept +{ + return kQKV_TO_CONTEXT_PLUGIN_NAME; +} + +char const* QKVToContextVarSeqlenPluginLegacy::getPluginVersion() const noexcept +{ + return kQKV_TO_CONTEXT_VAR_SEQLEN_LEGACY_PLUGIN_VERSION; +} + +int32_t QKVToContextVarSeqlenPluginLegacy::getNbOutputs() const noexcept +{ + return 1; +} + +int32_t QKVToContextVarSeqlenPluginLegacy::initialize() noexcept +{ + return 0; +} + +void QKVToContextVarSeqlenPluginLegacy::terminate() noexcept {} + +size_t QKVToContextVarSeqlenPluginLegacy::getSerializationSize() const noexcept +{ + return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(DataType) + sizeof(mHasImask) + sizeof(mHiddenSize) + + sizeof(mSM) + sizeof(mS) + sizeof(mB) + sizeof(mDqProbs) + mDispatcher->getSerializationSize() + + sizeof(mUseVarSeqlen) + sizeof(mHdim) + sizeof(mUseInt8ScaleMax); +} + +void QKVToContextVarSeqlenPluginLegacy::serialize(void* buffer) const noexcept +{ + serialize_value(&buffer, mType); + serialize_value(&buffer, mNumHeads); + serialize_value(&buffer, mHeadSize); + serialize_value(&buffer, mHasImask); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mSM); + serialize_value(&buffer, mS); + serialize_value(&buffer, mB); + + serialize_value(&buffer, mDqProbs); + serialize_value(&buffer, mUseVarSeqlen); + serialize_value(&buffer, mHdim); + serialize_value(&buffer, mUseInt8ScaleMax); + mDispatcher->serialize(buffer); +} + +void QKVToContextVarSeqlenPluginLegacy::destroy() noexcept +{ + delete this; +} + +void QKVToContextVarSeqlenPluginLegacy::setPluginNamespace(char const* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +char const* QKVToContextVarSeqlenPluginLegacy::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +int32_t QKVToContextVarSeqlenPluginLegacy::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept +{ + PLUGIN_VALIDATE(inputDesc != nullptr && outputDesc != nullptr && inputs != nullptr && outputs != nullptr); + + if (mUseVarSeqlen) + { + int32_t const B = inputDesc[2].dims.d[0] - 1; + int32_t const maxS = inputDesc[3].dims.d[0]; + PLUGIN_ASSERT((maxS <= 512) + && "No implementation for variable sequence length multi-head attention plugin with sequence > 512."); + + int32_t S = 512; + if (DataType::kHALF == mType && maxS <= 64) + { + S = 64; + } + else if (DataType::kHALF == mType && maxS <= 96) + { + S = 96; + } + else if (maxS <= 128) + { + S = 128; + } + else if (maxS <= 192) + { + S = 192; + if (mType == DataType::kHALF) + { + S = 256; + } + } + else if (maxS <= 256) + { + S = 256; + } + else if (maxS <= 384) + { + S = 384; + } + + auto runV2Kernel = [this, &S, &B, &workspace, &inputDesc, &outputDesc, &stream, &inputs, &outputs]( + MHARunner* dispatcher, QkvPaddingRunner* patcher, int32_t padSize) { + PLUGIN_ASSERT(dispatcher); + // Validate that we can padding to the dispatch required head size also there is kernel exist for this + // sequence length. + if (mHeadSize > padSize || !dispatcher->isValid(padSize, S)) + { + return false; + } + dispatcher->setup(S, B, padSize); + + // Need pad and unpad to run the V2 kernel. + if (mHeadSize < padSize) + { + PLUGIN_ASSERT(patcher); + PLUGIN_ASSERT(padSize <= patcher->getMaxPaddingHeadSize()); + auto sumSeqLen = inputDesc[0].dims.d[0]; + auto paddingWorkspace = patcher->get16BytesAlignedPointer(workspace, dispatcher->getWorkspaceSize()); + auto ret = mPatcher->pad(inputs[0], paddingWorkspace, sumSeqLen, mNumHeads, mHeadSize, padSize, stream); + if (ret != cudaSuccess) + { + return false; + } + + MhaRunParameter paddingArgs = patcher->patchMhaArgs( + inputDesc, outputDesc, inputs, outputs, paddingWorkspace, sumSeqLen, mNumHeads, padSize); + try + { + dispatcher->run(paddingArgs.inputDesc, paddingArgs.outputDesc, paddingArgs.inputs, + paddingArgs.outputs, workspace, stream, mCublas); + } + catch (std::exception const& e) + { + caughtError(e); + return false; + } + + ret = patcher->unpad( + paddingArgs.outputs[0], outputs[0], sumSeqLen, mNumHeads, mHeadSize, padSize, stream); + return ret == cudaSuccess; + } + else + { + // No pad/unpad is needed. + try + { + dispatcher->run(inputDesc, outputDesc, inputs, outputs, workspace, stream, mCublas); + } + catch (std::exception const& e) + { + caughtError(e); + return false; + } + return true; + } + }; + // Try pad head size to 32 first, if it failed, then try to pad head size to 64. + if (!runV2Kernel(mDispatcher.get(), mPatcher.get(), 32) && !runV2Kernel(mDispatcher.get(), mPatcher.get(), 64)) + { + return false; + } + + return cudaGetLastError(); + } + + PLUGIN_ASSERT(mS == inputDesc->dims.d[SDIM]); + PLUGIN_ASSERT(mB == inputDesc->dims.d[BDIM]); + + void const* maskPtr = mHasImask ? inputs[1] : nullptr; + mDispatcher->run(inputDesc[0], outputDesc[0], inputs[0], maskPtr, outputs[0], workspace, stream, mCublas); + return cudaGetLastError(); +} + +QKVToContextVarSeqlenPluginLegacyCreator::QKVToContextVarSeqlenPluginLegacyCreator() +{ + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("hidden_size", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("has_mask", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("dq_probs", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("var_seqlen", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("use_int8_scale_max", nullptr, PluginFieldType::kINT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* QKVToContextVarSeqlenPluginLegacyCreator::getPluginName() const noexcept +{ + return kQKV_TO_CONTEXT_PLUGIN_NAME; +} + +char const* QKVToContextVarSeqlenPluginLegacyCreator::getPluginVersion() const noexcept +{ + return kQKV_TO_CONTEXT_VAR_SEQLEN_LEGACY_PLUGIN_VERSION; +} + +PluginFieldCollection const* QKVToContextVarSeqlenPluginLegacyCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2* QKVToContextVarSeqlenPluginLegacyCreator::createPlugin( + char const* name, PluginFieldCollection const* fc) noexcept +{ + BERT_DEBUG_MSG("Creating QKV2ContextPlugin..."); + + int32_t hiddenSize = 0; + // Since numHeads must always exist or validateRequiredAttributes will fail, + // we can set numHeads to -1 so that static analysis tools don't warn about + // a division by zero in QKVToContextVarSeqelnPlugin constructor. + int32_t numHeads{-1}; + bool hasMask = false; + int32_t typeId = -1; + + int32_t varSeqlen = 0; + + float dqProbs = -1; + int32_t useInt8ScaleMax{-1}; + + plugin::validateRequiredAttributesExist({"type_id", "hidden_size", "num_heads", "has_mask"}, fc); + for (int32_t i = 0; i < fc->nbFields; i++) + { + std::string field_name(fc->fields[i].name); + + if (field_name.compare("type_id") == 0) + { + typeId = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(typeId >= 0 && typeId <= 2, ("QKV: Invalid TypeId " + std::to_string(typeId)).c_str()); + BERT_DEBUG_VALUE("Building typeId: ", typeId); + } + if (field_name.compare("hidden_size") == 0) + { + hiddenSize = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(hiddenSize > 0, ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); + BERT_DEBUG_VALUE("Building hiddenSize: ", hiddenSize); + } + if (field_name.compare("num_heads") == 0) + { + numHeads = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); + BERT_DEBUG_VALUE("Building numHeads: ", numHeads); + } + if (field_name.compare("has_mask") == 0) + { + hasMask = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(hasMask == 0 || hasMask == 1, ("QKV: Invalid hasMask " + std::to_string(hasMask)).c_str()); + BERT_DEBUG_VALUE("Building hasMask: ", hasMask); + } + + if (field_name.compare("dq_probs") == 0) + { + dqProbs = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(dqProbs > 0.0F, ("QKV: Invalid dqProbs " + std::to_string(dqProbs)).c_str()); + BERT_DEBUG_VALUE("Building dqProbs: ", dqProbs); + } + if (field_name.compare("var_seqlen") == 0) + { + varSeqlen = *static_cast(fc->fields[i].data); + BERT_DEBUG_VALUE("Building var_seqlen: ", varSeqlen); + } + if (field_name.compare("use_int8_scale_max") == 0) + { + useInt8ScaleMax = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(useInt8ScaleMax == 0 || useInt8ScaleMax == 1, + ("QKV: Invalid useInt8ScaleMax " + std::to_string(useInt8ScaleMax)).c_str()); + BERT_DEBUG_VALUE("Building useInt8ScaleMax: ", useInt8ScaleMax); + } + } + + if (useInt8ScaleMax < 0) + { + gLogInfo << "Using default for use_int8_scale_max: true" << std::endl; + useInt8ScaleMax = 1; + } + + BERT_DEBUG_MSG("Building the Plugin..."); + DataType type = static_cast(typeId); + if (type == DataType::kINT8 && dqProbs < 0) + { + gLogInfo << "Using default scale factor\n"; + dqProbs = 1.F / 127.F; + } + + auto const useInt8ScaleMaxFlag = static_cast(useInt8ScaleMax); + + QKVToContextVarSeqlenPluginLegacy* p = new QKVToContextVarSeqlenPluginLegacy( + name, type, hiddenSize, numHeads, dqProbs, hasMask, varSeqlen, useInt8ScaleMaxFlag); + return p; +} + +IPluginV2* QKVToContextVarSeqlenPluginLegacyCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept +{ + // This object will be deleted when the network is destroyed, which will + // call QKVToContextVarSeqlenPluginLegacy::destroy() + return new QKVToContextVarSeqlenPluginLegacy(name, serialData, serialLength); +} + +void QKVToContextVarSeqlenPluginLegacyCreator::setPluginNamespace(char const* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +char const* QKVToContextVarSeqlenPluginLegacyCreator::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +#endif // CUDA_VERSION >= 10010 diff --git a/plugin/bertQKVToContextPlugin/qkvToContextPluginLegacy.h b/plugin/bertQKVToContextPlugin/qkvToContextPluginLegacy.h new file mode 100644 index 00000000..6f3038b8 --- /dev/null +++ b/plugin/bertQKVToContextPlugin/qkvToContextPluginLegacy.h @@ -0,0 +1,266 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Need 10.1 for cublasGemmStridedBatchedEx +#include +#if CUDA_VERSION >= 10010 + +#ifndef TRT_QKV_TO_CONTEXT_PLUGIN_H +#define TRT_QKV_TO_CONTEXT_PLUGIN_H + +#include "NvInferPlugin.h" +#include "common/cublasWrapper.h" +#include "mhaRunner.h" +#include "zeroPadding2d.h" +#include +#include +#include + +namespace nvinfer1 +{ +namespace plugin +{ +namespace bert +{ + +std::pair tuneBatchedGemm( + int32_t const B, int32_t const S, int32_t const numHeads, int32_t const headSize); + +template +int32_t computeScaledSoftmax(cudaStream_t stream, int32_t const ld, int32_t const B, int32_t const N, + float const rsqrtHeadSize, T const* input, T* output); + +template +int32_t computeMaskedScaledSoftmax(cudaStream_t stream, int32_t const ld, int32_t const B, int32_t const N, + float const rsqrtHeadSize, int32_t const* maskIdx, T const* input, T* output); + +// One of the preferred ways of making TensorRT to be able to see +// our custom layer requires extending IPluginV2 and IPluginCreator classes. +// For requirements for overriden functions, check TensorRT API docs. + +class QKVToContextPluginDynamicLegacy : public nvinfer1::IPluginV2DynamicExt +{ +public: + QKVToContextPluginDynamicLegacy(std::string const name, nvinfer1::DataType const type, int32_t const hiddenSize, + int32_t const numHeads, float const dqProbs, bool hasImask = false); + + QKVToContextPluginDynamicLegacy(std::string const name, void const* data, size_t length); + + // It doesn't make sense to make QKVToContextPluginDynamicLegacy without arguments, so we + // delete default constructor. + QKVToContextPluginDynamicLegacy() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; + void attachToContext( + cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept override; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + +protected: + void createMHARunner(); + +private: + std::string const mLayerName; + std::string mNamespace; + + // used for sequence len 128, 384, precision int8 + // used for sequence len 64, 96, 128, 384, precision fp16 + std::unique_ptr fusedDispatcher; + // used for other sequence, precision fp32 and fp16 + std::unique_ptr unfusedDispatcher; + + int32_t mS{}; + int32_t mB{}; + int32_t mSM{}; + int32_t mHeadSize{}; + int32_t mHiddenSize; + int32_t mNumHeads{}; + bool mHasImask{}; + nvinfer1::DataType mType{}; + float mDqProbs{}; + nvinfer1::pluginInternal::cublasHandle_t mCublas{}; + // the wrapper pointer is shared among all plugins attached to the same context. + std::shared_ptr mCublasWrapper; + + using IPluginV2::getOutputDimensions; + using IPluginV2::getWorkspaceSize; + using IPluginV2::enqueue; + using IPluginV2Ext::configurePlugin; +}; + +class QKVToContextPluginDynamicLegacyCreator : public nvinfer1::IPluginCreator +{ +public: + QKVToContextPluginDynamicLegacyCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + +private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +class QKVToContextVarSeqlenPluginLegacy : public nvinfer1::IPluginV2DynamicExt +{ +public: + QKVToContextVarSeqlenPluginLegacy(std::string const name, nvinfer1::DataType const type, int32_t const hiddenSize, + int32_t const numHeads, float const dqProbs, bool hasImask = false, bool varSeqlen = false, + bool const useInt8ScaleMax = true); + + QKVToContextVarSeqlenPluginLegacy(std::string const name, void const* data, size_t length); + + // It doesn't make sense to make QKVToContextVarSeqlenPluginLegacy without arguments, so we + // delete default constructor. + QKVToContextVarSeqlenPluginLegacy() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; + void attachToContext( + cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) noexcept override; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + +protected: + void createMHARunner(); + +private: + std::string const mLayerName; + std::string mNamespace; + + // Used for kernels with header size equals to 32. + std::unique_ptr mDispatcher; + std::unique_ptr mPatcher; + + int32_t mS{}; + int32_t mB{}; + int32_t mSM{}; + int32_t mHeadSize{}; + int32_t mHiddenSize{}; + int32_t mNumHeads{}; + bool mHasImask{}; + nvinfer1::DataType mType{}; + + float mDqProbs{}; + + int32_t mHdim{}; + bool mUseVarSeqlen{}; + bool mUseInt8ScaleMax{true}; + nvinfer1::pluginInternal::cublasHandle_t mCublas{}; + // the wrapper pointer is shared among all plugins attached to the same context. + std::shared_ptr mCublasWrapper; + + using IPluginV2::getOutputDimensions; + using IPluginV2::getWorkspaceSize; + using IPluginV2::enqueue; + using IPluginV2Ext::configurePlugin; +}; + +class QKVToContextVarSeqlenPluginLegacyCreator : public nvinfer1::IPluginCreator +{ +public: + QKVToContextVarSeqlenPluginLegacyCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + +private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace plugin +} // namespace nvinfer1 +#endif // TRT_QKV_TO_CONTEXT_PLUGIN_H + +#endif // CUDA_VERSION >= 10010 diff --git a/plugin/common/bertCommon.h b/plugin/common/bertCommon.h index 16c3e54a..a52b380f 100644 --- a/plugin/common/bertCommon.h +++ b/plugin/common/bertCommon.h @@ -56,9 +56,6 @@ constexpr uint32_t BDIM = 1; // batch dimension constexpr uint32_t SDIM = 0; // seq len dimension constexpr uint32_t HDIM = 2; // hidden dimension -constexpr int32_t kSM_53 = 53; -constexpr int32_t kSM_70 = 70; -constexpr int32_t kSM_72 = 72; constexpr int32_t kSM_75 = 75; constexpr int32_t kSM_80 = 80; constexpr int32_t kSM_86 = 86; diff --git a/plugin/common/checkMacrosPlugin.cpp b/plugin/common/checkMacrosPlugin.cpp index f685c1fc..1fd255d4 100644 --- a/plugin/common/checkMacrosPlugin.cpp +++ b/plugin/common/checkMacrosPlugin.cpp @@ -17,7 +17,7 @@ #include "common/checkMacrosPlugin.h" #include "common/cublasWrapper.h" -#include "common/vfcCommon.h" +#include "vc/vfcCommon.h" #include #include @@ -28,41 +28,6 @@ namespace nvinfer1 namespace plugin { -// This will be populated by the logger supplied by the user to initLibNvInferPlugins() -ILogger* gLogger{}; - -template -int32_t LogStream::Buf::sync() -{ - std::string s = str(); - while (!s.empty() && s.back() == '\n') - { - s.pop_back(); - } - if (gLogger != nullptr) - { - gLogger->log(kSeverity, s.c_str()); - } - str(""); - return 0; -} - -// These use gLogger, and therefore require initLibNvInferPlugins() to be called with a logger -// (otherwise, it will not log) -LogStream gLogError; -LogStream gLogWarning; -LogStream gLogInfo; -LogStream gLogVerbose; - -// break-pointable -void throwCudaError(char const* file, char const* function, int32_t line, int32_t status, char const* msg) -{ - CudaError error(file, function, line, status, msg); - error.log(gLogError); - // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference) - throw error; -} - // break-pointable void throwCublasError(char const* file, char const* function, int32_t line, int32_t status, char const* msg) { @@ -98,67 +63,6 @@ void throwCudnnError(char const* file, char const* function, int32_t line, int32 throw error; } -// break-pointable -void throwPluginError(char const* file, char const* function, int32_t line, int32_t status, char const* msg) -{ - PluginError error(file, function, line, status, msg); - reportValidationFailure(msg, file, line); - // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference) - throw error; -} - -void logError(char const* msg, char const* file, char const* fn, int32_t line) -{ - gLogError << "Parameter check failed at: " << file << "::" << fn << "::" << line; - gLogError << ", condition: " << msg << std::endl; -} - -void reportValidationFailure(char const* msg, char const* file, int32_t line) -{ - std::ostringstream stream; - stream << "Validation failed: " << msg << "\n" << file << ':' << line << "\n"; -#ifdef COMPILE_VFC_PLUGIN - ILogger* logger = getPluginLogger(); - if (logger != nullptr) - { - logger->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); - } -#else - getLogger()->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); -#endif -} - -// break-pointable -void reportAssertion(char const* msg, char const* file, int32_t line) -{ - std::ostringstream stream; - stream << "Assertion failed: " << msg << "\n" - << file << ':' << line << "\n" - << "Aborting..." - << "\n"; -#ifdef COMPILE_VFC_PLUGIN - ILogger* logger = getPluginLogger(); - if (logger != nullptr) - { - logger->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); - } -#else - getLogger()->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); -#endif - PLUGIN_CUASSERT(cudaDeviceReset()); - exit(EXIT_FAILURE); -} - -void TRTException::log(std::ostream& logStream) const -{ - logStream << file << " (" << line << ") - " << name << " Error in " << function << ": " << status; - if (message != nullptr) - { - logStream << " (" << message << ")"; - } - logStream << std::endl; -} - } // namespace plugin } // namespace nvinfer1 diff --git a/plugin/common/checkMacrosPlugin.h b/plugin/common/checkMacrosPlugin.h index 90b6e263..976c1a8e 100644 --- a/plugin/common/checkMacrosPlugin.h +++ b/plugin/common/checkMacrosPlugin.h @@ -19,6 +19,7 @@ #include "NvInfer.h" #include "common/cudnnWrapper.h" +#include "vc/checkMacrosPlugin.h" #include #include @@ -32,110 +33,11 @@ namespace nvinfer1 { namespace plugin { -template -class LogStream : public std::ostream -{ - class Buf : public std::stringbuf - { - public: - int32_t sync() override; - }; - - Buf buffer; - std::mutex mLogStreamMutex; - -public: - std::mutex& getMutex() - { - return mLogStreamMutex; - } - LogStream() - : std::ostream(&buffer){}; -}; - -// Use mutex to protect multi-stream write to buffer -template -LogStream& operator<<(LogStream& stream, T const& msg) -{ - std::lock_guard guard(stream.getMutex()); - auto& os = static_cast(stream); - os << msg; - return stream; -} - -// Special handling static numbers -template -inline LogStream& operator<<(LogStream& stream, int32_t num) -{ - std::lock_guard guard(stream.getMutex()); - auto& os = static_cast(stream); - os << num; - return stream; -} -// Special handling std::endl -template -inline LogStream& operator<<(LogStream& stream, std::ostream& (*f)(std::ostream&) ) -{ - std::lock_guard guard(stream.getMutex()); - auto& os = static_cast(stream); - os << f; - return stream; -} - -extern LogStream gLogError; -extern LogStream gLogWarning; -extern LogStream gLogInfo; -extern LogStream gLogVerbose; - -void reportValidationFailure(char const* msg, char const* file, int32_t line); -void reportAssertion(char const* msg, char const* file, int32_t line); -void logError(char const* msg, char const* file, char const* fn, int32_t line); - -[[noreturn]] void throwCudaError( - char const* file, char const* function, int32_t line, int32_t status, char const* msg = nullptr); [[noreturn]] void throwCudnnError( char const* file, char const* function, int32_t line, int32_t status, char const* msg = nullptr); [[noreturn]] void throwCublasError( char const* file, char const* function, int32_t line, int32_t status, char const* msg = nullptr); -[[noreturn]] void throwPluginError( - char const* file, char const* function, int32_t line, int32_t status, char const* msg = nullptr); - -class TRTException : public std::exception -{ -public: - TRTException(char const* fl, char const* fn, int32_t ln, int32_t st, char const* msg, char const* nm) - : file(fl) - , function(fn) - , line(ln) - , status(st) - , message(msg) - , name(nm) - { - } - virtual void log(std::ostream& logStream) const; - void setMessage(char const* msg) - { - message = msg; - } - -protected: - char const* file{nullptr}; - char const* function{nullptr}; - int32_t line{0}; - int32_t status{0}; - char const* message{nullptr}; - char const* name{nullptr}; -}; - -class CudaError : public TRTException -{ -public: - CudaError(char const* fl, char const* fn, int32_t ln, int32_t stat, char const* msg = nullptr) - : TRTException(fl, fn, ln, stat, msg, "Cuda") - { - } -}; class CudnnError : public TRTException { @@ -155,55 +57,10 @@ class CublasError : public TRTException } }; -class PluginError : public TRTException -{ -public: - PluginError(char const* fl, char const* fn, int32_t ln, int32_t stat, char const* msg = nullptr) - : TRTException(fl, fn, ln, stat, msg, "Plugin") - { - } -}; - -inline void caughtError(std::exception const& e) -{ - gLogError << e.what() << std::endl; -} } // namespace plugin } // namespace nvinfer1 -#define PLUGIN_API_CHECK(condition) \ - { \ - if ((condition) == false) \ - { \ - nvinfer1::plugin::logError(#condition, __FILE__, FN_NAME, __LINE__); \ - return; \ - } \ - } - -#define PLUGIN_API_CHECK_RETVAL(condition, retval) \ - { \ - if ((condition) == false) \ - { \ - nvinfer1::plugin::logError(#condition, __FILE__, FN_NAME, __LINE__); \ - return retval; \ - } \ - } - -#define PLUGIN_API_CHECK_ENUM_RANGE(Type, val) PLUGIN_API_CHECK(int32_t(val) >= 0 && int32_t(val) < EnumMax()) -#define PLUGIN_API_CHECK_ENUM_RANGE_RETVAL(Type, val, retval) \ - PLUGIN_API_CHECK_RETVAL(int32_t(val) >= 0 && int32_t(val) < EnumMax(), retval) - -#define PLUGIN_CHECK_CUDA(call) \ - do \ - { \ - cudaError_t status = call; \ - if (status != cudaSuccess) \ - { \ - return status; \ - } \ - } while (0) - #define PLUGIN_CHECK_CUDNN(call) \ do \ { \ @@ -234,64 +91,4 @@ inline void caughtError(std::exception const& e) } \ } -#define PLUGIN_CUASSERT(status_) \ - { \ - auto s_ = status_; \ - if (s_ != cudaSuccess) \ - { \ - const char* msg = cudaGetErrorString(s_); \ - nvinfer1::plugin::throwCudaError(__FILE__, FN_NAME, __LINE__, s_, msg); \ - } \ - } - -#define GET_MACRO(_1, _2, NAME, ...) NAME -#define PLUGIN_VALIDATE(...) GET_MACRO(__VA_ARGS__, PLUGIN_VALIDATE_MSG, PLUGIN_VALIDATE_DEFAULT, )(__VA_ARGS__) - -// Logs failed condition and throws a PluginError. -// PLUGIN_ASSERT will eventually perform this function, at which point PLUGIN_VALIDATE -// will be removed. -#define PLUGIN_VALIDATE_DEFAULT(condition) \ - { \ - if (!(condition)) \ - { \ - nvinfer1::plugin::throwPluginError(__FILE__, FN_NAME, __LINE__, 0, #condition); \ - } \ - } - -#define PLUGIN_VALIDATE_MSG(condition, msg) \ - { \ - if (!(condition)) \ - { \ - nvinfer1::plugin::throwPluginError(__FILE__, FN_NAME, __LINE__, 0, msg); \ - } \ - } - -// Logs failed assertion and aborts. -// Aborting is undesirable and will be phased-out from the plugin module, at which point -// PLUGIN_ASSERT will perform the same function as PLUGIN_VALIDATE. -#define PLUGIN_ASSERT(assertion) \ - { \ - if (!(assertion)) \ - { \ - nvinfer1::plugin::reportAssertion(#assertion, __FILE__, __LINE__); \ - } \ - } - -#define PLUGIN_FAIL(msg) \ - { \ - nvinfer1::plugin::reportAssertion(msg, __FILE__, __LINE__); \ - } - -#define PLUGIN_ERROR(msg) \ - { \ - nvinfer1::plugin::throwPluginError(__FILE__, FN_NAME, __LINE__, 0, msg); \ - } - -#define PLUGIN_CUERROR(status_) \ - { \ - auto s_ = status_; \ - if (s_ != 0) \ - nvinfer1::plugin::logError(#status_ " failure.", __FILE__, FN_NAME, __LINE__); \ - } - #endif /*CHECK_MACROS_PLUGIN_H*/ diff --git a/plugin/proposalPlugin/proposalPlugin.cpp b/plugin/proposalPlugin/proposalPlugin.cpp index 6a6e48c0..fee880c1 100644 --- a/plugin/proposalPlugin/proposalPlugin.cpp +++ b/plugin/proposalPlugin/proposalPlugin.cpp @@ -514,7 +514,7 @@ void ProposalPlugin::setPluginNamespace(char const* libNamespace) noexcept { try { - PLUGIN_VALIDATE(libNamespace == nullptr); + PLUGIN_VALIDATE(libNamespace != nullptr); mNamespace = libNamespace; } catch (std::exception const& e) @@ -527,7 +527,7 @@ void ProposalDynamicPlugin::setPluginNamespace(char const* libNamespace) noexcep { try { - PLUGIN_VALIDATE(libNamespace == nullptr); + PLUGIN_VALIDATE(libNamespace != nullptr); mNamespace = libNamespace; } catch (std::exception const& e) diff --git a/plugin/proposalPlugin/proposalPlugin.h b/plugin/proposalPlugin/proposalPlugin.h index 025f1dee..fc5585ef 100644 --- a/plugin/proposalPlugin/proposalPlugin.h +++ b/plugin/proposalPlugin/proposalPlugin.h @@ -148,7 +148,7 @@ class ProposalDynamicPlugin : public IPluginV2DynamicExt char const* getPluginNamespace() const noexcept override; // IPluginV2Ext methods - DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputType, int32_t nbInputs) const noexcept override; + DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; // IPluginV2DynamicExt methods IPluginV2DynamicExt* clone() const noexcept override; diff --git a/plugin/roiAlignPlugin/roiAlignKernel.h b/plugin/roiAlignPlugin/roiAlignKernel.h index 3be3faaa..056a8c59 100644 --- a/plugin/roiAlignPlugin/roiAlignKernel.h +++ b/plugin/roiAlignPlugin/roiAlignKernel.h @@ -17,7 +17,7 @@ #ifndef TRT_ROIALIGN_KERNEL_H #define TRT_ROIALIGN_KERNEL_H -#include "common/checkMacrosPlugin.h" +#include "vc/checkMacrosPlugin.h" #include #include diff --git a/plugin/scatterElementsPlugin/scatterElementsPluginKernel.h b/plugin/scatterElementsPlugin/scatterElementsPluginKernel.h index 652ff2b5..09bfa408 100644 --- a/plugin/scatterElementsPlugin/scatterElementsPluginKernel.h +++ b/plugin/scatterElementsPlugin/scatterElementsPluginKernel.h @@ -24,8 +24,7 @@ #define TRT_SCATTER_ELEMENTS_KERNEL_PLUGIN_H #include "common/plugin.h" -#include "scatterElementsPlugin.h" -#include "scatterElementsPluginLegacy.h" +#include "scatterElementsCommon.h" namespace nvinfer1 { diff --git a/plugin/scatterElementsPlugin/scatterElementsPluginLegacy.h b/plugin/scatterElementsPlugin/scatterElementsPluginLegacy.h index a2fc4778..5c06f67f 100644 --- a/plugin/scatterElementsPlugin/scatterElementsPluginLegacy.h +++ b/plugin/scatterElementsPlugin/scatterElementsPluginLegacy.h @@ -36,45 +36,37 @@ class ScatterElementsPluginV2 final : public nvinfer1::IPluginV2DynamicExt ScatterElementsPluginV2(void const* serialData, size_t serialLength); ~ScatterElementsPluginV2() override = default; + // IPluginV2 methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int32_t getNbOutputs() const noexcept override; - - nvinfer1::DimsExprs getOutputDimensions(int32_t index, nvinfer1::DimsExprs const* inputs, int32_t nbInputDims, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; - int32_t initialize() noexcept override; - void terminate() noexcept override; - - size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, - nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; - - int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - - bool supportsFormatCombination( - int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; - - char const* getPluginType() const noexcept override; - - char const* getPluginVersion() const noexcept override; - - nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - void destroy() noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + void setClipParam(bool clip) noexcept; + void setScoreBits(int32_t scoreBits) noexcept; + void setCaffeSemantics(bool caffeSemantics) noexcept; + // IPluginV2Ext methods nvinfer1::DataType getOutputDataType( int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; - - char const* getPluginNamespace() const noexcept override; - - void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, - nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + // IPluginV2DynamicExt methods + IPluginV2DynamicExt* clone() const noexcept override; + DimsExprs getOutputDimensions( + int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + void configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; private: ReductionType mReduction; diff --git a/plugin/vc/CMakeLists.txt b/plugin/vc/CMakeLists.txt new file mode 100644 index 00000000..a2ac13d7 --- /dev/null +++ b/plugin/vc/CMakeLists.txt @@ -0,0 +1,26 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +file(GLOB SRCS *.cpp) +set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS}) +set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE) +set(VFC_PLUGIN_SOURCES ${VFC_PLUGIN_SOURCES} ${SRCS}) +set(VFC_PLUGIN_SOURCES ${VFC_PLUGIN_SOURCES} PARENT_SCOPE) +file(GLOB CU_SRCS *.cu) +set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS}) +set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE) +set(VFC_PLUGIN_CU_SOURCES ${VFC_PLUGIN_CU_SOURCES} ${CU_SRCS}) +set(VFC_PLUGIN_CU_SOURCES ${VFC_PLUGIN_CU_SOURCES} PARENT_SCOPE) diff --git a/plugin/vc/checkMacrosPlugin.cpp b/plugin/vc/checkMacrosPlugin.cpp new file mode 100644 index 00000000..4b767c66 --- /dev/null +++ b/plugin/vc/checkMacrosPlugin.cpp @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "vc/checkMacrosPlugin.h" +#include "vc/vfcCommon.h" +#include +#include + +namespace nvinfer1::plugin +{ + +// This will be populated by the logger supplied by the user to initLibNvInferPlugins() +ILogger* gLogger{}; + +template +int32_t LogStream::Buf::sync() +{ + std::string s = str(); + while (!s.empty() && s.back() == '\n') + { + s.pop_back(); + } + if (gLogger != nullptr) + { + gLogger->log(tSeverity, s.c_str()); + } + str(""); + return 0; +} + +// These use gLogger, and therefore require initLibNvInferPlugins() to be called with a logger +// (otherwise, it will not log) +LogStream gLogError; +LogStream gLogWarning; +LogStream gLogInfo; +LogStream gLogVerbose; + +// break-pointable +void throwCudaError(char const* file, char const* function, int32_t line, int32_t status, char const* msg) +{ + CudaError error(file, function, line, status, msg); + error.log(gLogError); + // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference) + throw error; +} + +// break-pointable +void throwPluginError(char const* file, char const* function, int32_t line, int32_t status, char const* msg) +{ + PluginError error(file, function, line, status, msg); + reportValidationFailure(msg, file, line); + // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference) + throw error; +} + +void logError(char const* msg, char const* file, char const* fn, int32_t line) +{ + gLogError << "Parameter check failed at: " << file << "::" << fn << "::" << line; + gLogError << ", condition: " << msg << std::endl; +} + +void reportValidationFailure(char const* msg, char const* file, int32_t line) +{ + std::ostringstream stream; + stream << "Validation failed: " << msg << "\n" << file << ':' << line << "\n"; +#ifdef COMPILE_VFC_PLUGIN + ILogger* logger = getPluginLogger(); + if (logger != nullptr) + { + logger->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); + } +#else + getLogger()->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); +#endif +} + +// break-pointable +void reportAssertion(char const* msg, char const* file, int32_t line) +{ + std::ostringstream stream; + stream << "Assertion failed: " << msg << "\n" + << file << ':' << line << "\n" + << "Aborting..." + << "\n"; +#ifdef COMPILE_VFC_PLUGIN + ILogger* logger = getPluginLogger(); + if (logger != nullptr) + { + logger->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); + } +#else + getLogger()->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, stream.str().c_str()); +#endif + PLUGIN_CUASSERT(cudaDeviceReset()); + exit(EXIT_FAILURE); +} + +void TRTException::log(std::ostream& logStream) const +{ + logStream << file << " (" << line << ") - " << name << " Error in " << function << ": " << status; + if (message != nullptr) + { + logStream << " (" << message << ")"; + } + logStream << std::endl; +} + +} // namespace nvinfer1::plugin diff --git a/plugin/vc/checkMacrosPlugin.h b/plugin/vc/checkMacrosPlugin.h new file mode 100644 index 00000000..e5f4f66d --- /dev/null +++ b/plugin/vc/checkMacrosPlugin.h @@ -0,0 +1,244 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef VC_CHECK_MACROS_PLUGIN_H +#define VC_CHECK_MACROS_PLUGIN_H + +#include "NvInfer.h" +#include +#include + +#ifdef _MSC_VER +#define FN_NAME __FUNCTION__ +#else +#define FN_NAME __func__ +#endif + +namespace nvinfer1 +{ +namespace plugin +{ +template +class LogStream : public std::ostream +{ + class Buf : public std::stringbuf + { + public: + int32_t sync() override; + }; + + Buf buffer; + std::mutex mLogStreamMutex; + +public: + std::mutex& getMutex() + { + return mLogStreamMutex; + } + LogStream() + : std::ostream(&buffer){}; +}; + +// Use mutex to protect multi-stream write to buffer +template +LogStream& operator<<(LogStream& stream, T const& msg) +{ + std::lock_guard guard(stream.getMutex()); + auto& os = static_cast(stream); + os << msg; + return stream; +} + +// Special handling static numbers +template +inline LogStream& operator<<(LogStream& stream, int32_t num) +{ + std::lock_guard guard(stream.getMutex()); + auto& os = static_cast(stream); + os << num; + return stream; +} + +// Special handling std::endl +template +inline LogStream& operator<<(LogStream& stream, std::ostream& (*f)(std::ostream&) ) +{ + std::lock_guard guard(stream.getMutex()); + auto& os = static_cast(stream); + os << f; + return stream; +} + +extern LogStream gLogError; +extern LogStream gLogWarning; +extern LogStream gLogInfo; +extern LogStream gLogVerbose; + +void reportValidationFailure(char const* msg, char const* file, int32_t line); +void reportAssertion(char const* msg, char const* file, int32_t line); +void logError(char const* msg, char const* file, char const* fn, int32_t line); + +[[noreturn]] void throwCudaError( + char const* file, char const* function, int32_t line, int32_t status, char const* msg = nullptr); +[[noreturn]] void throwPluginError( + char const* file, char const* function, int32_t line, int32_t status, char const* msg = nullptr); + +class TRTException : public std::exception +{ +public: + TRTException(char const* fl, char const* fn, int32_t ln, int32_t st, char const* msg, char const* nm) + : file(fl) + , function(fn) + , line(ln) + , status(st) + , message(msg) + , name(nm) + { + } + virtual void log(std::ostream& logStream) const; + void setMessage(char const* msg) + { + message = msg; + } + +protected: + char const* file{nullptr}; + char const* function{nullptr}; + int32_t line{0}; + int32_t status{0}; + char const* message{nullptr}; + char const* name{nullptr}; +}; + +class CudaError : public TRTException +{ +public: + CudaError(char const* fl, char const* fn, int32_t ln, int32_t stat, char const* msg = nullptr) + : TRTException(fl, fn, ln, stat, msg, "Cuda") + { + } +}; + +class PluginError : public TRTException +{ +public: + PluginError(char const* fl, char const* fn, int32_t ln, int32_t stat, char const* msg = nullptr) + : TRTException(fl, fn, ln, stat, msg, "Plugin") + { + } +}; + +inline void caughtError(std::exception const& e) +{ + gLogError << e.what() << std::endl; +} +} // namespace plugin + +} // namespace nvinfer1 + +#define PLUGIN_API_CHECK(condition) \ + { \ + if ((condition) == false) \ + { \ + nvinfer1::plugin::logError(#condition, __FILE__, FN_NAME, __LINE__); \ + return; \ + } \ + } + +#define PLUGIN_API_CHECK_RETVAL(condition, retval) \ + { \ + if ((condition) == false) \ + { \ + nvinfer1::plugin::logError(#condition, __FILE__, FN_NAME, __LINE__); \ + return retval; \ + } \ + } + +#define PLUGIN_API_CHECK_ENUM_RANGE(Type, val) PLUGIN_API_CHECK(int32_t(val) >= 0 && int32_t(val) < EnumMax()) +#define PLUGIN_API_CHECK_ENUM_RANGE_RETVAL(Type, val, retval) \ + PLUGIN_API_CHECK_RETVAL(int32_t(val) >= 0 && int32_t(val) < EnumMax(), retval) + +#define PLUGIN_CHECK_CUDA(call) \ + do \ + { \ + cudaError_t status = call; \ + if (status != cudaSuccess) \ + { \ + return status; \ + } \ + } while (0) + +#define PLUGIN_CUASSERT(status_) \ + { \ + auto s_ = status_; \ + if (s_ != cudaSuccess) \ + { \ + char const* msg = cudaGetErrorString(s_); \ + nvinfer1::plugin::throwCudaError(__FILE__, FN_NAME, __LINE__, s_, msg); \ + } \ + } + +#define GET_MACRO(_1, _2, NAME, ...) NAME +#define PLUGIN_VALIDATE(...) GET_MACRO(__VA_ARGS__, PLUGIN_VALIDATE_MSG, PLUGIN_VALIDATE_DEFAULT, )(__VA_ARGS__) + +// Logs failed condition and throws a PluginError. +// PLUGIN_ASSERT will eventually perform this function, at which point PLUGIN_VALIDATE +// will be removed. +#define PLUGIN_VALIDATE_DEFAULT(condition) \ + { \ + if (!(condition)) \ + { \ + nvinfer1::plugin::throwPluginError(__FILE__, FN_NAME, __LINE__, 0, #condition); \ + } \ + } + +#define PLUGIN_VALIDATE_MSG(condition, msg) \ + { \ + if (!(condition)) \ + { \ + nvinfer1::plugin::throwPluginError(__FILE__, FN_NAME, __LINE__, 0, msg); \ + } \ + } + +// Logs failed assertion and aborts. +// Aborting is undesirable and will be phased-out from the plugin module, at which point +// PLUGIN_ASSERT will perform the same function as PLUGIN_VALIDATE. +#define PLUGIN_ASSERT(assertion) \ + { \ + if (!(assertion)) \ + { \ + nvinfer1::plugin::reportAssertion(#assertion, __FILE__, __LINE__); \ + } \ + } + +#define PLUGIN_FAIL(msg) \ + { \ + nvinfer1::plugin::reportAssertion(msg, __FILE__, __LINE__); \ + } + +#define PLUGIN_ERROR(msg) \ + { \ + nvinfer1::plugin::throwPluginError(__FILE__, FN_NAME, __LINE__, 0, msg); \ + } + +#define PLUGIN_CUERROR(status_) \ + { \ + auto s_ = status_; \ + if (s_ != 0) \ + nvinfer1::plugin::logError(#status_ " failure.", __FILE__, FN_NAME, __LINE__); \ + } + +#endif /*VC_CHECK_MACROS_PLUGIN_H*/ diff --git a/plugin/common/vfcCommon.cpp b/plugin/vc/vfcCommon.cpp similarity index 100% rename from plugin/common/vfcCommon.cpp rename to plugin/vc/vfcCommon.cpp diff --git a/plugin/common/vfcCommon.h b/plugin/vc/vfcCommon.h similarity index 100% rename from plugin/common/vfcCommon.h rename to plugin/vc/vfcCommon.h diff --git a/python/docstrings/infer/pyCoreDoc.h b/python/docstrings/infer/pyCoreDoc.h index 8227c87f..541e66b8 100644 --- a/python/docstrings/infer/pyCoreDoc.h +++ b/python/docstrings/infer/pyCoreDoc.h @@ -713,8 +713,7 @@ constexpr char const* descr = R"trtdoc( :ivar streamable_weights_size: Returns the size of the streamable weights in the engine. This may not include all the weights. :ivar weight_streaming_budget_v2: Set and get the current weight streaming budget for inference. The budget may be set any non-negative value. A value of 0 streams the most weights. Values equal to streamable_weights_size (default) or larger will disable weight streaming. :ivar weight_streaming_scratch_memory_size: The amount of scratch memory required by a TensorRT ExecutionContext to perform inference. This value may change based on the current weight streaming budget. Please use the V2 memory APIs, engine.device_memory_size_v2 and ExecutionContext.set_device_memory() to provide memory which includes the current weight streaming scratch memory. Not specifying these APIs or using the V1 APIs will not include this memory, so TensorRT will resort to allocating itself. - )trtdoc" - ; + )trtdoc"; // Documentation bug with parameters on these three functions because they are overloaded. constexpr char const* serialize = R"trtdoc( diff --git a/python/docstrings/infer/pyGraphDoc.h b/python/docstrings/infer/pyGraphDoc.h index 7281174b..74a651d3 100644 --- a/python/docstrings/infer/pyGraphDoc.h +++ b/python/docstrings/infer/pyGraphDoc.h @@ -102,14 +102,14 @@ constexpr const char* CHW2 = R"trtdoc( constexpr const char* HWC8 = R"trtdoc( Eight channel format where C is padded to a multiple of 8. - This format is bound to FP16. It is only available for dimensions >= 3. + This format is bound to FP16 and BF16. It is only available for dimensions >= 3. For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to the array with dimensions [N][H][W][(C+7)/8*8], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][h][w][c]. )trtdoc"; constexpr const char* CHW4 = R"trtdoc( Four wide channel vectorized row major format. - This format is bound to INT8. It is only available for dimensions >= 3. + This format is bound to FP16 and INT8. It is only available for dimensions >= 3. For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+3)/4][H][W][4], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/4][h][w][c%4]. )trtdoc"; @@ -133,7 +133,7 @@ constexpr const char* CHW32 = R"trtdoc( constexpr const char* DHWC8 = R"trtdoc( Eight channel format where C is padded to a multiple of 8. - This format is bound to FP16, and it is only available for dimensions >= 4. + This format is bound to FP16 and BF16, and it is only available for dimensions >= 4. For a tensor with dimensions {N, C, D, H, W}, the memory layout is equivalent to an array with dimensions [N][D][H][W][(C+7)/8*8], with the tensor coordinates (n, c, d, h, w) mapping to array subscript [n][d][h][w][c]. )trtdoc"; @@ -148,7 +148,7 @@ constexpr const char* CDHW32 = R"trtdoc( constexpr const char* HWC = R"trtdoc( Non-vectorized channel-last format. - This format is bound to FP32 and is only available for dimensions >= 3. + This format is bound to FP32, FP16, INT8, INT64 and BF16, and is only available for dimensions >= 3. )trtdoc"; constexpr const char* DLA_LINEAR = R"trtdoc( @@ -168,7 +168,7 @@ constexpr const char* DLA_HWC4 = R"trtdoc( )trtdoc"; constexpr const char* HWC16 = R"trtdoc( - Sixteen channel format where C is padded to a multiple of 16. This format is bound to FP16. It is only available for dimensions >= 3. + Sixteen channel format where C is padded to a multiple of 16. This format is bound to FP16 and INT8. It is only available for dimensions >= 3. For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to the array with dimensions [N][H][W][(C+15)/16*16], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][h][w][c]. )trtdoc"; @@ -347,7 +347,7 @@ namespace IConvolutionLayerDoc constexpr const char* descr = R"trtdoc( A convolution layer in an :class:`INetworkDefinition` . - This layer performs a correlation operation between 3-dimensional filter with a 4-dimensional tensor to produce another 4-dimensional tensor. + This layer performs a correlation operation between 3 or 4 dimensional filter with a 4 or 5 dimensional tensor to produce another 4 or 5 dimensional tensor. An optional bias argument is supported, which adds a per-channel constant to each value in the output. @@ -666,7 +666,8 @@ constexpr const char* SIGN constexpr const char* ROUND = R"trtdoc(Round to nearest even for floating-point data type.)trtdoc"; constexpr const char* ISINF = R"trtdoc(Return true if the input value equals +/- infinity for floating-point data type.)trtdoc"; -constexpr const char* ISNAN = R"trtdoc(Return true if the input value equals NaN for floating-point data type.)trtdoc"; +constexpr const char* ISNAN + = R"trtdoc(Return true if the input value equals NaN for floating-point data type.)trtdoc"; } // namespace UnaryOperationDoc namespace IUnaryLayerDoc diff --git a/python/packaging/frontend_sdist/setup.py b/python/packaging/frontend_sdist/setup.py index ec711fc5..bb4d9172 100644 --- a/python/packaging/frontend_sdist/setup.py +++ b/python/packaging/frontend_sdist/setup.py @@ -63,8 +63,10 @@ def find_pip(): raise RuntimeError("TensorRT currently only builds wheels for Linux and Windows") if sys.implementation.name != "cpython": raise RuntimeError("TensorRT currently only builds wheels for CPython") -if platform.machine() not in ("x86_64", "AMD64"): - raise RuntimeError("TensorRT currently only builds wheels for x86_64 processors") +if platform.machine() not in ("x86_64", "AMD64", "aarch64"): + raise RuntimeError("TensorRT currently only builds wheels for x86_64 and ARM SBSA processors") +if "tegra" in platform.release(): + raise RuntimeError("TensorRT does not currently build wheels for Tegra systems") class InstallCommand(install): diff --git a/python/packaging/requirements.txt b/python/packaging/requirements.txt index 623fd70d..5100dd2b 100644 --- a/python/packaging/requirements.txt +++ b/python/packaging/requirements.txt @@ -1,4 +1,5 @@ # Required for building wheel files wheel==0.37.1 setuptools~=59.6.0; python_version<"3.8" -setuptools~=69.0.3; python_version>="3.8" +setuptools~=69.0.3; python_version>="3.8" and python_version < "3.12" +setuptools~=73.0.1; python_version>="3.12" diff --git a/python/requirements.txt b/python/requirements.txt index 3a69d335..e0ebce33 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -3,6 +3,7 @@ numpy==1.18.1; python_version < "3.8" and platform_system == "Windows" numpy==1.19.4; python_version < "3.8" and platform_system != "Windows" numpy==1.23.0; python_version >= "3.8" and python_version < "3.10" -numpy==1.23.1; python_version >= "3.10" +numpy==1.23.1; python_version >= "3.10" and python_version < "3.12" +numpy==1.26.4; python_version >= "3.12" Pillow; python_version<"3.6" ##PYTHON_BUILDDIR##/tensorrt_bindings-py3.##PYTHON3_MINOR##/dist/tensorrt-##TENSORRT_PYTHON_VERSION##-cp3##PYTHON3_MINOR##-none-linux_##TARGET##.whl ; python_version=="3.##PYTHON3_MINOR##" diff --git a/python/src/infer/pyCore.cpp b/python/src/infer/pyCore.cpp index 58c17892..c284d1d1 100644 --- a/python/src/infer/pyCore.cpp +++ b/python/src/infer/pyCore.cpp @@ -1096,8 +1096,7 @@ void bindCore(py::module& m) IExecutionContextDoc::set_tensor_debug_state) .def("get_debug_state", &IExecutionContext::getDebugState, "name"_a, IExecutionContextDoc::get_debug_state) .def("set_all_tensors_debug_state", &IExecutionContext::setAllTensorsDebugState, "flag"_a, - IExecutionContextDoc::set_all_tensors_debug_state) - ; + IExecutionContextDoc::set_all_tensors_debug_state); py::enum_(m, "ExecutionContextAllocationStrategy", py::arithmetic{}, ExecutionContextAllocationStrategyDoc::descr, py::module_local()) @@ -1293,7 +1292,6 @@ void bindCore(py::module& m) "weight_streaming_scratch_memory_size", &ICudaEngine::getWeightStreamingScratchMemorySize) // End weight streaming APIs .def("is_debug_tensor", &ICudaEngine::isDebugTensor, "name"_a, ICudaEngineDoc::is_debug_tensor) - .def("__del__", &utils::doNothingDel); py::enum_(m, "AllocatorFlag", py::arithmetic{}, AllocatorFlagDoc::descr, py::module_local()) diff --git a/python/src/infer/pyPlugin.cpp b/python/src/infer/pyPlugin.cpp index b1e37967..0396fee7 100644 --- a/python/src/infer/pyPlugin.cpp +++ b/python/src/infer/pyPlugin.cpp @@ -740,7 +740,7 @@ class IPluginCreatorImpl : public IPluginCreator } private: - nvinfer1::PluginFieldCollection mFC; + nvinfer1::PluginFieldCollection mFC{}; std::string mNamespace; std::string mName; std::string mPluginVersion; @@ -1824,7 +1824,7 @@ class IPluginCreatorV3OneImpl : public IPluginCreatorV3One } private: - nvinfer1::PluginFieldCollection mFC; + nvinfer1::PluginFieldCollection mFC{}; std::string mNamespace; std::string mName; std::string mPluginVersion; diff --git a/quickstart/IntroNotebooks/onnx_helper.py b/quickstart/IntroNotebooks/onnx_helper.py index d5b898cf..13b24804 100644 --- a/quickstart/IntroNotebooks/onnx_helper.py +++ b/quickstart/IntroNotebooks/onnx_helper.py @@ -69,12 +69,12 @@ def predict(self, batch): # result gets copied into output return self.output -def convert_onnx_to_engine(onnx_filename, engine_filename = None, max_batch_size = 32, max_workspace_size = 1 << 30, fp16_mode = True): +def convert_onnx_to_engine(onnx_filename, engine_filename = None, max_workspace_size = 1 << 30, fp16_mode = True): logger = trt.Logger(trt.Logger.WARNING) - with trt.Builder(logger) as builder, builder.create_network() as network, trt.OnnxParser(network, logger) as parser: - builder.max_workspace_size = max_workspace_size - builder.fp16_mode = fp16_mode - builder.max_batch_size = max_batch_size + with trt.Builder(logger) as builder, builder.create_network() as network, trt.OnnxParser(network, logger) as parser, builder.create_builder_config() as builder_config: + builder_config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace_size) + if (fp16_mode): + builder_config.set_flag(trt.BuilderFlag.FP16) print("Parsing ONNX file.") with open(onnx_filename, 'rb') as model: @@ -83,10 +83,10 @@ def convert_onnx_to_engine(onnx_filename, engine_filename = None, max_batch_size print(parser.get_error(error)) print("Building TensorRT engine. This may take a few minutes.") - engine = builder.build_cuda_engine(network) + serialized_engine = builder.build_serialized_network(network, builder_config) if engine_filename: with open(engine_filename, 'wb') as f: - f.write(engine.serialize()) + f.write(serialized_engine) - return engine, logger + return serialized_engine, logger diff --git a/samples/common/safeCommon.h b/samples/common/safeCommon.h index f10aad18..9bffdb2d 100644 --- a/samples/common/safeCommon.h +++ b/samples/common/safeCommon.h @@ -43,6 +43,8 @@ #include #endif // IS_QNX_SAFE +using namespace nvinfer1; + #undef CHECK #define CHECK(status) \ do \ @@ -66,6 +68,52 @@ } \ } while (0) +#define LWE_CALL(api_call, recorder) \ + do \ + { \ + const ErrorCode ret = (api_call); \ + if (ret != ErrorCode::kSUCCESS) \ + { \ + std::cerr << "LWE Error: [" << #api_call << "]: " << toString(ret); \ + throw ret; \ + } \ + std::cout << "LWE:[" << #api_call << "]: PASSED"; \ + } while (0) + +#define CUDA_CALL(cuda_api_call, recorder) \ + do \ + { \ + cudaError_t error = (cuda_api_call); \ + if (error != cudaSuccess) \ + { \ + std::cerr << "CUDA Error: [" << #cuda_api_call << "]: " << cudaGetErrorString(error); \ + throw ErrorCode::kFAILED_EXECUTION; \ + } \ + std::cout << "CUDA:[" << #cuda_api_call << "]: PASSED"; \ + } while (0) + +inline std::string toString(ErrorCode ec) +{ + static const auto ecStrings = [] { + std::unordered_map result; +#define INSERT_ELEMENT(p, s) result.emplace(p, s); + INSERT_ELEMENT(ErrorCode::kSUCCESS, "SUCCESS") + INSERT_ELEMENT(ErrorCode::kUNSPECIFIED_ERROR, "UNSPECIFIED_ERROR") + INSERT_ELEMENT(ErrorCode::kINTERNAL_ERROR, "INTERNAL_ERROR") + INSERT_ELEMENT(ErrorCode::kINVALID_ARGUMENT, "INVALID_ARGUMENT") + INSERT_ELEMENT(ErrorCode::kINVALID_CONFIG, "INVALID_CONFIG") + INSERT_ELEMENT(ErrorCode::kFAILED_ALLOCATION, "FAILED_ALLOCATION") + INSERT_ELEMENT(ErrorCode::kFAILED_INITIALIZATION, "FAILED_INITIALIZATION") + INSERT_ELEMENT(ErrorCode::kFAILED_EXECUTION, "FAILED_EXECUTION") + INSERT_ELEMENT(ErrorCode::kFAILED_COMPUTATION, "FAILED_COMPUTATION") + INSERT_ELEMENT(ErrorCode::kINVALID_STATE, "INVALID_STATE") + INSERT_ELEMENT(ErrorCode::kUNSUPPORTED_STATE, "UNSUPPORTED_STATE") +#undef INSERT_ELEMENT + return result; + }(); + return ecStrings.at(ec); +} + //! Locate path to file, given its filename or filepath suffix and possible dirs it might lie in. //! Function will also walk back MAX_DEPTH dirs from CWD to check for such a file path. inline std::string locateFile( @@ -219,8 +267,7 @@ inline int32_t getSMVersion() inline bool isSMSafe() { const int32_t smVersion = getSMVersion(); - return smVersion == 0x0700 || smVersion == 0x0705 || smVersion == 0x0800 || smVersion == 0x0806 - || smVersion == 0x0807; + return smVersion == 0x0705 || smVersion == 0x0800 || smVersion == 0x0806 || smVersion == 0x0807; } inline int32_t calculateSoftmax(float* const prob, int32_t const numDigits) diff --git a/samples/python/aliased_io_plugin/aliased_io_plugin.py b/samples/python/aliased_io_plugin/aliased_io_plugin.py old mode 100644 new mode 100755 diff --git a/samples/python/efficientnet/README.md b/samples/python/efficientnet/README.md index cd5dc399..15989257 100644 --- a/samples/python/efficientnet/README.md +++ b/samples/python/efficientnet/README.md @@ -24,6 +24,8 @@ August 2023: ## Setup +Note: The sample is not compatible with Python-3.12 because tensorflow-addons does not support Python-3.12. + For best results, we recommend running these scripts on an environment with TensorRT >= 8.0.1 and TensorFlow 2.12.0. Install TensorRT as per the [TensorRT Install Guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html). You will need to make sure the Python bindings for TensorRT are also installed correctly, these are available by installing the `python3-libnvinfer` and `python3-libnvinfer-dev` packages on your TensorRT download. diff --git a/samples/python/network_api_pytorch_mnist/requirements.txt b/samples/python/network_api_pytorch_mnist/requirements.txt index f96acc70..153cbfbd 100644 --- a/samples/python/network_api_pytorch_mnist/requirements.txt +++ b/samples/python/network_api_pytorch_mnist/requirements.txt @@ -1,10 +1,10 @@ Pillow>=10.0.0 -f https://download.pytorch.org/whl/torch_stable.html torch==2.0.0; (platform_machine=="aarch64" and sys.platform=="linux") -torch==2.0.0+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") +torch==2.2.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") -f https://download.pytorch.org/whl/torch_stable.html torchvision==0.15.1; (platform_machine=="aarch64" and sys.platform=="linux") -torchvision==0.15.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") +torchvision==0.17.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") cuda-python==12.2.0; python_version <= "3.10" cuda-python==12.5.0; python_version >= "3.11" pywin32; platform_system == "Windows" diff --git a/samples/python/onnx_packnet/requirements.txt b/samples/python/onnx_packnet/requirements.txt index e88f8646..1cb16357 100644 --- a/samples/python/onnx_packnet/requirements.txt +++ b/samples/python/onnx_packnet/requirements.txt @@ -3,10 +3,10 @@ onnx==1.16.0 onnx-graphsurgeon>=0.3.20 -f https://download.pytorch.org/whl/torch_stable.html torch==2.0.0; (platform_machine=="aarch64" and sys.platform=="linux") -torch==2.0.0+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") +torch==2.2.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") -f https://download.pytorch.org/whl/torch_stable.html torchvision==0.15.1; (platform_machine=="aarch64" and sys.platform=="linux") -torchvision==0.15.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") +torchvision==0.17.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") pyyaml==6.0.1 requests==2.32.2 tqdm==4.66.4 diff --git a/samples/python/python_plugin/circ_pad_plugin_cpp.py b/samples/python/python_plugin/circ_pad_plugin_cpp.py old mode 100644 new mode 100755 diff --git a/samples/python/python_plugin/circ_pad_plugin_cuda_python.py b/samples/python/python_plugin/circ_pad_plugin_cuda_python.py old mode 100644 new mode 100755 diff --git a/samples/python/python_plugin/circ_pad_plugin_cupy.py b/samples/python/python_plugin/circ_pad_plugin_cupy.py old mode 100644 new mode 100755 diff --git a/samples/python/tensorflow_object_detection_api/README.md b/samples/python/tensorflow_object_detection_api/README.md index f566304b..789eead3 100644 --- a/samples/python/tensorflow_object_detection_api/README.md +++ b/samples/python/tensorflow_object_detection_api/README.md @@ -7,6 +7,8 @@ Support for [TensorFlow Object Detection (TFOD) API](https://github.com/tensorfl ### TensorFlow and TensorRT Environment +Note: The sample is not compatible with Python-3.12 because tensorflow-addons does not support Python-3.12. + In order for scripts to work we suggest an environment with TensorRT >= 8.0.1 and TensorFlow 2.13.1. Install TensorRT as per the [TensorRT Install Guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html). You will need to make sure the Python bindings for TensorRT are also installed correctly, these are available by installing the `python3-libnvinfer` and `python3-libnvinfer-dev` packages on your TensorRT download. diff --git a/tools/experimental/trt-engine-explorer/bin/trex b/tools/experimental/trt-engine-explorer/bin/trex old mode 100644 new mode 100755 diff --git a/tools/experimental/trt-engine-explorer/trex/interactive.py b/tools/experimental/trt-engine-explorer/trex/interactive.py index ae7b429a..53d0b80e 100644 --- a/tools/experimental/trt-engine-explorer/trex/interactive.py +++ b/tools/experimental/trt-engine-explorer/trex/interactive.py @@ -77,4 +77,4 @@ def _render(self, title, renderer): def dropdown_state_eventhandler(self, user_choice): renderer = self.choices[user_choice] - self._render(user_choice, renderer) + self._render(user_choice, renderer) \ No newline at end of file diff --git a/tools/experimental/trt-engine-explorer/trex/misc.py b/tools/experimental/trt-engine-explorer/trex/misc.py index 7aa00e5a..92f8d36c 100644 --- a/tools/experimental/trt-engine-explorer/trex/misc.py +++ b/tools/experimental/trt-engine-explorer/trex/misc.py @@ -81,4 +81,4 @@ def stack_dataframes( # A list of values lists. values = [df[values_col].tolist() for df in df_list] - return _merge_keys_values(names, values, empty_placeholder) + return _merge_keys_values(names, values, empty_placeholder) \ No newline at end of file diff --git a/tools/experimental/trt-engine-explorer/utils/config_gpu.py b/tools/experimental/trt-engine-explorer/utils/config_gpu.py old mode 100644 new mode 100755 diff --git a/tools/experimental/trt-engine-explorer/utils/device_info.py b/tools/experimental/trt-engine-explorer/utils/device_info.py old mode 100644 new mode 100755 diff --git a/tools/experimental/trt-engine-explorer/utils/draw_engine.py b/tools/experimental/trt-engine-explorer/utils/draw_engine.py old mode 100644 new mode 100755 diff --git a/tools/experimental/trt-engine-explorer/utils/parse_trtexec_log.py b/tools/experimental/trt-engine-explorer/utils/parse_trtexec_log.py old mode 100644 new mode 100755 index f01d97e3..ab93b159 --- a/tools/experimental/trt-engine-explorer/utils/parse_trtexec_log.py +++ b/tools/experimental/trt-engine-explorer/utils/parse_trtexec_log.py @@ -162,4 +162,4 @@ def post_process_device_info(device_info: dict): parser = argparse.ArgumentParser() parser.add_argument('input', help="name of engine build log file to parse.") args = parser.parse_args() - parse_build_log(args.input) + parse_build_log(args.input) \ No newline at end of file diff --git a/tools/experimental/trt-engine-explorer/utils/process_engine.py b/tools/experimental/trt-engine-explorer/utils/process_engine.py old mode 100644 new mode 100755