Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Null void* workspace Parameter in Custom Plugin's enqueue Function in TensorRT 10 with CUDA 11.6" #4263

Closed
make-a opened this issue Nov 27, 2024 · 2 comments

Comments

@make-a
Copy link

make-a commented Nov 27, 2024

Description

I am using C++ with TensorRT 10 and CUDA 11.6, and I have created a custom plugin. However, when using it, I found that the parameter void* workspace in the function FpsamplePlugin::enqueue is null. Why is this happening?

At the same time, I ran the TensorRT sample code sample_non_zero_plugin, and in that case, the void* workspace parameter is fine. So, I think the issue might be with the definition of my plugin. However, I've searched for a long time and couldn't resolve it.

Image

Image

Image

########## headder file:

#pragma once
#include
#include

#include <NvInferPlugin.h>
#include <cuda_runtime_api.h>

using namespace nvinfer1;

namespace nvinfer1 {

class FpsamplePlugin : public nvinfer1::IPluginV3,
public nvinfer1::IPluginV3OneCore,
public nvinfer1::IPluginV3OneBuild,
public nvinfer1::IPluginV3OneRuntime {

public:
FpsamplePlugin( int32_t nsample );

// 通过 IPluginV3 继承
IPluginCapability* getCapabilityInterface( PluginCapabilityType type ) noexcept override;
IPluginV3*         clone() noexcept override;

// 通过 IPluginV3OneCore 继承
AsciiChar const* getPluginName() const noexcept override;
AsciiChar const* getPluginVersion() const noexcept override;
AsciiChar const* getPluginNamespace() const noexcept override;

// 通过 IPluginV3OneBuild 继承
int32_t configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
                         int32_t nbOutputs ) noexcept override;
int32_t getOutputDataTypes( DataType* outputTypes, int32_t nbOutputs, const DataType* inputTypes,
                            int32_t nbInputs ) const noexcept override;
int32_t getOutputShapes( DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
                         int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
                         IExprBuilder& exprBuilder ) noexcept override;
bool    supportsFormatCombination( int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs,
                                   int32_t nbOutputs ) noexcept override;
int32_t getNbOutputs() const noexcept override;

// 通过 IPluginV3OneRuntime 继承
int32_t onShapeChange( PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out,
                       int32_t nbOutputs ) noexcept override;
int32_t enqueue( PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs,
                 void* const* outputs, void* workspace, cudaStream_t stream ) noexcept override;
IPluginV3*                   attachToContext( IPluginResourceContext* context ) noexcept override;
PluginFieldCollection const* getFieldsToSerialize() noexcept override;

private:
int32_t m_nsample;
PluginFieldCollection m_pfc;
std::vector m_pluginAttributes;

int32_t m_in_n  = 1;
int32_t m_out_n = 1;

};

class FpsamplePluginCreator : public nvinfer1::IPluginCreatorV3One {
public:
FpsamplePluginCreator();

// 通过 IPluginCreatorV3One 继承
IPluginV3* createPlugin( AsciiChar const* name, PluginFieldCollection const* fc,
                         TensorRTPhase phase ) noexcept override;

PluginFieldCollection const* getFieldNames() noexcept override;
AsciiChar const*             getPluginName() const noexcept override;
AsciiChar const*             getPluginVersion() const noexcept override;
AsciiChar const*             getPluginNamespace() const noexcept override;

private:
PluginFieldCollection m_pfc;
std::vector m_pluginAttributes;
};
} // namespace nvinfer1

########## cpp file:

#include "FpsamplePlugin.h"

#include
#include

#include "../cuda_impl/cuda_impl.h"

namespace {
char const* const FPSAMPLE_PLUGIN_VERSION{ "1" };
char const* const FPSAMPLE_PLUGIN_NAME{ "KDTreeFpsample" };
char const* const FPSAMPLE_PLUGIN_NAMESPACE{ "" };
} // namespace

nvinfer1::FpsamplePlugin::FpsamplePlugin( int32_t nsample )
: m_nsample( nsample ) {
m_pluginAttributes.clear();

PluginField pf_nsample = { "nsample", &m_nsample, PluginFieldType::kINT32, 1 };
m_pluginAttributes.emplace_back( pf_nsample );

m_pfc.nbFields = m_pluginAttributes.size();
m_pfc.fields   = m_pluginAttributes.data();

};

