Skip to content

Commit

Permalink
feat(adapters): add embedding support for Groq
Browse files Browse the repository at this point in the history
Ref: #176
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Dec 13, 2024
1 parent 7e3d8ef commit 4673b5e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
15 changes: 12 additions & 3 deletions src/adapters/groq/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import { GetRunContext } from "@/context.js";
import { Serializer } from "@/serializer/serializer.js";
import { getPropStrict } from "@/internals/helpers/object.js";
import { ChatCompletionCreateParams } from "groq-sdk/resources/chat/completions";
import { NotImplementedError } from "@/errors.js";

type Parameters = Omit<ChatCompletionCreateParams, "stream" | "messages" | "model">;
type Response = Omit<Client.Chat.ChatCompletionChunk, "object">;
Expand Down Expand Up @@ -148,9 +147,19 @@ export class GroqChatLLM extends ChatLLM<ChatGroqOutput> {
};
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
throw new NotImplementedError();
const { data } = await this.client.embeddings.create(
{
model: this.modelId,
input: input.flatMap((msgs) => msgs.map((msg) => msg.text)) as string[],
encoding_format: "float",
},
{
signal: options?.signal,
stream: false,
},
);
return { embeddings: data.map(({ embedding }) => embedding as number[]) };
}

async tokenize(input: BaseMessage[]): Promise<BaseLLMTokenizeOutput> {
Expand Down
16 changes: 14 additions & 2 deletions tests/e2e/adapters/groq/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import { GroqChatLLM } from "@/adapters/groq/chat.js";
const apiKey = process.env.GROQ_API_KEY;

describe.runIf(Boolean(apiKey))("Adapter Groq Chat LLM", () => {
const createChatLLM = () => {
const createChatLLM = (modelId = "llama3-8b-8192") => {
const model = new GroqChatLLM({
modelId: "llama3-8b-8192",
modelId,
parameters: {
temperature: 0,
max_tokens: 1024,
Expand Down Expand Up @@ -69,4 +69,16 @@ describe.runIf(Boolean(apiKey))("Adapter Groq Chat LLM", () => {
);
}
});

// Embedding model does not available right now
it.skip("Embeds", async () => {
const llm = createChatLLM("nomic-embed-text-v1_5");
const response = await llm.embed([
[BaseMessage.of({ role: "user", text: `Hello world!` })],
[BaseMessage.of({ role: "user", text: `Hello family!` })],
]);
expect(response.embeddings.length).toBe(2);
expect(response.embeddings[0].length).toBe(1024);
expect(response.embeddings[1].length).toBe(1024);
});
});

0 comments on commit 4673b5e

Please sign in to comment.