Skip to content

Commit

Permalink
feat: update flows interfaces
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Dec 19, 2024
1 parent b52bc08 commit 358ffb2
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/flows/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const workflow = new Flow({ schema: schema })
next: "critique",
};
})
.addStep("critique", schema.required(), async (state) => {
.addStrictStep("critique", schema.required(), async (state) => {
const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct");
const { parsed: critiqueResponse } = await new JsonDriver(llm).generate(
z.object({ score: z.number().int().min(0).max(100) }),
Expand Down
8 changes: 3 additions & 5 deletions examples/flows/contentCreator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ const flow = new Flow({
? { update: { output: parsed.error }, next: Flow.END }
: { update: pick(parsed, ["notes", "topic"]) };
})
.addStep("planner", schema.required({ topic: true }), async (state) => {
.addStrictStep("planner", schema.required({ topic: true }), async (state) => {
const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct");
const agent = new BeeAgent({
llm,
Expand All @@ -86,15 +86,13 @@ const flow = new Flow({
].join("\n"),
});

console.info(result.text);

return {
update: {
plan: result.text,
},
};
})
.addStep("writer", schema.required({ plan: true }), async (state) => {
.addStrictStep("writer", schema.required({ plan: true }), async (state) => {
const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct");
const output = await llm.generate([
BaseMessage.of({
Expand Down Expand Up @@ -122,7 +120,7 @@ const flow = new Flow({
update: { draft: output.getTextContent() },
};
})
.addStep("editor", schema.required({ draft: true }), async (state) => {
.addStrictStep("editor", schema.required({ draft: true }), async (state) => {
const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct");
const output = await llm.generate([
BaseMessage.of({
Expand Down
34 changes: 28 additions & 6 deletions examples/flows/simple.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,43 @@ const schema = z.object({
hops: z.number().default(0),
});

const flow = new Flow({ schema })
const sumFlow = new Flow({ schema })
.addStep("a", async (state) => ({})) // does nothing
.addStep("b", async (state) => ({
// adds one and moves to b
update: { hops: state.hops + 1 },
}))
.addStep("c", async (state) => ({
update: { hops: state.hops + 1 },
.addStep("c", async () => ({
next: Math.random() > 0.5 ? "b" : Flow.END,
}));

const multipleFlow = new Flow({
schema: schema.extend({ multiplier: z.number().int().min(1) }),
})
.addStep("a", async (state) => ({
// adds one and moves to b
update: { hops: state.hops * state.multiplier },
}))
.addStep("b", async () => ({
next: Math.random() > 0.5 ? "a" : Flow.END,
}));

const flow = new Flow({ schema })
.addStep("start", () => ({
next: Math.random() > 0.5 ? "sum" : "multiple",
}))
.addStep(
"multiple",
multipleFlow.asStep({ next: Flow.END, input: ({ hops }) => ({ hops, multiplier: 2 }) }),
)
.addStep("sum", sumFlow);

const response = await flow.run({ hops: 0 }).observe((emitter) => {
emitter.on("start", (data) => console.log(`-> start ${data.step}`));
emitter.on("error", (data) => console.log(`-> error ${data.step}`));
emitter.on("success", (data) => console.log(`-> finish ${data.step}`));
emitter.on("start", (data, event) =>
console.log(`-> step ${data.step}`, event.path, event.trace),
);
//emitter.on("error", (data) => console.log(`-> error ${data.step}`));
//emitter.on("success", (data) => console.log(`-> finish ${data.step}`));
});

console.log(`Hops: ${response.result.hops}`);
Expand Down
80 changes: 56 additions & 24 deletions src/flows.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { ZodSchema, z } from "zod";
import { z, ZodSchema } from "zod";
import { Serializable } from "@/internals/serializable.js";
import { Callback, Emitter } from "@/emitter/emitter.js";
import { RunContext } from "@/context.js";
import { omit, pick, toCamelCase } from "remeda";
import { shallowCopy } from "@/serializer/utils.js";
import { FrameworkError } from "@/errors.js";
import { FrameworkError, ValueError } from "@/errors.js";

export interface FlowStepResponse<T extends ZodSchema, K extends string> {
update?: Partial<z.output<T>>;
Expand All @@ -17,7 +17,8 @@ export interface FlowRun<T extends ZodSchema, T2 extends ZodSchema, K extends st
state: z.output<T>;
}

export interface FlowRunOptions {
export interface FlowRunOptions<K extends string> {
start?: K;
signal?: AbortSignal;
}

Expand Down Expand Up @@ -103,43 +104,52 @@ export class Flow<

addStep<L extends string>(
name: L,
step: FlowStepHandler<TInput, TKeys>,
): Flow<TInput, TOutput, L | TKeys>;
addStep<L extends string, TI2 extends ZodSchema>(
step: FlowStepHandler<TInput, TKeys> | Flow<TInput, TInput, TKeys>,
): Flow<TInput, TOutput, L | TKeys> {
return this._addStep(name, step);
}

addStrictStep<L extends string, TI2 extends ZodSchema>(
name: L,
schema: TI2,
step: FlowStepHandler<TI2, TKeys>,
): Flow<TInput, TOutput, L | TKeys>;
addStep<L extends string, TI2 extends ZodSchema = TInput>(
step: FlowStepHandler<TI2, TKeys> | Flow<TInput, TInput, TKeys>,
): Flow<TInput, TOutput, L | TKeys> {
return this._addStep(name, schema, step);
}

protected _addStep<TI2 extends ZodSchema = TInput, L extends string = TKeys>(
name: L,
schemaOrStep: TI2 | FlowStepHandler<TInput, TKeys>,
step?: FlowStepHandler<TI2, TKeys>,
schemaOrStep: TI2 | FlowStepHandler<TInput, TKeys> | Flow<TInput, TInput, TKeys>,
stepOrEmpty?: FlowStepHandler<TI2, TKeys> | Flow<TInput, TInput, TKeys>,
next?: FlowNextStep<TKeys>,
): Flow<TInput, TOutput, L | TKeys> {
if (!name.trim()) {
throw new ValueError(`Step name cannot be empty!`);
}
if (this.steps.has(name)) {
throw new FlowError(`Step '${name}' already exists!`);
throw new ValueError(`Step '${name}' already exists!`);
}
if (name === Flow.END) {
throw new FlowError(`The name '${name}' cannot be used!`);
throw new ValueError(`The name '${name}' cannot be used!`);
}

if (schemaOrStep && step) {
// @ts-expect-error
this.steps.set(name, { handler: step, schema: schemaOrStep });
} else if (typeof schemaOrStep === "function") {
this.steps.set(name, { handler: schemaOrStep, schema: this.input.schema });
} else {
throw new FlowError(`Wrong parameters provided. Neither 'schema' nor 'node' were provided.`);
}
const schema = (schemaOrStep && stepOrEmpty ? schemaOrStep : this.input.schema) as TInput;
const stepOrFlow = stepOrEmpty || schemaOrStep;

this.steps.set(name, {
schema,
handler: stepOrFlow instanceof Flow ? stepOrFlow.asStep({ next }) : stepOrFlow,
} as FlowStepDef<TInput, TKeys>);

return this as Flow<TInput, TOutput, L | TKeys>;
return this as unknown as Flow<TInput, TOutput, L | TKeys>;
}

setStart(name: TKeys) {
this.startStep = name;
return this;
}

run(state: z.input<TInput>, options: FlowRunOptions = {}) {
run(state: z.input<TInput>, options: FlowRunOptions<TKeys> = {}) {
return RunContext.enter(
this,
{ signal: options?.signal, params: [state, options] as const },
Expand All @@ -155,7 +165,7 @@ export class Flow<
abort: (reason) => runContext.abort(reason),
};

let stepName = this.startStep ?? this.findNextStep();
let stepName = options?.start || this.startStep || this.findNextStep();
while (stepName !== Flow.END) {
const step = this.steps.get(stepName);
if (!step) {
Expand Down Expand Up @@ -204,6 +214,28 @@ export class Flow<
return this as unknown as Flow<TInput, TOutput, Exclude<TKeys, L>>;
}

asStep<
TInput2 extends ZodSchema = TInput,
TOutput2 extends ZodSchema = TOutput,
TKeys2 extends string = TKeys,
>(overrides: {
input?: (input: z.output<TInput2>) => z.output<TInput> | z.input<TInput>;
output?: (output: z.output<TOutput>) => z.output<TOutput2> | z.input<TOutput2>;
start?: TKeys;
next?: FlowNextStep<TKeys2>;
}): FlowStepHandler<TInput2, TKeys | TKeys2> {
return async (input, ctx) => {
const mappedInput = overrides?.input ? overrides.input(input) : input;
const result = await this.run(mappedInput, { start: overrides?.start, signal: ctx.signal });
const mappedOutput = overrides?.output ? overrides.output(result.state) : result.state;

return {
update: mappedOutput,
next: overrides?.next,
};
};
}

protected findNextStep(start: TKeys | null = null): FlowNextStep<TKeys> {
const keys = Array.from(this.steps.keys()) as TKeys[];
const curIndex = start ? keys.indexOf(start) : -1;
Expand Down

0 comments on commit 358ffb2

Please sign in to comment.