diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 86b75601bc45..12a14eace4fb 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -172,6 +172,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup funcs := input.Functions shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() + strictMode := false + + for _, f := range input.Functions { + if f.Strict { + strictMode = true + break + } + } // Allow the user to set custom actions via config file // to be "embedded" in each model @@ -187,10 +195,33 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup if config.ResponseFormatMap != nil { d := schema.ChatCompletionResponseFormat{} - dat, _ := json.Marshal(config.ResponseFormatMap) - _ = json.Unmarshal(dat, &d) + dat, err := json.Marshal(config.ResponseFormatMap) + if err != nil { + return err + } + err = json.Unmarshal(dat, &d) + if err != nil { + return err + } if d.Type == "json_object" { input.Grammar = functions.JSONBNF + } else if d.Type == "json_schema" { + d := schema.JsonSchemaRequest{} + dat, err := json.Marshal(config.ResponseFormatMap) + if err != nil { + return err + } + err = json.Unmarshal(dat, &d) + if err != nil { + return err + } + fs := &functions.JSONFunctionStructure{ + AnyOf: []functions.Item{d.JsonSchema.Schema}, + } + g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...) + if err == nil { + input.Grammar = g + } } } @@ -201,7 +232,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } switch { - case !config.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn: + case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn: noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, diff --git a/core/schema/openai.go b/core/schema/openai.go index 3b39eaf3c7c6..fe4745bfcbd6 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -139,6 +139,17 @@ type ChatCompletionResponseFormat struct { Type ChatCompletionResponseFormatType `json:"type,omitempty"` } +type JsonSchemaRequest struct { + Type string `json:"type"` + JsonSchema JsonSchema `json:"json_schema"` +} + +type JsonSchema struct { + Name string `json:"name"` + Strict bool `json:"strict"` + Schema functions.Item `json:"schema"` +} + type OpenAIRequest struct { PredictionOptions diff --git a/pkg/functions/functions.go b/pkg/functions/functions.go index 19012d53dd99..1a7e1ff1711e 100644 --- a/pkg/functions/functions.go +++ b/pkg/functions/functions.go @@ -14,6 +14,7 @@ const ( type Function struct { Name string `json:"name"` Description string `json:"description"` + Strict bool `json:"strict"` Parameters map[string]interface{} `json:"parameters"` } type Functions []Function