From 58c89d93b737e5221ffb2f4c63f8e1cf508a4c33 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Tue, 3 Sep 2024 14:03:08 +0200 Subject: [PATCH] feat(agent): add support for overriding templates --- src/agents/bee/agent.ts | 5 +++-- src/agents/bee/prompts.ts | 3 +-- src/agents/bee/runner.ts | 18 +++++++++++------- src/agents/bee/types.ts | 15 +++++++++++++++ 4 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/agents/bee/agent.ts b/src/agents/bee/agent.ts index 50a69a33..67535743 100644 --- a/src/agents/bee/agent.ts +++ b/src/agents/bee/agent.ts @@ -20,11 +20,12 @@ import { BaseMemory } from "@/memory/base.js"; import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; import { BaseMessage, Role } from "@/llms/primitives/message.js"; import { AgentMeta } from "@/agents/types.js"; -import { BeeAgentTemplate, BeeAssistantPrompt } from "@/agents/bee/prompts.js"; +import { BeeAssistantPrompt } from "@/agents/bee/prompts.js"; import * as R from "remeda"; import { Emitter } from "@/emitter/emitter.js"; import { BeeAgentRunIteration, + BeeAgentTemplates, BeeCallbacks, BeeMeta, BeeRunInput, @@ -40,8 +41,8 @@ export interface BeeInput { llm: ChatLLM; tools: AnyTool[]; memory: BaseMemory; - promptTemplate?: BeeAgentTemplate; meta?: AgentMeta; + templates?: BeeAgentTemplates; } export class BeeAgent extends BaseAgent { diff --git a/src/agents/bee/prompts.ts b/src/agents/bee/prompts.ts index 89d44b46..f7cfb390 100644 --- a/src/agents/bee/prompts.ts +++ b/src/agents/bee/prompts.ts @@ -16,7 +16,7 @@ import { PromptTemplate } from "@/template.js"; -export const BeeAgentSystemPrompt = new PromptTemplate({ +export const BeeSystemPrompt = new PromptTemplate({ variables: ["instructions", "tools", "tool_names"] as const, defaults: { instructions: "You are a helpful assistant that uses tools to answer questions.", @@ -84,7 +84,6 @@ Responses must always have the following structure: - IMPORTANT: Lines 'Thought', 'Tool Name', 'Tool Caption', 'Tool Input', 'Tool Output' and 'Final Answer' must be sent within a single message. `, }); -export type BeeAgentTemplate = typeof BeeAgentSystemPrompt; export const BeeAssistantPrompt = new PromptTemplate({ variables: ["thought", "toolName", "toolCaption", "toolInput", "toolOutput", "finalAnswer"], diff --git a/src/agents/bee/runner.ts b/src/agents/bee/runner.ts index 6988ab0f..2925a4dd 100644 --- a/src/agents/bee/runner.ts +++ b/src/agents/bee/runner.ts @@ -27,7 +27,7 @@ import { FrameworkError } from "@/errors.js"; import { BeeInput } from "@/agents/bee/agent.js"; import { RetryCounter } from "@/internals/helpers/counter.js"; import { - BeeAgentSystemPrompt, + BeeSystemPrompt, BeeToolErrorPrompt, BeeToolInputErrorPrompt, BeeToolNoResultsPrompt, @@ -86,11 +86,10 @@ export class BeeAgentRunner { }, }, }); - const template = input.promptTemplate ?? BeeAgentSystemPrompt; await memory.addMany([ BaseMessage.of({ role: Role.SYSTEM, - text: template.render({ + text: (input.templates?.system ?? BeeSystemPrompt).render({ tools: await Promise.all( input.tools.map(async (tool) => ({ name: tool.name, @@ -105,7 +104,7 @@ export class BeeAgentRunner { ...input.memory.messages, BaseMessage.of({ role: Role.USER, - text: BeeUserPrompt.clone().render({ + text: (input.templates?.user ?? BeeUserPrompt).render({ input: prompt.trim() ? prompt : "Empty message.", }), }), @@ -259,7 +258,8 @@ export class BeeAgentRunner { }); if (toolOutput.isEmpty()) { - return { output: BeeToolNoResultsPrompt.render({}), success: true }; + const template = this.input.templates?.toolNoResultError ?? BeeToolNoResultsPrompt; + return { output: template.render({}), success: true }; } return { @@ -280,9 +280,11 @@ export class BeeAgentRunner { if (error instanceof ToolInputValidationError) { this.failedAttemptsCounter.use(error); + + const template = this.input.templates?.toolInputError ?? BeeToolInputErrorPrompt; return { success: false, - output: BeeToolInputErrorPrompt.render({ + output: template.render({ reason: error.toString(), }), }; @@ -290,9 +292,11 @@ export class BeeAgentRunner { if (FrameworkError.isRetryable(error)) { this.failedAttemptsCounter.use(error); + + const template = this.input.templates?.toolError ?? BeeToolErrorPrompt; return { success: false, - output: BeeToolErrorPrompt.render({ + output: template.render({ reason: FrameworkError.ensure(error).explain(), }), }; diff --git a/src/agents/bee/types.ts b/src/agents/bee/types.ts index 1a5f66c1..d2009b56 100644 --- a/src/agents/bee/types.ts +++ b/src/agents/bee/types.ts @@ -20,6 +20,13 @@ import { BaseMemory } from "@/memory/base.js"; import { BaseMessage } from "@/llms/primitives/message.js"; import { Callback } from "@/emitter/types.js"; import { AnyTool, BaseToolRunOptions, Tool, ToolError, ToolOutput } from "@/tools/base.js"; +import { + BeeSystemPrompt, + BeeToolErrorPrompt, + BeeToolInputErrorPrompt, + BeeToolNoResultsPrompt, + BeeUserPrompt, +} from "@/agents/bee/prompts.js"; export interface BeeRunInput { prompt: string; @@ -112,3 +119,11 @@ export interface BeeCallbacks { meta: BeeMeta; }>; } + +export interface BeeAgentTemplates { + system: typeof BeeSystemPrompt; + user: typeof BeeUserPrompt; + toolError: typeof BeeToolErrorPrompt; + toolInputError: typeof BeeToolInputErrorPrompt; + toolNoResultError: typeof BeeToolNoResultsPrompt; +}