Skip to content

Commit

Permalink
feat(agent): add support for overriding templates
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Sep 3, 2024
1 parent 0e25450 commit 58c89d9
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/agents/bee/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import { BaseMemory } from "@/memory/base.js";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { AgentMeta } from "@/agents/types.js";
import { BeeAgentTemplate, BeeAssistantPrompt } from "@/agents/bee/prompts.js";
import { BeeAssistantPrompt } from "@/agents/bee/prompts.js";
import * as R from "remeda";
import { Emitter } from "@/emitter/emitter.js";
import {
BeeAgentRunIteration,
BeeAgentTemplates,
BeeCallbacks,
BeeMeta,
BeeRunInput,
Expand All @@ -40,8 +41,8 @@ export interface BeeInput {
llm: ChatLLM<ChatLLMOutput>;
tools: AnyTool[];
memory: BaseMemory;
promptTemplate?: BeeAgentTemplate;
meta?: AgentMeta;
templates?: BeeAgentTemplates;
}

export class BeeAgent extends BaseAgent<BeeRunInput, BeeRunOutput, BeeRunOptions> {
Expand Down
3 changes: 1 addition & 2 deletions src/agents/bee/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import { PromptTemplate } from "@/template.js";

export const BeeAgentSystemPrompt = new PromptTemplate({
export const BeeSystemPrompt = new PromptTemplate({
variables: ["instructions", "tools", "tool_names"] as const,
defaults: {
instructions: "You are a helpful assistant that uses tools to answer questions.",
Expand Down Expand Up @@ -84,7 +84,6 @@ Responses must always have the following structure:
- IMPORTANT: Lines 'Thought', 'Tool Name', 'Tool Caption', 'Tool Input', 'Tool Output' and 'Final Answer' must be sent within a single message.
`,
});
export type BeeAgentTemplate = typeof BeeAgentSystemPrompt;

export const BeeAssistantPrompt = new PromptTemplate({
variables: ["thought", "toolName", "toolCaption", "toolInput", "toolOutput", "finalAnswer"],
Expand Down
18 changes: 11 additions & 7 deletions src/agents/bee/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import { FrameworkError } from "@/errors.js";
import { BeeInput } from "@/agents/bee/agent.js";
import { RetryCounter } from "@/internals/helpers/counter.js";
import {
BeeAgentSystemPrompt,
BeeSystemPrompt,
BeeToolErrorPrompt,
BeeToolInputErrorPrompt,
BeeToolNoResultsPrompt,
Expand Down Expand Up @@ -86,11 +86,10 @@ export class BeeAgentRunner {
},
},
});
const template = input.promptTemplate ?? BeeAgentSystemPrompt;
await memory.addMany([
BaseMessage.of({
role: Role.SYSTEM,
text: template.render({
text: (input.templates?.system ?? BeeSystemPrompt).render({
tools: await Promise.all(
input.tools.map(async (tool) => ({
name: tool.name,
Expand All @@ -105,7 +104,7 @@ export class BeeAgentRunner {
...input.memory.messages,
BaseMessage.of({
role: Role.USER,
text: BeeUserPrompt.clone().render({
text: (input.templates?.user ?? BeeUserPrompt).render({
input: prompt.trim() ? prompt : "Empty message.",
}),
}),
Expand Down Expand Up @@ -259,7 +258,8 @@ export class BeeAgentRunner {
});

if (toolOutput.isEmpty()) {
return { output: BeeToolNoResultsPrompt.render({}), success: true };
const template = this.input.templates?.toolNoResultError ?? BeeToolNoResultsPrompt;
return { output: template.render({}), success: true };
}

return {
Expand All @@ -280,19 +280,23 @@ export class BeeAgentRunner {

if (error instanceof ToolInputValidationError) {
this.failedAttemptsCounter.use(error);

const template = this.input.templates?.toolInputError ?? BeeToolInputErrorPrompt;
return {
success: false,
output: BeeToolInputErrorPrompt.render({
output: template.render({
reason: error.toString(),
}),
};
}

if (FrameworkError.isRetryable(error)) {
this.failedAttemptsCounter.use(error);

const template = this.input.templates?.toolError ?? BeeToolErrorPrompt;
return {
success: false,
output: BeeToolErrorPrompt.render({
output: template.render({
reason: FrameworkError.ensure(error).explain(),
}),
};
Expand Down
15 changes: 15 additions & 0 deletions src/agents/bee/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ import { BaseMemory } from "@/memory/base.js";
import { BaseMessage } from "@/llms/primitives/message.js";
import { Callback } from "@/emitter/types.js";
import { AnyTool, BaseToolRunOptions, Tool, ToolError, ToolOutput } from "@/tools/base.js";
import {
BeeSystemPrompt,
BeeToolErrorPrompt,
BeeToolInputErrorPrompt,
BeeToolNoResultsPrompt,
BeeUserPrompt,
} from "@/agents/bee/prompts.js";

export interface BeeRunInput {
prompt: string;
Expand Down Expand Up @@ -112,3 +119,11 @@ export interface BeeCallbacks {
meta: BeeMeta;
}>;
}

export interface BeeAgentTemplates {
system: typeof BeeSystemPrompt;
user: typeof BeeUserPrompt;
toolError: typeof BeeToolErrorPrompt;
toolInputError: typeof BeeToolInputErrorPrompt;
toolNoResultError: typeof BeeToolNoResultsPrompt;
}

0 comments on commit 58c89d9

Please sign in to comment.