diff --git a/paddle2onnx/mapper/nn/interpolate.cc b/paddle2onnx/mapper/nn/interpolate.cc index 49523a8a6..721a6adb8 100755 --- a/paddle2onnx/mapper/nn/interpolate.cc +++ b/paddle2onnx/mapper/nn/interpolate.cc @@ -17,6 +17,7 @@ namespace paddle2onnx { REGISTER_MAPPER(bilinear_interp, InterpolateMapper) REGISTER_MAPPER(bilinear_interp_v2, InterpolateMapper) +REGISTER_MAPPER(nearest_interp, InterpolateMapper) REGISTER_MAPPER(nearest_interp_v2, InterpolateMapper) REGISTER_MAPPER(bicubic_interp_v2, InterpolateMapper) REGISTER_MAPPER(linear_interp_v2, InterpolateMapper) @@ -95,12 +96,21 @@ void InterpolateMapper::Opset11() { out_size.push_back(out_w_); size = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, out_size); } else { - std::vector scale_; - GetAttr("scale", &scale_); + std::vector scale_vector; float padding = 1.0; - scale_.insert(scale_.begin(), padding); - scale_.insert(scale_.begin(), padding); - scale = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale_); + GetAttr("scale", &scale_vector); + if (scale_vector.size() != 0){ + scale_vector.insert(scale_vector.begin(), padding); + scale_vector.insert(scale_vector.begin(), padding); + }else{ + float scale; + GetAttr("scale", &scale); + scale_vector.emplace_back(padding); + scale_vector.emplace_back(padding); + scale_vector.emplace_back(scale); + scale_vector.emplace_back(scale); + } + scale = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale_vector); } } std::string roi = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, std::vector()); diff --git a/paddle2onnx/mapper/nn/interpolate.h b/paddle2onnx/mapper/nn/interpolate.h index d3b8a7a7a..2159eb217 100755 --- a/paddle2onnx/mapper/nn/interpolate.h +++ b/paddle2onnx/mapper/nn/interpolate.h @@ -32,6 +32,7 @@ class InterpolateMapper : public Mapper { resize_mapper_["bilinear_interp"] = "linear"; resize_mapper_["bilinear_interp_v2"] = "linear"; + resize_mapper_["nearest_interp"] = "nearest"; resize_mapper_["nearest_interp_v2"] = "nearest"; resize_mapper_["bicubic_interp_v2"] = "cubic"; resize_mapper_["linear_interp_v2"] = "linear"; @@ -40,7 +41,6 @@ class InterpolateMapper : public Mapper { int32_t GetMinOpset(bool verbose = false); void Opset11(); - private: std::string ComputeOutSize(); std::string ComputeScale();