diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index 2d3531dfcc..ca7fd552e2 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -71,6 +71,7 @@ import { } from "../../core/src/azurecontentsafety" import { resolveGlobalConfiguration } from "../../core/src/config" import { HostConfiguration } from "../../core/src/hostconfiguration" +import { YAMLStringify } from "../../core/src/yaml" class NodeServerManager implements ServerManager { async start(): Promise { @@ -102,13 +103,31 @@ class ModelManager implements ModelService { if (this.pulled.includes(modelid)) return { ok: true } if (provider === MODEL_PROVIDER_OLLAMA) { - logVerbose(`${provider}: pull ${model}`) try { + logVerbose(`${provider}: show ${model}`) const conn = await this.getModelToken(modelid) - let res: Response - // OLLAMA - res = await fetch(`${conn.base}/api/pull`, { + // test if model is present + const resTags = await fetch(`${conn.base}/api/tags`, { + method: "GET", + headers: { + "User-Agent": TOOL_ID, + "Content-Type": "application/json", + }, + }) + if (resTags.ok) { + const { models }: { models: { model: string }[] } = + await resTags.json() + if (models.find((m) => m.model === model)) + return { ok: true } + logVerbose( + `${provider}: ${model} not found in\n${YAMLStringify(models.map((m) => m.model))}` + ) + } + + // pull + logVerbose(`${provider}: pull ${model}`) + const resPull = await fetch(`${conn.base}/api/pull`, { method: "POST", headers: { "User-Agent": TOOL_ID, @@ -120,15 +139,19 @@ class ModelManager implements ModelService { 2 ), }) - if (res.ok) { - const resj = await res.json() - //logVerbose(JSON.stringify(resj, null, 2)) + if (resPull.ok) { + const resj = await resPull.json() + logVerbose(JSON.stringify(resj, null, 2)) + } + if (resPull.ok) this.pulled.push(modelid) + else { + logError(`${provider}: failed to pull model ${model}`) + trace?.error(`${provider}: pull model ${model} failed`) } - if (res.ok) this.pulled.push(modelid) - return { ok: res.ok, status: res.status } + return { ok: resPull.ok, status: resPull.status } } catch (e) { logError(`${provider}: failed to pull model ${model}`) - trace?.error(`${provider}: pull model failed`, e) + trace?.error(`${provider}: pull model ${model} failed`, e) return { ok: false, status: 500, error: serializeError(e) } } }