Skip to content

Commit

Permalink
feat: support cloudflare AI Gateway flavored azure openai (#715)
Browse files Browse the repository at this point in the history
* feat: support cloudflare AI Gateway flavored azure openai

Signed-off-by: STRRL <im@strrl.dev>

* test: add test for cloudflare azure fullURL

---------

Signed-off-by: STRRL <im@strrl.dev>
Co-authored-by: STRRL <im@strrl.dev>
  • Loading branch information
woorui and STRRL authored Apr 24, 2024
1 parent 2d58f8f commit c84ab5f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
36 changes: 36 additions & 0 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) {
})
}
}

func TestCloudflareAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Expect string
}{
{
"CloudflareAzureBaseURLWithSlashAutoStrip",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
{
"CloudflareAzureBaseURLWithoutSlashOK",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL)
az.APIType = APITypeCloudflareAzure

cli := NewClientWithConfig(az)

actual := cli.fullURL("/chat/completions")
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
t.Logf("Full URL: %s", actual)
})
}
}
10 changes: 8 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
Expand Down Expand Up @@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string {
)
}

// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
if c.config.APIType == APITypeCloudflareAzure {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
}

return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

Expand Down
7 changes: 4 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ const (
type APIType string

const (
APITypeOpenAI APIType = "OPEN_AI"
APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD"
APITypeOpenAI APIType = "OPEN_AI"
APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
)

const AzureAPIKeyHeader = "api-key"
Expand Down

0 comments on commit c84ab5f

Please sign in to comment.