This repository describes:
(1) how to add a custom TensorRT plugin in c++,
(2) how to build and serialize network with the custom plugin in python
(3) how to load and forward the network in c++.
We follow flattenconcat plugin to create flattenConcat plugin.
Since the flattenConcat plugin is already in TensorRT, we renamed the class name. The corresponding source codes are in flattenConcatCustom.cpp flattenConcatCustom.h We use file CMakeLists.txt to build shared lib: libflatten_concat.so
Please follow builder.py.
# You should configure the path to libnvinfer_plugin.so
nvinfer = ctypes.CDLL("/path-to-tensorrt/TensorRT-6.0.1.5/lib/libnvinfer_plugin.so", mode=ctypes.RTLD_GLOBAL)
print('load nvinfer')
pg = ctypes.CDLL("./libflatten_concat.so", mode=ctypes.RTLD_GLOBAL)
print('load customed plugin')
#TensorRT Initialization
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
plg_registry = trt.get_plugin_registry()
# to call the constructor@https://github.com/YirongMao/TensorRT-Custom-Plugin/blob/master/flattenConcatCustom.cpp#L36
plg_creator = plg_registry.get_plugin_creator("FlattenConcatCustom", "1", "")
print(plg_creator)
axis_pf = trt.PluginField("axis", np.array([1], np.int32), trt.PluginFieldType.INT32)
batch_pf = trt.PluginField("ignoreBatch", np.array([0], np.int32), trt.PluginFieldType.INT32)
pfc = trt.PluginFieldCollection([axis_pf, batch_pf])
fn = plg_creator.create_plugin("FlattenConcatCustom1", pfc)
print(fn)
network = builder.create_network()
input_1 = network.add_input(name="input_1", dtype=trt.float32, shape=(4, 2, 2))
input_2 = network.add_input(name="input_2", dtype=trt.float32, shape=(2, 2, 2))
inputs = [input_1, input_2]
# to call configurePlugin@https://github.com/YirongMao/TensorRT-Custom-Plugin/blob/master/flattenConcatCustom.cpp#L258
emb_layer = network.add_plugin_v2(inputs, fn)
Please follow load_trt_engine.cpp. To load the engine with custom plugin, its header *.h file should be included.
TensorRT-6.0.1.5, cuda-9.0
If you encounter any problem, be free to create an issue.