Skip to content

Commit

Permalink
feat(openai): add json_schema format type and strict mode (#3193)
Browse files Browse the repository at this point in the history
* feat(openai): add json_schema and strict mode

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* handle err vs _

security scanners prefer if we put these branches in, and I tend to agree.

Signed-off-by: Dave <dave@gray101.com>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Dave <dave@gray101.com>
Co-authored-by: Dave <dave@gray101.com>
  • Loading branch information
mudler and dave-gray101 authored Aug 7, 2024
1 parent 66cf38b commit e198347
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
37 changes: 34 additions & 3 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
strictMode := false

for _, f := range input.Functions {
if f.Strict {
strictMode = true
break
}
}

// Allow the user to set custom actions via config file
// to be "embedded" in each model
Expand All @@ -187,10 +195,33 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

if config.ResponseFormatMap != nil {
d := schema.ChatCompletionResponseFormat{}
dat, _ := json.Marshal(config.ResponseFormatMap)
_ = json.Unmarshal(dat, &d)
dat, err := json.Marshal(config.ResponseFormatMap)
if err != nil {
return err
}
err = json.Unmarshal(dat, &d)
if err != nil {
return err
}
if d.Type == "json_object" {
input.Grammar = functions.JSONBNF
} else if d.Type == "json_schema" {
d := schema.JsonSchemaRequest{}
dat, err := json.Marshal(config.ResponseFormatMap)
if err != nil {
return err
}
err = json.Unmarshal(dat, &d)
if err != nil {
return err
}
fs := &functions.JSONFunctionStructure{
AnyOf: []functions.Item{d.JsonSchema.Schema},
}
g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil {
input.Grammar = g
}
}
}

Expand All @@ -201,7 +232,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}

switch {
case !config.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn:
case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
noActionGrammar := functions.Function{
Name: noActionName,
Description: noActionDescription,
Expand Down
11 changes: 11 additions & 0 deletions core/schema/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ type ChatCompletionResponseFormat struct {
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
}

type JsonSchemaRequest struct {
Type string `json:"type"`
JsonSchema JsonSchema `json:"json_schema"`
}

type JsonSchema struct {
Name string `json:"name"`
Strict bool `json:"strict"`
Schema functions.Item `json:"schema"`
}

type OpenAIRequest struct {
PredictionOptions

Expand Down
1 change: 1 addition & 0 deletions pkg/functions/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const (
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Strict bool `json:"strict"`
Parameters map[string]interface{} `json:"parameters"`
}
type Functions []Function
Expand Down

0 comments on commit e198347

Please sign in to comment.