From 0925563e86c2fdc5011310aa616ba493989cfe0a Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Fri, 15 Mar 2024 18:59:16 +0800 Subject: [PATCH] Fix broken implementation AssistantModify implementation (#685) * add custom marshaller, documentation and isolate tests * fix linter --- assistant.go | 30 ++++++++++++- assistant_test.go | 109 ++++++++++++++++++++++++++++++++++------------ 2 files changed, 109 insertions(+), 30 deletions(-) diff --git a/assistant.go b/assistant.go index 7a7a7652e..4ca2dda62 100644 --- a/assistant.go +++ b/assistant.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -21,7 +22,7 @@ type Assistant struct { Description *string `json:"description,omitempty"` Model string `json:"model"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools,omitempty"` + Tools []AssistantTool `json:"tools"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` @@ -41,16 +42,41 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +// AssistantRequest provides the assistant request parameters. +// When modifying the tools the API functions as the following: +// If Tools is undefined, no changes are made to the Assistant's tools. +// If Tools is empty slice it will effectively delete all of the Assistant's tools. +// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { Model string `json:"model"` Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` + Tools []AssistantTool `json:"-"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } +// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases +// If Tools is nil, the field is omitted from the JSON. +// If Tools is an empty slice, it's included in the JSON as an empty array ([]). +// If Tools is populated, it's included in the JSON with the elements. +func (a AssistantRequest) MarshalJSON() ([]byte, error) { + type Alias AssistantRequest + assistantAlias := &struct { + Tools *[]AssistantTool `json:"tools,omitempty"` + *Alias + }{ + Alias: (*Alias)(&a), + } + + if a.Tools != nil { + assistantAlias.Tools = &a.Tools + } + + return json.Marshal(assistantAlias) +} + // AssistantsList is a list of assistants. type AssistantsList struct { Assistants []Assistant `json:"data"` diff --git a/assistant_test.go b/assistant_test.go index 48bc6f91d..40de0e50f 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -96,7 +96,7 @@ When asked a question, write and run Python code to answer the question.` }) fmt.Fprintln(w, string(resBytes)) case http.MethodPost: - var request openai.AssistantRequest + var request openai.Assistant err := json.NewDecoder(r.Body).Decode(&request) checks.NoError(t, err, "Decode error") @@ -163,44 +163,97 @@ When asked a question, write and run Python code to answer the question.` ctx := context.Background() - _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("create_assistant", func(t *testing.T) { + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") }) - checks.NoError(t, err, "CreateAssistant error") - _, err = client.RetrieveAssistant(ctx, assistantID) - checks.NoError(t, err, "RetrieveAssistant error") + t.Run("retrieve_assistant", func(t *testing.T) { + _, err := client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + }) - _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("delete_assistant", func(t *testing.T) { + _, err := client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") }) - checks.NoError(t, err, "ModifyAssistant error") - _, err = client.DeleteAssistant(ctx, assistantID) - checks.NoError(t, err, "DeleteAssistant error") + t.Run("list_assistant", func(t *testing.T) { + _, err := client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + }) - _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistants error") + t.Run("create_assistant_file", func(t *testing.T) { + _, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + }) - _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ - FileID: assistantFileID, + t.Run("list_assistant_files", func(t *testing.T) { + _, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") }) - checks.NoError(t, err, "CreateAssistantFile error") - _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistantFiles error") + t.Run("retrieve_assistant_file", func(t *testing.T) { + _, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + }) - _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "RetrieveAssistantFile error") + t.Run("delete_assistant_file", func(t *testing.T) { + err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") + }) - err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "DeleteAssistantFile error") + t.Run("modify_assistant_no_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools != nil { + t.Errorf("expected nil got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_with_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}}, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil || len(assistant.Tools) != 1 { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_empty_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: make([]openai.AssistantTool, 0), + }) + + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) } func TestAzureAssistant(t *testing.T) {