Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fullURL endpoint generation #817

Merged
merged 18 commits into from
Aug 16, 2024
24 changes: 19 additions & 5 deletions api_internal_test.go
Original file line number Diff line number Diff line change
@@ -112,13 +112,15 @@ func TestAzureFullURL(t *testing.T) {
Name string
BaseURL string
AzureModelMapper map[string]string
Suffix string
Model string
Expect string
}{
{
"AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
@@ -128,19 +130,28 @@ func TestAzureFullURL(t *testing.T) {
"AzureBaseURLWithoutSlashOK",
"https://httpbin.org",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-05-15",
},
{
"",
"https://httpbin.org",
nil,
"/assistants?limit=10",
"chatgpt-demo",
"https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions", c.Model)
actual := cli.fullURL(c.Suffix, withModel(c.Model))
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
@@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Suffix string
Expect string
}{
{
"CloudflareAzureBaseURLWithSlashAutoStrip",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
"/chat/completions",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
{
"CloudflareAzureBaseURLWithoutSlashOK",
"",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
"/assistants?limit=10",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" +
"/assistants?api-version=2023-05-15&limit=10",
},
}

@@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) {

cli := NewClientWithConfig(az)

actual := cli.fullURL("/chat/completions")
actual := cli.fullURL(c.Suffix)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
9 changes: 7 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
@@ -122,8 +122,13 @@ func (c *Client) callAudioAPI(
}

urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
withBody(&formBody), withContentType(builder.FormDataContentType()))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(&formBody),
withContentType(builder.FormDataContentType()),
)
if err != nil {
return AudioResponse{}, err
}
7 changes: 6 additions & 1 deletion chat.go
Original file line number Diff line number Diff line change
@@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return
}
7 changes: 6 additions & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
@@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return nil, err
}
84 changes: 54 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
@@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error {
return nil
}

type fullURLOptions struct {
model string
}

type fullURLOption func(*fullURLOptions)

func withModel(model string) fullURLOption {
return func(args *fullURLOptions) {
args.model = model
}
}

var azureDeploymentsEndpoints = []string{
"/completions",
"/embeddings",
"/chat/completions",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
}

// fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string, args ...any) string {
// /openai/deployments/{model}/chat/completions?api-version={api_version}
func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
baseURL := strings.TrimRight(c.config.BaseURL, "/")
args := fullURLOptions{}
for _, setter := range setters {
setter(&args)
}

if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
parseURL, _ := url.Parse(baseURL)
query := parseURL.Query()
query.Add("api-version", c.config.APIVersion)
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, query.Encode(),
)
baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model)
}

if c.config.APIVersion != "" {
suffix = c.suffixWithAPIVersion(suffix)
}
return fmt.Sprintf("%s%s", baseURL, suffix)
}

// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
if c.config.APIType == APITypeCloudflareAzure {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
func (c *Client) suffixWithAPIVersion(suffix string) string {
parsedSuffix, err := url.Parse(suffix)
if err != nil {
panic("failed to parse url suffix")
}
query := parsedSuffix.Query()
query.Add("api-version", c.config.APIVersion)
return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
}

return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix)
if containsSubstr(azureDeploymentsEndpoints, suffix) {
azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
if azureDeploymentName == "" {
azureDeploymentName = "UNKNOWN"
}
baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is pretty hard to figure about 🤔 Could I suggest breaking it down into more functions and keeping fullURLOptions closer to the code that uses it?

Something like this:

func (c *Client) suffixWithAPIVersion(previousSuffix string) (newSuffix string) {
	parsedSuffix, err := url.Parse(previousSuffix)
	if err != nil {
		panic("failed to parse url suffix")
	}

	query := parsedSuffix.Query()
	query.Add("api-version", c.config.APIVersion)
	return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
}

func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
	azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
	if azureDeploymentName == "" {
		azureDeploymentName = "UNKNOWN"
	}
	baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix)
	if containsSubstr(azureDeploymentsEndpoints, suffix) {
		baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
	}
	return baseURL
}


// fullURL returns full URL for request.
func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
	baseURL := strings.TrimRight(c.config.BaseURL, "/")
	urlOptions := fullURLOptions{}
	for _, setter := range setters {
		setter(&args)
	}

	if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
		baseURL = baseURLWithAzureDeployment(baseURL, suffix, urlOptions.model)
	}

	if c.config.APIVersion != "" {
		suffix = c.suffixWithAPIVersion(suffix)
	}

	return fmt.Sprintf("%s%s", baseURL, suffix)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've completed the suggested changes. Could you please review it again?

return baseURL
}

func (c *Client) handleErrorResp(resp *http.Response) error {
96 changes: 96 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

func TestClient_suffixWithAPIVersion(t *testing.T) {
type fields struct {
apiVersion string
}
type args struct {
suffix string
}
tests := []struct {
name string
fields fields
args args
want string
wantPanic string
}{
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "/assistants"},
"/assistants?api-version=2023-05",
"",
},
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "/assistants?limit=5"},
"/assistants?api-version=2023-05&limit=5",
"",
},
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "123:assistants?limit=5"},
"/assistants?api-version=2023-05&limit=5",
"failed to parse url suffix",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
config: ClientConfig{APIVersion: tt.fields.apiVersion},
}
defer func() {
if r := recover(); r != nil {
if r.(string) != tt.wantPanic {
t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic)
}
}
}()
if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want {
t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want)
}
})
}
}

func TestClient_baseURLWithAzureDeployment(t *testing.T) {
type args struct {
baseURL string
suffix string
model string
}
tests := []struct {
name string
args args
wantNewBaseURL string
}{
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini},
"https://test.openai.azure.com/openai",
},
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini},
"https://test.openai.azure.com/openai/deployments/gpt-4o-mini",
},
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""},
"https://test.openai.azure.com/openai/deployments/UNKNOWN",
},
}
client := NewClient("")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotNewBaseURL := client.baseURLWithAzureDeployment(
tt.args.baseURL,
tt.args.suffix,
tt.args.model,
); gotNewBaseURL != tt.wantNewBaseURL {
t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL)
}
})
}
}
7 changes: 6 additions & 1 deletion completion.go
Original file line number Diff line number Diff line change
@@ -212,7 +212,12 @@ func (c *Client) CreateCompletion(
return
}

req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return
}
7 changes: 6 additions & 1 deletion edits.go
Original file line number Diff line number Diff line change
@@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024.
You can use CreateChatCompletion or CreateChatCompletionStream instead.
*/
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/edits", withModel(fmt.Sprint(request.Model))),
withBody(request),
)
if err != nil {
return
}
7 changes: 6 additions & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
@@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
withBody(baseReq),
)
if err != nil {
return
}
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() {
return
}

fmt.Printf(response.Choices[0].Delta.Content)
fmt.Println(response.Choices[0].Delta.Content)
}
}

Loading
Loading