From 9986746aa890f2490ff18fd4fc79be4de0e4dbe2 Mon Sep 17 00:00:00 2001 From: MohamedBassem Date: Wed, 27 Mar 2024 16:30:27 +0000 Subject: [PATCH] fix: Attempt to increase the reliability of the ollama inference --- apps/workers/inference.ts | 45 +++++++++++++++++++++++++----------- apps/workers/openaiWorker.ts | 11 ++++++--- docs/docs/02-installation.md | 2 +- packages/shared/queues.ts | 8 +++++++ 4 files changed, 49 insertions(+), 17 deletions(-) diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts index c622dd54..3b0b5943 100644 --- a/apps/workers/inference.ts +++ b/apps/workers/inference.ts @@ -2,6 +2,7 @@ import { Ollama } from "ollama"; import OpenAI from "openai"; import serverConfig from "@hoarder/shared/config"; +import logger from "@hoarder/shared/logger"; export interface InferenceResponse { response: string; @@ -96,16 +97,41 @@ class OllamaInferenceClient implements InferenceClient { }); } - async inferFromText(prompt: string): Promise { + async runModel(model: string, prompt: string, image?: string) { const chatCompletion = await this.ollama.chat({ - model: serverConfig.inference.textModel, + model: model, format: "json", - messages: [{ role: "system", content: prompt }], + stream: true, + messages: [ + { role: "user", content: prompt, images: image ? [image] : undefined }, + ], }); - const response = chatCompletion.message.content; + let totalTokens = 0; + let response = ""; + try { + for await (const part of chatCompletion) { + response += part.message.content; + if (!isNaN(part.eval_count)) { + totalTokens += part.eval_count; + } + if (!isNaN(part.prompt_eval_count)) { + totalTokens += part.prompt_eval_count; + } + } + } catch (e) { + // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. + // Using stream + accumulating the response so far is a workaround. + // https://github.com/ollama/ollama-js/issues/72 + totalTokens = NaN; + logger.warn(`Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`) + } + + return { response, totalTokens }; + } - return { response, totalTokens: chatCompletion.eval_count }; + async inferFromText(prompt: string): Promise { + return await this.runModel(serverConfig.inference.textModel, prompt); } async inferFromImage( @@ -113,13 +139,6 @@ class OllamaInferenceClient implements InferenceClient { _contentType: string, image: string, ): Promise { - const chatCompletion = await this.ollama.chat({ - model: serverConfig.inference.imageModel, - format: "json", - messages: [{ role: "user", content: prompt, images: [`${image}`] }], - }); - - const response = chatCompletion.message.content; - return { response, totalTokens: chatCompletion.eval_count }; + return await this.runModel(serverConfig.inference.imageModel, prompt, image); } } diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index b706fb90..9b2934e3 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -14,7 +14,7 @@ import { zOpenAIRequestSchema, } from "@hoarder/shared/queues"; -import { InferenceClientFactory, InferenceClient } from "./inference"; +import { InferenceClient, InferenceClientFactory } from "./inference"; const openAIResponseSchema = z.object({ tags: z.array(z.string()), @@ -36,7 +36,7 @@ async function attemptMarkTaggingStatus( }) .where(eq(bookmarks.id, request.bookmarkId)); } catch (e) { - console.log(`Something went wrong when marking the tagging status: ${e}`); + logger.error(`Something went wrong when marking the tagging status: ${e}`); } } @@ -196,8 +196,9 @@ async function inferTags( return tags; } catch (e) { + const responseSneak = response.response.substr(0, 20); throw new Error( - `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`, + `[inference][${jobId}] The model ignored our prompt and didn't respond with the expected JSON: ${JSON.stringify(e)}. Here's a sneak peak from the response: ${responseSneak}`, ); } } @@ -285,6 +286,10 @@ async function runOpenAI(job: Job) { ); } + logger.info( + `[inference][${jobId}] Starting an inference job for bookmark with id "${bookmark.id}"`, + ); + const tags = await inferTags(jobId, bookmark, inferenceClient); await connectTags(bookmarkId, tags, bookmark.userId); diff --git a/docs/docs/02-installation.md b/docs/docs/02-installation.md index 70fc3bb1..94d44f5d 100644 --- a/docs/docs/02-installation.md +++ b/docs/docs/02-installation.md @@ -55,7 +55,7 @@ Learn more about the costs of using openai [here](/openai). - Make sure ollama is running. - Set the `OLLAMA_BASE_URL` env variable to the address of the ollama API. - - Set `INFERENCE_TEXT_MODEL` to the model you want to use for text inference in ollama (for example: `llama2`) + - Set `INFERENCE_TEXT_MODEL` to the model you want to use for text inference in ollama (for example: `mistral`) - Set `INFERENCE_IMAGE_MODEL` to the model you want to use for image inference in ollama (for example: `llava`) - Make sure that you `ollama pull`-ed the models that you want to use. diff --git a/packages/shared/queues.ts b/packages/shared/queues.ts index b264e2c4..146c19c6 100644 --- a/packages/shared/queues.ts +++ b/packages/shared/queues.ts @@ -1,5 +1,6 @@ import { Queue } from "bullmq"; import { z } from "zod"; + import serverConfig from "./config"; export const queueConnectionDetails = { @@ -27,6 +28,13 @@ export type ZOpenAIRequest = z.infer; export const OpenAIQueue = new Queue("openai_queue", { connection: queueConnectionDetails, + defaultJobOptions: { + attempts: 3, + backoff: { + type: "exponential", + delay: 500, + }, + }, }); // Search Indexing Worker