Skip to content

Commit

Permalink
Improve handling of JSON Schema in OpenAI API Response Context (#819)
Browse files Browse the repository at this point in the history
* feat: add jsonschema.Validate and jsonschema.Unmarshal

* fix Sanity check

* remove slices.Contains

* fix Sanity check

* add SchemaWrapper

* update api_integration_test.go

* update method 'reflectSchema' to support 'omitempty' in JSON tag

* add GenerateSchemaForType

* update json_test.go

* update `Warp` to `Wrap`

* fix Sanity check

* fix Sanity check

* update api_internal_test.go

* update README.md

* update README.md

* remove jsonschema.SchemaWrapper

* remove jsonschema.SchemaWrapper

* fix Sanity check

* optimize code formatting
  • Loading branch information
eiixy authored Aug 24, 2024
1 parent 5162adb commit a3bd256
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 30 deletions.
64 changes: 64 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,70 @@ func main() {
}
```
</details>

<details>
<summary>Structured Outputs</summary>

```go
package main

import (
"context"
"fmt"
"log"

"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)

func main() {
client := openai.NewClient("your token")
ctx := context.Background()

type Result struct {
Steps []struct {
Explanation string `json:"explanation"`
Output string `json:"output"`
} `json:"steps"`
FinalAnswer string `json:"final_answer"`
}
var result Result
schema, err := jsonschema.GenerateSchemaForType(result)
if err != nil {
log.Fatalf("GenerateSchemaForType error: %v", err)
}
resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "You are a helpful math tutor. Guide the user through the solution step by step.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "how can I solve 8x + 7 = -23",
},
},
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "math_reasoning",
Schema: schema,
Strict: true,
},
},
})
if err != nil {
log.Fatalf("CreateChatCompletion error: %v", err)
}
err = schema.Unmarshal(resp.Choices[0].Message.Content, &result)
if err != nil {
log.Fatalf("Unmarshal schema error: %v", err)
}
fmt.Println(result)
}
```
</details>
See the `examples/` folder for more.

## Frequently Asked Questions
Expand Down
36 changes: 16 additions & 20 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package openai_test

import (
"context"
"encoding/json"
"errors"
"io"
"os"
Expand Down Expand Up @@ -190,6 +189,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
c := openai.NewClient(apiToken)
ctx := context.Background()

type MyStructuredResponse struct {
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"`
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
}
var result MyStructuredResponse
schema, err := jsonschema.GenerateSchemaForType(result)
if err != nil {
t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error")
}
resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Expand All @@ -212,31 +222,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": jsonschema.Definition{Type: jsonschema.String},
"CamelCase": jsonschema.Definition{Type: jsonschema.String},
"KebabCase": jsonschema.Definition{Type: jsonschema.String},
"SnakeCase": jsonschema.Definition{Type: jsonschema.String},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
Name: "cases",
Schema: schema,
Strict: true,
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
}
if err == nil {
err = schema.Unmarshal(resp.Choices[0].Message.Content, &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
}
}

Expand Down
10 changes: 4 additions & 6 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"encoding/json"
"errors"
"net/http"

"github.com/sashabaranov/go-openai/jsonschema"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct {
}

type ChatCompletionResponseFormatJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema jsonschema.Definition `json:"schema"`
Strict bool `json:"strict"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema json.Marshaler `json:"schema"`
Strict bool `json:"strict"`
}

// ChatCompletionRequest represents a request structure for chat completion API.
Expand Down
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() {
}
defer stream.Close()

fmt.Printf("Stream response: ")
fmt.Print("Stream response: ")
for {
var response openai.ChatCompletionStreamResponse
response, err = stream.Recv()
Expand Down
105 changes: 102 additions & 3 deletions jsonschema/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
// and/or pass in the schema in []byte format.
package jsonschema

import "encoding/json"
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
)

type DataType string

Expand Down Expand Up @@ -42,14 +48,107 @@ type Definition struct {
AdditionalProperties any `json:"additionalProperties,omitempty"`
}

func (d Definition) MarshalJSON() ([]byte, error) {
func (d *Definition) MarshalJSON() ([]byte, error) {
if d.Properties == nil {
d.Properties = make(map[string]Definition)
}
type Alias Definition
return json.Marshal(struct {
Alias
}{
Alias: (Alias)(d),
Alias: (Alias)(*d),
})
}

func (d *Definition) Unmarshal(content string, v any) error {
return VerifySchemaAndUnmarshal(*d, []byte(content), v)
}

func GenerateSchemaForType(v any) (*Definition, error) {
return reflectSchema(reflect.TypeOf(v))
}

func reflectSchema(t reflect.Type) (*Definition, error) {
var d Definition
switch t.Kind() {
case reflect.String:
d.Type = String
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
d.Type = Integer
case reflect.Float32, reflect.Float64:
d.Type = Number
case reflect.Bool:
d.Type = Boolean
case reflect.Slice, reflect.Array:
d.Type = Array
items, err := reflectSchema(t.Elem())
if err != nil {
return nil, err
}
d.Items = items
case reflect.Struct:
d.Type = Object
d.AdditionalProperties = false
object, err := reflectSchemaObject(t)
if err != nil {
return nil, err
}
d = *object
case reflect.Ptr:
definition, err := reflectSchema(t.Elem())
if err != nil {
return nil, err
}
d = *definition
case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128,
reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.UnsafePointer:
return nil, fmt.Errorf("unsupported type: %s", t.Kind().String())
default:
}
return &d, nil
}

func reflectSchemaObject(t reflect.Type) (*Definition, error) {
var d = Definition{
Type: Object,
AdditionalProperties: false,
}
properties := make(map[string]Definition)
var requiredFields []string
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if !field.IsExported() {
continue
}
jsonTag := field.Tag.Get("json")
var required = true
if jsonTag == "" {
jsonTag = field.Name
} else if strings.HasSuffix(jsonTag, ",omitempty") {
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
required = false
}

item, err := reflectSchema(field.Type)
if err != nil {
return nil, err
}
description := field.Tag.Get("description")
if description != "" {
item.Description = description
}
properties[jsonTag] = *item

if s := field.Tag.Get("required"); s != "" {
required, _ = strconv.ParseBool(s)
}
if required {
requiredFields = append(requiredFields, jsonTag)
}
}
d.Required = requiredFields
d.Properties = properties
return &d, nil
}
89 changes: 89 additions & 0 deletions jsonschema/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package jsonschema

import (
"encoding/json"
"errors"
)

func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error {
var data any
err := json.Unmarshal(content, &data)
if err != nil {
return err
}
if !Validate(schema, data) {
return errors.New("data validation failed against the provided schema")
}
return json.Unmarshal(content, &v)
}

func Validate(schema Definition, data any) bool {
switch schema.Type {
case Object:
return validateObject(schema, data)
case Array:
return validateArray(schema, data)
case String:
_, ok := data.(string)
return ok
case Number: // float64 and int
_, ok := data.(float64)
if !ok {
_, ok = data.(int)
}
return ok
case Boolean:
_, ok := data.(bool)
return ok
case Integer:
_, ok := data.(int)
return ok
case Null:
return data == nil
default:
return false
}
}

func validateObject(schema Definition, data any) bool {
dataMap, ok := data.(map[string]any)
if !ok {
return false
}
for _, field := range schema.Required {
if _, exists := dataMap[field]; !exists {
return false
}
}
for key, valueSchema := range schema.Properties {
value, exists := dataMap[key]
if exists && !Validate(valueSchema, value) {
return false
} else if !exists && contains(schema.Required, key) {
return false
}
}
return true
}

func validateArray(schema Definition, data any) bool {
dataArray, ok := data.([]any)
if !ok {
return false
}
for _, item := range dataArray {
if !Validate(*schema.Items, item) {
return false
}
}
return true
}

func contains[S ~[]E, E comparable](s S, v E) bool {
for i := range s {
if v == s[i] {
return true
}
}
return false
}
Loading

0 comments on commit a3bd256

Please sign in to comment.