diff --git a/parsers/caffe/caffeParser/opParsers/parsePReLU.cpp b/parsers/caffe/caffeParser/opParsers/parsePReLU.cpp index 85289a2b..f692f548 100644 --- a/parsers/caffe/caffeParser/opParsers/parsePReLU.cpp +++ b/parsers/caffe/caffeParser/opParsers/parsePReLU.cpp @@ -37,13 +37,15 @@ ILayer* parsePReLU(INetworkDefinition& network, const trtcaffe::LayerParameter& { return nullptr; } - int nWeights = channelShared ? 1 : inputDims.d[1]; // Caffe treats second input dimension as channels - Dims slopesDims{inputDims.nbDims, {1}, {DimensionType::kSPATIAL}}; - slopesDims.d[1] = nWeights; + + int nWeights = channelShared ? 1 : inputDims.d[0]; // Caffe treats second input dimension as channels + Dims slopesDims{inputDims.nbDims, {}, {}}; + std::fill(slopesDims.d, slopesDims.d + slopesDims.nbDims, 1); + slopesDims.d[0] = nWeights; Weights w = weightFactory.isInitialized() ? weightFactory(msg.name(), WeightType::kGENERIC) : weightFactory.allocateWeights(nWeights, std::uniform_real_distribution(0.F, 1.F)); auto constLayer = network.addConstant(slopesDims, w); return network.addParametricReLU(*tensors[msg.bottom(0)], *constLayer->getOutput(0)); } -} //namespace nvcaffeparser1 \ No newline at end of file +} //namespace nvcaffeparser1