From dfb20a1bfdc0595da9d8073bec60512c19330274 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 7 Aug 2024 23:06:00 +0800 Subject: [PATCH] fix: fullURL endpoint generation, add tests --- client.go | 21 +++++++++++++----- client_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index d5d555c3..bf881c4d 100644 --- a/client.go +++ b/client.go @@ -229,13 +229,21 @@ func (c *Client) fullURL(suffix string, args ...any) string { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") - parseURL, _ := url.Parse(baseURL) - query := parseURL.Query() + parseSuffix, _ := url.Parse(suffix) + query := parseSuffix.Query() query.Add("api-version", c.config.APIVersion) // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) + // the interface is not supported by AzureOpenAI, and there is currently no better handling solution. + if !containsSubstr([]string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/images/generations", + }, parseSuffix.Path) { + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, parseSuffix.Path, query.Encode()) } azureDeploymentName := "UNKNOWN" if len(args) > 0 { @@ -254,7 +262,10 @@ func (c *Client) fullURL(suffix string, args ...any) string { if c.config.APIType == APITypeCloudflareAzure { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") - return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + parseSuffix, _ := url.Parse(suffix) + query := parseSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode()) } return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) diff --git a/client_test.go b/client_test.go index e49da9b3..6dc997da 100644 --- a/client_test.go +++ b/client_test.go @@ -431,3 +431,62 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_fullURL(t *testing.T) { + type args struct { + client *Client + suffix string + args []any + } + client := NewClient("") + azureClient := NewClientWithConfig(ClientConfig{ + APIType: APITypeAzure, + BaseURL: "https://xxx.openai.azure.com/", + APIVersion: "2023-05-15", + }) + azureADClient := NewClientWithConfig(ClientConfig{ + APIType: APITypeAzureAD, + BaseURL: "https://xxx.openai.azure.com/", + APIVersion: "2023-05-15", + }) + cloudflareAzureClient := NewClientWithConfig(ClientConfig{ + APIType: APITypeCloudflareAzure, + BaseURL: "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}" + + "/azure-openai/{resource_name}/{deployment_name}", + APIVersion: "2023-05-15", + }) + suffix := fmt.Sprintf("%s?limit=10", assistantsSuffix) + tests := []struct { + name string + args args + want string + }{ + // /assistants + {"", args{client: client, suffix: suffix, args: nil}, + "https://api.openai.com/v1/assistants?limit=10"}, + {"", args{client: azureClient, suffix: suffix, args: nil}, + "https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"}, + {"", args{client: azureADClient, suffix: suffix, args: nil}, + "https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"}, + {"", args{client: cloudflareAzureClient, suffix: suffix, args: nil}, + "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" + + "/{resource_name}/{deployment_name}/assistants?api-version=2023-05-15&limit=10"}, + // /chat/completions + {"", args{client: client, suffix: chatCompletionsSuffix, args: nil}, + "https://api.openai.com/v1/chat/completions"}, + {"", args{client: azureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, + "https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"}, + {"", args{client: azureADClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, + "https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"}, + {"", args{client: cloudflareAzureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, + "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" + + "/{resource_name}/{deployment_name}/chat/completions?api-version=2023-05-15"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.args.client.fullURL(tt.args.suffix, tt.args.args...); got != tt.want { + t.Errorf("fullURL() = %v, want %v", got, tt.want) + } + }) + } +}