Skip to content

Commit

Permalink
Don't serialise empty content on ChatCompletionMessage with tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
greysteil committed Sep 1, 2024
1 parent c37cf9a commit 3058f9e
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
111 changes: 111 additions & 0 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,114 @@ func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) {
}
}
}

func TestChatCompletionJsonSchemaWithFunctionCallingResponse(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := openai.NewClient(apiToken)
ctx := context.Background()

type MyStructuredResponse struct {
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"`
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
}
var result MyStructuredResponse
schema, err := jsonschema.GenerateSchemaForType(result)
if err != nil {
t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error")
}

resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
"2. CamelCase: The first word starts with a lowercase letter, " +
"and subsequent words start with an uppercase letter, with no spaces or separators." +
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello World",
},
{
Role: openai.ChatMessageRoleAssistant,
ToolCalls: []openai.ToolCall{
{
ID: "call_cTSjmyVCPkRh870yFvkUrql5",
Type: openai.ToolTypeFunction,
Function: openai.FunctionCall{
Name: "display_cases",
Arguments: `{"PascalCase":"HelloWorld","CamelCase":"helloWorld","KebabCase":"hello-world","SnakeCase":"hello_world"}`,
},
},
},
},
{
Role: openai.ChatMessageRoleTool,
ToolCallID: "call_cTSjmyVCPkRh870yFvkUrql5",
Content: "Here are the naming conventions for 'Hello World':\n" +
"PascalCase: HelloWorld\n" +
"CamelCase: helloWorld\n" +
"KebabCase: hello-world\n" +
"SnakeCase: hello_world\n",
},
},
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: schema,
Strict: true,
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: "display_cases",
Strict: true,
Parameters: &jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": {
Type: jsonschema.String,
},
"CamelCase": {
Type: jsonschema.String,
},
"KebabCase": {
Type: jsonschema.String,
},
"SnakeCase": {
Type: jsonschema.String,
},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
},
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion returned error")
if err == nil {
err = schema.Unmarshal(resp.Choices[0].Message.Content, &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
}
if result.PascalCase != "HelloWorld" {
t.Errorf("PascalCase: expected 'HelloWorld', got '%s'", result.PascalCase)
}
}
11 changes: 11 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
ToolCallID string `json:"tool_call_id,omitempty"`
}(m)
return json.Marshal(msg)
} else if len(m.ToolCalls) > 0 && m.Content == "" {
msg := struct {
Role string `json:"role"`
Content string `json:"-"`
MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}(m)
return json.Marshal(msg)

Check warning on line 127 in chat.go

View check run for this annotation

Codecov / codecov/patch

chat.go#L118-L127

Added lines #L118 - L127 were not covered by tests
}
msg := struct {
Role string `json:"role"`
Expand Down

0 comments on commit 3058f9e

Please sign in to comment.