Skip to content

Commit

Permalink
feat: enhance model pull logic with tag check 🛠️
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 13, 2024
1 parent da9e171 commit 143ac3a
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import {
} from "../../core/src/azurecontentsafety"
import { resolveGlobalConfiguration } from "../../core/src/config"
import { HostConfiguration } from "../../core/src/hostconfiguration"
import { YAMLStringify } from "../../core/src/yaml"

class NodeServerManager implements ServerManager {
async start(): Promise<void> {
Expand Down Expand Up @@ -102,13 +103,31 @@ class ModelManager implements ModelService {
if (this.pulled.includes(modelid)) return { ok: true }

if (provider === MODEL_PROVIDER_OLLAMA) {
logVerbose(`${provider}: pull ${model}`)
try {
logVerbose(`${provider}: show ${model}`)
const conn = await this.getModelToken(modelid)

let res: Response
// OLLAMA
res = await fetch(`${conn.base}/api/pull`, {
// test if model is present
const resTags = await fetch(`${conn.base}/api/tags`, {
method: "GET",
headers: {
"User-Agent": TOOL_ID,
"Content-Type": "application/json",
},
})
if (resTags.ok) {
const { models }: { models: { model: string }[] } =
await resTags.json()
if (models.find((m) => m.model === model))
return { ok: true }
logVerbose(
`${provider}: ${model} not found in\n${YAMLStringify(models.map((m) => m.model))}`
)
}

// pull
logVerbose(`${provider}: pull ${model}`)
const resPull = await fetch(`${conn.base}/api/pull`, {
method: "POST",
headers: {
"User-Agent": TOOL_ID,
Expand All @@ -120,15 +139,19 @@ class ModelManager implements ModelService {
2
),
})
if (res.ok) {
const resj = await res.json()
//logVerbose(JSON.stringify(resj, null, 2))
if (resPull.ok) {
const resj = await resPull.json()
logVerbose(JSON.stringify(resj, null, 2))
}
if (resPull.ok) this.pulled.push(modelid)
else {
logError(`${provider}: failed to pull model ${model}`)
trace?.error(`${provider}: pull model ${model} failed`)
}
if (res.ok) this.pulled.push(modelid)
return { ok: res.ok, status: res.status }
return { ok: resPull.ok, status: resPull.status }
} catch (e) {
logError(`${provider}: failed to pull model ${model}`)
trace?.error(`${provider}: pull model failed`, e)
trace?.error(`${provider}: pull model ${model} failed`, e)
return { ok: false, status: 500, error: serializeError(e) }
}
}
Expand Down

0 comments on commit 143ac3a

Please sign in to comment.