From 3058f9ef7bb151d7326bb74b484f655f0aea9a72 Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Sun, 1 Sep 2024 19:44:30 -0400 Subject: [PATCH 1/2] Don't serialise empty content on ChatCompletionMessage with tool calls --- api_integration_test.go | 111 ++++++++++++++++++++++++++++++++++++++++ chat.go | 11 ++++ 2 files changed, 122 insertions(+) diff --git a/api_integration_test.go b/api_integration_test.go index 7828d9451..7e4b6c47c 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -312,3 +312,114 @@ func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) { } } } + +func TestChatCompletionJsonSchemaWithFunctionCallingResponse(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + var result MyStructuredResponse + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + { + Role: openai.ChatMessageRoleAssistant, + ToolCalls: []openai.ToolCall{ + { + ID: "call_cTSjmyVCPkRh870yFvkUrql5", + Type: openai.ToolTypeFunction, + Function: openai.FunctionCall{ + Name: "display_cases", + Arguments: `{"PascalCase":"HelloWorld","CamelCase":"helloWorld","KebabCase":"hello-world","SnakeCase":"hello_world"}`, + }, + }, + }, + }, + { + Role: openai.ChatMessageRoleTool, + ToolCallID: "call_cTSjmyVCPkRh870yFvkUrql5", + Content: "Here are the naming conventions for 'Hello World':\n" + + "PascalCase: HelloWorld\n" + + "CamelCase: helloWorld\n" + + "KebabCase: hello-world\n" + + "SnakeCase: hello_world\n", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "cases", + Schema: schema, + Strict: true, + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "display_cases", + Strict: true, + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": { + Type: jsonschema.String, + }, + "CamelCase": { + Type: jsonschema.String, + }, + "KebabCase": { + Type: jsonschema.String, + }, + "SnakeCase": { + Type: jsonschema.String, + }, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + }, + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion returned error") + if err == nil { + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") + } + if result.PascalCase != "HelloWorld" { + t.Errorf("PascalCase: expected 'HelloWorld', got '%s'", result.PascalCase) + } +} diff --git a/chat.go b/chat.go index 56e99a78b..174b44eac 100644 --- a/chat.go +++ b/chat.go @@ -114,6 +114,17 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) + } else if len(m.ToolCalls) > 0 && m.Content == "" { + msg := struct { + Role string `json:"role"` + Content string `json:"-"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) } msg := struct { Role string `json:"role"` From 43814294eab1447cf21ed0b07bd5a66a59881415 Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Sun, 1 Sep 2024 20:13:11 -0400 Subject: [PATCH 2/2] Add unit test --- chat_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/chat_test.go b/chat_test.go index 37dc09d4d..8a962ce73 100644 --- a/chat_test.go +++ b/chat_test.go @@ -425,6 +425,29 @@ func TestMultipartChatMessageSerialization(t *testing.T) { } } +func TestToolCallChatMessageSerialization(t *testing.T) { + jsonText := `{"role":"assistant","tool_calls":` + + `[{"id":"123","type":"function","function":{"name":"my_func","arguments":"{}"}}]}` + + toolCallMsg := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + ToolCalls: []openai.ToolCall{{ + ID: "123", + Type: openai.ToolTypeFunction, + Function: openai.FunctionCall{Name: "my_func", Arguments: "{}"}, + }}, + } + + s, err := json.Marshal(toolCallMsg) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + res := strings.ReplaceAll(string(s), " ", "") + if res != jsonText { + t.Fatalf("invalid message: %s", string(s)) + } +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error