diff --git a/client.go b/client.go index e7a4d5beb..7b1a313a8 100644 --- a/client.go +++ b/client.go @@ -221,7 +221,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { baseURL = strings.TrimRight(baseURL, "/") // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") || strings.Contains(suffix, "/assistants") { + if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } azureDeploymentName := "UNKNOWN" @@ -258,3 +258,12 @@ func (c *Client) handleErrorResp(resp *http.Response) error { errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } + +func containsSubstr(s []string, e string) bool { + for _, v := range s { + if strings.Contains(e, v) { + return true + } + } + return false +} diff --git a/thread_test.go b/thread_test.go index 227ab6330..1ac0f3c0e 100644 --- a/thread_test.go +++ b/thread_test.go @@ -93,3 +93,86 @@ func TestThread(t *testing.T) { _, err = client.DeleteThread(ctx, threadID) checks.NoError(t, err, "DeleteThread error") } + +// TestAzureThread Tests the thread endpoint of the API using the Azure mocked server. +func TestAzureThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +}