Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 'enum' tag to JSON schema generation #912

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions jsonschema/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
required = false
}

enumTag := field.Tag.Get("enum")
var enumValues []string
if enumTag != "" {
enumValues = strings.Split(enumTag, ",")
}

item, err := reflectSchema(field.Type)
if err != nil {
return nil, err
Expand All @@ -139,6 +145,11 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
if description != "" {
item.Description = description
}

if len(enumValues) > 0 {
item.Enum = enumValues
}

properties[jsonTag] = *item

if s := field.Tag.Get("required"); s != "" {
Expand Down
268 changes: 247 additions & 21 deletions jsonschema/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,232 @@ import (
"github.com/sashabaranov/go-openai/jsonschema"
)

func TestDefinition_GenerateSchemaForType(t *testing.T) {
type MenuItem struct {
ItemName string `json:"item_name" description:"Menu item name" required:"true"`
Quantity int `json:"quantity" description:"Quantity of menu item ordered" required:"true"`
Price int `json:"price" description:"Price of the menu item" required:"true"`
}

type UserOrder struct {
MenuItems []MenuItem `json:"menu_items" description:"List of menu items ordered by the user" required:"true"`
DeliveryAddress string `json:"delivery_address" description:"Delivery address for the order" required:"true"`
UserName string `json:"user_name" description:"Name of the user placing the order" required:"true"`
PhoneNumber string `json:"phone_number" description:"Phone number of the user" required:"true"`
PaymentMethod string `json:"payment_method" description:"Payment method" required:"true" enum:"cash,transfer"`
}

tests := []struct {
name string
input any
want string
wantErr bool
}{
{
name: "Test MenuItem Schema",
input: MenuItem{},
want: `{
"type":"object",
"additionalProperties":false,
"properties":{
"item_name":{
"type":"string",
"description":"Menu item name"
},
"quantity":{
"type":"integer",
"description":"Quantity of menu item ordered"
},
"price":{
"type":"integer",
"description":"Price of the menu item"
}
},
"required":[
"item_name",
"quantity",
"price"
]
}`,
},
{
name: "Test UserOrder Schema",
input: UserOrder{},
want: `{
"type":"object",
"additionalProperties":false,
"properties":{
"menu_items":{
"type":"array",
"description":"List of menu items ordered by the user",
"items":{
"type":"object",
"additionalProperties":false,
"properties":{
"item_name":{
"type":"string",
"description":"Menu item name"
},
"quantity":{
"type":"integer",
"description":"Quantity of menu item ordered"
},
"price":{
"type":"integer",
"description":"Price of the menu item"
}
},
"required":[
"item_name",
"quantity",
"price"
]
}
},
"delivery_address":{
"type":"string",
"description":"Delivery address for the order"
},
"user_name":{
"type":"string",
"description":"Name of the user placing the order"
},
"phone_number":{
"type":"string",
"description":"Phone number of the user"
},
"payment_method":{
"type":"string",
"description":"Payment method",
"enum":[
"cash",
"transfer"
]
}
},
"required":[
"menu_items",
"delivery_address",
"user_name",
"phone_number",
"payment_method"
]
}`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate schema
got, err := jsonschema.GenerateSchemaForType(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateSchemaForType() error = %v, wantErr %v", err, tt.wantErr)
return
}

// Convert both the generated schema and the expected JSON to maps for comparison
wantBytes := []byte(tt.want)
var want map[string]interface{}
err = json.Unmarshal(wantBytes, &want)
if err != nil {
t.Errorf("Failed to Unmarshal expected JSON: error = %v", err)
return
}

gotMap := structToMap(t, got)

// Compare the maps
if !reflect.DeepEqual(gotMap, want) {
t.Errorf("GenerateSchemaForType() got = %v, want %v", gotMap, want)
}
})
}
}

func TestDefinition_SchemaGenerationComparison(t *testing.T) {
type MenuItem struct {
ItemName string `json:"item_name" description:"Menu item name" required:"true"`
Quantity int `json:"quantity" description:"Quantity of menu item ordered" required:"true"`
Price float64 `json:"price" description:"Price of the menu item" required:"true"`
}

type UserOrder struct {
MenuItems []MenuItem `json:"menu_items" description:"List of menu items ordered by the user" required:"true"`
DeliveryAddress string `json:"delivery_address" description:"Delivery address for the order" required:"true"`
UserName string `json:"user_name" description:"Name of the user placing the order" required:"true"`
PhoneNumber string `json:"phone_number" description:"Phone number of the user" required:"true"`
PaymentMethod string `json:"payment_method" description:"Payment method" required:"true" enum:"cash,transfer"`
}

// Manually created schema to compare against struct-generated schema
manualSchema := &jsonschema.Definition{
Type: jsonschema.Object,
AdditionalProperties: false,
Properties: map[string]jsonschema.Definition{
"menu_items": {
Type: jsonschema.Array,
Description: "List of menu items ordered by the user",
Items: &jsonschema.Definition{
Type: jsonschema.Object,
AdditionalProperties: false,
Properties: map[string]jsonschema.Definition{
"item_name": {
Type: jsonschema.String,
Description: "Menu item name",
},
"quantity": {
Type: jsonschema.Integer,
Description: "Quantity of menu item ordered",
},
"price": {
Type: jsonschema.Number,
Description: "Price of the menu item",
},
},
Required: []string{"item_name", "quantity", "price"},
},
},
"delivery_address": {
Type: jsonschema.String,
Description: "Delivery address for the order",
},
"user_name": {
Type: jsonschema.String,
Description: "Name of the user placing the order",
},
"phone_number": {
Type: jsonschema.String,
Description: "Phone number of the user",
},
"payment_method": {
Type: jsonschema.String,
Description: "Payment method",
Enum: []string{"cash", "transfer"},
},
},
Required: []string{"menu_items", "delivery_address", "user_name", "phone_number", "payment_method"},
}

t.Run("Compare Struct-Generated and Manual Schema", func(t *testing.T) {
// Generate schema from struct
structSchema, err := jsonschema.GenerateSchemaForType(UserOrder{})
if err != nil {
t.Fatalf("Failed to generate schema from struct: %v", err)
}

// Convert both schemas to maps for comparison
structMap := structToMap(t, structSchema)
manualMap := structToMap(t, manualSchema)

// Compare the maps
if !reflect.DeepEqual(structMap, manualMap) {
t.Errorf("Schema generated from struct and manual schema do not match")
t.Errorf("Struct generated schema: %v", structMap)
t.Errorf("Manual schema: %v", manualMap)
}
})
}

func TestDefinition_MarshalJSON(t *testing.T) {
tests := []struct {
name string
Expand All @@ -17,7 +243,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
{
name: "Test with empty Definition",
def: jsonschema.Definition{},
want: `{"properties":{}}`,
want: `{}`,
},
{
name: "Test with Definition properties set",
Expand All @@ -35,8 +261,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"description":"A string type",
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
}
}
}`,
Expand Down Expand Up @@ -66,12 +291,10 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object",
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
},
"age":{
"type":"integer",
"properties":{}
"type":"integer"
}
}
}
Expand Down Expand Up @@ -114,23 +337,19 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object",
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
},
"age":{
"type":"integer",
"properties":{}
"type":"integer"
},
"address":{
"type":"object",
"properties":{
"city":{
"type":"string",
"properties":{}
"type":"string"
},
"country":{
"type":"string",
"properties":{}
"type":"string"
}
}
}
Expand All @@ -155,19 +374,26 @@ func TestDefinition_MarshalJSON(t *testing.T) {
want: `{
"type":"array",
"items":{
"type":"string",
"properties":{

}
"type":"string"
},
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
}
}
}`,
},
{
name: "Test with Enum type Definition",
def: jsonschema.Definition{
Type: jsonschema.String,
Enum: []string{"celsius", "fahrenheit"},
},
want: `{
"type":"string",
"enum":["celsius","fahrenheit"]
}`,
},
}

for _, tt := range tests {
Expand Down
Loading