Skip to content

Commit

Permalink
feat(functions): simplify parsing, read functions as list (#2340)
Browse files Browse the repository at this point in the history
Signed-off-by: mudler <mudler@localai.io>
  • Loading branch information
mudler authored May 18, 2024
1 parent 9ab8f8f commit 02f1b47
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 132 deletions.
147 changes: 52 additions & 95 deletions pkg/functions/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package functions

import (
"encoding/json"
"fmt"
"regexp"
"strings"

Expand Down Expand Up @@ -68,134 +67,92 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC

log.Debug().Msgf("LLM result(processed): %s", llmresult)

multipleResults := functionConfig.ParallelCalls
useGrammars := !functionConfig.NoGrammar

functionNameKey := "function"
if functionConfig.FunctionName {
functionNameKey = "name"
}

results := []FuncCallResults{}

returnResult := func(s string) (name, arguments string, e error) {
returnResult := func(s string) (result []FuncCallResults, e error) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
var ss []map[string]interface{}
result = make([]FuncCallResults, 0)
s = utils.EscapeNewLines(s)
err := json.Unmarshal([]byte(s), &ss)
if err != nil {
log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
}
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss[functionNameKey]
if !ok {
return "", "", fmt.Errorf("unable to find function name in result")
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return "", "", fmt.Errorf("unable to find arguments in result")
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return "", "", fmt.Errorf("unable to cast function name to string")
}

return funcName, string(d), nil
}

// if no grammar is used, we have to extract function and arguments from the result
if !useGrammars {
// the response is a string that we have to parse
result := make(map[string]string)

if functionConfig.ResponseRegex != "" {
// We use named regexes here to extract the function name and arguments
// obviously, this expects the LLM to be stable and return correctly formatted JSON
// TODO: optimize this and pre-compile it
var respRegex = regexp.MustCompile(functionConfig.ResponseRegex)
match := respRegex.FindStringSubmatch(llmresult)
for i, name := range respRegex.SubexpNames() {
if i != 0 && name != "" && len(match) > i {
result[name] = match[i]
}
}

// TODO: open point about multiple results and/or mixed with chat messages
// This is not handled as for now, we only expect one function call per response
functionName := result[functionNameKey]
if functionName == "" {
return results
}
} else if functionConfig.JSONRegexMatch != "" {
//re := regexp.MustCompile(`(?s)<tool_call>(.*?)</tool_call>`)
//m:= re.FindStringSubmatch(`<tool_call>{ foo barr }</tool_call>`)

// We use a regex to extract the JSON object from the response
var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
match := respRegex.FindStringSubmatch(llmresult)
if len(match) < 2 {
return results
}

funcName, args, err := returnResult(match[1])
if err != nil {
return results
}

return append(results, FuncCallResults{Name: funcName, Arguments: args})

} else {

funcName, args, err := returnResult(llmresult)
// If the LLM result is a single object, try unmarshaling it into a single map
var singleObj map[string]interface{}
err = json.Unmarshal([]byte(s), &singleObj)
if err != nil {
return results
log.Warn().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
} else {
ss = []map[string]interface{}{singleObj}
}

return append(results, FuncCallResults{Name: funcName, Arguments: args})
}

return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
}

// with grammars
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
err := json.Unmarshal([]byte(s), &ss)
if err != nil {
log.Warn().Err(err).Str("escapedLLMResult", s).Msg("multiple results: unable to unmarshal llm result")
}
log.Debug().Msgf("Function return: %s %+v", s, ss)

for _, s := range ss {
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := s[functionNameKey]
if !ok {
continue
//return result, fmt.Errorf("unable to find function name in result")
}
args, ok := s["arguments"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
continue
//return result, fmt.Errorf("unable to find arguments in result")
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
//return result, fmt.Errorf("unable to cast function name to string")
}
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})

result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)})
}
} else {
funcName, args, err := returnResult(llmresult)
if err != nil {

return result, nil
}

// the response is a string that we have to parse
result := make(map[string]string)

if functionConfig.ResponseRegex != "" {
// We use named regexes here to extract the function name and arguments
// obviously, this expects the LLM to be stable and return correctly formatted JSON
// TODO: optimize this and pre-compile it
var respRegex = regexp.MustCompile(functionConfig.ResponseRegex)
match := respRegex.FindStringSubmatch(llmresult)
for i, name := range respRegex.SubexpNames() {
if i != 0 && name != "" && len(match) > i {
result[name] = match[i]
}
}

// TODO: open point about multiple results and/or mixed with chat messages
// This is not handled as for now, we only expect one function call per response
functionName := result[functionNameKey]
if functionName == "" {
return results
}
results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
} else if functionConfig.JSONRegexMatch != "" {

results = append(results, FuncCallResults{Name: funcName, Arguments: args})
// We use a regex to extract the JSON object from the response
var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
match := respRegex.FindStringSubmatch(llmresult)
if len(match) < 2 {
return results
}

results, _ = returnResult(match[1])
} else {
results, _ = returnResult(llmresult)
}

return results
Expand Down
41 changes: 4 additions & 37 deletions pkg/functions/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,12 @@ var _ = Describe("LocalAI function parse tests", func() {

BeforeEach(func() {
// Default configuration setup
functionConfig = FunctionsConfig{
ParallelCalls: false,
NoGrammar: false,
ResponseRegex: `(?P<function>\w+)\s*\((?P<arguments>.*)\)`,
}
functionConfig = FunctionsConfig{}
})

Context("when using grammars and single result expected", func() {
It("should parse the function name and arguments correctly", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = false

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expand All @@ -34,7 +28,7 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("when not using grammars and regex is needed", func() {
It("should extract function name and arguments from the regex", func() {
input := `add({"x":5,"y":3})`
functionConfig.NoGrammar = true
functionConfig.ResponseRegex = `(?P<function>\w+)\s*\((?P<arguments>.*)\)`

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expand All @@ -46,33 +40,19 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("when having invalid input", func() {
It("returns no results when there is no input", func() {
input := ""
functionConfig.NoGrammar = true

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))

functionConfig.NoGrammar = false

results = ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
})
It("returns no results when is invalid", func() {
input := "invalid input"
functionConfig.NoGrammar = true

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
functionConfig.NoGrammar = false

results = ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(0))
})
})
Context("when parallel calls are enabled", func() {
It("should handle multiple function calls", func() {
input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]`
functionConfig.ParallelCalls = true
functionConfig.NoGrammar = false

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(2))
Expand All @@ -86,9 +66,6 @@ var _ = Describe("LocalAI function parse tests", func() {
Context("without grammars and without regex", func() {
It("should parse the function name and arguments correctly with the name key", func() {
input := `{"name": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = true

results := ParseFunctionCall(input, functionConfig)
Expand All @@ -99,10 +76,6 @@ var _ = Describe("LocalAI function parse tests", func() {

It("should parse the function name and arguments correctly with the function key", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expand All @@ -115,11 +88,8 @@ var _ = Describe("LocalAI function parse tests", func() {
<tool_call>
{"function": "add", "arguments": {"x": 5, "y": 3}}
</tool_call>`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true

functionConfig.JSONRegexMatch = `(?s)<tool_call>(.*?)</tool_call>`
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expand All @@ -131,11 +101,8 @@ var _ = Describe("LocalAI function parse tests", func() {
input := `
{"function": "add", "arguments": {"x": 5, "y": 3}}
</tool_call>`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true

functionConfig.JSONRegexMatch = `(?s)(.*?)</tool_call>`
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expand Down

0 comments on commit 02f1b47

Please sign in to comment.