diff --git a/examples/flows/agent.ts b/examples/flows/agent.ts index 0e73a8c..3cdcba3 100644 --- a/examples/flows/agent.ts +++ b/examples/flows/agent.ts @@ -31,7 +31,7 @@ const workflow = new Flow({ schema: schema }) next: "critique", }; }) - .addStep("critique", schema.required(), async (state) => { + .addStrictStep("critique", schema.required(), async (state) => { const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); const { parsed: critiqueResponse } = await new JsonDriver(llm).generate( z.object({ score: z.number().int().min(0).max(100) }), diff --git a/examples/flows/contentCreator.ts b/examples/flows/contentCreator.ts index 8919e4b..b31b87e 100644 --- a/examples/flows/contentCreator.ts +++ b/examples/flows/contentCreator.ts @@ -59,7 +59,7 @@ const flow = new Flow({ ? { update: { output: parsed.error }, next: Flow.END } : { update: pick(parsed, ["notes", "topic"]) }; }) - .addStep("planner", schema.required({ topic: true }), async (state) => { + .addStrictStep("planner", schema.required({ topic: true }), async (state) => { const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); const agent = new BeeAgent({ llm, @@ -86,15 +86,13 @@ const flow = new Flow({ ].join("\n"), }); - console.info(result.text); - return { update: { plan: result.text, }, }; }) - .addStep("writer", schema.required({ plan: true }), async (state) => { + .addStrictStep("writer", schema.required({ plan: true }), async (state) => { const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); const output = await llm.generate([ BaseMessage.of({ @@ -122,7 +120,7 @@ const flow = new Flow({ update: { draft: output.getTextContent() }, }; }) - .addStep("editor", schema.required({ draft: true }), async (state) => { + .addStrictStep("editor", schema.required({ draft: true }), async (state) => { const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); const output = await llm.generate([ BaseMessage.of({ diff --git a/examples/flows/simple.ts b/examples/flows/simple.ts index ff2daf9..2318b16 100644 --- a/examples/flows/simple.ts +++ b/examples/flows/simple.ts @@ -5,21 +5,43 @@ const schema = z.object({ hops: z.number().default(0), }); -const flow = new Flow({ schema }) +const sumFlow = new Flow({ schema }) .addStep("a", async (state) => ({})) // does nothing .addStep("b", async (state) => ({ // adds one and moves to b update: { hops: state.hops + 1 }, })) - .addStep("c", async (state) => ({ - update: { hops: state.hops + 1 }, + .addStep("c", async () => ({ next: Math.random() > 0.5 ? "b" : Flow.END, })); +const multipleFlow = new Flow({ + schema: schema.extend({ multiplier: z.number().int().min(1) }), +}) + .addStep("a", async (state) => ({ + // adds one and moves to b + update: { hops: state.hops * state.multiplier }, + })) + .addStep("b", async () => ({ + next: Math.random() > 0.5 ? "a" : Flow.END, + })); + +const flow = new Flow({ schema }) + .addStep("start", () => ({ + next: Math.random() > 0.5 ? "sum" : "multiple", + })) + .addStep( + "multiple", + multipleFlow.asStep({ next: Flow.END, input: ({ hops }) => ({ hops, multiplier: 2 }) }), + ) + .addStep("sum", sumFlow); + const response = await flow.run({ hops: 0 }).observe((emitter) => { - emitter.on("start", (data) => console.log(`-> start ${data.step}`)); - emitter.on("error", (data) => console.log(`-> error ${data.step}`)); - emitter.on("success", (data) => console.log(`-> finish ${data.step}`)); + emitter.on("start", (data, event) => + console.log(`-> step ${data.step}`, event.path, event.trace), + ); + //emitter.on("error", (data) => console.log(`-> error ${data.step}`)); + //emitter.on("success", (data) => console.log(`-> finish ${data.step}`)); }); console.log(`Hops: ${response.result.hops}`); diff --git a/src/flows.ts b/src/flows.ts index 4dccbe1..a4edf02 100644 --- a/src/flows.ts +++ b/src/flows.ts @@ -1,10 +1,10 @@ -import { ZodSchema, z } from "zod"; +import { z, ZodSchema } from "zod"; import { Serializable } from "@/internals/serializable.js"; import { Callback, Emitter } from "@/emitter/emitter.js"; import { RunContext } from "@/context.js"; import { omit, pick, toCamelCase } from "remeda"; import { shallowCopy } from "@/serializer/utils.js"; -import { FrameworkError } from "@/errors.js"; +import { FrameworkError, ValueError } from "@/errors.js"; export interface FlowStepResponse { update?: Partial>; @@ -17,7 +17,8 @@ export interface FlowRun; } -export interface FlowRunOptions { +export interface FlowRunOptions { + start?: K; signal?: AbortSignal; } @@ -103,35 +104,44 @@ export class Flow< addStep( name: L, - step: FlowStepHandler, - ): Flow; - addStep( + step: FlowStepHandler | Flow, + ): Flow { + return this._addStep(name, step); + } + + addStrictStep( name: L, schema: TI2, - step: FlowStepHandler, - ): Flow; - addStep( + step: FlowStepHandler | Flow, + ): Flow { + return this._addStep(name, schema, step); + } + + protected _addStep( name: L, - schemaOrStep: TI2 | FlowStepHandler, - step?: FlowStepHandler, + schemaOrStep: TI2 | FlowStepHandler | Flow, + stepOrEmpty?: FlowStepHandler | Flow, + next?: FlowNextStep, ): Flow { + if (!name.trim()) { + throw new ValueError(`Step name cannot be empty!`); + } if (this.steps.has(name)) { - throw new FlowError(`Step '${name}' already exists!`); + throw new ValueError(`Step '${name}' already exists!`); } if (name === Flow.END) { - throw new FlowError(`The name '${name}' cannot be used!`); + throw new ValueError(`The name '${name}' cannot be used!`); } - if (schemaOrStep && step) { - // @ts-expect-error - this.steps.set(name, { handler: step, schema: schemaOrStep }); - } else if (typeof schemaOrStep === "function") { - this.steps.set(name, { handler: schemaOrStep, schema: this.input.schema }); - } else { - throw new FlowError(`Wrong parameters provided. Neither 'schema' nor 'node' were provided.`); - } + const schema = (schemaOrStep && stepOrEmpty ? schemaOrStep : this.input.schema) as TInput; + const stepOrFlow = stepOrEmpty || schemaOrStep; + + this.steps.set(name, { + schema, + handler: stepOrFlow instanceof Flow ? stepOrFlow.asStep({ next }) : stepOrFlow, + } as FlowStepDef); - return this as Flow; + return this as unknown as Flow; } setStart(name: TKeys) { @@ -139,7 +149,7 @@ export class Flow< return this; } - run(state: z.input, options: FlowRunOptions = {}) { + run(state: z.input, options: FlowRunOptions = {}) { return RunContext.enter( this, { signal: options?.signal, params: [state, options] as const }, @@ -155,7 +165,7 @@ export class Flow< abort: (reason) => runContext.abort(reason), }; - let stepName = this.startStep ?? this.findNextStep(); + let stepName = options?.start || this.startStep || this.findNextStep(); while (stepName !== Flow.END) { const step = this.steps.get(stepName); if (!step) { @@ -204,6 +214,28 @@ export class Flow< return this as unknown as Flow>; } + asStep< + TInput2 extends ZodSchema = TInput, + TOutput2 extends ZodSchema = TOutput, + TKeys2 extends string = TKeys, + >(overrides: { + input?: (input: z.output) => z.output | z.input; + output?: (output: z.output) => z.output | z.input; + start?: TKeys; + next?: FlowNextStep; + }): FlowStepHandler { + return async (input, ctx) => { + const mappedInput = overrides?.input ? overrides.input(input) : input; + const result = await this.run(mappedInput, { start: overrides?.start, signal: ctx.signal }); + const mappedOutput = overrides?.output ? overrides.output(result.state) : result.state; + + return { + update: mappedOutput, + next: overrides?.next, + }; + }; + } + protected findNextStep(start: TKeys | null = null): FlowNextStep { const keys = Array.from(this.steps.keys()) as TKeys[]; const curIndex = start ? keys.indexOf(start) : -1;