From c368232f502a7682dd89cc40c9d777ae6ab42a51 Mon Sep 17 00:00:00 2001 From: wanthigh <31470660+wanthigh@users.noreply.github.com> Date: Sat, 9 Nov 2024 23:31:46 +0800 Subject: [PATCH] fix: changeoptional field to pointer type (#1907) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix:修复在渠道配置中设置模型重定向时,temperature为0被忽略的问题 * fix: set optional fields to pointer type --------- Co-authored-by: JustSong --- common/helper/helper.go | 20 ++++++++++++++++++++ relay/adaptor/ali/main.go | 4 +--- relay/adaptor/ali/model.go | 4 ++-- relay/adaptor/anthropic/model.go | 4 ++-- relay/adaptor/aws/claude/model.go | 4 ++-- relay/adaptor/aws/llama3/model.go | 8 ++++---- relay/adaptor/baidu/main.go | 6 +++--- relay/adaptor/cloudflare/model.go | 2 +- relay/adaptor/cohere/main.go | 2 +- relay/adaptor/cohere/model.go | 8 ++++---- relay/adaptor/gemini/model.go | 4 ++-- relay/adaptor/ollama/model.go | 16 ++++++++-------- relay/adaptor/palm/model.go | 10 +++++----- relay/adaptor/tencent/main.go | 4 ++-- relay/adaptor/vertexai/claude/model.go | 4 ++-- relay/adaptor/xunfei/model.go | 10 +++++----- relay/adaptor/zhipu/adaptor.go | 14 +++++++------- relay/adaptor/zhipu/model.go | 4 ++-- relay/model/general.go | 8 ++++---- 19 files changed, 77 insertions(+), 59 deletions(-) diff --git a/common/helper/helper.go b/common/helper/helper.go index e06dfb6e64..df7b0a5f9c 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -137,3 +137,23 @@ func String2Int(str string) int { } return num } + +func Float64PtrMax(p *float64, maxValue float64) *float64 { + if p == nil { + return nil + } + if *p > maxValue { + return &maxValue + } + return p +} + +func Float64PtrMin(p *float64, minValue float64) *float64 { + if p == nil { + return nil + } + if *p < minValue { + return &minValue + } + return p +} diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index ec5848ce09..6a73c7072f 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -36,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { enableSearch = true aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } - if request.TopP >= 1 { - request.TopP = 0.9999 - } + request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) return &ChatRequest{ Model: aliModel, Input: Input{ diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go index 450b5f5292..a680c7e24b 100644 --- a/relay/adaptor/ali/model.go +++ b/relay/adaptor/ali/model.go @@ -16,13 +16,13 @@ type Input struct { } type Parameters struct { - TopP float64 `json:"top_p,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Seed uint64 `json:"seed,omitempty"` EnableSearch bool `json:"enable_search,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` ResultFormat string `json:"result_format,omitempty"` Tools []model.Tool `json:"tools,omitempty"` } diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index 47f766291d..47f193faa0 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -48,8 +48,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` diff --git a/relay/adaptor/aws/claude/model.go b/relay/adaptor/aws/claude/model.go index 6d00b68865..106228877b 100644 --- a/relay/adaptor/aws/claude/model.go +++ b/relay/adaptor/aws/claude/model.go @@ -11,8 +11,8 @@ type Request struct { Messages []anthropic.Message `json:"messages"` System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go index 7b86c3b8ff..6cb64cdeac 100644 --- a/relay/adaptor/aws/llama3/model.go +++ b/relay/adaptor/aws/llama3/model.go @@ -4,10 +4,10 @@ package aws // // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html type Request struct { - Prompt string `json:"prompt"` - MaxGenLen int `json:"max_gen_len,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` } // Response is the response from AWS Llama3 diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index ebe70c3241..ac8a562544 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -35,9 +35,9 @@ type Message struct { type ChatRequest struct { Messages []Message `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - PenaltyScore float64 `json:"penalty_score,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + PenaltyScore *float64 `json:"penalty_score,omitempty"` Stream bool `json:"stream,omitempty"` System string `json:"system,omitempty"` DisableSearch bool `json:"disable_search,omitempty"` diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go index 0d3bafe098..8e382ba7ad 100644 --- a/relay/adaptor/cloudflare/model.go +++ b/relay/adaptor/cloudflare/model.go @@ -9,5 +9,5 @@ type Request struct { Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` } diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 45db437b6b..736c5a8d86 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { K: textRequest.TopK, Stream: textRequest.Stream, FrequencyPenalty: textRequest.FrequencyPenalty, - PresencePenalty: textRequest.FrequencyPenalty, + PresencePenalty: textRequest.PresencePenalty, Seed: int(textRequest.Seed), } if cohereRequest.Model == "" { diff --git a/relay/adaptor/cohere/model.go b/relay/adaptor/cohere/model.go index 64fa9c9403..3a8bc99dc7 100644 --- a/relay/adaptor/cohere/model.go +++ b/relay/adaptor/cohere/model.go @@ -10,15 +10,15 @@ type Request struct { PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" Connectors []Connector `json:"connectors,omitempty"` Documents []Document `json:"documents,omitempty"` - Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3 + Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3 MaxTokens int `json:"max_tokens,omitempty"` MaxInputTokens int `json:"max_input_tokens,omitempty"` K int `json:"k,omitempty"` // 默认值为0 - P float64 `json:"p,omitempty"` // 默认值为0.75 + P *float64 `json:"p,omitempty"` // 默认值为0.75 Seed int `json:"seed,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 - PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 Tools []Tool `json:"tools,omitempty"` ToolResults []ToolResult `json:"tool_results,omitempty"` } diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index f6a3b25042..720cb65d19 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -67,8 +67,8 @@ type ChatTools struct { type ChatGenerationConfig struct { ResponseMimeType string `json:"responseMimeType,omitempty"` ResponseSchema any `json:"responseSchema,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` TopK float64 `json:"topK,omitempty"` MaxOutputTokens int `json:"maxOutputTokens,omitempty"` CandidateCount int `json:"candidateCount,omitempty"` diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go index 7039984fcc..94f2ab7332 100644 --- a/relay/adaptor/ollama/model.go +++ b/relay/adaptor/ollama/model.go @@ -1,14 +1,14 @@ package ollama type Options struct { - Seed int `json:"seed,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } type Message struct { diff --git a/relay/adaptor/palm/model.go b/relay/adaptor/palm/model.go index f653022c3e..2bdd8f298b 100644 --- a/relay/adaptor/palm/model.go +++ b/relay/adaptor/palm/model.go @@ -19,11 +19,11 @@ type Prompt struct { } type ChatRequest struct { - Prompt Prompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` + Prompt Prompt `json:"prompt"` + Temperature *float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` } type Error struct { diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 365e33aef6..827c8a46dd 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { Model: &request.Model, Stream: &request.Stream, Messages: messages, - TopP: &request.TopP, - Temperature: &request.Temperature, + TopP: request.TopP, + Temperature: request.Temperature, } } diff --git a/relay/adaptor/vertexai/claude/model.go b/relay/adaptor/vertexai/claude/model.go index e1bd5dd48d..c08ba460d9 100644 --- a/relay/adaptor/vertexai/claude/model.go +++ b/relay/adaptor/vertexai/claude/model.go @@ -11,8 +11,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` diff --git a/relay/adaptor/xunfei/model.go b/relay/adaptor/xunfei/model.go index 1f37c04655..c9fb1bb8f2 100644 --- a/relay/adaptor/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -19,11 +19,11 @@ type ChatRequest struct { } `json:"header"` Parameter struct { Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` + Domain string `json:"domain,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` } `json:"chat"` } `json:"parameter"` Payload struct { diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 78b01fb3f7..660bd37960 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -4,13 +4,13 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "io" - "math" "net/http" "strings" ) @@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) return baiduEmbeddingRequest, err default: - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) + // TopP [0.0, 1.0] + request.TopP = helper.Float64PtrMax(request.TopP, 1) + request.TopP = helper.Float64PtrMin(request.TopP, 0) - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) + // Temperature [0.0, 1.0] + request.Temperature = helper.Float64PtrMax(request.Temperature, 1) + request.Temperature = helper.Float64PtrMin(request.Temperature, 0) a.SetVersionByModeName(request.Model) if a.APIVersion == "v4" { return request, nil diff --git a/relay/adaptor/zhipu/model.go b/relay/adaptor/zhipu/model.go index f91de1dced..06e22dc153 100644 --- a/relay/adaptor/zhipu/model.go +++ b/relay/adaptor/zhipu/model.go @@ -12,8 +12,8 @@ type Message struct { type Request struct { Prompt []Message `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` RequestId string `json:"request_id,omitempty"` Incremental bool `json:"incremental,omitempty"` } diff --git a/relay/model/general.go b/relay/model/general.go index fe73779ed7..a84e64abf1 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -26,17 +26,17 @@ type GeneralOpenAIRequest struct { Model string `json:"model,omitempty"` Modalities []string `json:"modalities,omitempty"` Audio *Audio `json:"audio,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` Stop any `json:"stop,omitempty"` Stream bool `json:"stream,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`