Skip to content

Commit

Permalink
Restore lReLUPlugin in OSS
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
rajeevsrao committed Sep 4, 2020
1 parent 300cd20 commit 275eefc
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 0 deletions.
1 change: 1 addition & 0 deletions plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ set(PLUGIN_LISTS
instanceNormalizationPlugin
groupNormalizationPlugin
coordConvACPlugin
leakyReluPlugin
)

# Add BERT sources if ${BERT_GENCODES} was populated
Expand Down
2 changes: 2 additions & 0 deletions plugin/InferPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ using namespace nvinfer1::plugin;
#include "generateDetectionPlugin.h"
#include "gridAnchorPlugin.h"
#include "instanceNormalizationPlugin.h"
#include "lReluPlugin.h"
#include "multilevelCropAndResizePlugin.h"
#include "multilevelProposeROIPlugin.h"
#include "nmsPlugin.h"
Expand Down Expand Up @@ -167,6 +168,7 @@ bool initLibNvInferPlugins(void* logger, const char* libNamespace)
initializePlugin<nvinfer1::plugin::GridAnchorPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::GridAnchorRectPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::InstanceNormalizationPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::LReluPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::MultilevelCropAndResizePluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::MultilevelProposeROIPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::NMSPluginCreator>(logger, libNamespace);
Expand Down
1 change: 1 addition & 0 deletions plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
| [gridAnchorRectPlugin](gridAnchorPlugin) | GridAnchorRect_TRT | 1 |
| [groupNormalizationPlugin](groupNormalizationPlugin) | GroupNormalizationPlugin | 1 |
| [instanceNormalizationPlugin](instanceNormalizationPlugin) | InstanceNormalization_TRT | 1 |
| [leakyReluPlugin](leakyReluPlugin) | LReLU_TRT | 1 |
| [multilevelCropAndResizePlugin](multilevelCropAndResizePlugin) | MultilevelCropAndResize_TRT | 1 |
| [multilevelProposeROI](multilevelProposeROI) | MultilevelProposeROI_TRT | 1 |
| [nmsPlugin](nmsPlugin) | NMS_TRT | 1 |
Expand Down
3 changes: 3 additions & 0 deletions plugin/common/kernels/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ pluginStatus_t normalizeInference(cudaStream_t stream, cublasHandle_t handle, bo
pluginStatus_t priorBoxInference(cudaStream_t stream, PriorBoxParameters param, int H, int W, int numPriors,
int numAspectRatios, const void* minSize, const void* maxSize, const void* aspectRatios, void* outputData);

pluginStatus_t lReLUInference(
cudaStream_t stream, int n, float negativeSlope, const void* input, void* output);

pluginStatus_t reorgInference(
cudaStream_t stream, int batch, int C, int H, int W, int stride, const void* input, void* output);

Expand Down
56 changes: 56 additions & 0 deletions plugin/common/kernels/lReLU.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel.h"

template <unsigned nthdsPerCTA>
__launch_bounds__(nthdsPerCTA)
__global__ void pReLUKernel(
const int n,
const float negativeSlope,
const float* input,
float* output)
{
for (int i = blockIdx.x * nthdsPerCTA + threadIdx.x; i < n; i += gridDim.x * nthdsPerCTA)
{
output[i] = input[i] > 0 ? input[i] : input[i] * negativeSlope;
}
}

pluginStatus_t lReLUGPU(
cudaStream_t stream,
const int n,
const float negativeSlope,
const void* input,
void* output)
{
const int BS = 512;
const int GS = (n + BS - 1) / BS;
pReLUKernel<BS><<<GS, BS, 0, stream>>>(n, negativeSlope,
(const float*) input,
(float*) output);
return STATUS_SUCCESS;
}

pluginStatus_t lReLUInference(
cudaStream_t stream,
const int n,
const float negativeSlope,
const void* input,
void* output)
{
return lReLUGPU(stream, n, negativeSlope, (const float*) input, (float*) output);
}
10 changes: 10 additions & 0 deletions plugin/common/plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ namespace nvinfer1
namespace plugin
{

class BasePlugin : public IPluginV2
{
protected:
void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; }

const char* getPluginNamespace() const override { return mNamespace.c_str(); }

std::string mNamespace;
};

class BaseCreator : public IPluginCreator
{
public:
Expand Down
18 changes: 18 additions & 0 deletions plugin/leakyReluPlugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
170 changes: 170 additions & 0 deletions plugin/leakyReluPlugin/lReluPlugin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "lReluPlugin.h"
#include "checkMacrosPlugin.h"
#include "kernel.h"

using namespace nvinfer1;
using nvinfer1::PluginType;
using nvinfer1::plugin::LReLU;
using nvinfer1::plugin::LReluPluginCreator;

static const char* LRELU_PLUGIN_VERSION{"1"};
static const char* LRELU_PLUGIN_NAME{"LReLU_TRT"};
PluginFieldCollection LReluPluginCreator::mFC{};
std::vector<PluginField> LReluPluginCreator::mPluginAttributes;

// LeakyReLU {{{
LReLU::LReLU(float negSlope)
: mNegSlope(negSlope)
, mBatchDim(1)
{
}

LReLU::LReLU(const void* buffer, size_t length)
{
const char *d = reinterpret_cast<const char*>(buffer), *a = d;
mNegSlope = read<float>(d);
mBatchDim = read<int>(d);
ASSERT(d == a + length);
}

int LReLU::getNbOutputs() const
{
return 1;
}

Dims LReLU::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
{
ASSERT(nbInputDims == 1);
ASSERT(index == 0);
return inputs[0];
}

int LReLU::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream)
{
const void* inputData = inputs[0];
void* outputData = outputs[0];
pluginStatus_t status = lReLUInference(stream, mBatchDim * batchSize, mNegSlope, inputData, outputData);
ASSERT(status == STATUS_SUCCESS);
return status;
}

size_t LReLU::getSerializationSize() const
{
// mNegSlope, mBatchDim
return sizeof(float) + sizeof(int);
}

void LReLU::serialize(void* buffer) const
{
char *d = reinterpret_cast<char*>(buffer), *a = d;
write(d, mNegSlope);
write(d, mBatchDim);
ASSERT(d == a + getSerializationSize());
}

void LReLU::configureWithFormat(
const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int)
{
ASSERT(type == DataType::kFLOAT && format == PluginFormat::kNCHW);
ASSERT(mBatchDim == 1);
ASSERT(nbOutputs == 1);
for (int i = 0; i < inputDims[0].nbDims; ++i)
{
mBatchDim *= inputDims[0].d[i];
}
}

bool LReLU::supportsFormat(DataType type, PluginFormat format) const
{
return (type == DataType::kFLOAT && format == PluginFormat::kNCHW);
}

int LReLU::initialize()
{
return 0;
}

void LReLU::terminate() {}

size_t LReLU::getWorkspaceSize(int maxBatchSize) const
{
return 0;
}

const char* LReLU::getPluginType() const
{
return LRELU_PLUGIN_NAME;
}

const char* LReLU::getPluginVersion() const
{
return LRELU_PLUGIN_VERSION;
}

void LReLU::destroy()
{
delete this;
}

IPluginV2* LReLU::clone() const
{
IPluginV2* plugin = new LReLU(mNegSlope);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}

LReluPluginCreator::LReluPluginCreator()
{
mPluginAttributes.emplace_back(PluginField("negSlope", nullptr, PluginFieldType::kFLOAT32, 1));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char* LReluPluginCreator::getPluginName() const
{
return LRELU_PLUGIN_NAME;
}

const char* LReluPluginCreator::getPluginVersion() const
{
return LRELU_PLUGIN_VERSION;
}

const PluginFieldCollection* LReluPluginCreator::getFieldNames()
{
return &mFC;
}

IPluginV2* LReluPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
const PluginField* fields = fc->fields;
ASSERT(fc->nbFields == 1);
ASSERT(fields[0].type == PluginFieldType::kFLOAT32);
negSlope = *(static_cast<const float*>(fields[0].data));

return new LReLU(negSlope);
}

IPluginV2* LReluPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
{
// This object will be deleted when the network is destroyed, which will
// call LReluPlugin::destroy()
return new LReLU(serialData, serialLength);
}
// LeakReLU }}}
88 changes: 88 additions & 0 deletions plugin/leakyReluPlugin/lReluPlugin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_L_RELU_PLUGIN_H
#define TRT_L_RELU_PLUGIN_H
#include "NvInfer.h"
#include "kernel.h"
#include "plugin.h"
#include <cassert>
#include <iostream>
#include <string>
#include <vector>

namespace nvinfer1
{
namespace plugin
{

class LReLU : public BasePlugin
{
public:
LReLU(float negSlope);
LReLU(const void* buffer, size_t length);
~LReLU() override = default;

// IPluginV2 Methods
int getNbOutputs() const override;
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
int initialize() override;
void terminate() override;
size_t getWorkspaceSize(int maxBatchSize) const override;
int enqueue(
int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type,
PluginFormat format, int maxBatchSize) override;
bool supportsFormat(DataType type, PluginFormat format) const override;
const char* getPluginType() const override;
const char* getPluginVersion() const override;
void destroy() override;
IPluginV2* clone() const override;

private:
float mNegSlope;
int mBatchDim;
};

class LReluPluginCreator : public BaseCreator
{
public:
LReluPluginCreator();

~LReluPluginCreator() override = default;

const char* getPluginName() const override;

const char* getPluginVersion() const override;

const PluginFieldCollection* getFieldNames() override;

IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override;

IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;

private:
static PluginFieldCollection mFC;
float negSlope{};
static std::vector<PluginField> mPluginAttributes;
};

typedef LReLU PReLU; // Temporary. For backward compatibilty.
} // namespace plugin
} // namespace nvinfer1

#endif // TRT_L_RELU_PLUGIN_H

0 comments on commit 275eefc

Please sign in to comment.