diff --git a/examples/llms/structured.ts b/examples/llms/structured.ts index 4708e16a..69efccc3 100644 --- a/examples/llms/structured.ts +++ b/examples/llms/structured.ts @@ -2,9 +2,11 @@ import "dotenv/config.js"; import { z } from "zod"; import { BaseMessage, Role } from "bee-agent-framework/llms/primitives/message"; import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat"; +import { JsonDriver } from "bee-agent-framework/drivers/json"; const llm = new OllamaChatLLM(); -const response = await llm.generateStructured( +const driver = new JsonDriver(llm); +const response = await driver.generate( z.union([ z.object({ firstName: z.string().min(1), diff --git a/package.json b/package.json index bc46ea7d..15413775 100644 --- a/package.json +++ b/package.json @@ -89,6 +89,8 @@ "fast-xml-parser": "^4.4.1", "header-generator": "^2.1.54", "joplin-turndown-plugin-gfm": "^1.0.12", + "js-yaml": "^4.1.0", + "json-schema-to-typescript": "^15.0.2", "mathjs": "^13.1.1", "mustache": "^4.2.0", "object-hash": "^3.0.0", @@ -129,6 +131,7 @@ "@types/eslint": "^9.6.1", "@types/eslint-config-prettier": "^6.11.3", "@types/eslint__js": "^8.42.3", + "@types/js-yaml": "^4.0.9", "@types/mustache": "^4", "@types/needle": "^3.3.0", "@types/node": "^20.16.1", diff --git a/src/drivers/base.ts b/src/drivers/base.ts new file mode 100644 index 00000000..eb30d7ca --- /dev/null +++ b/src/drivers/base.ts @@ -0,0 +1,108 @@ +import { + AnySchemaLike, + FromSchemaLike, + createSchemaValidator, + toJsonSchema, +} from "@/internals/helpers/schema.js"; +import { GenerateOptions, LLMError } from "@/llms/base.js"; +import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; +import { BaseMessage, Role } from "@/llms/primitives/message.js"; +import { Retryable } from "@/internals/helpers/retryable.js"; +import { PromptTemplate } from "@/template.js"; +import { SchemaObject } from "ajv"; +import { z } from "zod"; + +export interface GenerateSchemaInput { + maxRetries?: number; + options?: T; +} + +export abstract class BaseDriver { + protected abstract template: PromptTemplate.infer<{ schema: string }>; + protected errorTemplate = new PromptTemplate({ + schema: z.object({ + errors: z.string(), + expected: z.string(), + received: z.string(), + }), + template: `Generated response does not match the expected schema! +Validation Errors: "{{errors}}"`, + }); + + constructor(protected llm: ChatLLM) {} + + protected abstract parseResponse(textResponse: string): unknown; + protected abstract schemaToString(schema: SchemaObject): Promise | string; + protected guided(schema: SchemaObject): GenerateOptions["guided"] | undefined { + return undefined; + } + + async generate( + schema: T, + input: BaseMessage[], + { maxRetries = 3, options }: GenerateSchemaInput = {}, + ): Promise> { + const jsonSchema = toJsonSchema(schema); + const validator = createSchemaValidator(jsonSchema); + const schemaString = await this.schemaToString(jsonSchema); + + const messages: BaseMessage[] = [ + BaseMessage.of({ + role: Role.SYSTEM, + text: this.template.render({ schema: schemaString }), + }), + ...input, + ]; + + return new Retryable({ + executor: async () => { + const rawResponse = await this.llm.generate(messages, { + guided: this.guided(jsonSchema), + ...options, + } as TGenerateOptions); + const textResponse = rawResponse.getTextContent(); + let parsedResponse: any; + + try { + parsedResponse = this.parseResponse(textResponse); + } catch (error) { + throw new LLMError(`Failed to parse the generated response.`, [], { + isFatal: false, + isRetryable: true, + context: { error: (error as Error).message, received: textResponse }, + }); + } + + const success = validator(parsedResponse); + if (!success) { + const context = { + expected: schemaString, + received: textResponse, + errors: JSON.stringify(validator.errors ?? []), + }; + + messages.push( + BaseMessage.of({ + role: Role.USER, + text: this.errorTemplate.render(context), + }), + ); + throw new LLMError( + "Failed to generate a structured response adhering to the provided schema.", + [], + { + isFatal: false, + isRetryable: true, + context, + }, + ); + } + return parsedResponse as FromSchemaLike; + }, + config: { + signal: options?.signal, + maxRetries, + }, + }).get(); + } +} diff --git a/src/drivers/json.ts b/src/drivers/json.ts new file mode 100644 index 00000000..e74a88e7 --- /dev/null +++ b/src/drivers/json.ts @@ -0,0 +1,36 @@ +import { parseBrokenJson } from "@/internals/helpers/schema.js"; +import { GenerateOptions } from "@/llms/base.js"; +import { PromptTemplate } from "@/template.js"; +import { BaseDriver } from "./base.js"; +import { SchemaObject } from "ajv"; +import { z } from "zod"; + +export class JsonDriver< + TGenerateOptions extends GenerateOptions = GenerateOptions, +> extends BaseDriver { + protected template = new PromptTemplate({ + schema: z.object({ + schema: z.string(), + }), + template: `You are a helpful assistant that generates only valid JSON adhering to the following JSON Schema. + +\`\`\` +{{schema}} +\`\`\` + +IMPORTANT: Every message must be a parsable JSON string without additional output. +`, + }); + + protected parseResponse(textResponse: string): unknown { + return parseBrokenJson(textResponse); + } + + protected schemaToString(schema: SchemaObject): string { + return JSON.stringify(schema, null, 2); + } + + protected guided(schema: SchemaObject) { + return { json: schema } as const; + } +} diff --git a/src/drivers/typescript.ts b/src/drivers/typescript.ts new file mode 100644 index 00000000..441037ba --- /dev/null +++ b/src/drivers/typescript.ts @@ -0,0 +1,37 @@ +import { parseBrokenJson } from "@/internals/helpers/schema.js"; +import { GenerateOptions } from "@/llms/base.js"; +import { PromptTemplate } from "@/template.js"; +import { BaseDriver } from "./base.js"; +import * as jsonSchemaToTypescript from "json-schema-to-typescript"; +import { SchemaObject } from "ajv"; +import { z } from "zod"; + +export class TypescriptDriver< + TGenerateOptions extends GenerateOptions = GenerateOptions, +> extends BaseDriver { + protected template = new PromptTemplate({ + schema: z.object({ + schema: z.string(), + }), + template: `You are a helpful assistant that generates only valid JSON adhering to the following TypeScript type. + +\`\`\` +{{schema}} +\`\`\` + +IMPORTANT: Every message must be a parsable JSON string without additional output. +`, + }); + + protected parseResponse(textResponse: string): unknown { + return parseBrokenJson(textResponse); + } + + protected async schemaToString(schema: SchemaObject): Promise { + return await jsonSchemaToTypescript.compile(schema, "Output"); + } + + protected guided(schema: SchemaObject) { + return { json: schema } as const; + } +} diff --git a/src/drivers/yaml.ts b/src/drivers/yaml.ts new file mode 100644 index 00000000..8c8c2b70 --- /dev/null +++ b/src/drivers/yaml.ts @@ -0,0 +1,32 @@ +import { GenerateOptions } from "@/llms/base.js"; +import { PromptTemplate } from "@/template.js"; +import { BaseDriver } from "./base.js"; +import yaml from "js-yaml"; +import { SchemaObject } from "ajv"; +import { z } from "zod"; + +export class YamlDriver< + TGenerateOptions extends GenerateOptions = GenerateOptions, +> extends BaseDriver { + protected template = new PromptTemplate({ + schema: z.object({ + schema: z.string(), + }), + template: `You are a helpful assistant that generates only valid YAML adhering to the following schema. + +\`\`\` +{{schema}} +\`\`\` + +IMPORTANT: Every message must be a parsable YAML string without additional output. +`, + }); + + protected parseResponse(textResponse: string): unknown { + return yaml.load(textResponse); + } + + protected schemaToString(schema: SchemaObject): string { + return yaml.dump(schema); + } +} diff --git a/src/llms/chat.ts b/src/llms/chat.ts index c8cdfc44..55332198 100644 --- a/src/llms/chat.ts +++ b/src/llms/chat.ts @@ -14,97 +14,14 @@ * limitations under the License. */ -import { BaseLLM, BaseLLMOutput, GenerateOptions, LLMError } from "@/llms/base.js"; -import { BaseMessage, Role } from "@/llms/primitives/message.js"; -import { - AnySchemaLike, - createSchemaValidator, - FromSchemaLike, - parseBrokenJson, - toJsonSchema, -} from "@/internals/helpers/schema.js"; -import { Retryable } from "@/internals/helpers/retryable.js"; -import { GeneratedStructuredErrorTemplate, GeneratedStructuredTemplate } from "@/llms/prompts.js"; +import { BaseLLM, BaseLLMOutput, GenerateOptions } from "@/llms/base.js"; +import { BaseMessage } from "@/llms/primitives/message.js"; export abstract class ChatLLMOutput extends BaseLLMOutput { abstract get messages(): readonly BaseMessage[]; } -export interface GenerateSchemaInput { - template?: typeof GeneratedStructuredTemplate; - errorTemplate?: typeof GeneratedStructuredErrorTemplate; - maxRetries?: number; - options?: T; -} - export abstract class ChatLLM< TOutput extends ChatLLMOutput, TGenerateOptions extends GenerateOptions = GenerateOptions, -> extends BaseLLM { - async generateStructured( - schema: T, - input: BaseMessage[], - { - template = GeneratedStructuredTemplate, - errorTemplate = GeneratedStructuredErrorTemplate, - maxRetries = 3, - options, - }: GenerateSchemaInput = {}, - ): Promise> { - const jsonSchema = toJsonSchema(schema); - const validator = createSchemaValidator(jsonSchema); - - const finalOptions = { ...options } as TGenerateOptions; - if (!options?.guided) { - finalOptions.guided = { json: jsonSchema }; - } - - const messages: BaseMessage[] = [ - BaseMessage.of({ - role: Role.SYSTEM, - text: template.render({ - schema: JSON.stringify(jsonSchema, null, 2), - }), - }), - ...input, - ]; - - return new Retryable({ - executor: async () => { - const rawResponse = await this.generate(messages, finalOptions); - const textResponse = rawResponse.getTextContent(); - const jsonResponse: any = parseBrokenJson(textResponse); - - const success = validator(jsonResponse); - if (!success) { - const context = { - expected: JSON.stringify(jsonSchema), - received: jsonResponse ? JSON.stringify(jsonResponse) : textResponse, - errors: JSON.stringify(validator.errors ?? []), - }; - - messages.push( - BaseMessage.of({ - role: Role.USER, - text: errorTemplate.render(context), - }), - ); - throw new LLMError( - "Failed to generate a structured response adhering to the provided schema.", - [], - { - isFatal: false, - isRetryable: true, - context, - }, - ); - } - return jsonResponse as FromSchemaLike; - }, - config: { - signal: options?.signal, - maxRetries, - }, - }).get(); - } -} +> extends BaseLLM {} diff --git a/src/llms/prompts.ts b/src/llms/prompts.ts deleted file mode 100644 index 41fc5f77..00000000 --- a/src/llms/prompts.ts +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2024 IBM Corp. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { PromptTemplate } from "@/template.js"; -import { z } from "zod"; - -export const GeneratedStructuredTemplate = new PromptTemplate({ - schema: z.object({ - schema: z.string(), - }), - template: `You are a helpful assistant that generates only valid JSON adhering to the following JSON Schema. - -\`\`\` -{{schema}} -\`\`\` - -IMPORTANT: Every message must be a parsable JSON string without additional output. -`, -}); - -export const GeneratedStructuredErrorTemplate = new PromptTemplate({ - schema: z.object({ - errors: z.string(), - expected: z.string(), - received: z.string(), - }), - template: `Generated response does not match the expected schema! -Validation Errors: "{{errors}}"`, -}); diff --git a/yarn.lock b/yarn.lock index f99c2585..e74e15d5 100644 --- a/yarn.lock +++ b/yarn.lock @@ -24,6 +24,17 @@ __metadata: languageName: node linkType: hard +"@apidevtools/json-schema-ref-parser@npm:^11.5.5": + version: 11.7.0 + resolution: "@apidevtools/json-schema-ref-parser@npm:11.7.0" + dependencies: + "@jsdevtools/ono": "npm:^7.1.3" + "@types/json-schema": "npm:^7.0.15" + js-yaml: "npm:^4.1.0" + checksum: 10c0/0ae9ced2953918a14b17874dc0515d3367a95197d58d869f44e756223e23eb9996283e21cb7ef1c7e016ae94467920b4aca705e0ea3d61a61455431572ad4ab5 + languageName: node + linkType: hard + "@babel/code-frame@npm:^7.0.0, @babel/code-frame@npm:^7.21.4, @babel/code-frame@npm:^7.22.13": version: 7.24.7 resolution: "@babel/code-frame@npm:7.24.7" @@ -798,6 +809,13 @@ __metadata: languageName: node linkType: hard +"@jsdevtools/ono@npm:^7.1.3": + version: 7.1.3 + resolution: "@jsdevtools/ono@npm:7.1.3" + checksum: 10c0/a9f7e3e8e3bc315a34959934a5e2f874c423cf4eae64377d3fc9de0400ed9f36cb5fd5ebce3300d2e8f4085f557c4a8b591427a583729a87841fda46e6c216b9 + languageName: node + linkType: hard + "@langchain/community@npm:~0.2.28": version: 0.2.28 resolution: "@langchain/community@npm:0.2.28" @@ -1878,7 +1896,14 @@ __metadata: languageName: node linkType: hard -"@types/json-schema@npm:*": +"@types/js-yaml@npm:^4.0.9": + version: 4.0.9 + resolution: "@types/js-yaml@npm:4.0.9" + checksum: 10c0/24de857aa8d61526bbfbbaa383aa538283ad17363fcd5bb5148e2c7f604547db36646440e739d78241ed008702a8920665d1add5618687b6743858fae00da211 + languageName: node + linkType: hard + +"@types/json-schema@npm:*, @types/json-schema@npm:^7.0.15": version: 7.0.15 resolution: "@types/json-schema@npm:7.0.15" checksum: 10c0/a996a745e6c5d60292f36731dd41341339d4eeed8180bb09226e5c8d23759067692b1d88e5d91d72ee83dfc00d3aca8e7bd43ea120516c17922cbcb7c3e252db @@ -1894,7 +1919,7 @@ __metadata: languageName: node linkType: hard -"@types/lodash@npm:*": +"@types/lodash@npm:*, @types/lodash@npm:^4.17.7": version: 4.17.7 resolution: "@types/lodash@npm:4.17.7" checksum: 10c0/40c965b5ffdcf7ff5c9105307ee08b782da228c01b5c0529122c554c64f6b7168fc8f11dc79aa7bae4e67e17efafaba685dc3a47e294dbf52a65ed2b67100561 @@ -2527,6 +2552,7 @@ __metadata: "@types/eslint": "npm:^9.6.1" "@types/eslint-config-prettier": "npm:^6.11.3" "@types/eslint__js": "npm:^8.42.3" + "@types/js-yaml": "npm:^4.0.9" "@types/mustache": "npm:^4" "@types/needle": "npm:^3.3.0" "@types/node": "npm:^20.16.1" @@ -2546,6 +2572,8 @@ __metadata: header-generator: "npm:^2.1.54" husky: "npm:^9.1.5" joplin-turndown-plugin-gfm: "npm:^1.0.12" + js-yaml: "npm:^4.1.0" + json-schema-to-typescript: "npm:^15.0.2" langchain: "npm:~0.2.16" lint-staged: "npm:^15.2.9" mathjs: "npm:^13.1.1" @@ -4610,7 +4638,7 @@ __metadata: languageName: node linkType: hard -"glob@npm:^10.2.2, glob@npm:^10.3.10, glob@npm:^10.4.1": +"glob@npm:^10.2.2, glob@npm:^10.3.10, glob@npm:^10.3.12, glob@npm:^10.4.1": version: 10.4.5 resolution: "glob@npm:10.4.5" dependencies: @@ -5446,6 +5474,25 @@ __metadata: languageName: node linkType: hard +"json-schema-to-typescript@npm:^15.0.2": + version: 15.0.2 + resolution: "json-schema-to-typescript@npm:15.0.2" + dependencies: + "@apidevtools/json-schema-ref-parser": "npm:^11.5.5" + "@types/json-schema": "npm:^7.0.15" + "@types/lodash": "npm:^4.17.7" + glob: "npm:^10.3.12" + is-glob: "npm:^4.0.3" + js-yaml: "npm:^4.1.0" + lodash: "npm:^4.17.21" + minimist: "npm:^1.2.8" + prettier: "npm:^3.2.5" + bin: + json2ts: dist/src/cli.js + checksum: 10c0/cad5f9f525bef3171253d5029be094c1db3fd948aa35a2f7303e9223c86fa16e5123697ae9599201c19da0d728110e77972ddefb6aa69b9b887971120ac73168 + languageName: node + linkType: hard + "json-schema-traverse@npm:^0.4.1": version: 0.4.1 resolution: "json-schema-traverse@npm:0.4.1" @@ -5960,7 +6007,7 @@ __metadata: languageName: node linkType: hard -"lodash@npm:4.17.21": +"lodash@npm:4.17.21, lodash@npm:^4.17.21": version: 4.17.21 resolution: "lodash@npm:4.17.21" checksum: 10c0/d8cbea072bb08655bb4c989da418994b073a608dffa608b09ac04b43a791b12aeae7cd7ad919aa4c925f33b48490b5cfe6c1f71d827956071dae2e7bb3a6b74c @@ -7728,7 +7775,7 @@ __metadata: languageName: node linkType: hard -"prettier@npm:^3.3.3": +"prettier@npm:^3.2.5, prettier@npm:^3.3.3": version: 3.3.3 resolution: "prettier@npm:3.3.3" bin: