diff --git a/drivers/plugins/extra-params/config.go b/drivers/plugins/extra-params/config.go index 02bf9e24..30baca77 100644 --- a/drivers/plugins/extra-params/config.go +++ b/drivers/plugins/extra-params/config.go @@ -11,10 +11,10 @@ type Config struct { } type ExtraParam struct { - Name string `json:"name" label:"参数名"` - Position string `json:"position" enum:"header,query,body" label:"参数位置"` - Value interface{} `json:"value" label:"参数值"` - Conflict string `json:"conflict" label:"参数冲突时的处理方式" enum:"origin,convert,error"` + Name string `json:"name" label:"参数名"` + Position string `json:"position" enum:"header,query,body" label:"参数位置"` + Value string `json:"value" label:"参数值"` + Conflict string `json:"conflict" label:"参数冲突时的处理方式" enum:"origin,convert,error"` } func (c *Config) doCheck() error { diff --git a/drivers/plugins/extra-params/extra-params.go b/drivers/plugins/extra-params/extra-params.go index 8e74c4ce..6bbbf137 100644 --- a/drivers/plugins/extra-params/extra-params.go +++ b/drivers/plugins/extra-params/extra-params.go @@ -47,10 +47,16 @@ func (e *ExtraParams) access(ctx http_service.IHttpContext) (int, error) { headers := ctx.Proxy().Header().Headers() // 先判断参数类型 for _, param := range e.params { + var paramValue interface{} + err = json.Unmarshal([]byte(param.Value), ¶mValue) + if err != nil { + paramValue = param.Value + } switch param.Position { case "query": { - value, err := getQueryValue(ctx, param) + v, _ := json.Marshal(paramValue) + value, err := getQueryValue(ctx, param, string(v)) if err != nil { err = encodeErr(e.errorType, err.Error(), clientErrStatusCode) return clientErrStatusCode, err @@ -59,7 +65,8 @@ func (e *ExtraParams) access(ctx http_service.IHttpContext) (int, error) { } case "header": { - value, err := getHeaderValue(headers, param) + v, _ := json.Marshal(paramValue) + value, err := getHeaderValue(headers, param, string(v)) if err != nil { err = encodeErr(e.errorType, err.Error(), clientErrStatusCode) return clientErrStatusCode, err @@ -68,7 +75,7 @@ func (e *ExtraParams) access(ctx http_service.IHttpContext) (int, error) { } case "body": { - value, err := getBodyValue(bodyParams, formParams, param, contentType) + value, err := getBodyValue(bodyParams, formParams, param, contentType, paramValue) if err != nil { err = encodeErr(e.errorType, err.Error(), clientErrStatusCode) return clientErrStatusCode, err diff --git a/drivers/plugins/extra-params/util.go b/drivers/plugins/extra-params/util.go index bdd1b3ab..b1045ac3 100644 --- a/drivers/plugins/extra-params/util.go +++ b/drivers/plugins/extra-params/util.go @@ -60,12 +60,9 @@ func parseBodyParams(ctx http_service.IHttpContext, body []byte, contentType str return bodyParams, formParams, nil } -func getHeaderValue(headers map[string][]string, param *ExtraParam) (string, error) { +func getHeaderValue(headers map[string][]string, param *ExtraParam, value string) (string, error) { paramName := ConvertHeaderKey(param.Name) - if _, ok := param.Value.(string); !ok { - errInfo := "[extra_params] Header param " + param.Name + " must be a string" - return "", errors.New(errInfo) - } + if param.Conflict == "" { param.Conflict = paramConvert } @@ -79,12 +76,7 @@ func getHeaderValue(headers map[string][]string, param *ExtraParam) (string, err } if param.Conflict == paramConvert { - if value, ok := param.Value.(string); ok { - paramValue = value - } else { - errInfo := `[extra_params] Illegal "paramValue" in "` + param.Name + `"` - return "", errors.New(errInfo) - } + paramValue = value } else if param.Conflict == paramError { errInfo := `[extra_params] "` + param.Name + `" has a conflict.` return "", errors.New(errInfo) @@ -115,12 +107,8 @@ func hasQueryValue(rawQuery string, paramName string) bool { return false } -func getQueryValue(ctx http_service.IHttpContext, param *ExtraParam) (string, error) { - if _, ok := param.Value.(string); !ok { - errInfo := "[extra_params] Query param " + param.Name + " must be a string" - return "", errors.New(errInfo) - } - value := "" +func getQueryValue(ctx http_service.IHttpContext, param *ExtraParam, value string) (string, error) { + paramValue := "" if param.Conflict == "" { param.Conflict = paramConvert } @@ -129,49 +117,45 @@ func getQueryValue(ctx http_service.IHttpContext, param *ExtraParam) (string, er if !hasQueryValue(ctx.Proxy().URI().RawQuery(), param.Name) { param.Conflict = paramConvert } else { - value = ctx.Proxy().URI().GetQuery(param.Name) + paramValue = ctx.Proxy().URI().GetQuery(param.Name) } if param.Conflict == paramConvert { - value = param.Value.(string) + paramValue = value } else if param.Conflict == paramError { errInfo := `[extra_params] "` + param.Name + `" has a conflict.` return "", errors.New(errInfo) } - return value, nil + return paramValue, nil } -func getBodyValue(bodyParams map[string]interface{}, formParams map[string][]string, param *ExtraParam, contentType string) (interface{}, error) { - var value interface{} = nil +func getBodyValue(bodyParams map[string]interface{}, formParams map[string][]string, param *ExtraParam, contentType string, value interface{}) (interface{}, error) { + var paramValue interface{} = nil if param.Conflict == "" { param.Conflict = paramConvert } if strings.Contains(contentType, FormParamType) { - if _, ok := param.Value.(string); !ok { - errInfo := "[extra_params] Body param " + param.Name + " must be a string" - return "", errors.New(errInfo) - } if _, ok := formParams[param.Name]; !ok { param.Conflict = paramConvert } else { - value = formParams[param.Name][0] + paramValue = formParams[param.Name][0] } } else if strings.Contains(contentType, JsonType) { if _, ok := bodyParams[param.Name]; !ok { param.Conflict = paramConvert } else { - value = bodyParams[param.Name] + paramValue = bodyParams[param.Name] } } if param.Conflict == paramConvert { - value = param.Value + paramValue = value } else if param.Conflict == paramError { errInfo := `[extra_params] "` + param.Name + `" has a conflict.` return "", errors.New(errInfo) } - return value, nil + return paramValue, nil } func ConvertHeaderKey(header string) string {