From dd7f5824f9a4c3860cccfaf8350d5d09e864038f Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sat, 17 Aug 2024 01:11:38 +0800 Subject: [PATCH] fix: fullURL endpoint generation (#817) --- api_internal_test.go | 24 ++++++++--- audio.go | 9 ++++- chat.go | 7 +++- chat_stream.go | 7 +++- client.go | 84 ++++++++++++++++++++++++-------------- client_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++ completion.go | 7 +++- edits.go | 7 +++- embeddings.go | 7 +++- example_test.go | 2 +- image.go | 25 +++++++++--- moderation.go | 7 +++- speech.go | 5 ++- stream.go | 8 +++- 14 files changed, 244 insertions(+), 51 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index a590ec9ab..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) { Name string BaseURL string AzureModelMapper map[string]string + Suffix string Model string Expect string }{ @@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithoutSlashOK", "https://httpbin.org", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + "/chat/completions?api-version=2023-05-15", }, + { + "", + "https://httpbin.org", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -140,7 +151,7 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) { cases := []struct { Name string BaseURL string + Suffix string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "/chat/completions", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, { - "CloudflareAzureBaseURLWithoutSlashOK", + "", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + - "chat/completions?api-version=2023-05-15", + "/assistants?limit=10", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + + "/assistants?api-version=2023-05-15&limit=10", }, } @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/audio.go b/audio.go index dbc26d154..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), - withBody(&formBody), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } diff --git a/chat.go b/chat.go index 31fa887d6..826fd3bd5 100644 --- a/chat.go +++ b/chat.go @@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index ffd512ff6..3f90bc019 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/client.go b/client.go index d5d555c3d..9f547e7cb 100644 --- a/client.go +++ b/client.go @@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error { return nil } +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model + } +} + +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", +} + // fullURL returns full URL for request. -// args[0] is model name, if API type is Azure, model name is required to get deployment name. -func (c *Client) fullURL(suffix string, args ...any) string { - // /openai/deployments/{model}/chat/completions?api-version={api_version} +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - parseURL, _ := url.Parse(baseURL) - query := parseURL.Query() - query.Add("api-version", c.config.APIVersion) - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) - } - azureDeploymentName := "UNKNOWN" - if len(args) > 0 { - model, ok := args[0].(string) - if ok { - azureDeploymentName = c.config.GetAzureDeploymentByModel(model) - } - } - return fmt.Sprintf("%s/%s/%s/%s%s?%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, query.Encode(), - ) + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) + } + + if c.config.APIVersion != "" { + suffix = c.suffixWithAPIVersion(suffix) } + return fmt.Sprintf("%s%s", baseURL, suffix) +} - // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ - if c.config.APIType == APITypeCloudflareAzure { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") } + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) + } + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { diff --git a/client_test.go b/client_test.go index e49da9b3d..a0d3bb390 100644 --- a/client_test.go +++ b/client_test.go @@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_suffixWithAPIVersion(t *testing.T) { + type fields struct { + apiVersion string + } + type args struct { + suffix string + } + tests := []struct { + name string + fields fields + args args + want string + wantPanic string + }{ + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants"}, + "/assistants?api-version=2023-05", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "123:assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "failed to parse url suffix", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: ClientConfig{APIVersion: tt.fields.apiVersion}, + } + defer func() { + if r := recover(); r != nil { + if r.(string) != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + } + } + }() + if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { + t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_baseURLWithAzureDeployment(t *testing.T) { + type args struct { + baseURL string + suffix string + model string + } + tests := []struct { + name string + args args + wantNewBaseURL string + }{ + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""}, + "https://test.openai.azure.com/openai/deployments/UNKNOWN", + }, + } + client := NewClient("") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotNewBaseURL := client.baseURLWithAzureDeployment( + tt.args.baseURL, + tt.args.suffix, + tt.args.model, + ); gotNewBaseURL != tt.wantNewBaseURL { + t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL) + } + }) + } +} diff --git a/completion.go b/completion.go index bc2a63795..e8e9242c9 100644 --- a/completion.go +++ b/completion.go @@ -213,7 +213,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/edits.go b/edits.go index 97d026029..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index b513ba6a7..74eb8aa57 100644 --- a/embeddings.go +++ b/embeddings.go @@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(baseReq), + ) if err != nil { return } diff --git a/example_test.go b/example_test.go index de67c57cd..1bdb8496e 100644 --- a/example_test.go +++ b/example_test.go @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } diff --git a/image.go b/image.go index 665de1a74..577d7db95 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,12 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -132,8 +137,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } @@ -183,8 +193,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } diff --git a/moderation.go b/moderation.go index ae285ef83..c8652efc8 100644 --- a/moderation.go +++ b/moderation.go @@ -88,7 +88,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/speech.go b/speech.go index 19b21bdf1..20b52e334 100644 --- a/speech.go +++ b/speech.go @@ -44,7 +44,10 @@ type CreateSpeechRequest struct { } func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/stream.go b/stream.go index b277f3c29..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package openai import ( "context" "errors" + "net/http" ) var ( @@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err }