Skip to content

Commit

Permalink
add SchemaWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
eiixy committed Aug 8, 2024
1 parent 051a17f commit 8b49a36
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 28 deletions.
35 changes: 14 additions & 21 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package openai_test

import (
"context"
"encoding/json"
"errors"
"io"
"os"
Expand Down Expand Up @@ -190,6 +189,14 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
c := openai.NewClient(apiToken)
ctx := context.Background()

type MyStructuredResponse struct {
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"`
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
Keywords []string `json:"keywords" description:"Keywords" required:"true"`
}
schema := jsonschema.Warp(MyStructuredResponse{})
resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Expand All @@ -211,31 +218,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
},
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": jsonschema.Definition{Type: jsonschema.String},
"CamelCase": jsonschema.Definition{Type: jsonschema.String},
"KebabCase": jsonschema.Definition{Type: jsonschema.String},
"SnakeCase": jsonschema.Definition{Type: jsonschema.String},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: schema,
Strict: true,
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
}
if err == nil {
_, err = schema.Unmarshal(resp.Choices[0].Message.Content)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
}
}
10 changes: 4 additions & 6 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"encoding/json"
"errors"
"net/http"

"github.com/sashabaranov/go-openai/jsonschema"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct {
}

type ChatCompletionResponseFormatJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema jsonschema.Definition `json:"schema"`
Strict bool `json:"strict"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema json.Marshaler `json:"schema"`
Strict bool `json:"strict"`
}

// ChatCompletionRequest represents a request structure for chat completion API.
Expand Down
87 changes: 86 additions & 1 deletion jsonschema/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
// and/or pass in the schema in []byte format.
package jsonschema

import "encoding/json"
import (
"encoding/json"
"reflect"
"strconv"
)

type DataType string

Expand Down Expand Up @@ -53,3 +57,84 @@ func (d Definition) MarshalJSON() ([]byte, error) {
Alias: (Alias)(d),
})
}

type SchemaWrapper[T any] struct {
data T
schema Definition
}

func (r SchemaWrapper[T]) Schema() Definition {
return r.schema
}

func (r SchemaWrapper[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(r.schema)
}

func (r SchemaWrapper[T]) Unmarshal(content string) (*T, error) {
var v T
err := Unmarshal(r.schema, []byte(content), &v)
if err != nil {
return nil, err
}
return &v, nil
}

func (r SchemaWrapper[T]) String() string {
bytes, _ := json.MarshalIndent(r.schema, "", " ")
return string(bytes)
}

func Warp[T any](v T) SchemaWrapper[T] {
return SchemaWrapper[T]{
data: v,
schema: reflectSchema(reflect.TypeOf(v)),
}
}

func reflectSchema(t reflect.Type) Definition {
var d Definition
switch t.Kind() {

Check failure on line 97 in jsonschema/json.go

View workflow job for this annotation

GitHub Actions / Sanity check

missing cases in switch of type reflect.Kind: reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer|reflect.Ptr, reflect.UnsafePointer (exhaustive)
case reflect.String:
d.Type = String
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
d.Type = Integer
case reflect.Float32, reflect.Float64:
d.Type = Number
case reflect.Bool:
d.Type = Boolean
case reflect.Slice, reflect.Array:
d.Type = Array
items := reflectSchema(t.Elem())
d.Items = &items
case reflect.Struct:
d.Type = Object
d.AdditionalProperties = false
properties := make(map[string]Definition)
var requiredFields []string
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
jsonTag := field.Tag.Get("json")
if jsonTag == "" {
jsonTag = field.Name
}

item := reflectSchema(field.Type)
description := field.Tag.Get("description")
if description != "" {
item.Description = description
}
properties[jsonTag] = item

required, _ := strconv.ParseBool(field.Tag.Get("required"))
if required {
requiredFields = append(requiredFields, jsonTag)
}
}
d.Required = requiredFields
d.Properties = properties
default:
}
return d
}
67 changes: 67 additions & 0 deletions jsonschema/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,70 @@ func structToMap(t *testing.T, v any) map[string]any {
}
return got
}

type MyStructuredResponse struct {
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
KebabCase string `json:"kebab_case" required:"false" description:"KebabCase"`
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
Keywords []string `json:"keywords" description:"Keywords" required:"true"`
}

func TestWarp(t *testing.T) {
schemaStr := `{
"type": "object",
"properties": {
"camel_case": {
"type": "string",
"description": "CamelCase"
},
"kebab_case": {
"type": "string",
"description": "KebabCase"
},
"keywords": {
"type": "array",
"description": "Keywords",
"items": {
"type": "string"
}
},
"pascal_case": {
"type": "string",
"description": "PascalCase"
},
"snake_case": {
"type": "string",
"description": "SnakeCase"
}
},
"required": [
"pascal_case",
"camel_case",
"snake_case",
"keywords"
]
}`
schema := jsonschema.Warp(MyStructuredResponse{})
if schema.String() == schemaStr {
t.Errorf("Failed to Generate JSONSchema: schema = %s", schema)
}
}

func TestSchemaWrapper_Unmarshal(t *testing.T) {
schema := jsonschema.Warp(MyStructuredResponse{})
result, err := schema.Unmarshal(`{"pascal_case":"a","camel_case":"b","snake_case":"c","keywords":[]}`)
if err != nil {
t.Errorf("Failed to SchemaWrapper Unmarshal: error = %v", err)
} else {
var v = MyStructuredResponse{
PascalCase: "a",
CamelCase: "b",
SnakeCase: "c",
Keywords: []string{},
}
if !reflect.DeepEqual(*result, v) {
t.Errorf("Failed to SchemaWrapper Unmarshal: result = %v", *result)
}
}
}

0 comments on commit 8b49a36

Please sign in to comment.