IPluginCapability* FpsamplePlugin::getCapabilityInterface( PluginCapabilityType type ) noexcept {
try {
if ( type == PluginCapabilityType::kBUILD ) {
return static_cast<IPluginV3OneBuild*>( this );
}
if ( type == PluginCapabilityType::kRUNTIME ) {
return static_cast<IPluginV3OneRuntime*>( this );
}
assert( type == PluginCapabilityType::kCORE );
return static_cast<IPluginV3OneCore*>( this );
}
catch ( ... ) {
// log error
}
return nullptr;
}

IPluginV3* FpsamplePlugin::clone() noexcept {
return new FpsamplePlugin( m_nsample );
}

AsciiChar const* FpsamplePlugin::getPluginName() const noexcept {
return FPSAMPLE_PLUGIN_NAME;
}

AsciiChar const* FpsamplePlugin::getPluginVersion() const noexcept {
return FPSAMPLE_PLUGIN_VERSION;
}

AsciiChar const* FpsamplePlugin::getPluginNamespace() const noexcept {
return FPSAMPLE_PLUGIN_NAMESPACE;
}

int32_t FpsamplePlugin::configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs,
DynamicPluginTensorDesc const* out, int32_t nbOutputs ) noexcept {
return 0;
}

int32_t FpsamplePlugin::getOutputDataTypes( DataType* outputTypes, int32_t nbOutputs, const DataType* inputTypes,
int32_t nbInputs ) const noexcept {
// 确保输入输出数量符合预期
assert( nbInputs == m_in_n );
assert( nbOutputs == m_out_n );

// 检查输入类型是否为 float
assert( inputTypes[ 0 ] == DataType::kFLOAT );

// 设置输出类型为 int64
outputTypes[ 0 ] = DataType::kINT64;

return 0;

}

int32_t FpsamplePlugin::getOutputShapes( const DimsExprs* inputs, int32_t nbInputs, const DimsExprs* shapeInputs,
int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
IExprBuilder& exprBuilder ) noexcept {
// 确保输入输出数量符合预期
assert( nbInputs == m_in_n );
assert( nbOutputs == m_out_n );

// 输入张量的形状
const DimsExprs& inputShape = inputs[ 0 ];

// 设置输出张量的形状,假设为 [BATCH, NUM_POINTS]
outputs[ 0 ].nbDims = 2;                                  // 输出维度数为 2
outputs[ 0 ].d[ 0 ] = inputShape.d[ 0 ];                  // BATCH
outputs[ 0 ].d[ 1 ] = exprBuilder.constant( m_nsample );  // NUM_POINTS

return 0;  // 返回 0 表示成功

}

bool FpsamplePlugin::supportsFormatCombination( int32_t pos, const DynamicPluginTensorDesc* inOut, int32_t nbInputs,
int32_t nbOutputs ) noexcept {
// 确保 pos 在合法范围内
assert( pos < ( nbInputs + nbOutputs ) );

// 检查输入
if ( pos == 0 ) {
    // 输入必须是 float 类型,且格式为线性
    return inOut[ pos ].desc.type == DataType::kFLOAT && inOut[ pos ].desc.format == TensorFormat::kLINEAR;
}
// 检查输出
else if ( pos == 1 ) {
    // 输出必须是 int64 类型,且格式为线性
    return inOut[ pos ].desc.type == DataType::kINT64 && inOut[ pos ].desc.format == TensorFormat::kLINEAR;
}

return false;  // 其他情况不支持

}

int32_t FpsamplePlugin::getNbOutputs() const noexcept {
return m_out_n;
}

int32_t FpsamplePlugin::onShapeChange( PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out,
int32_t nbOutputs ) noexcept {
return 0;
}

