diff --git a/src/tools/custom.test.ts b/src/tools/custom.test.ts index 5757c819..07f4dd14 100644 --- a/src/tools/custom.test.ts +++ b/src/tools/custom.test.ts @@ -52,7 +52,7 @@ describe("CustomTool", () => { }, }); - const customTool = await CustomTool.fromSourceCode("http://localhost", "source code"); + const customTool = await CustomTool.fromSourceCode({ url: "http://localhost" }, "source code"); expect(customTool.name).toBe("test"); expect(customTool.description).toBe("A test tool"); @@ -76,9 +76,9 @@ describe("CustomTool", () => { }, }); - await expect(CustomTool.fromSourceCode("http://localhost", "source code")).rejects.toThrow( - "Error parsing tool", - ); + await expect( + CustomTool.fromSourceCode({ url: "http://localhost" }, "source code"), + ).rejects.toThrow("Error parsing tool"); }); it("should run the custom tool", async () => { @@ -101,7 +101,7 @@ describe("CustomTool", () => { }); const customTool = await CustomTool.fromSourceCode( - "http://localhost", + { url: "http://localhost" }, "source code", "executor-id", ); @@ -148,7 +148,7 @@ describe("CustomTool", () => { }); const customTool = await CustomTool.fromSourceCode( - "http://localhost", + { url: "http://localhost" }, "source code", "executor-id", ); diff --git a/src/tools/custom.ts b/src/tools/custom.ts index 9a560936..e4dabccb 100644 --- a/src/tools/custom.ts +++ b/src/tools/custom.ts @@ -21,13 +21,17 @@ import { FrameworkError } from "@/errors.js"; import { z } from "zod"; import { validate } from "@/internals/helpers/general.js"; import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect"; +import { CodeInterpreterOptions } from "./python/python.js"; export class CustomToolCreateError extends FrameworkError {} export class CustomToolExecuteError extends FrameworkError {} const toolOptionsSchema = z .object({ - codeInterpreterUrl: z.string().url(), + codeInterpreter: z.object({ + url: z.string().url(), + connectionOptions: z.any().optional(), + }), sourceCode: z.string().min(1), name: z.string().min(1), description: z.string().min(1), @@ -38,10 +42,14 @@ const toolOptionsSchema = z export type CustomToolOptions = z.output & BaseToolOptions; -function createCodeInterpreterClient(url: string) { +function createCodeInterpreterClient(codeInterpreter: CodeInterpreterOptions) { return createPromiseClient( CodeInterpreterService, - createGrpcTransport({ baseUrl: url, httpVersion: "2" }), + createGrpcTransport({ + baseUrl: codeInterpreter.url, + httpVersion: "2", + nodeOptions: codeInterpreter.connectionOptions, + }), ); } @@ -65,7 +73,7 @@ export class CustomTool extends Tool { ) { validate(options, toolOptionsSchema); super(options); - this.client = client || createCodeInterpreterClient(options.codeInterpreterUrl); + this.client = client || createCodeInterpreterClient(options.codeInterpreter); this.name = options.name; this.description = options.description; } @@ -89,11 +97,15 @@ export class CustomTool extends Tool { loadSnapshot(snapshot: ReturnType): void { super.loadSnapshot(snapshot); - this.client = createCodeInterpreterClient(this.options.codeInterpreterUrl); + this.client = createCodeInterpreterClient(this.options.codeInterpreter); } - static async fromSourceCode(codeInterpreterUrl: string, sourceCode: string, executorId?: string) { - const client = createCodeInterpreterClient(codeInterpreterUrl); + static async fromSourceCode( + codeInterpreter: CodeInterpreterOptions, + sourceCode: string, + executorId?: string, + ) { + const client = createCodeInterpreterClient(codeInterpreter); const response = await client.parseCustomTool({ toolSourceCode: sourceCode }); if (response.response.case === "error") { @@ -104,7 +116,7 @@ export class CustomTool extends Tool { return new CustomTool( { - codeInterpreterUrl, + codeInterpreter, sourceCode, name: toolName, description: toolDescription, diff --git a/src/tools/python/python.ts b/src/tools/python/python.ts index 8995f40f..cca07409 100644 --- a/src/tools/python/python.ts +++ b/src/tools/python/python.ts @@ -35,11 +35,13 @@ import { ValidationError } from "ajv"; import { ConnectionOptions } from "node:tls"; import { AnySchemaLike } from "@/internals/helpers/schema.js"; +export interface CodeInterpreterOptions { + url: string; + connectionOptions?: ConnectionOptions; +} + export interface PythonToolOptions extends BaseToolOptions { - codeInterpreter: { - url: string; - connectionOptions?: ConnectionOptions; - }; + codeInterpreter: CodeInterpreterOptions; executorId?: string; preprocess?: { llm: LLM; promptTemplate: PromptTemplate<"input"> }; storage: PythonStorage;