diff --git a/source/tnn/network/tensorrt/layer_builder/strided_slice_v2_layer_builder.cc b/source/tnn/network/tensorrt/layer_builder/strided_slice_v2_layer_builder.cc index 80c14fe21..b26943a53 100644 --- a/source/tnn/network/tensorrt/layer_builder/strided_slice_v2_layer_builder.cc +++ b/source/tnn/network/tensorrt/layer_builder/strided_slice_v2_layer_builder.cc @@ -75,6 +75,13 @@ ILayer* StrideSliceV2TRTLayerBuilder::AddToNetwork(INetworkDefinition* network) end_dim[param->axes[i]] = param->begins[param->axes[i]] + dim[param->axes[i]]; } } + + if (stride_dim.size() == 0 && input_tensors.size() < 5) { // stride defaults to 1 + stride_dim = DimsVector(axes.size() > 0 ? axes.size() : dims.size(), 1); + } + // NOTE & TODO: when input_tensors[3] exists, should we accept it as axes? + // NOTE & TODO: when input_tensors[4] exists, should we accept it as strides? + axes = ShapeTensor(1, std::move(axes_dim)); strides = ShapeTensor(1, std::move(stride_dim)); diff --git a/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc b/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc index 469fe3085..a58c8b34a 100644 --- a/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc +++ b/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc @@ -119,14 +119,17 @@ ILayer* UpsampleTRTPluginLayerBuilder::AddToNetwork(INetworkDefinition* network) } } } else { - if (output_dims.size() == 4) { + auto trt_dim = input_tensor->getDimensions(); + if (output_dims.size() == 4 || (output_dims.size() == 0 && trt_dim.nbDims == 4)) { + // NOTE: keep 2nd condition until dims or rank of blob can be tracked in dynamic shape scenario float scale[4]; scale[0] = 1; scale[1] = 1; scale[2] = paramlist->scales[1]; scale[3] = paramlist->scales[0]; layer->setScales(scale, 4); - } else if (output_dims.size() == 5) { + } else if (output_dims.size() == 5 || (output_dims.size() == 0 && trt_dim.nbDims == 5)) { + // NOTE: keep 2nd condition until dims or rank of blob can be tracked in dynamic shape scenario float scale[5]; scale[0] = 1; scale[1] = 1;