Skip to content

Commit

Permalink
feat: implement structured generators
Browse files Browse the repository at this point in the history
  • Loading branch information
JanPokorny committed Sep 11, 2024
1 parent d819562 commit ecbb1e2
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 134 deletions.
4 changes: 3 additions & 1 deletion examples/llms/structured.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ import "dotenv/config.js";
import { z } from "zod";
import { BaseMessage, Role } from "bee-agent-framework/llms/primitives/message";
import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat";
import { JsonDriver } from "bee-agent-framework/drivers/json";

const llm = new OllamaChatLLM();
const response = await llm.generateStructured(
const driver = new JsonDriver(llm);
const response = await driver.generate(
z.union([
z.object({
firstName: z.string().min(1),
Expand Down
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
"fast-xml-parser": "^4.4.1",
"header-generator": "^2.1.54",
"joplin-turndown-plugin-gfm": "^1.0.12",
"js-yaml": "^4.1.0",
"json-schema-to-typescript": "^15.0.2",
"mathjs": "^13.1.1",
"mustache": "^4.2.0",
"object-hash": "^3.0.0",
Expand Down Expand Up @@ -129,6 +131,7 @@
"@types/eslint": "^9.6.1",
"@types/eslint-config-prettier": "^6.11.3",
"@types/eslint__js": "^8.42.3",
"@types/js-yaml": "^4.0.9",
"@types/mustache": "^4",
"@types/needle": "^3.3.0",
"@types/node": "^20.16.1",
Expand Down
108 changes: 108 additions & 0 deletions src/drivers/base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import {
AnySchemaLike,
FromSchemaLike,
createSchemaValidator,
toJsonSchema,
} from "@/internals/helpers/schema.js";
import { GenerateOptions, LLMError } from "@/llms/base.js";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { Retryable } from "@/internals/helpers/retryable.js";
import { PromptTemplate } from "@/template.js";
import { SchemaObject } from "ajv";
import { z } from "zod";

export interface GenerateSchemaInput<T> {
maxRetries?: number;
options?: T;
}

export abstract class BaseDriver<TGenerateOptions extends GenerateOptions = GenerateOptions> {
protected abstract template: PromptTemplate.infer<{ schema: string }>;
protected errorTemplate = new PromptTemplate({
schema: z.object({
errors: z.string(),
expected: z.string(),
received: z.string(),
}),
template: `Generated response does not match the expected schema!
Validation Errors: "{{errors}}"`,
});

constructor(protected llm: ChatLLM<ChatLLMOutput, TGenerateOptions>) {}

protected abstract parseResponse(textResponse: string): unknown;
protected abstract schemaToString(schema: SchemaObject): Promise<string> | string;
protected guided(schema: SchemaObject): GenerateOptions["guided"] | undefined {

Check warning on line 36 in src/drivers/base.ts

View workflow job for this annotation

GitHub Actions / Lint & Build & Test

'schema' is defined but never used. Allowed unused args must match /^_/u
return undefined;
}

async generate<T extends AnySchemaLike>(
schema: T,
input: BaseMessage[],
{ maxRetries = 3, options }: GenerateSchemaInput<TGenerateOptions> = {},
): Promise<FromSchemaLike<T>> {
const jsonSchema = toJsonSchema(schema);
const validator = createSchemaValidator(jsonSchema);
const schemaString = await this.schemaToString(jsonSchema);

const messages: BaseMessage[] = [
BaseMessage.of({
role: Role.SYSTEM,
text: this.template.render({ schema: schemaString }),
}),
...input,
];

return new Retryable({
executor: async () => {
const rawResponse = await this.llm.generate(messages, {
guided: this.guided(jsonSchema),
...options,
} as TGenerateOptions);
const textResponse = rawResponse.getTextContent();
let parsedResponse: any;

try {
parsedResponse = this.parseResponse(textResponse);
} catch (error) {
throw new LLMError(`Failed to parse the generated response.`, [], {
isFatal: false,
isRetryable: true,
context: { error: (error as Error).message, received: textResponse },
});
}

const success = validator(parsedResponse);
if (!success) {
const context = {
expected: schemaString,
received: textResponse,
errors: JSON.stringify(validator.errors ?? []),
};

messages.push(
BaseMessage.of({
role: Role.USER,
text: this.errorTemplate.render(context),
}),
);
throw new LLMError(
"Failed to generate a structured response adhering to the provided schema.",
[],
{
isFatal: false,
isRetryable: true,
context,
},
);
}
return parsedResponse as FromSchemaLike<T>;
},
config: {
signal: options?.signal,
maxRetries,
},
}).get();
}
}
36 changes: 36 additions & 0 deletions src/drivers/json.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { parseBrokenJson } from "@/internals/helpers/schema.js";
import { GenerateOptions } from "@/llms/base.js";
import { PromptTemplate } from "@/template.js";
import { BaseDriver } from "./base.js";
import { SchemaObject } from "ajv";
import { z } from "zod";

export class JsonDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseDriver<TGenerateOptions> {
protected template = new PromptTemplate({
schema: z.object({
schema: z.string(),
}),
template: `You are a helpful assistant that generates only valid JSON adhering to the following JSON Schema.
\`\`\`
{{schema}}
\`\`\`
IMPORTANT: Every message must be a parsable JSON string without additional output.
`,
});

protected parseResponse(textResponse: string): unknown {
return parseBrokenJson(textResponse);
}

protected schemaToString(schema: SchemaObject): string {
return JSON.stringify(schema, null, 2);
}

protected guided(schema: SchemaObject) {
return { json: schema } as const;
}
}
37 changes: 37 additions & 0 deletions src/drivers/typescript.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { parseBrokenJson } from "@/internals/helpers/schema.js";
import { GenerateOptions } from "@/llms/base.js";
import { PromptTemplate } from "@/template.js";
import { BaseDriver } from "./base.js";
import * as jsonSchemaToTypescript from "json-schema-to-typescript";
import { SchemaObject } from "ajv";
import { z } from "zod";

export class TypescriptDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseDriver<TGenerateOptions> {
protected template = new PromptTemplate({
schema: z.object({
schema: z.string(),
}),
template: `You are a helpful assistant that generates only valid JSON adhering to the following TypeScript type.
\`\`\`
{{schema}}
\`\`\`
IMPORTANT: Every message must be a parsable JSON string without additional output.
`,
});

protected parseResponse(textResponse: string): unknown {
return parseBrokenJson(textResponse);
}

protected async schemaToString(schema: SchemaObject): Promise<string> {
return await jsonSchemaToTypescript.compile(schema, "Output");
}

protected guided(schema: SchemaObject) {
return { json: schema } as const;
}
}
32 changes: 32 additions & 0 deletions src/drivers/yaml.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { GenerateOptions } from "@/llms/base.js";
import { PromptTemplate } from "@/template.js";
import { BaseDriver } from "./base.js";
import yaml from "js-yaml";
import { SchemaObject } from "ajv";
import { z } from "zod";

export class YamlDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseDriver<TGenerateOptions> {
protected template = new PromptTemplate({
schema: z.object({
schema: z.string(),
}),
template: `You are a helpful assistant that generates only valid YAML adhering to the following schema.
\`\`\`
{{schema}}
\`\`\`
IMPORTANT: Every message must be a parsable YAML string without additional output.
`,
});

protected parseResponse(textResponse: string): unknown {
return yaml.load(textResponse);
}

protected schemaToString(schema: SchemaObject): string {
return yaml.dump(schema);
}
}
89 changes: 3 additions & 86 deletions src/llms/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,97 +14,14 @@
* limitations under the License.
*/

import { BaseLLM, BaseLLMOutput, GenerateOptions, LLMError } from "@/llms/base.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import {
AnySchemaLike,
createSchemaValidator,
FromSchemaLike,
parseBrokenJson,
toJsonSchema,
} from "@/internals/helpers/schema.js";
import { Retryable } from "@/internals/helpers/retryable.js";
import { GeneratedStructuredErrorTemplate, GeneratedStructuredTemplate } from "@/llms/prompts.js";
import { BaseLLM, BaseLLMOutput, GenerateOptions } from "@/llms/base.js";
import { BaseMessage } from "@/llms/primitives/message.js";

export abstract class ChatLLMOutput extends BaseLLMOutput {
abstract get messages(): readonly BaseMessage[];
}

export interface GenerateSchemaInput<T> {
template?: typeof GeneratedStructuredTemplate;
errorTemplate?: typeof GeneratedStructuredErrorTemplate;
maxRetries?: number;
options?: T;
}

export abstract class ChatLLM<
TOutput extends ChatLLMOutput,
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseLLM<BaseMessage[], TOutput, TGenerateOptions> {
async generateStructured<T extends AnySchemaLike>(
schema: T,
input: BaseMessage[],
{
template = GeneratedStructuredTemplate,
errorTemplate = GeneratedStructuredErrorTemplate,
maxRetries = 3,
options,
}: GenerateSchemaInput<TGenerateOptions> = {},
): Promise<FromSchemaLike<T>> {
const jsonSchema = toJsonSchema(schema);
const validator = createSchemaValidator(jsonSchema);

const finalOptions = { ...options } as TGenerateOptions;
if (!options?.guided) {
finalOptions.guided = { json: jsonSchema };
}

const messages: BaseMessage[] = [
BaseMessage.of({
role: Role.SYSTEM,
text: template.render({
schema: JSON.stringify(jsonSchema, null, 2),
}),
}),
...input,
];

return new Retryable({
executor: async () => {
const rawResponse = await this.generate(messages, finalOptions);
const textResponse = rawResponse.getTextContent();
const jsonResponse: any = parseBrokenJson(textResponse);

const success = validator(jsonResponse);
if (!success) {
const context = {
expected: JSON.stringify(jsonSchema),
received: jsonResponse ? JSON.stringify(jsonResponse) : textResponse,
errors: JSON.stringify(validator.errors ?? []),
};

messages.push(
BaseMessage.of({
role: Role.USER,
text: errorTemplate.render(context),
}),
);
throw new LLMError(
"Failed to generate a structured response adhering to the provided schema.",
[],
{
isFatal: false,
isRetryable: true,
context,
},
);
}
return jsonResponse as FromSchemaLike<T>;
},
config: {
signal: options?.signal,
maxRetries,
},
}).get();
}
}
> extends BaseLLM<BaseMessage[], TOutput, TGenerateOptions> {}
42 changes: 0 additions & 42 deletions src/llms/prompts.ts

This file was deleted.

Loading

0 comments on commit ecbb1e2

Please sign in to comment.