Skip to content

Commit

Permalink
add /audio/speech to AzureOpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
eiixy committed Aug 8, 2024
1 parent dfb20a1 commit f4cb38b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 33 deletions.
15 changes: 3 additions & 12 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
if c.config.APIType.IsAzure() {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
Expand Down Expand Up @@ -226,7 +226,7 @@ func decodeString(body io.Reader, output *string) error {
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string, args ...any) string {
// /openai/deployments/{model}/chat/completions?api-version={api_version}
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
if c.config.APIType.IsAzure() {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
parseSuffix, _ := url.Parse(suffix)
Expand All @@ -241,6 +241,7 @@ func (c *Client) fullURL(suffix string, args ...any) string {
"/chat/completions",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
}, parseSuffix.Path) {
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, parseSuffix.Path, query.Encode())
Expand All @@ -258,16 +259,6 @@ func (c *Client) fullURL(suffix string, args ...any) string {
)
}

// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
if c.config.APIType == APITypeCloudflareAzure {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
parseSuffix, _ := url.Parse(suffix)
query := parseSuffix.Query()
query.Add("api-version", c.config.APIVersion)
return fmt.Sprintf("%s%s?%s", baseURL, parseSuffix.Path, query.Encode())
}

return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

Expand Down
27 changes: 6 additions & 21 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,17 +444,6 @@ func TestClient_fullURL(t *testing.T) {
BaseURL: "https://xxx.openai.azure.com/",
APIVersion: "2023-05-15",
})
azureADClient := NewClientWithConfig(ClientConfig{
APIType: APITypeAzureAD,
BaseURL: "https://xxx.openai.azure.com/",
APIVersion: "2023-05-15",
})
cloudflareAzureClient := NewClientWithConfig(ClientConfig{
APIType: APITypeCloudflareAzure,
BaseURL: "https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}" +
"/azure-openai/{resource_name}/{deployment_name}",
APIVersion: "2023-05-15",
})
suffix := fmt.Sprintf("%s?limit=10", assistantsSuffix)
tests := []struct {
name string
Expand All @@ -466,21 +455,17 @@ func TestClient_fullURL(t *testing.T) {
"https://api.openai.com/v1/assistants?limit=10"},
{"", args{client: azureClient, suffix: suffix, args: nil},
"https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"},
{"", args{client: azureADClient, suffix: suffix, args: nil},
"https://xxx.openai.azure.com/openai/assistants?api-version=2023-05-15&limit=10"},
{"", args{client: cloudflareAzureClient, suffix: suffix, args: nil},
"https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" +
"/{resource_name}/{deployment_name}/assistants?api-version=2023-05-15&limit=10"},
// /chat/completions
{"", args{client: client, suffix: chatCompletionsSuffix, args: nil},
"https://api.openai.com/v1/chat/completions"},
{"", args{client: azureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}},
"https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"},
{"", args{client: azureADClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}},
"https://xxx.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-05-15"},
{"", args{client: cloudflareAzureClient, suffix: chatCompletionsSuffix, args: []any{GPT4oMini}},
"https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/azure-openai" +
"/{resource_name}/{deployment_name}/chat/completions?api-version=2023-05-15"},

// /audio/speech
{"", args{client: client, suffix: "/audio/speech", args: nil},
"https://api.openai.com/v1/audio/speech"},
{"", args{client: azureClient, suffix: "/audio/speech", args: []any{string(TTSModel1HD)}},
"https://xxx.openai.azure.com/openai/deployments/tts-1-hd/audio/speech?api-version=2023-05-15"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ const (

type APIType string

func (r APIType) IsAzure() bool {
switch r {

Check failure on line 19 in config.go

View workflow job for this annotation

GitHub Actions / Sanity check

missing cases in switch of type openai.APIType: openai.APITypeOpenAI (exhaustive)
case APITypeAzure, APITypeAzureAD, APITypeCloudflareAzure:
return true
}
return false
}

const (
APITypeOpenAI APIType = "OPEN_AI"
APITypeAzure APIType = "AZURE"
Expand Down

0 comments on commit f4cb38b

Please sign in to comment.