int32_t FpsamplePlugin::enqueue( PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream ) noexcept {

const float* points = static_cast<const float*>( inputs[ 0 ] );
int64_t*     idxs   = static_cast<int64_t*>( outputs[ 0 ] );

// 从PluginTensorDesc中提取维度信息
int b = inputDesc[ 0 ].dims.d[ 0 ];
int n = inputDesc[ 0 ].dims.d[ 1 ];
int m = outputDesc[ 0 ].dims.d[ 1 ];

// 计算所需的内存大小
size_t tempSize  = b * n * sizeof( float );  // temp的大小
size_t idxsSize  = b * m * sizeof( int );    // idxs_int的大小
size_t totalSize = tempSize + idxsSize;

if ( workspace == nullptr ) {
    return -1;
}
// 在workspace中分配内存
// float* temp = static_cast<float*>( workspace );
float*      temp = static_cast<float*>( workspace );
cudaError_t err  = cudaMemsetAsync( temp, 110.f, 10 * sizeof( float ) );
if ( err != cudaSuccess ) {
    return -1;
}

float* debug_data = new float[ tempSize ];
cudaMemcpy( debug_data, temp, tempSize * sizeof( float ), cudaMemcpyDeviceToHost );
delete[] debug_data;

// 调用内核函数
furthest_point_sampling_kernel_wrapper( b, n, m, points, temp, idxs, stream );

// 拷贝回主机方便调试
int64_t* idxs_d = new int64_t[ m * b ];
cudaMemcpy( idxs, idxs_d, idxsSize, cudaMemcpyDeviceToHost );
delete[] idxs_d;

return 0;

}

IPluginV3* FpsamplePlugin::attachToContext( IPluginResourceContext* context ) noexcept {
return clone();
}

PluginFieldCollection const* FpsamplePlugin::getFieldsToSerialize() noexcept {
return &m_pfc;
}

/// /////////////////////////////////////////////////////
/// FpsamplePluginCreator

FpsamplePluginCreator::FpsamplePluginCreator() {
m_pluginAttributes.clear();

PluginField pf_nsample = { "nsample", nullptr, PluginFieldType::kINT32, 1 };
m_pluginAttributes.emplace_back( pf_nsample );

m_pfc.nbFields = m_pluginAttributes.size();
m_pfc.fields   = m_pluginAttributes.data();

}

IPluginV3* FpsamplePluginCreator::createPlugin( AsciiChar const* name, PluginFieldCollection const* fc,
TensorRTPhase phase ) noexcept {
assert( fc->nbFields == 1 );
assert( fc->fields[ 0 ].type == PluginFieldType::kINT32 );
fc->fields[ 0 ].name;

FpsamplePlugin* plugin = new FpsamplePlugin( *static_cast<int32_t const*>( fc->fields[ 0 ].data ) );

return plugin;

}

PluginFieldCollection const* FpsamplePluginCreator::getFieldNames() noexcept {
return &m_pfc;
}

AsciiChar const* FpsamplePluginCreator::getPluginName() const noexcept {
return FPSAMPLE_PLUGIN_NAME;
}

AsciiChar const* FpsamplePluginCreator::getPluginVersion() const noexcept {
return FPSAMPLE_PLUGIN_VERSION;
}

AsciiChar const* FpsamplePluginCreator::getPluginNamespace() const noexcept {
return FPSAMPLE_PLUGIN_NAMESPACE;
}

Environment

TensorRT Version:
TensorRT-10.6.0.26
NVIDIA GPU:
NVIDIA GeForce RTX 4050 Laptop GPU

驱动程序版本:	31.0.15.3227
驱动程序日期:	2023/7/9
DirectX 版本:	12 (FL 12.1)

CUDA Version:
cuda 11.6

Operating System:
win11

@lix19937
Copy link

You should assign size_t getWorkspaceSize(int32_t) const noexcept override;

@make-a
Copy link
Author

make-a commented Nov 28, 2024

You should assign size_t getWorkspaceSize(int32_t) const noexcept override;

OK, Thanks.
if use

class FpsamplePlugin : public nvinfer1::IPluginV3,
                       public nvinfer1::IPluginV3OneCore,
                       public nvinfer1::IPluginV3OneBuild,
                       public nvinfer1::IPluginV3OneRuntime

need func:

    size_t  getWorkspaceSize( DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
                              DynamicPluginTensorDesc const* outputs, int32_t nbOutputs ) const noexcept override;

ヾ(๑╹◡╹)ノ"

@make-a make-a closed this as completed Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants