diff --git a/examples/flows/agent.ts b/examples/flows/agent.ts new file mode 100644 index 00000000..0e73a8c8 --- /dev/null +++ b/examples/flows/agent.ts @@ -0,0 +1,78 @@ +import "dotenv/config"; +import { BeeAgent } from "bee-agent-framework/agents/bee/agent"; +import { BAMChatLLM } from "bee-agent-framework/adapters/bam/chat"; +import { z } from "zod"; +import { BaseMessage } from "bee-agent-framework/llms/primitives/message"; +import { JsonDriver } from "bee-agent-framework/llms/drivers/json"; +import { WikipediaTool } from "bee-agent-framework/tools/search/wikipedia"; +import { OpenMeteoTool } from "bee-agent-framework/tools/weather/openMeteo"; +import { ReadOnlyMemory } from "bee-agent-framework/memory/base"; +import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; +import { Flow } from "bee-agent-framework/flows"; +import { createConsoleReader } from "examples/helpers/io.js"; + +const schema = z.object({ + answer: z.instanceof(BaseMessage).optional(), + memory: z.instanceof(ReadOnlyMemory), +}); + +const workflow = new Flow({ schema: schema }) + .addStep("simpleAgent", async (state) => { + const simpleAgent = new BeeAgent({ + llm: BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"), + tools: [], + memory: state.memory, + }); + const answer = await simpleAgent.run({ prompt: null }); + reader.write("🤖 Simple Agent", answer.result.text); + + return { + update: { answer: answer.result }, + next: "critique", + }; + }) + .addStep("critique", schema.required(), async (state) => { + const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); + const { parsed: critiqueResponse } = await new JsonDriver(llm).generate( + z.object({ score: z.number().int().min(0).max(100) }), + [ + BaseMessage.of({ + role: "system", + text: `You are an evaluation assistant who scores the credibility of the last assistant's response. Chitchatting always has a score of 100. If the assistant was unable to answer the user's query, then the score will be 0.`, + }), + ...state.memory.messages, + state.answer, + ], + ); + reader.write("🧠 Score", critiqueResponse.score.toString()); + + return { + next: critiqueResponse.score < 75 ? "complexAgent" : Flow.END, + }; + }) + .addStep("complexAgent", async (state) => { + const complexAgent = new BeeAgent({ + llm: BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"), + tools: [new WikipediaTool(), new OpenMeteoTool()], + memory: state.memory, + }); + const { result } = await complexAgent.run({ prompt: null }); + reader.write("🤖 Complex Agent", result.text); + return { update: { answer: result } }; + }) + .setStart("simpleAgent"); + +const reader = createConsoleReader(); +const memory = new UnconstrainedMemory(); + +for await (const { prompt } of reader) { + const userMessage = BaseMessage.of({ role: "user", text: prompt }); + await memory.add(userMessage); + + const response = await workflow.run({ + memory: memory.asReadOnly(), + }); + await memory.add(response.state.answer!); + + reader.write("🤖 Final Answer", response.state.answer!.text); +} diff --git a/examples/flows/contentCreator.ts b/examples/flows/contentCreator.ts new file mode 100644 index 00000000..76410f13 --- /dev/null +++ b/examples/flows/contentCreator.ts @@ -0,0 +1,151 @@ +import "dotenv/config.js"; +import { Flow } from "bee-agent-framework/flows"; +import { z } from "zod"; +import { BeeAgent } from "bee-agent-framework/agents/bee/agent"; +import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; +import { BAMChatLLM } from "bee-agent-framework/adapters/bam/chat"; +import { createConsoleReader } from "examples/helpers/io.js"; +import { BaseMessage } from "bee-agent-framework/llms/primitives/message"; +import { DuckDuckGoSearchTool } from "bee-agent-framework/tools/search/duckDuckGoSearch"; +import { JsonDriver } from "bee-agent-framework/llms/drivers/json"; +import { isEmpty, pick } from "remeda"; + +const schema = z.object({ + input: z.string(), + output: z.string().optional(), + + topic: z.string().optional(), + notes: z.array(z.string()).default([]), + plan: z.string().optional(), + draft: z.string().optional(), +}); + +const flow = new Flow({ + schema: schema, + outputSchema: schema.required({ output: true }), +}) + .addStep("preprocess", async (state) => { + const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); + const driver = new JsonDriver(llm); + + const { parsed } = await driver.generate( + schema.pick({ topic: true, notes: true }).or( + z.object({ + error: z.string().describe("Use this field if the user input is not a valid topic."), + }), + ), + [ + BaseMessage.of({ + role: `user`, + text: [ + "Your task is to rewrite the user input so that it guides the content planner and editor to craft a blog post that perfectly aligns with the user's needs. Notes should be used only if the user complains about something.", + "", + ...[!isEmpty(state.notes) && ["# Previous Topic", state.topic, ""]], + ...[!isEmpty(state.notes) && ["# Previous Notes", state.notes.join("\n"), ""]], + "# User Query", + state.input, + ] + .filter(Boolean) + .join("\n"), + }), + ], + ); + + return "error" in parsed + ? { update: { output: parsed.error }, next: Flow.END } + : { update: pick(parsed, ["notes", "topic"]) }; + }) + .addStep("planner", schema.required({ topic: true }), async (state) => { + const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); + const agent = new BeeAgent({ + llm, + memory: new UnconstrainedMemory(), + tools: [new DuckDuckGoSearchTool()], + }); + + const { result } = await agent.run({ + prompt: [ + `You are a Content Planner. Your task is to create a content plan for "${state.topic}" topic.`, + ``, + `# Objectives`, + `1. Prioritize the latest trends, key players, and noteworthy news.`, + `2. Identify the target audience, considering their interests and pain points.`, + `3. Develop a detailed content outline including introduction, key points, and a call to action.`, + `4. Include SEO keywords and relevant sources.`, + ``, + ...[!isEmpty(state.notes) && ["# Notes", state.notes.join("\n"), ""]], + ].join("\n"), + }); + + return { + update: { + plan: result.text, + }, + }; + }) + .addStep("writer", schema.required({ plan: true }), async (state) => { + const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); + const output = await llm.generate([ + BaseMessage.of({ + role: `system`, + text: `You are a Content Writer. Your task is to write a compelling blog post based on the provided Context. + +# Context +${state.plan} + +# Objectives +- An engaging introduction +- Insightful body paragraphs (2-3 per section) +- Properly named sections/subtitles +- A summarizing conclusion + +Ensure the content flows naturally, incorporates SEO keywords, and is well-structured.`, + }), + ]); + + return { + update: { draft: output.getTextContent() }, + }; + }) + .addStep("editor", schema.required({ draft: true }), async (state) => { + const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); + const output = await llm.generate([ + BaseMessage.of({ + role: `system`, + text: `You are an Editor. Your task is to transform the following draft blog post to a final version. + +# Draft +${state.draft} + +# Objectives +- Fix Grammatical errors +- Journalistic best practices + +IMPORTANT: The final version must not contain any editor's comments. +`, + }), + ]); + + return { + update: { output: output.getTextContent() }, + }; + }); + +let lastResult = {} as Flow.output; +const reader = createConsoleReader(); +for await (const { prompt } of reader) { + const { result } = await flow + .run({ + input: prompt, + notes: lastResult?.notes, + topic: lastResult?.topic, + }) + .observe((emitter) => { + emitter.on("start", ({ step, run }) => { + reader.write(`-> ▶️ ${step}`, JSON.stringify(run.state).substring(0, 200).concat("...")); + }); + }); + + lastResult = result; + reader.write("🤖 Answer", lastResult.output); +} diff --git a/examples/flows/simple.ts b/examples/flows/simple.ts new file mode 100644 index 00000000..ca14324d --- /dev/null +++ b/examples/flows/simple.ts @@ -0,0 +1,43 @@ +import { Flow } from "bee-agent-framework/flows"; +import { z } from "zod"; + +const schema = z.object({ + hops: z.number().default(0), +}); + +const flow = new Flow({ schema }) + .addStep("a", async (state) => ({ + update: { hops: state.hops + 1 }, + })) + .addStep("b", async (state) => ({ + update: { hops: state.hops + 1 }, + })) + .addStep("c", async (state) => ({ + update: { hops: state.hops + 1 }, + next: Math.random() > 0.5 ? "a" : Flow.END, + })) + .addStep("d", () => ({})) + .delStep("d"); + +{ + console.info("Example 1"); + const result = await flow.run({ hops: 0 }); + console.info(`-> steps`, result.steps.map((step) => step.name).join(",")); +} + +{ + console.info("Example 2"); + const result = await flow.setStart("c").run({ hops: 10 }); + console.info(`-> steps`, result.steps.map((step) => step.name).join(",")); +} + +{ + // Type utils + const input: Flow.input = {}; + const output: Flow.output = { hops: 10 }; + const response: Flow.run = { + steps: [], + state: { hops: 1 }, + result: { hops: 1 }, + }; +} diff --git a/package.json b/package.json index 341b7202..d7cce65a 100644 --- a/package.json +++ b/package.json @@ -39,7 +39,8 @@ "serializer", "infra", "deps", - "instrumentation" + "instrumentation", + "flows" ] ] } diff --git a/src/context.ts b/src/context.ts index 45966a7b..7b635685 100644 --- a/src/context.ts +++ b/src/context.ts @@ -97,8 +97,8 @@ export class RunContext extends Serializable { return this.controller.signal; } - abort() { - this.controller.abort(); + abort(reason?: Error) { + this.controller.abort(reason); } constructor( diff --git a/src/flows.ts b/src/flows.ts new file mode 100644 index 00000000..7449a4e2 --- /dev/null +++ b/src/flows.ts @@ -0,0 +1,237 @@ +import { FrameworkError } from "@/errors.js"; +import { ZodSchema, z } from "zod"; +import { Serializable } from "@/internals/serializable.js"; +import { Callback, Emitter } from "@/emitter/emitter.js"; +import { RunContext } from "@/context.js"; +import { omit, pick, toCamelCase } from "remeda"; +import { shallowCopy } from "@/serializer/utils.js"; + +export interface FlowStepResponse { + update?: Partial>; + next?: FlowNextStep; +} + +export interface FlowRun { + result: z.output; + steps: FlowStepRes[]; + state: z.output; +} + +export interface FlowRunOptions { + signal?: AbortSignal; +} + +export interface FlowStepDef { + schema: T; + handler: FlowStepHandler; +} + +export interface FlowStepRes { + name: K; + state: z.output; +} + +export interface FlowRunContext { + steps: FlowStepRes[]; + signal: AbortSignal; + abort: (reason?: Error) => void; +} + +export type FlowStepHandler = ( + state: z.output, + context: FlowRunContext, +) => Promise> | FlowStepResponse; + +export class WorkflowError< + T extends ZodSchema, + T2 extends ZodSchema, + K extends string, +> extends FrameworkError { + constructor(message: string, extra?: { run?: FlowRun; errors?: Error[] }) { + super(message, extra?.errors, { + context: extra?.run ?? {}, + isRetryable: false, + isFatal: true, + }); + } +} + +export interface FlowEvents { + start: Callback<{ step: K; run: FlowRun }>; + error: Callback<{ + step: K; + error: Error; + run: FlowRun; + }>; + success: Callback<{ + step: K; + response: FlowStepResponse; + run: FlowRun; + }>; +} + +interface FlowInput { + name?: string; + schema: TS; + outputSchema?: TS2; +} + +type FlowNextStep = K | typeof Flow.END; + +export class Flow< + TInput extends ZodSchema, + TOutput extends ZodSchema = TInput, + TKeys extends string = string, +> extends Serializable { + public static readonly END = "__end__"; + public readonly emitter: Emitter>; + + protected readonly steps = new Map>(); + protected startStep: TKeys | null = null; + + constructor(protected readonly input: FlowInput) { + super(); + this.emitter = Emitter.root.child({ + namespace: ["flow", toCamelCase(input?.name ?? "")].filter(Boolean), + creator: this, + }); + } + + get schemas() { + return pick(this.input, ["schema", "outputSchema"]); + } + + addStep( + name: L, + step: FlowStepHandler, + ): Flow; + addStep( + name: L, + schema: TI2, + step: FlowStepHandler, + ): Flow; + addStep( + name: L, + schemaOrStep: TI2 | FlowStepHandler, + step?: FlowStepHandler, + ): Flow { + if (this.steps.has(name)) { + throw new WorkflowError(`Step '${name}' already exists!`); + } + if (name === Flow.END) { + throw new WorkflowError(`The name '${name}' cannot be used!`); + } + + if (schemaOrStep && step) { + // @ts-expect-error + this.steps.set(name, { handler: step, schema: schemaOrStep }); + } else if (typeof schemaOrStep === "function") { + this.steps.set(name, { handler: schemaOrStep, schema: this.input.schema }); + } else { + throw new WorkflowError( + `Wrong parameters provided. Neither 'schema' nor 'node' were provided.`, + ); + } + + return this as Flow; + } + + setStart(name: TKeys) { + this.startStep = name; + return this; + } + + run(state: z.input, options: FlowRunOptions = {}) { + return RunContext.enter( + this, + { signal: options?.signal, params: [state, options] as const }, + async (runContext): Promise> => { + const run: FlowRun = { + steps: [], + state: this.input.schema.parse(state), + result: undefined as z.output, + }; + const handlers: FlowRunContext = { + steps: run.steps, + signal: runContext.signal, + abort: (reason) => runContext.abort(reason), + }; + + let stepName = this.startStep ?? this.findNextStep(); + while (stepName !== Flow.END) { + const step = this.steps.get(stepName); + if (!step) { + throw new WorkflowError(`Step '${stepName}' was not found.`, { run }); + } + run.steps.push({ name: stepName, state: run.state }); + await runContext.emitter.emit("start", { run, step: stepName }); + try { + const stepInput = await step.schema.parseAsync(run.state).catch((err: Error) => { + throw new WorkflowError( + `Step '${stepName}' cannot be executed because the workflow schema doesn't adhere to the required one.`, + { run: shallowCopy(run), errors: [err] }, + ); + }); + const response = await step.handler(stepInput, handlers); + await runContext.emitter.emit("success", { + run: shallowCopy(run), + response, + step: stepName as TKeys, + }); + if (response.update) { + run.state = { ...run.state, ...response.update }; + } + stepName = response.next || this.findNextStep(stepName); + } catch (error) { + await runContext.emitter.emit("error", { + run: shallowCopy(run), + step: stepName as TKeys, + error, + }); + throw error; + } + } + + run.result = (this.input.outputSchema ?? this.input.schema).parse(run.state); + return run; + }, + ); + } + + delStep(name: L): Flow> { + if (this.startStep === name) { + this.startStep = null; + } + this.steps.delete(name); + return this as unknown as Flow>; + } + + protected findNextStep(start: TKeys | null = null): FlowNextStep { + const keys = Array.from(this.steps.keys()) as TKeys[]; + const curIndex = start ? keys.indexOf(start) : -1; + return keys[curIndex + 1] ?? Flow.END; + } + + createSnapshot() { + return { + input: omit(this.input, ["schema", "outputSchema"]), + emitter: this.emitter, + startStep: this.startStep, + steps: this.steps, + }; + } + + loadSnapshot(snapshot: ReturnType) { + Object.assign(this, snapshot); + this.input.schema ??= z.any() as unknown as TInput; + this.input.outputSchema ??= z.any() as unknown as TOutput; + } +} + +// eslint-disable-next-line @typescript-eslint/no-namespace +export namespace Flow { + export type run = T extends Flow ? FlowRun : never; + export type state = T extends Flow ? z.output : never; + export type input = T extends Flow ? z.input : never; + export type output = T extends Flow ? z.output : never; +}