Skip to content

Commit

Permalink
fix: fullURL endpoint generation, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eiixy committed Aug 7, 2024
1 parent 623074c commit dfb20a1
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 5 deletions.
21 changes: 16 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,21 @@ func (c *Client) fullURL(suffix string, args ...any) string {
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
parseURL, _ := url.Parse(baseURL)
query := parseURL.Query()
parseSuffix, _ := url.Parse(suffix)
query := parseSuffix.Query()
query.Add("api-version", c.config.APIVersion)
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
// the interface is not supported by AzureOpenAI, and there is currently no better handling solution.
if !containsSubstr([]string{
"/completions",
"/embeddings",
"/chat/completions",
"/audio/transcriptions",
"/audio/translations",
"/images/generations",
}, parseSuffix.Path) {
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, parseSuffix.Path, query.Encode())
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
Expand All @@ -254,7 +262,10 @@ func (c *Client) fullURL(suffix string, args ...any) string {
if c.config.APIType == APITypeCloudflareAzure {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
parseSuffix, _ := url.Parse(suffix)
query := parseSuffix.Query()
query.Add("api-version", c.config.APIVersion)
return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode())
}

return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
Expand Down
59 changes: 59 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,62 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

func TestClient_fullURL(t *testing.T) {
type args struct {
client *Client
suffix string
args []any
}
client := NewClient("")
azureClient := NewClientWithConfig(ClientConfig{
APIType: APITypeAzure,
BaseURL: "https://xxx.openai.azure.com/",
APIVersion: "2023-05-15",
})
azureADClient := NewClientWithConfig(ClientConfig{
APIType: APITypeAzureAD,
BaseURL: "https://xxx.openai.azure.com/",
APIVersion: "2023-05-15",
})
cloudflareAzureClient := NewClientWithConfig(ClientConfig{
APIType: APITypeCloudflareAzure,
BaseURL: "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}" +
"/azure-openai/{resource_name}/{deployment_name}",
APIVersion: "2023-05-15",
})
suffix := fmt.Sprintf("%s?limit=10", assistantsSuffix)
tests := []struct {
name string
args args
want string
}{
// /assistants
{"", args{client: client, suffix: suffix, args: nil},
"https://api.openai.com/v1/assistants?limit=10"},
{"", args{client: azureClient, suffix: suffix, args: nil},
"https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"},
{"", args{client: azureADClient, suffix: suffix, args: nil},
"https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"},
{"", args{client: cloudflareAzureClient, suffix: suffix, args: nil},
"https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" +
"/{resource_name}/{deployment_name}/assistants?api-version=2023-05-15&limit=10"},
// /chat/completions
{"", args{client: client, suffix: chatCompletionsSuffix, args: nil},
"https://api.openai.com/v1/chat/completions"},
{"", args{client: azureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}},
"https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"},
{"", args{client: azureADClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}},
"https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"},
{"", args{client: cloudflareAzureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}},
"https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" +
"/{resource_name}/{deployment_name}/chat/completions?api-version=2023-05-15"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.args.client.fullURL(tt.args.suffix, tt.args.args...); got != tt.want {
t.Errorf("fullURL() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit dfb20a1

Please sign in to comment.