From dfb20a1bfdc0595da9d8073bec60512c19330274 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 7 Aug 2024 23:06:00 +0800 Subject: [PATCH 01/18] fix: fullURL endpoint generation, add tests --- client.go | 21 +++++++++++++----- client_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index d5d555c3d..bf881c4dd 100644 --- a/client.go +++ b/client.go @@ -229,13 +229,21 @@ func (c *Client) fullURL(suffix string, args ...any) string { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") - parseURL, _ := url.Parse(baseURL) - query := parseURL.Query() + parseSuffix, _ := url.Parse(suffix) + query := parseSuffix.Query() query.Add("api-version", c.config.APIVersion) // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) + // the interface is not supported by AzureOpenAI, and there is currently no better handling solution. + if !containsSubstr([]string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/images/generations", + }, parseSuffix.Path) { + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, parseSuffix.Path, query.Encode()) } azureDeploymentName := "UNKNOWN" if len(args) > 0 { @@ -254,7 +262,10 @@ func (c *Client) fullURL(suffix string, args ...any) string { if c.config.APIType == APITypeCloudflareAzure { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") - return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + parseSuffix, _ := url.Parse(suffix) + query := parseSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode()) } return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) diff --git a/client_test.go b/client_test.go index e49da9b3d..6dc997daf 100644 --- a/client_test.go +++ b/client_test.go @@ -431,3 +431,62 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_fullURL(t *testing.T) { + type args struct { + client *Client + suffix string + args []any + } + client := NewClient("") + azureClient := NewClientWithConfig(ClientConfig{ + APIType: APITypeAzure, + BaseURL: "https://xxx.openai.azure.com/", + APIVersion: "2023-05-15", + }) + azureADClient := NewClientWithConfig(ClientConfig{ + APIType: APITypeAzureAD, + BaseURL: "https://xxx.openai.azure.com/", + APIVersion: "2023-05-15", + }) + cloudflareAzureClient := NewClientWithConfig(ClientConfig{ + APIType: APITypeCloudflareAzure, + BaseURL: "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}" + + "/azure-openai/{resource_name}/{deployment_name}", + APIVersion: "2023-05-15", + }) + suffix := fmt.Sprintf("%s?limit=10", assistantsSuffix) + tests := []struct { + name string + args args + want string + }{ + // /assistants + {"", args{client: client, suffix: suffix, args: nil}, + "https://api.openai.com/v1/assistants?limit=10"}, + {"", args{client: azureClient, suffix: suffix, args: nil}, + "https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"}, + {"", args{client: azureADClient, suffix: suffix, args: nil}, + "https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"}, + {"", args{client: cloudflareAzureClient, suffix: suffix, args: nil}, + "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" + + "/{resource_name}/{deployment_name}/assistants?api-version=2023-05-15&limit=10"}, + // /chat/completions + {"", args{client: client, suffix: chatCompletionsSuffix, args: nil}, + "https://api.openai.com/v1/chat/completions"}, + {"", args{client: azureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, + "https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"}, + {"", args{client: azureADClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, + "https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"}, + {"", args{client: cloudflareAzureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, + "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" + + "/{resource_name}/{deployment_name}/chat/completions?api-version=2023-05-15"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.args.client.fullURL(tt.args.suffix, tt.args.args...); got != tt.want { + t.Errorf("fullURL() = %v, want %v", got, tt.want) + } + }) + } +} From f4cb38bcc59963dbe64ac49cbe396a2c09cdcdf1 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 17:36:04 +0800 Subject: [PATCH 02/18] add `/audio/speech` to AzureOpenAI --- client.go | 15 +++------------ client_test.go | 27 ++++++--------------------- config.go | 8 ++++++++ 3 files changed, 17 insertions(+), 33 deletions(-) diff --git a/client.go b/client.go index bf881c4dd..25cb46839 100644 --- a/client.go +++ b/client.go @@ -183,7 +183,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { + if c.config.APIType.IsAzure() { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -226,7 +226,7 @@ func decodeString(body io.Reader, output *string) error { // args[0] is model name, if API type is Azure, model name is required to get deployment name. func (c *Client) fullURL(suffix string, args ...any) string { // /openai/deployments/{model}/chat/completions?api-version={api_version} - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { + if c.config.APIType.IsAzure() { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") parseSuffix, _ := url.Parse(suffix) @@ -241,6 +241,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { "/chat/completions", "/audio/transcriptions", "/audio/translations", + "/audio/speech", "/images/generations", }, parseSuffix.Path) { return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, parseSuffix.Path, query.Encode()) @@ -258,16 +259,6 @@ func (c *Client) fullURL(suffix string, args ...any) string { ) } - // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ - if c.config.APIType == APITypeCloudflareAzure { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - parseSuffix, _ := url.Parse(suffix) - query := parseSuffix.Query() - query.Add("api-version", c.config.APIVersion) - return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode()) - } - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/client_test.go b/client_test.go index 6dc997daf..4efa17ba9 100644 --- a/client_test.go +++ b/client_test.go @@ -444,17 +444,6 @@ func TestClient_fullURL(t *testing.T) { BaseURL: "https://xxx.openai.azure.com/", APIVersion: "2023-05-15", }) - azureADClient := NewClientWithConfig(ClientConfig{ - APIType: APITypeAzureAD, - BaseURL: "https://xxx.openai.azure.com/", - APIVersion: "2023-05-15", - }) - cloudflareAzureClient := NewClientWithConfig(ClientConfig{ - APIType: APITypeCloudflareAzure, - BaseURL: "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}" + - "/azure-openai/{resource_name}/{deployment_name}", - APIVersion: "2023-05-15", - }) suffix := fmt.Sprintf("%s?limit=10", assistantsSuffix) tests := []struct { name string @@ -466,21 +455,17 @@ func TestClient_fullURL(t *testing.T) { "https://api.openai.com/v1/assistants?limit=10"}, {"", args{client: azureClient, suffix: suffix, args: nil}, "https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"}, - {"", args{client: azureADClient, suffix: suffix, args: nil}, - "https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"}, - {"", args{client: cloudflareAzureClient, suffix: suffix, args: nil}, - "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" + - "/{resource_name}/{deployment_name}/assistants?api-version=2023-05-15&limit=10"}, // /chat/completions {"", args{client: client, suffix: chatCompletionsSuffix, args: nil}, "https://api.openai.com/v1/chat/completions"}, {"", args{client: azureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, "https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"}, - {"", args{client: azureADClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, - "https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"}, - {"", args{client: cloudflareAzureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, - "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" + - "/{resource_name}/{deployment_name}/chat/completions?api-version=2023-05-15"}, + + // /audio/speech + {"", args{client: client, suffix: "/audio/speech", args: nil}, + "https://api.openai.com/v1/audio/speech"}, + {"", args{client: azureClient, suffix: "/audio/speech", args: []any{string(TTSModel1HD)}}, + "https://xxx.openai.azure.com/openai/deployments/tts-1-hd/audio/speech?api-version=2023-05-15"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/config.go b/config.go index 1347567d7..abfda16b5 100644 --- a/config.go +++ b/config.go @@ -15,6 +15,14 @@ const ( type APIType string +func (r APIType) IsAzure() bool { + switch r { + case APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure: + return true + } + return false +} + const ( APITypeOpenAI APIType = "OPEN_AI" APITypeAzure APIType = "AZURE" From 33a2a5d09b4d6ce3aa2d28215502981452d11bcf Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 17:39:24 +0800 Subject: [PATCH 03/18] fix Sanity check --- config.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index abfda16b5..ee4425138 100644 --- a/config.go +++ b/config.go @@ -19,8 +19,9 @@ func (r APIType) IsAzure() bool { switch r { case APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure: return true + default: + return false } - return false } const ( From 192686f92838849299fff8e2c975a259b8d02e5e Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 17:44:41 +0800 Subject: [PATCH 04/18] fix Sanity check --- config.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/config.go b/config.go index ee4425138..943093d58 100644 --- a/config.go +++ b/config.go @@ -15,13 +15,15 @@ const ( type APIType string +var azureTypes = []APIType{APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure} + func (r APIType) IsAzure() bool { - switch r { - case APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure: - return true - default: - return false + for i := range azureTypes { + if r == azureTypes[i] { + return true + } } + return false } const ( From 0e4b7d911aa455ce72663c021e845fed1597f7e4 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 18:06:21 +0800 Subject: [PATCH 05/18] update TestCloudflareAzureFullURL --- api_internal_test.go | 9 ++++++--- client.go | 3 +++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index a590ec9ab..0237c3d3e 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -153,17 +153,20 @@ func TestCloudflareAzureFullURL(t *testing.T) { cases := []struct { Name string BaseURL string + Model string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", + "chatgpt-demo", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, { "CloudflareAzureBaseURLWithoutSlashOK", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", + "chatgpt-demo", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, @@ -176,7 +179,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL("/chat/completions", c.Model) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/client.go b/client.go index 25cb46839..a571f71eb 100644 --- a/client.go +++ b/client.go @@ -253,6 +253,9 @@ func (c *Client) fullURL(suffix string, args ...any) string { azureDeploymentName = c.config.GetAzureDeploymentByModel(model) } } + if c.config.APIType == APITypeCloudflareAzure { + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureDeploymentName, parseSuffix.Path, query.Encode()) + } return fmt.Sprintf("%s/%s/%s/%s%s?%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix, azureDeploymentName, suffix, query.Encode(), From 3867539971651fd72f7afa6e2364fcce6cb06806 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 18:27:41 +0800 Subject: [PATCH 06/18] update fullURL --- client.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index a571f71eb..ad069d29d 100644 --- a/client.go +++ b/client.go @@ -235,6 +235,9 @@ func (c *Client) fullURL(suffix string, args ...any) string { // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP // the interface is not supported by AzureOpenAI, and there is currently no better handling solution. + if c.config.APIType != APITypeCloudflareAzure { + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix) + } if !containsSubstr([]string{ "/completions", "/embeddings", @@ -244,7 +247,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { "/audio/speech", "/images/generations", }, parseSuffix.Path) { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, parseSuffix.Path, query.Encode()) + return fmt.Sprintf("%s/%s?%s", baseURL, parseSuffix.Path, query.Encode()) } azureDeploymentName := "UNKNOWN" if len(args) > 0 { @@ -253,13 +256,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { azureDeploymentName = c.config.GetAzureDeploymentByModel(model) } } - if c.config.APIType == APITypeCloudflareAzure { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureDeploymentName, parseSuffix.Path, query.Encode()) - } - return fmt.Sprintf("%s/%s/%s/%s%s?%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, query.Encode(), - ) + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureDeploymentName, suffix, query.Encode()) } return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) From 656031e762d411db17bc16ca75154332b8a2bb5d Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 18:46:40 +0800 Subject: [PATCH 07/18] update api_internal_test.go --- api_internal_test.go | 26 ++++++++++++++++++++++++-- client.go | 4 ++-- client_test.go | 44 -------------------------------------------- 3 files changed, 26 insertions(+), 48 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 0237c3d3e..ea50f2a4b 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) { Name string BaseURL string AzureModelMapper map[string]string + Suffix string Model string Expect string }{ @@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithoutSlashOK", "https://httpbin.org", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + "/chat/completions?api-version=2023-05-15", }, + { + "", + "https://httpbin.org", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/deployments/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -140,7 +151,7 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, c.Model) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } @@ -153,12 +164,14 @@ func TestCloudflareAzureFullURL(t *testing.T) { cases := []struct { Name string BaseURL string + Suffix string Model string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", + "/chat/completions", "chatgpt-demo", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", @@ -166,10 +179,19 @@ func TestCloudflareAzureFullURL(t *testing.T) { { "CloudflareAzureBaseURLWithoutSlashOK", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", + "/chat/completions", "chatgpt-demo", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, + { + "", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", + "/assistants?limit=10", + "chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource" + + "/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -179,7 +201,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, c.Model) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/client.go b/client.go index ad069d29d..3b79908f4 100644 --- a/client.go +++ b/client.go @@ -183,7 +183,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType.IsAzure() { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -247,7 +247,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { "/audio/speech", "/images/generations", }, parseSuffix.Path) { - return fmt.Sprintf("%s/%s?%s", baseURL, parseSuffix.Path, query.Encode()) + return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode()) } azureDeploymentName := "UNKNOWN" if len(args) > 0 { diff --git a/client_test.go b/client_test.go index 4efa17ba9..e49da9b3d 100644 --- a/client_test.go +++ b/client_test.go @@ -431,47 +431,3 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } - -func TestClient_fullURL(t *testing.T) { - type args struct { - client *Client - suffix string - args []any - } - client := NewClient("") - azureClient := NewClientWithConfig(ClientConfig{ - APIType: APITypeAzure, - BaseURL: "https://xxx.openai.azure.com/", - APIVersion: "2023-05-15", - }) - suffix := fmt.Sprintf("%s?limit=10", assistantsSuffix) - tests := []struct { - name string - args args - want string - }{ - // /assistants - {"", args{client: client, suffix: suffix, args: nil}, - "https://api.openai.com/v1/assistants?limit=10"}, - {"", args{client: azureClient, suffix: suffix, args: nil}, - "https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"}, - // /chat/completions - {"", args{client: client, suffix: chatCompletionsSuffix, args: nil}, - "https://api.openai.com/v1/chat/completions"}, - {"", args{client: azureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}}, - "https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"}, - - // /audio/speech - {"", args{client: client, suffix: "/audio/speech", args: nil}, - "https://api.openai.com/v1/audio/speech"}, - {"", args{client: azureClient, suffix: "/audio/speech", args: []any{string(TTSModel1HD)}}, - "https://xxx.openai.azure.com/openai/deployments/tts-1-hd/audio/speech?api-version=2023-05-15"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.args.client.fullURL(tt.args.suffix, tt.args.args...); got != tt.want { - t.Errorf("fullURL() = %v, want %v", got, tt.want) - } - }) - } -} From bc298b574c0fb37572845b14061274c49c159a5c Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 8 Aug 2024 18:49:48 +0800 Subject: [PATCH 08/18] fix Sanity check --- client.go | 4 +++- config.go | 11 ----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 3b79908f4..68b0fe3ca 100644 --- a/client.go +++ b/client.go @@ -226,7 +226,9 @@ func decodeString(body io.Reader, output *string) error { // args[0] is model name, if API type is Azure, model name is required to get deployment name. func (c *Client) fullURL(suffix string, args ...any) string { // /openai/deployments/{model}/chat/completions?api-version={api_version} - if c.config.APIType.IsAzure() { + if c.config.APIType == APITypeAzure || + c.config.APIType == APITypeAzureAD || + c.config.APIType == APITypeCloudflareAzure { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") parseSuffix, _ := url.Parse(suffix) diff --git a/config.go b/config.go index 943093d58..1347567d7 100644 --- a/config.go +++ b/config.go @@ -15,17 +15,6 @@ const ( type APIType string -var azureTypes = []APIType{APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure} - -func (r APIType) IsAzure() bool { - for i := range azureTypes { - if r == azureTypes[i] { - return true - } - } - return false -} - const ( APITypeOpenAI APIType = "OPEN_AI" APITypeAzure APIType = "AZURE" From 4c43288513d2c8a222c08944502c9baedc600661 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 09:48:09 +0800 Subject: [PATCH 09/18] feature: #636 --- api_internal_test.go | 4 ++-- audio.go | 2 +- chat.go | 2 +- chat_stream.go | 2 +- client.go | 36 +++++++++++++++++++++++++++--------- completion.go | 2 +- edits.go | 5 ++++- embeddings.go | 5 ++++- image.go | 6 +++--- moderation.go | 2 +- speech.go | 2 +- stream.go | 2 +- 12 files changed, 47 insertions(+), 23 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index ea50f2a4b..b6f970c97 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -151,7 +151,7 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL(c.Suffix, c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } @@ -201,7 +201,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL(c.Suffix, c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/audio.go b/audio.go index dbc26d154..5574a3a0e 100644 --- a/audio.go +++ b/audio.go @@ -122,7 +122,7 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(&formBody), withContentType(builder.FormDataContentType())) if err != nil { return AudioResponse{}, err diff --git a/chat.go b/chat.go index 8bfe558b5..704b7053b 100644 --- a/chat.go +++ b/chat.go @@ -358,7 +358,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index ffd512ff6..014018ea8 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,7 +60,7 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) if err != nil { return nil, err } diff --git a/client.go b/client.go index 68b0fe3ca..912edafaa 100644 --- a/client.go +++ b/client.go @@ -222,9 +222,21 @@ func decodeString(body io.Reader, output *string) error { return nil } +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model + } +} + // fullURL returns full URL for request. // args[0] is model name, if API type is Azure, model name is required to get deployment name. -func (c *Client) fullURL(suffix string, args ...any) string { +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { // /openai/deployments/{model}/chat/completions?api-version={api_version} if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD || @@ -238,8 +250,18 @@ func (c *Client) fullURL(suffix string, args ...any) string { // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP // the interface is not supported by AzureOpenAI, and there is currently no better handling solution. if c.config.APIType != APITypeCloudflareAzure { - baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix) + baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) } + + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + azureDeploymentName := c.config.GetAzureDeploymentByModel(args.model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + if !containsSubstr([]string{ "/completions", "/embeddings", @@ -250,14 +272,10 @@ func (c *Client) fullURL(suffix string, args ...any) string { "/images/generations", }, parseSuffix.Path) { return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode()) + } else if c.config.APIType != APITypeCloudflareAzure { + baseURL = fmt.Sprintf("%s/%s", baseURL, azureDeploymentsPrefix) } - azureDeploymentName := "UNKNOWN" - if len(args) > 0 { - model, ok := args[0].(string) - if ok { - azureDeploymentName = c.config.GetAzureDeploymentByModel(model) - } - } + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureDeploymentName, suffix, query.Encode()) } diff --git a/completion.go b/completion.go index d435eb382..08bce407b 100644 --- a/completion.go +++ b/completion.go @@ -212,7 +212,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) if err != nil { return } diff --git a/edits.go b/edits.go index 97d026029..46a04c048 100644 --- a/edits.go +++ b/edits.go @@ -38,7 +38,10 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL( + "/edits", + withModel(fmt.Sprint(request.Model)), + ), withBody(request)) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index b513ba6a7..c9de6837e 100644 --- a/embeddings.go +++ b/embeddings.go @@ -241,7 +241,10 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL( + "/embeddings", + withModel(string(baseReq.Model)), + ), withBody(baseReq)) if err != nil { return } diff --git a/image.go b/image.go index 665de1a74..213ed90f7 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) if err != nil { return } @@ -132,7 +132,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", withModel(request.Model)), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return @@ -183,7 +183,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", withModel(request.Model)), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return diff --git a/moderation.go b/moderation.go index ae285ef83..855d3cc0b 100644 --- a/moderation.go +++ b/moderation.go @@ -88,7 +88,7 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", withModel(request.Model)), withBody(&request)) if err != nil { return } diff --git a/speech.go b/speech.go index 19b21bdf1..bbdffd54e 100644 --- a/speech.go +++ b/speech.go @@ -44,7 +44,7 @@ type CreateSpeechRequest struct { } func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/stream.go b/stream.go index b277f3c29..ad7c956f0 100644 --- a/stream.go +++ b/stream.go @@ -33,7 +33,7 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) if err != nil { return nil, err } From c5245232ab54f138f2587179e0f43e381e3282b1 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 14:40:40 +0800 Subject: [PATCH 10/18] update fullURL --- api_internal_test.go | 2 +- client.go | 38 +++++++++++++++++++++----------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index b6f970c97..18fc563e3 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -142,7 +142,7 @@ func TestAzureFullURL(t *testing.T) { nil, "/assistants?limit=10", "chatgpt-demo", - "https://httpbin.org/openai/deployments/assistants?api-version=2023-05-15&limit=10", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", }, } diff --git a/client.go b/client.go index 912edafaa..c457e4d8c 100644 --- a/client.go +++ b/client.go @@ -237,22 +237,15 @@ func withModel(model string) fullURLOption { // fullURL returns full URL for request. // args[0] is model name, if API type is Azure, model name is required to get deployment name. func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { - // /openai/deployments/{model}/chat/completions?api-version={api_version} if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD || c.config.APIType == APITypeCloudflareAzure { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") parseSuffix, _ := url.Parse(suffix) + suffix = parseSuffix.Path query := parseSuffix.Query() query.Add("api-version", c.config.APIVersion) - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - // the interface is not supported by AzureOpenAI, and there is currently no better handling solution. - if c.config.APIType != APITypeCloudflareAzure { - baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) - } - args := fullURLOptions{} for _, setter := range setters { setter(&args) @@ -261,8 +254,7 @@ func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { if azureDeploymentName == "" { azureDeploymentName = "UNKNOWN" } - - if !containsSubstr([]string{ + inEndpoints := containsSubstr([]string{ "/completions", "/embeddings", "/chat/completions", @@ -270,15 +262,27 @@ func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { "/audio/translations", "/audio/speech", "/images/generations", - }, parseSuffix.Path) { - return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode()) - } else if c.config.APIType != APITypeCloudflareAzure { - baseURL = fmt.Sprintf("%s/%s", baseURL, azureDeploymentsPrefix) + }, suffix) + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { + baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) + if inEndpoints { + return fmt.Sprintf("%s/%s/%s%s?%s", + baseURL, + azureDeploymentsPrefix, + azureDeploymentName, + suffix, + query.Encode(), + ) + } + return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) + } + if c.config.APIType == APITypeCloudflareAzure { + if inEndpoints { + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureDeploymentName, suffix, query.Encode()) + } + return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) } - - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureDeploymentName, suffix, query.Encode()) } - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } From 194007e071e40b5a2f17388251c8a1e1b5feb64d Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 15:01:28 +0800 Subject: [PATCH 11/18] fix Sanity check --- client.go | 88 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/client.go b/client.go index c457e4d8c..688db646f 100644 --- a/client.go +++ b/client.go @@ -235,55 +235,50 @@ func withModel(model string) fullURLOption { } // fullURL returns full URL for request. -// args[0] is model name, if API type is Azure, model name is required to get deployment name. func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { - if c.config.APIType == APITypeAzure || - c.config.APIType == APITypeAzureAD || - c.config.APIType == APITypeCloudflareAzure { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - parseSuffix, _ := url.Parse(suffix) - suffix = parseSuffix.Path - query := parseSuffix.Query() - query.Add("api-version", c.config.APIVersion) - args := fullURLOptions{} - for _, setter := range setters { - setter(&args) - } + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + if contains([]APIType{APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure}, c.config.APIType) { azureDeploymentName := c.config.GetAzureDeploymentByModel(args.model) if azureDeploymentName == "" { azureDeploymentName = "UNKNOWN" } - inEndpoints := containsSubstr([]string{ - "/completions", - "/embeddings", - "/chat/completions", - "/audio/transcriptions", - "/audio/translations", - "/audio/speech", - "/images/generations", - }, suffix) - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) - if inEndpoints { - return fmt.Sprintf("%s/%s/%s%s?%s", - baseURL, - azureDeploymentsPrefix, - azureDeploymentName, - suffix, - query.Encode(), - ) - } - return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) - } - if c.config.APIType == APITypeCloudflareAzure { - if inEndpoints { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureDeploymentName, suffix, query.Encode()) - } - return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) + return c.azureFullURL(suffix, azureDeploymentName) + } + baseURL := strings.TrimRight(c.config.BaseURL, "/") + return fmt.Sprintf("%s%s", baseURL, suffix) +} + +func (c *Client) azureFullURL(suffix string, deployment string) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + parseSuffix, _ := url.Parse(suffix) + suffix = parseSuffix.Path + query := parseSuffix.Query() + query.Add("api-version", c.config.APIVersion) + inEndpoints := containsSubstr([]string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", + }, suffix) + + if c.config.APIType == APITypeCloudflareAzure { + if inEndpoints { + return fmt.Sprintf("%s/%s%s?%s", baseURL, deployment, suffix, query.Encode()) } + return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) } - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) + + baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) + if inEndpoints { + return fmt.Sprintf("%s/%s/%s%s?%s", baseURL, azureDeploymentsPrefix, deployment, suffix, query.Encode()) + } + return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) } func (c *Client) handleErrorResp(resp *http.Response) error { @@ -312,3 +307,12 @@ func containsSubstr(s []string, e string) bool { } return false } + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} From 18a57255c5ac1349ba49b11e1e4ba47beb233f47 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 18:52:48 +0800 Subject: [PATCH 12/18] update TestCloudflareAzureFullURL --- api_internal_test.go | 19 ++++----------- client.go | 57 +++++++++++++++++++------------------------- 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 18fc563e3..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -165,31 +165,20 @@ func TestCloudflareAzureFullURL(t *testing.T) { Name string BaseURL string Suffix string - Model string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", "/chat/completions", - "chatgpt-demo", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + - "chat/completions?api-version=2023-05-15", - }, - { - "CloudflareAzureBaseURLWithoutSlashOK", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", - "/chat/completions", - "chatgpt-demo", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, { "", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", "/assistants?limit=10", - "chatgpt-demo", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource" + + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + "/assistants?api-version=2023-05-15&limit=10", }, } @@ -201,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL(c.Suffix, withModel(c.Model)) + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/client.go b/client.go index 688db646f..267f5ee76 100644 --- a/client.go +++ b/client.go @@ -234,51 +234,42 @@ func withModel(model string) fullURLOption { } } +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", +} + // fullURL returns full URL for request. func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") args := fullURLOptions{} for _, setter := range setters { setter(&args) } - if contains([]APIType{APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure}, c.config.APIType) { + + if c.config.APIVersion != "" { + parseSuffix, _ := url.Parse(suffix) + query := parseSuffix.Query() + query.Add("api-version", c.config.APIVersion) + suffix = fmt.Sprintf("%s?%s", parseSuffix.Path, query.Encode()) + } + + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { azureDeploymentName := c.config.GetAzureDeploymentByModel(args.model) if azureDeploymentName == "" { azureDeploymentName = "UNKNOWN" } - return c.azureFullURL(suffix, azureDeploymentName) - } - baseURL := strings.TrimRight(c.config.BaseURL, "/") - return fmt.Sprintf("%s%s", baseURL, suffix) -} - -func (c *Client) azureFullURL(suffix string, deployment string) string { - baseURL := strings.TrimRight(c.config.BaseURL, "/") - parseSuffix, _ := url.Parse(suffix) - suffix = parseSuffix.Path - query := parseSuffix.Query() - query.Add("api-version", c.config.APIVersion) - inEndpoints := containsSubstr([]string{ - "/completions", - "/embeddings", - "/chat/completions", - "/audio/transcriptions", - "/audio/translations", - "/audio/speech", - "/images/generations", - }, suffix) - - if c.config.APIType == APITypeCloudflareAzure { - if inEndpoints { - return fmt.Sprintf("%s/%s%s?%s", baseURL, deployment, suffix, query.Encode()) + baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) } - return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) } - - baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) - if inEndpoints { - return fmt.Sprintf("%s/%s/%s%s?%s", baseURL, azureDeploymentsPrefix, deployment, suffix, query.Encode()) - } - return fmt.Sprintf("%s%s?%s", baseURL, suffix, query.Encode()) + return fmt.Sprintf("%s%s", baseURL, suffix) } func (c *Client) handleErrorResp(resp *http.Response) error { From 88fddbe4c23f7535dfb22d50447f3cf6904287fb Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 9 Aug 2024 18:55:47 +0800 Subject: [PATCH 13/18] fix Sanity check --- client.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/client.go b/client.go index 267f5ee76..ad1ca80a9 100644 --- a/client.go +++ b/client.go @@ -298,12 +298,3 @@ func containsSubstr(s []string, e string) bool { } return false } - -func contains[S ~[]E, E comparable](s S, v E) bool { - for i := range s { - if v == s[i] { - return true - } - } - return false -} From 71311b9e9bb895d36ea3d0ef8ce20d5b00f2c8ca Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 14 Aug 2024 10:35:22 +0800 Subject: [PATCH 14/18] update fullURL --- audio.go | 9 +++++++-- chat.go | 7 ++++++- chat_stream.go | 7 ++++++- client.go | 34 +++++++++++++++++++++++----------- completion.go | 7 ++++++- edits.go | 10 ++++++---- embeddings.go | 10 ++++++---- image.go | 25 ++++++++++++++++++++----- moderation.go | 7 ++++++- speech.go | 5 ++++- stream.go | 8 +++++++- 11 files changed, 97 insertions(+), 32 deletions(-) diff --git a/audio.go b/audio.go index 5574a3a0e..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(&formBody), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } diff --git a/chat.go b/chat.go index 704b7053b..cfc93e0b5 100644 --- a/chat.go +++ b/chat.go @@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 014018ea8..3f90bc019 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/client.go b/client.go index ad1ca80a9..a8bdf6f01 100644 --- a/client.go +++ b/client.go @@ -252,24 +252,36 @@ func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { setter(&args) } + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) + } + if c.config.APIVersion != "" { - parseSuffix, _ := url.Parse(suffix) - query := parseSuffix.Query() - query.Add("api-version", c.config.APIVersion) - suffix = fmt.Sprintf("%s?%s", parseSuffix.Path, query.Encode()) + suffix = c.suffixWithAPIVersion(suffix) } + return fmt.Sprintf("%s%s", baseURL, suffix) +} - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - azureDeploymentName := c.config.GetAzureDeploymentByModel(args.model) +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") + } + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} + +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) if azureDeploymentName == "" { azureDeploymentName = "UNKNOWN" } - baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) - if containsSubstr(azureDeploymentsEndpoints, suffix) { - baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) - } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) } - return fmt.Sprintf("%s%s", baseURL, suffix) + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { diff --git a/completion.go b/completion.go index 08bce407b..8761d30bb 100644 --- a/completion.go +++ b/completion.go @@ -212,7 +212,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/edits.go b/edits.go index 46a04c048..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -38,10 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL( - "/edits", - withModel(fmt.Sprint(request.Model)), - ), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index c9de6837e..74eb8aa57 100644 --- a/embeddings.go +++ b/embeddings.go @@ -241,10 +241,12 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL( - "/embeddings", - withModel(string(baseReq.Model)), - ), withBody(baseReq)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(baseReq), + ) if err != nil { return } diff --git a/image.go b/image.go index 213ed90f7..577d7db95 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,12 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -132,8 +137,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", withModel(request.Model)), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } @@ -183,8 +193,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", withModel(request.Model)), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } diff --git a/moderation.go b/moderation.go index 855d3cc0b..c8652efc8 100644 --- a/moderation.go +++ b/moderation.go @@ -88,7 +88,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", withModel(request.Model)), withBody(&request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/speech.go b/speech.go index bbdffd54e..20b52e334 100644 --- a/speech.go +++ b/speech.go @@ -44,7 +44,10 @@ type CreateSpeechRequest struct { } func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", withModel(string(request.Model))), + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/stream.go b/stream.go index ad7c956f0..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package openai import ( "context" "errors" + "net/http" ) var ( @@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, withModel(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } From 903fcbf8b781e8c86a026444b2185efc90d4b6e2 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 14 Aug 2024 22:46:06 +0800 Subject: [PATCH 15/18] fix Sanity check --- example_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_test.go b/example_test.go index de67c57cd..1bdb8496e 100644 --- a/example_test.go +++ b/example_test.go @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } From 0fb96ea210f8034a20fe7bfdb91b7dd9a96ed78c Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 14 Aug 2024 23:36:38 +0800 Subject: [PATCH 16/18] add test cases --- client.go | 2 +- client_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index a8bdf6f01..9f547e7cb 100644 --- a/client.go +++ b/client.go @@ -273,7 +273,7 @@ func (c *Client) suffixWithAPIVersion(suffix string) string { } func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { - baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix) + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) if containsSubstr(azureDeploymentsEndpoints, suffix) { azureDeploymentName := c.config.GetAzureDeploymentByModel(model) if azureDeploymentName == "" { diff --git a/client_test.go b/client_test.go index e49da9b3d..5a62a571f 100644 --- a/client_test.go +++ b/client_test.go @@ -431,3 +431,78 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_suffixWithAPIVersion(t *testing.T) { + type fields struct { + apiVersion string + } + type args struct { + suffix string + } + tests := []struct { + name string + fields fields + args args + want string + }{ + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants"}, + "/assistants?api-version=2023-05", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: ClientConfig{APIVersion: tt.fields.apiVersion}, + } + if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { + t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_baseURLWithAzureDeployment(t *testing.T) { + + type args struct { + baseURL string + suffix string + model string + } + tests := []struct { + name string + args args + wantNewBaseURL string + }{ + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", + }, + } + client := NewClient("") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotNewBaseURL := client.baseURLWithAzureDeployment( + tt.args.baseURL, + tt.args.suffix, + tt.args.model, + ); gotNewBaseURL != tt.wantNewBaseURL { + t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL) + } + }) + } +} From fbe6e1c1b1bd715d3e6867a139b1482278e2518b Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 14 Aug 2024 23:39:44 +0800 Subject: [PATCH 17/18] fix Sanity check --- client_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client_test.go b/client_test.go index 5a62a571f..5774ae5c5 100644 --- a/client_test.go +++ b/client_test.go @@ -471,7 +471,6 @@ func TestClient_suffixWithAPIVersion(t *testing.T) { } func TestClient_baseURLWithAzureDeployment(t *testing.T) { - type args struct { baseURL string suffix string From 2c75de4f2e7754b9d7b2398850068be397b01085 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 15 Aug 2024 00:14:36 +0800 Subject: [PATCH 18/18] update test cases --- client_test.go | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/client_test.go b/client_test.go index 5774ae5c5..a0d3bb390 100644 --- a/client_test.go +++ b/client_test.go @@ -440,22 +440,32 @@ func TestClient_suffixWithAPIVersion(t *testing.T) { suffix string } tests := []struct { - name string - fields fields - args args - want string + name string + fields fields + args args + want string + wantPanic string }{ { "", fields{apiVersion: "2023-05"}, args{suffix: "/assistants"}, "/assistants?api-version=2023-05", + "", }, { "", fields{apiVersion: "2023-05"}, args{suffix: "/assistants?limit=5"}, "/assistants?api-version=2023-05&limit=5", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "123:assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "failed to parse url suffix", }, } for _, tt := range tests { @@ -463,6 +473,13 @@ func TestClient_suffixWithAPIVersion(t *testing.T) { c := &Client{ config: ClientConfig{APIVersion: tt.fields.apiVersion}, } + defer func() { + if r := recover(); r != nil { + if r.(string) != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + } + } + }() if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) } @@ -491,6 +508,11 @@ func TestClient_baseURLWithAzureDeployment(t *testing.T) { args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""}, + "https://test.openai.azure.com/openai/deployments/UNKNOWN", + }, } client := NewClient("") for _, tt := range tests {