Skip to content

Commit

Permalink
refactor: use setModelAlias for model alias updates ✨
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 13, 2024
1 parent 5c9ba9b commit 5c0d888
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
5 changes: 3 additions & 2 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ export class NodeHost implements RuntimeHost {
if (typeof value === "string") value = { model: value }
const aliases = this._modelAliases[source]
const c = aliases[id] || (aliases[id] = {})
c.model = value.model
c.temperature = value.temperature
if (value.model !== undefined) (c as any).model = value.model
if (!isNaN(value.temperature))
(c as any).temperature = value.temperature
}

async readConfig(): Promise<HostConfiguration> {
Expand Down
8 changes: 3 additions & 5 deletions packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,9 @@ export async function runScriptInternal(
const fenceFormat = options.fenceFormat

if (options.json || options.yaml) overrideStdoutWithStdErr()
if (options.model) runtimeHost.modelAliases.large.model = options.model
if (options.smallModel)
runtimeHost.modelAliases.small.model = options.smallModel
if (options.visionModel)
runtimeHost.modelAliases.vision.model = options.visionModel
if (options.model) runtimeHost.setModelAlias("cli", "large", options.model)
if (options.smallModel) runtimeHost.setModelAlias("cli", "small", options.smallModel)
if (options.visionModel) runtimeHost.setModelAlias("cli", "vision", options.visionModel)
for (const kv of options.modelAlias || []) {
const aliases = parseKeyValuePair(kv)
for (const [key, value] of Object.entries(aliases))
Expand Down
6 changes: 3 additions & 3 deletions packages/cli/src/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ export async function runPromptScriptTests(
testDelay?: string
}
): Promise<PromptScriptTestRunResponse> {
if (options.model) runtimeHost.modelAliases.large.model = options.model
if (options.model) runtimeHost.setModelAlias("cli", "large", options.model)
if (options.smallModel)
runtimeHost.modelAliases.small.model = options.smallModel
runtimeHost.setModelAlias("cli", "small", options.smallModel)
if (options.visionModel)
runtimeHost.modelAliases.vision.model = options.visionModel
runtimeHost.setModelAlias("cli", "vision", options.visionModel)
Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
logVerbose(` ${key}: ${value.model}`)
)
Expand Down
7 changes: 4 additions & 3 deletions packages/cli/src/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ export async function worker() {
} & object
switch (type) {
case "run": {
const { scriptId, files, ...options } = data as {
const { scriptId, files, options } = data as {
scriptId: string
files: string[]
} & object
files: string[],
options: object
}
const { result } = await runScriptInternal(scriptId, files, options)
parentPort.postMessage(result)
break
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export function findEnvVar(
export async function parseDefaultsFromEnv(env: Record<string, string>) {
// legacy
if (env.GENAISCRIPT_DEFAULT_MODEL)
runtimeHost.modelAliases.large.model = env.GENAISCRIPT_DEFAULT_MODEL
runtimeHost.setModelAlias("env", "large", env.GENAISCRIPT_DEFAULT_MODEL)

const rx =
/^GENAISCRIPT(_DEFAULT)?_((?<id>[A-Z0-9_\-]+)_MODEL|MODEL_(?<id2>[A-Z0-9_\-]+))$/i
Expand All @@ -88,7 +88,7 @@ export async function parseDefaultsFromEnv(env: Record<string, string>) {
runtimeHost.setModelAlias("env", id, v)
}
const t = normalizeFloat(env.GENAISCRIPT_DEFAULT_TEMPERATURE)
if (!isNaN(t)) runtimeHost.modelAliases.large.temperature = t
if (!isNaN(t)) runtimeHost.setModelAlias("env", "large", { temperature: t })
}

export async function parseTokenFromEnv(
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ export interface AzureTokenResolver {
): Promise<AuthenticationToken>
}

export type ModelConfiguration = Pick<ModelOptions, "model" | "temperature">
export type ModelConfiguration = Readonly<Pick<ModelOptions, "model" | "temperature">>

export type ModelConfigurations = {
large: ModelConfiguration
Expand Down

0 comments on commit 5c0d888

Please sign in to comment.