From 5bff776607b6602499d4476b4ca31d69415bb476 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 13 Dec 2024 19:12:41 +0100 Subject: [PATCH] feat(adapters): add embedding support for Bedrock Ref: #176 Signed-off-by: Tomas Dvorak --- src/adapters/bedrock/chat.ts | 27 ++++++++++++++++--- tests/e2e/adapters/bedrock/chat.test.ts | 35 +++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 tests/e2e/adapters/bedrock/chat.test.ts diff --git a/src/adapters/bedrock/chat.ts b/src/adapters/bedrock/chat.ts index d70082a5..e7338d19 100644 --- a/src/adapters/bedrock/chat.ts +++ b/src/adapters/bedrock/chat.ts @@ -32,6 +32,7 @@ import { Emitter } from "@/emitter/emitter.js"; import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types"; import { BedrockRuntimeClient as Client, + InvokeModelCommand, ConverseCommand, ConverseCommandOutput, ConverseStreamCommand, @@ -42,10 +43,13 @@ import { } from "@aws-sdk/client-bedrock-runtime"; import { GetRunContext } from "@/context.js"; import { Serializer } from "@/serializer/serializer.js"; -import { NotImplementedError } from "@/errors.js"; type Response = ContentBlockDeltaEvent | ConverseCommandOutput; +export interface BedrockEmbeddingOptions extends EmbeddingOptions { + body?: Record; +} + export class ChatBedrockOutput extends ChatLLMOutput { public readonly responses: Response[]; @@ -204,9 +208,24 @@ export class BedrockChatLLM extends ChatLLM { }; } - // eslint-disable-next-line unused-imports/no-unused-vars - async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise { - throw new NotImplementedError(); + async embed( + input: BaseMessage[][], + options: BedrockEmbeddingOptions = {}, + ): Promise { + const command = new InvokeModelCommand({ + modelId: this.modelId, + contentType: "application/json", + accept: "application/json", + body: JSON.stringify({ + texts: input.map((msgs) => msgs.map((msg) => msg.text)), + input_type: "search_document", + ...options?.body, + }), + }); + + const response = await this.client.send(command, { abortSignal: options?.signal }); + const jsonString = new TextDecoder().decode(response.body); + return JSON.parse(jsonString); } async tokenize(input: BaseMessage[]): Promise { diff --git a/tests/e2e/adapters/bedrock/chat.test.ts b/tests/e2e/adapters/bedrock/chat.test.ts new file mode 100644 index 00000000..e7696a6f --- /dev/null +++ b/tests/e2e/adapters/bedrock/chat.test.ts @@ -0,0 +1,35 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BedrockChatLLM } from "@/adapters/bedrock/chat.js"; +import { BaseMessage } from "@/llms/primitives/message.js"; + +describe.runIf([process.env.AWS_REGION].every((env) => Boolean(env)))("Bedrock Chat LLM", () => { + it("Embeds", async () => { + const llm = new BedrockChatLLM({ + region: process.env.AWS_REGION, + modelId: "amazon.titan-embed-text-v1", + }); + + const response = await llm.embed([ + [BaseMessage.of({ role: "user", text: `Hello world!` })], + [BaseMessage.of({ role: "user", text: `Hello family!` })], + ]); + expect(response.embeddings.length).toBe(2); + expect(response.embeddings[0].length).toBe(512); + expect(response.embeddings[1].length).toBe(512); + }); +});