Skip to content

Commit

Permalink
update api_internal_test.go
Browse files Browse the repository at this point in the history
  • Loading branch information
eiixy committed Aug 8, 2024
1 parent 3867539 commit 656031e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 48 deletions.
26 changes: 24 additions & 2 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ func TestAzureFullURL(t *testing.T) {
Name string
BaseURL string
AzureModelMapper map[string]string
Suffix string
Model string
Expect string
}{
{
"AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
Expand All @@ -128,19 +130,28 @@ func TestAzureFullURL(t *testing.T) {
"AzureBaseURLWithoutSlashOK",
"https://httpbin.org",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-05-15",
},
{
"",
"https://httpbin.org",
nil,
"/assistants?limit=10",
"chatgpt-demo",
"https://httpbin.org/openai/deployments/assistants?api-version=2023-05-15&limit=10",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions", c.Model)
actual := cli.fullURL(c.Suffix, c.Model)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand All @@ -153,23 +164,34 @@ func TestCloudflareAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Suffix string
Model string
Expect string
}{
{
"CloudflareAzureBaseURLWithSlashAutoStrip",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/",
"/chat/completions",
"chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
{
"CloudflareAzureBaseURLWithoutSlashOK",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/",
"/chat/completions",
"chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
{
"",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/",
"/assistants?limit=10",
"chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource" +
"/assistants?api-version=2023-05-15&limit=10",
},
}

for _, c := range cases {
Expand All @@ -179,7 +201,7 @@ func TestCloudflareAzureFullURL(t *testing.T) {

cli := NewClientWithConfig(az)

actual := cli.fullURL("/chat/completions", c.Model)
actual := cli.fullURL(c.Suffix, c.Model)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand Down
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType.IsAzure() {
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
Expand Down Expand Up @@ -247,7 +247,7 @@ func (c *Client) fullURL(suffix string, args ...any) string {
"/audio/speech",
"/images/generations",
}, parseSuffix.Path) {
return fmt.Sprintf("%s/%s?%s", baseURL, parseSuffix.Path, query.Encode())
return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode())
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
Expand Down
44 changes: 0 additions & 44 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,47 +431,3 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

func TestClient_fullURL(t *testing.T) {
type args struct {
client *Client
suffix string
args []any
}
client := NewClient("")
azureClient := NewClientWithConfig(ClientConfig{
APIType: APITypeAzure,
BaseURL: "https://xxx.openai.azure.com/",
APIVersion: "2023-05-15",
})
suffix := fmt.Sprintf("%s?limit=10", assistantsSuffix)
tests := []struct {
name string
args args
want string
}{
// /assistants
{"", args{client: client, suffix: suffix, args: nil},
"https://api.openai.com/v1/assistants?limit=10"},
{"", args{client: azureClient, suffix: suffix, args: nil},
"https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"},
// /chat/completions
{"", args{client: client, suffix: chatCompletionsSuffix, args: nil},
"https://api.openai.com/v1/chat/completions"},
{"", args{client: azureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}},
"https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"},

// /audio/speech
{"", args{client: client, suffix: "/audio/speech", args: nil},
"https://api.openai.com/v1/audio/speech"},
{"", args{client: azureClient, suffix: "/audio/speech", args: []any{string(TTSModel1HD)}},
"https://xxx.openai.azure.com/openai/deployments/tts-1-hd/audio/speech?api-version=2023-05-15"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.args.client.fullURL(tt.args.suffix, tt.args.args...); got != tt.want {
t.Errorf("fullURL() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 656031e

Please sign in to comment.