From 5ac89c15fce25ce78698fd5a9b80437de0af5e88 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Thu, 3 Oct 2024 18:55:29 +0200 Subject: [PATCH] feat(agent): add fallback for invalid llm output Ref: #55 --- src/agents/bee/runner.ts | 92 ++++++++++++++++++-------------- src/agents/parsers/linePrefix.ts | 24 ++++++++- 2 files changed, 74 insertions(+), 42 deletions(-) diff --git a/src/agents/bee/runner.ts b/src/agents/bee/runner.ts index 6e88fbcd..f9233c41 100644 --- a/src/agents/bee/runner.ts +++ b/src/agents/bee/runner.ts @@ -161,47 +161,59 @@ export class BeeAgentRunner { const parserRegex = /Thought:.+\n(?:Final Answer:[\S\s]+|Function Name:.+\nFunction Input:\{.+\}\nFunction Caption:.+\nFunction Output:)?/; - const parser = new LinePrefixParser({ - thought: { - prefix: "Thought:", - next: ["tool_name", "final_answer"], - isStart: true, - field: new ZodParserField(z.string().min(1)), - }, - tool_name: { - prefix: "Function Name:", - next: ["tool_input"], - field: new ZodParserField(z.enum(tools.map((tool) => tool.name) as [string, ...string[]])), - }, - tool_input: { - prefix: "Function Input:", - next: ["tool_caption", "tool_output"], - isEnd: true, - field: new JSONParserField({ - schema: z.object({}).passthrough(), - base: {}, - }), - }, - tool_caption: { - prefix: "Function Caption:", - next: ["tool_output"], - isEnd: true, - field: new ZodParserField(z.string()), - }, - tool_output: { - prefix: "Function Output:", - next: ["final_answer"], - isEnd: true, - field: new ZodParserField(z.string()), - }, - final_answer: { - prefix: "Final Answer:", - next: [], - isStart: true, - isEnd: true, - field: new ZodParserField(z.string().min(1)), + const parser = new LinePrefixParser( + { + thought: { + prefix: "Thought:", + next: ["tool_name", "final_answer"], + isStart: true, + field: new ZodParserField(z.string().min(1)), + }, + tool_name: { + prefix: "Function Name:", + next: ["tool_input"], + field: new ZodParserField( + z.enum(tools.map((tool) => tool.name) as [string, ...string[]]), + ), + }, + tool_input: { + prefix: "Function Input:", + next: ["tool_caption", "tool_output"], + isEnd: true, + field: new JSONParserField({ + schema: z.object({}).passthrough(), + base: {}, + matchPair: ["{", "}"], + }), + }, + tool_caption: { + prefix: "Function Caption:", + next: ["tool_output"], + isEnd: true, + field: new ZodParserField(z.string()), + }, + tool_output: { + prefix: "Function Output:", + next: ["final_answer"], + isEnd: true, + field: new ZodParserField(z.string()), + }, + final_answer: { + prefix: "Final Answer:", + next: [], + isStart: true, + isEnd: true, + field: new ZodParserField(z.string().min(1)), + }, + } as const, + { + fallback: (stash) => + [ + { key: "thought", value: "I now know the final answer." }, + { key: "final_answer", value: stash }, + ] as const, }, - } as const); + ); return { parser, diff --git a/src/agents/parsers/linePrefix.ts b/src/agents/parsers/linePrefix.ts index 4ca2ae0b..2041e9ee 100644 --- a/src/agents/parsers/linePrefix.ts +++ b/src/agents/parsers/linePrefix.ts @@ -100,7 +100,12 @@ export class LinePrefixParser< return this.done; } - constructor(protected readonly nodes: T) { + constructor( + protected readonly nodes: T, + protected readonly options: { + fallback?: (value: string) => readonly { key: StringKey; value: string }[]; + } = {}, + ) { super(); let hasStartNode = false; @@ -219,6 +224,8 @@ export class LinePrefixParser< context: { lines: linesToString(this.lines.concat(extra.line ? [extra.line] : [])), excludedLines: linesToString(this.excludedLines), + finalState: this.finalState, + partialState: this.partialState, }, }, ); @@ -228,6 +235,18 @@ export class LinePrefixParser< if (this.done) { return this.finalState; } + + if (!this.lastNodeKey && this.options.fallback) { + const stash = linesToString(this.excludedLines); + this.lines.length = 0; + this.excludedLines.length = 0; + + const nodes = this.options.fallback(stash); + await this.add( + nodes.map((node) => `${this.nodes[node.key].prefix}${node.value}`).join(NEW_LINE_CHARACTER), + ); + } + this.done = true; if (!this.lastNodeKey) { @@ -285,7 +304,7 @@ export class LinePrefixParser< } catch (e) { if (e instanceof ZodError) { this.throwWithContext( - `Value for ${key} cannot be retrieved because it's value does not adhere to the appropriate schema.`, + `Value for '${key}' cannot be retrieved because it's value does not adhere to the appropriate schema.`, { errors: [e] }, ); } @@ -332,6 +351,7 @@ export class LinePrefixParser< emitter: this.emitter, done: this.done, lastNodeKey: this.lastNodeKey, + options: this.options, }; }