From 5c3357a890878daa29823d79aac274ba6d11a829 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 30 Aug 2024 12:18:36 +0200 Subject: [PATCH] test: refactor tests --- tests/e2e/adapters/sdk/chat.test.ts | 2 +- tests/e2e/adapters/watsonx/chat.test.ts | 43 ++++--------------------- tests/e2e/agents/bee.test.ts | 7 ++-- tests/utils/llmFactory.ts | 33 +++++++++++++++++++ 4 files changed, 43 insertions(+), 42 deletions(-) create mode 100644 tests/utils/llmFactory.ts diff --git a/tests/e2e/adapters/sdk/chat.test.ts b/tests/e2e/adapters/sdk/chat.test.ts index 1a5b20b4..f02ff6ed 100644 --- a/tests/e2e/adapters/sdk/chat.test.ts +++ b/tests/e2e/adapters/sdk/chat.test.ts @@ -19,7 +19,7 @@ import { BaseMessage } from "@/llms/primitives/message.js"; import { expect } from "vitest"; import { verifyDeserialization } from "@tests/e2e/utils.js"; -describe("Adapter SDK Chat LLM", () => { +describe.runIf(Boolean(process.env.GENAI_API_KEY))("Adapter SDK Chat LLM", () => { const createChatLLM = () => { return BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); }; diff --git a/tests/e2e/adapters/watsonx/chat.test.ts b/tests/e2e/adapters/watsonx/chat.test.ts index bf31a5a5..d2252884 100644 --- a/tests/e2e/adapters/watsonx/chat.test.ts +++ b/tests/e2e/adapters/watsonx/chat.test.ts @@ -14,52 +14,23 @@ * limitations under the License. */ -import { PromptTemplate } from "@/template.js"; import { BaseMessage } from "@/llms/primitives/message.js"; import { expect } from "vitest"; import { verifyDeserialization } from "@tests/e2e/utils.js"; import { WatsonXChatLLM } from "@/adapters/watsonx/chat.js"; -import { WatsonXLLM } from "@/adapters/watsonx/llm.js"; const apiKey = process.env.WATSONX_API_KEY!; const projectId = process.env.WATSONX_PROJECT_ID!; describe.runIf(Boolean(apiKey && projectId))("WatsonX Chat LLM", () => { const createChatLLM = () => { - const template = new PromptTemplate({ - variables: ["messages"], - template: `{{#messages}}{{#system}}<|begin_of_text|><|start_header_id|>system<|end_header_id|> - -{{system}}<|eot_id|>{{/system}}{{#user}}<|start_header_id|>user<|end_header_id|> - -{{user}}<|eot_id|>{{/user}}{{#assistant}}<|start_header_id|>assistant<|end_header_id|> - -{{assistant}}<|eot_id|>{{/assistant}}{{/messages}}<|start_header_id|>assistant<|end_header_id|> - -`, - }); - - return new WatsonXChatLLM({ - llm: new WatsonXLLM({ - modelId: "meta-llama/llama-3-70b-instruct", - projectId, - apiKey, - parameters: { - decoding_method: "greedy", - min_new_tokens: 5, - max_new_tokens: 50, - }, - }), - config: { - messagesToPrompt(messages: BaseMessage[]) { - return template.render({ - messages: messages.map((message) => ({ - system: message.role === "system" ? [message.text] : [], - user: message.role === "user" ? [message.text] : [], - assistant: message.role === "assistant" ? [message.text] : [], - })), - }); - }, + return WatsonXChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct", { + apiKey, + projectId, + parameters: { + decoding_method: "greedy", + min_new_tokens: 5, + max_new_tokens: 50, }, }); }; diff --git a/tests/e2e/agents/bee.test.ts b/tests/e2e/agents/bee.test.ts index ab086bf4..a5825465 100644 --- a/tests/e2e/agents/bee.test.ts +++ b/tests/e2e/agents/bee.test.ts @@ -19,17 +19,14 @@ import { FrameworkError } from "@/errors.js"; import { beforeEach, expect, vi } from "vitest"; import { Logger } from "@/logger/logger.js"; import { BeeAgent } from "@/agents/bee/agent.js"; -import { BAMChatLLM } from "@/adapters/bam/chat.js"; import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js"; import { BaseMessage } from "@/llms/primitives/message.js"; import { createCallbackRegister } from "@tests/e2e/utils.js"; import { omitEmptyValues } from "@/internals/helpers/object.js"; +import * as process from "node:process"; +import { createChatLLM } from "@tests/utils/llmFactory.js"; describe("Bee Agent", () => { - const createChatLLM = () => { - return BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); - }; - const createAgent = () => { return new BeeAgent({ llm: createChatLLM(), diff --git a/tests/utils/llmFactory.ts b/tests/utils/llmFactory.ts new file mode 100644 index 00000000..4b207b64 --- /dev/null +++ b/tests/utils/llmFactory.ts @@ -0,0 +1,33 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; +import process from "node:process"; +import { BAMChatLLM } from "@/adapters/bam/chat.js"; +import { OpenAIChatLLM } from "@/adapters/openai/chat.js"; +import { WatsonXChatLLM } from "@/adapters/watsonx/chat.js"; + +export function createChatLLM(): ChatLLM { + if (process.env.GENAI_API_KEY) { + return BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"); + } else if (process.env.OPENAI_API_KEY) { + return new OpenAIChatLLM({ modelId: "gpt-4o" }); + } else if (process.env.WATSONX_API_KEY) { + return WatsonXChatLLM.fromPreset("meta-llama/llama-3-70b-instruct"); + } else { + throw new Error("No API key for any LLM provider has been provided. Cannot run test case."); + } +}