Skip to content

Commit

Permalink
fix: Attempt to increase the reliability of the ollama inference
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedBassem committed Mar 27, 2024
1 parent 5cbce67 commit 9986746
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 17 deletions.
45 changes: 32 additions & 13 deletions apps/workers/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Ollama } from "ollama";
import OpenAI from "openai";

import serverConfig from "@hoarder/shared/config";
import logger from "@hoarder/shared/logger";

export interface InferenceResponse {
response: string;
Expand Down Expand Up @@ -96,30 +97,48 @@ class OllamaInferenceClient implements InferenceClient {
});
}

async inferFromText(prompt: string): Promise<InferenceResponse> {
async runModel(model: string, prompt: string, image?: string) {
const chatCompletion = await this.ollama.chat({
model: serverConfig.inference.textModel,
model: model,
format: "json",
messages: [{ role: "system", content: prompt }],
stream: true,
messages: [
{ role: "user", content: prompt, images: image ? [image] : undefined },
],
});

const response = chatCompletion.message.content;
let totalTokens = 0;
let response = "";
try {
for await (const part of chatCompletion) {
response += part.message.content;
if (!isNaN(part.eval_count)) {
totalTokens += part.eval_count;
}
if (!isNaN(part.prompt_eval_count)) {
totalTokens += part.prompt_eval_count;
}
}
} catch (e) {
// There seem to be some bug in ollama where you can get some successfull response, but still throw an error.
// Using stream + accumulating the response so far is a workaround.
// https://github.com/ollama/ollama-js/issues/72
totalTokens = NaN;
logger.warn(`Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`)
}

return { response, totalTokens };
}

return { response, totalTokens: chatCompletion.eval_count };
async inferFromText(prompt: string): Promise<InferenceResponse> {
return await this.runModel(serverConfig.inference.textModel, prompt);
}

async inferFromImage(
prompt: string,
_contentType: string,
image: string,
): Promise<InferenceResponse> {
const chatCompletion = await this.ollama.chat({
model: serverConfig.inference.imageModel,
format: "json",
messages: [{ role: "user", content: prompt, images: [`${image}`] }],
});

const response = chatCompletion.message.content;
return { response, totalTokens: chatCompletion.eval_count };
return await this.runModel(serverConfig.inference.imageModel, prompt, image);
}
}
11 changes: 8 additions & 3 deletions apps/workers/openaiWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
zOpenAIRequestSchema,
} from "@hoarder/shared/queues";

import { InferenceClientFactory, InferenceClient } from "./inference";
import { InferenceClient, InferenceClientFactory } from "./inference";

const openAIResponseSchema = z.object({
tags: z.array(z.string()),
Expand All @@ -36,7 +36,7 @@ async function attemptMarkTaggingStatus(
})
.where(eq(bookmarks.id, request.bookmarkId));
} catch (e) {
console.log(`Something went wrong when marking the tagging status: ${e}`);
logger.error(`Something went wrong when marking the tagging status: ${e}`);
}
}

Expand Down Expand Up @@ -196,8 +196,9 @@ async function inferTags(

return tags;
} catch (e) {
const responseSneak = response.response.substr(0, 20);
throw new Error(
`[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`,
`[inference][${jobId}] The model ignored our prompt and didn't respond with the expected JSON: ${JSON.stringify(e)}. Here's a sneak peak from the response: ${responseSneak}`,
);
}
}
Expand Down Expand Up @@ -285,6 +286,10 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
);
}

logger.info(
`[inference][${jobId}] Starting an inference job for bookmark with id "${bookmark.id}"`,
);

const tags = await inferTags(jobId, bookmark, inferenceClient);

await connectTags(bookmarkId, tags, bookmark.userId);
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/02-installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Learn more about the costs of using openai [here](/openai).

- Make sure ollama is running.
- Set the `OLLAMA_BASE_URL` env variable to the address of the ollama API.
- Set `INFERENCE_TEXT_MODEL` to the model you want to use for text inference in ollama (for example: `llama2`)
- Set `INFERENCE_TEXT_MODEL` to the model you want to use for text inference in ollama (for example: `mistral`)
- Set `INFERENCE_IMAGE_MODEL` to the model you want to use for image inference in ollama (for example: `llava`)
- Make sure that you `ollama pull`-ed the models that you want to use.

Expand Down
8 changes: 8 additions & 0 deletions packages/shared/queues.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Queue } from "bullmq";
import { z } from "zod";

import serverConfig from "./config";

export const queueConnectionDetails = {
Expand Down Expand Up @@ -27,6 +28,13 @@ export type ZOpenAIRequest = z.infer<typeof zOpenAIRequestSchema>;

export const OpenAIQueue = new Queue<ZOpenAIRequest, void>("openai_queue", {
connection: queueConnectionDetails,
defaultJobOptions: {
attempts: 3,
backoff: {
type: "exponential",
delay: 500,
},
},
});

// Search Indexing Worker
Expand Down

0 comments on commit 9986746

Please sign in to comment.