Skip to content

Commit

Permalink
修改获取Param的方法
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Apr 9, 2024
1 parent cca1178 commit ce3b881
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 19 deletions.
22 changes: 22 additions & 0 deletions include/runtime/runtime_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <map>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -87,8 +88,29 @@ struct RuntimeOperatorBase {

/// Operator attributes like weights
std::map<std::string, std::shared_ptr<RuntimeAttribute>> attribute;

bool has_parameter(const std::string& param_name);

bool has_attribute(const std::string& attr_name);
};

template <typename T>
bool RuntimeOperatorBase<T>::has_attribute(const std::string& attr_name) {
return false;
}

template <typename T>
bool RuntimeOperatorBase<T>::has_parameter(const std::string& param_name) {
if (this->params.empty()) {
return false;
}
if (this->params.find(param_name) == this->params.end()) {
return false;
} else {
return true;
}
}

using RuntimeOperator = RuntimeOperatorBase<float>;

using RuntimeOperatorQuantized = RuntimeOperatorBase<int8_t>;
Expand Down
32 changes: 16 additions & 16 deletions source/layer/details/base_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
return StatusCode::kParseParameterError;
}

if (params.find("dilation") == params.end()) {
if (op->has_parameter("dilation")) {
LOG(ERROR) << "Can not find the dilation parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -249,7 +249,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
const uint32_t dilation_h = dilation_param->value.at(0);
const uint32_t dilation_w = dilation_param->value.at(1);

if (params.find("in_channels") == params.end()) {
if (op->has_parameter("in_channels")) {
LOG(ERROR) << "Can not find the in channel parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -259,7 +259,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
return StatusCode::kParseParameterError;
}

if (params.find("out_channels") == params.end()) {
if (op->has_parameter("out_channels")) {
LOG(ERROR) << "Can not find the out channel parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -270,7 +270,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
return StatusCode::kParseParameterError;
}

if (params.find("padding") == params.end()) {
if (op->has_parameter("padding")) {
LOG(ERROR) << "Can not find the padding parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -281,7 +281,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
return StatusCode::kParseParameterError;
}

if (params.find("bias") == params.end()) {
if (op->has_parameter("bias")) {
LOG(ERROR) << "Can not find the bias parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -291,7 +291,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
return StatusCode::kParseParameterError;
}

if (params.find("stride") == params.end()) {
if (op->has_parameter("stride")) {
LOG(ERROR) << "Can not find the stride parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -301,7 +301,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
return StatusCode::kParseParameterError;
}

if (params.find("kernel_size") == params.end()) {
if (op->has_parameter("kernel_size")) {
LOG(ERROR) << "Can not find the kernel parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -312,7 +312,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
}

if (op->type == "nn.Conv2d") {
if (params.find("padding_mode") != params.end()) {
if (op->has_parameter("padding_mode")) {
auto padding_mode =
std::dynamic_pointer_cast<RuntimeParameterString>(params.at("padding_mode"));
if (padding_mode == nullptr) {
Expand All @@ -331,7 +331,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
}
}

if (params.find("groups") == params.end()) {
if (op->has_parameter("groups")) {
LOG(ERROR) << "Can not find the groups parameter";
return StatusCode::kParseParameterError;
}
Expand All @@ -342,29 +342,29 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
return StatusCode::kParseParameterError;
}

const uint32_t dims = 2;
const uint32_t conv_dims = 2;
const std::vector<int32_t>& kernels = kernel->value;
const std::vector<int32_t>& paddings = padding->value;
const std::vector<int32_t>& strides = stride->value;
if (paddings.size() != dims) {
if (paddings.size() != conv_dims) {
LOG(ERROR) << "Can not find the right padding parameter";
return StatusCode::kParseParameterError;
}

if (strides.size() != dims) {
if (strides.size() != conv_dims) {
LOG(ERROR) << "Can not find the right stride parameter";
return StatusCode::kParseParameterError;
}

if (kernels.size() != dims) {
if (kernels.size() != conv_dims) {
LOG(ERROR) << "Can not find the right kernel size parameter";
return StatusCode::kParseParameterError;
}

uint32_t output_padding_h = 0;
uint32_t output_padding_w = 0;
if (op->type == "nn.ConvTranspose2d") {
if (params.find("output_padding") != params.end()) {
if (op->has_parameter("output_padding")) {
auto output_padding_arr =
std::dynamic_pointer_cast<RuntimeParameterIntArray>(params.at("output_padding"));
if (!output_padding_arr) {
Expand Down Expand Up @@ -406,7 +406,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
// load weights
const std::map<std::string, std::shared_ptr<RuntimeAttribute>>& attrs = op->attribute;
if (use_bias->value) {
if (attrs.find("bias") == attrs.end()) {
if (op->has_attribute("bias")) {
LOG(ERROR) << "Can not find the bias attribute";
return StatusCode::kParseWeightError;
}
Expand All @@ -421,7 +421,7 @@ StatusCode BaseConvolutionLayer::CreateInstance(const std::shared_ptr<RuntimeOpe
conv_layer->set_bias(bias_values);
}

if (attrs.find("weight") == attrs.end()) {
if (op->has_attribute("weight")) {
LOG(ERROR) << "Can not find the weight attribute";
return StatusCode::kParseWeightError;
}
Expand Down
6 changes: 3 additions & 3 deletions source/layer/details/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ StatusCode LinearLayer::CreateInstance(const std::shared_ptr<RuntimeOperator>& o
return StatusCode::kParseParameterError;
}

if (params.find("bias") == params.end()) {
if (op->has_parameter("bias")) {
LOG(ERROR) << "Can not find the use bias parameter in the parameter list.";
return StatusCode::kParseParameterError;
}
Expand All @@ -160,13 +160,13 @@ StatusCode LinearLayer::CreateInstance(const std::shared_ptr<RuntimeOperator>& o
return StatusCode::kParseWeightError;
}

if (attr.find("weight") == attr.end()) {
if (op->has_attribute("weight")) {
LOG(ERROR) << "Can not find the weight parameter in the parameter list.";
return StatusCode::kParseWeightError;
}

if (use_bias_param->value) {
if (attr.find("bias") == attr.end()) {
if (op->has_attribute("bias")) {
LOG(ERROR) << "Can not find the bias parameter in the parameter list.";
return StatusCode::kParseWeightError;
}
Expand Down

0 comments on commit ce3b881

Please sign in to comment.