From 41dc99cff9a1fe634eeb143020ee64504f4bf75e Mon Sep 17 00:00:00 2001 From: pelikhan Date: Wed, 11 Dec 2024 08:25:16 -0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9E=95=20add=20logitBias,=20topP=20s?= =?UTF-8?q?upport=20and=20refactor=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/public/schemas/llms.json | 11 +++++++ packages/core/src/chat.ts | 58 ++++++++++++++++++++++++++++++---- packages/core/src/constants.ts | 14 +++----- packages/core/src/llms.json | 5 +-- packages/core/src/openai.ts | 18 ----------- 5 files changed, 69 insertions(+), 37 deletions(-) diff --git a/docs/public/schemas/llms.json b/docs/public/schemas/llms.json index ba6c853b66..82970443d3 100644 --- a/docs/public/schemas/llms.json +++ b/docs/public/schemas/llms.json @@ -37,6 +37,10 @@ "type": "string", "description": "Description of the LLM provider" }, + "logitBias": { + "type": "boolean", + "description": "Indicates if logit_bias is supported" + }, "logprobs": { "type": "boolean", "description": "Indicates if log probabilities are supported" @@ -45,6 +49,10 @@ "type": "boolean", "description": "Indicates if top log probabilities are supported" }, + "topP": { + "type": "boolean", + "description": "Indicates if top_p is supported" + }, "seed": { "type": "boolean", "description": "Indicates if seeding is supported" @@ -54,15 +62,18 @@ "description": "Indicates if tools are supported" } }, + "additionalProperties": false, "required": ["id", "detail"] } } }, "pricings": { "type": "object", + "additionalProperties": false, "patternProperties": { "^[a-zA-Z0-9:_-]+$": { "type": "object", + "additionalProperties": false, "properties": { "price_per_million_input_tokens": { "type": "number" diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 128a6cc96d..83ab1d92d4 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -37,6 +37,7 @@ import { MAX_DATA_REPAIRS, MAX_TOOL_CALLS, MAX_TOOL_CONTENT_TOKENS, + MODEL_PROVIDERS, SYSTEM_FENCE, } from "./constants" import { parseAnnotations } from "./annotations" @@ -44,7 +45,7 @@ import { errorMessage, isCancelError, serializeError } from "./error" import { estimateChatTokens } from "./chatencoder" import { createChatTurnGenerationContext } from "./runpromptcontext" import { dedent } from "./indent" -import { traceLanguageModelConnection } from "./models" +import { parseModelIdentifier, traceLanguageModelConnection } from "./models" import { ChatCompletionAssistantMessageParam, ChatCompletionContentPartImage, @@ -834,7 +835,6 @@ export async function executeChatSession( while (true) { stats.turns++ const tokens = estimateChatTokens(model, messages) - logVerbose(`prompting ${model} (~${tokens ?? "?"} tokens)\n`) if (messages) trace.details( `💬 messages (${messages.length})`, @@ -857,7 +857,7 @@ export async function executeChatSession( model, choices ) - req = deleteUndefinedValues({ + req = { model, temperature: temperature, top_p: topP, @@ -888,11 +888,11 @@ export async function executeChatSession( } : undefined, messages, - }) - if (/^o1/i.test(model)) { - req.max_completion_tokens = maxTokens - delete req.max_tokens } + updateChatFeatures(trace, req) + logVerbose( + `chat: sending ${messages.length} messages to ${model} (~${tokens ?? "?"} tokens)\n` + ) resp = await completer( req, connectionToken, @@ -941,6 +941,50 @@ export async function executeChatSession( } } +function updateChatFeatures( + trace: MarkdownTrace, + req: CreateChatCompletionRequest +) { + const { provider, model } = parseModelIdentifier(req.model) + const features = MODEL_PROVIDERS.find(({ id }) => id === provider) + + if (!isNaN(req.seed) && features?.seed === false) { + logVerbose(`seed: disabled, not supported by ${provider}`) + trace.itemValue(`seed`, `disabled`) + delete req.seed // some providers do not support seed + } + if (req.logit_bias && features?.logitBias === false) { + logVerbose(`logit_bias: disabled, not supported by ${provider}`) + trace.itemValue(`logit_bias`, `disabled`) + delete req.logit_bias // some providers do not support logit_bias + } + if (!isNaN(req.top_p) && features?.topP === false) { + logVerbose(`top_p: disabled, not supported by ${provider}`) + trace.itemValue(`top_p`, `disabled`) + delete req.top_p + } + if (req.logprobs && features?.logprobs === false) { + logVerbose(`logprobs: disabled, not supported by ${provider}`) + trace.itemValue(`logprobs`, `disabled`) + delete req.logprobs + delete req.top_logprobs + } + if ( + req.top_logprobs && + (features?.logprobs === false || features?.topLogprobs === false) + ) { + logVerbose(`top_logprobs: disabled, not supported by ${provider}`) + trace.itemValue(`top_logprobs`, `disabled`) + delete req.top_logprobs + } + if (/^o1/i.test(model) && !req.max_completion_tokens) { + req.max_completion_tokens = req.max_tokens + delete req.max_tokens + } + + deleteUndefinedValues(req) +} + export function tracePromptResult(trace: MarkdownTrace, resp: RunPromptResult) { const { json, text } = resp diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index d6f8ac40f3..aded0ccdc3 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -214,18 +214,12 @@ export const MODEL_PROVIDERS = Object.freeze< { id: string detail: string - /** - * Supports seed - */ seed?: boolean - /** - * Supports logit_bias (choices) - */ - logit_bias?: boolean - /** - * Supports tools. Set to false to enable fallbackTools - */ + logitBias?: boolean tools?: boolean + logprobs?: boolean + topLogprobs?: boolean + topP?: boolean }[] >(CONFIGURATION_DATA.providers) export const MODEL_PRICINGS = Object.freeze< diff --git a/packages/core/src/llms.json b/packages/core/src/llms.json index 80e1a5cd69..66910aef87 100644 --- a/packages/core/src/llms.json +++ b/packages/core/src/llms.json @@ -8,7 +8,8 @@ { "id": "github", "detail": "GitHub Models", - "logprobs": false + "logprobs": false, + "topLogprobs": false }, { "id": "azure", @@ -47,7 +48,7 @@ { "id": "ollama", "detail": "Ollama local model", - "logit_bias": false + "logitBias": false }, { "id": "lmstudio", diff --git a/packages/core/src/openai.ts b/packages/core/src/openai.ts index c8d94e2af0..478b443875 100644 --- a/packages/core/src/openai.ts +++ b/packages/core/src/openai.ts @@ -98,8 +98,6 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( const { provider, model } = parseModelIdentifier(req.model) const { encode: encoder } = await resolveTokenEncoder(model) - const features = MODEL_PROVIDERS.find(({ id }) => id === provider) - const cache = !!cacheOrName || !!cacheName const cacheStore = getChatCompletionCache( typeof cacheOrName === "string" ? cacheOrName : cacheName @@ -140,22 +138,6 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( model, } satisfies CreateChatCompletionRequest) - if (!isNaN(postReq.seed) && features?.seed === false) { - logVerbose(`seed: disabled, not supported by ${provider}`) - trace.itemValue(`seed`, `disabled`) - delete postReq.seed // some providers do not support seed - } - if (postReq.logit_bias && features?.logit_bias === false) { - logVerbose(`logit_bias: disabled, not supported by ${provider}`) - trace.itemValue(`logit_bias`, `disabled`) - delete postReq.logit_bias // some providers do not support logit_bias - } - if (!isNaN(postReq.top_p) && features?.top_p === false) { - logVerbose(`top_p: disabled, not supported by ${provider}`) - trace.itemValue(`top_p`, `disabled`) - delete postReq.top_p - } - // stream_options fails in some cases if (model === "gpt-4-turbo-v" || /mistral/i.test(model)) { delete postReq.stream_options