Skip to content

Commit

Permalink
feat: ➕ add logitBias, topP support and refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 11, 2024
1 parent 1ef1dd8 commit 41dc99c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 37 deletions.
11 changes: 11 additions & 0 deletions docs/public/schemas/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
"type": "string",
"description": "Description of the LLM provider"
},
"logitBias": {
"type": "boolean",
"description": "Indicates if logit_bias is supported"
},
"logprobs": {
"type": "boolean",
"description": "Indicates if log probabilities are supported"
Expand All @@ -45,6 +49,10 @@
"type": "boolean",
"description": "Indicates if top log probabilities are supported"
},
"topP": {
"type": "boolean",
"description": "Indicates if top_p is supported"
},
"seed": {
"type": "boolean",
"description": "Indicates if seeding is supported"
Expand All @@ -54,15 +62,18 @@
"description": "Indicates if tools are supported"
}
},
"additionalProperties": false,
"required": ["id", "detail"]
}
}
},
"pricings": {
"type": "object",
"additionalProperties": false,
"patternProperties": {
"^[a-zA-Z0-9:_-]+$": {
"type": "object",
"additionalProperties": false,
"properties": {
"price_per_million_input_tokens": {
"type": "number"
Expand Down
58 changes: 51 additions & 7 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ import {
MAX_DATA_REPAIRS,
MAX_TOOL_CALLS,
MAX_TOOL_CONTENT_TOKENS,
MODEL_PROVIDERS,
SYSTEM_FENCE,
} from "./constants"
import { parseAnnotations } from "./annotations"
import { errorMessage, isCancelError, serializeError } from "./error"
import { estimateChatTokens } from "./chatencoder"
import { createChatTurnGenerationContext } from "./runpromptcontext"
import { dedent } from "./indent"
import { traceLanguageModelConnection } from "./models"
import { parseModelIdentifier, traceLanguageModelConnection } from "./models"
import {
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImage,
Expand Down Expand Up @@ -834,7 +835,6 @@ export async function executeChatSession(
while (true) {
stats.turns++
const tokens = estimateChatTokens(model, messages)
logVerbose(`prompting ${model} (~${tokens ?? "?"} tokens)\n`)
if (messages)
trace.details(
`💬 messages (${messages.length})`,
Expand All @@ -857,7 +857,7 @@ export async function executeChatSession(
model,
choices
)
req = deleteUndefinedValues({
req = {
model,
temperature: temperature,
top_p: topP,
Expand Down Expand Up @@ -888,11 +888,11 @@ export async function executeChatSession(
}
: undefined,
messages,
})
if (/^o1/i.test(model)) {
req.max_completion_tokens = maxTokens
delete req.max_tokens
}
updateChatFeatures(trace, req)
logVerbose(
`chat: sending ${messages.length} messages to ${model} (~${tokens ?? "?"} tokens)\n`
)
resp = await completer(
req,
connectionToken,
Expand Down Expand Up @@ -941,6 +941,50 @@ export async function executeChatSession(
}
}

function updateChatFeatures(
trace: MarkdownTrace,
req: CreateChatCompletionRequest
) {
const { provider, model } = parseModelIdentifier(req.model)
const features = MODEL_PROVIDERS.find(({ id }) => id === provider)

if (!isNaN(req.seed) && features?.seed === false) {
logVerbose(`seed: disabled, not supported by ${provider}`)
trace.itemValue(`seed`, `disabled`)
delete req.seed // some providers do not support seed
}
if (req.logit_bias && features?.logitBias === false) {
logVerbose(`logit_bias: disabled, not supported by ${provider}`)
trace.itemValue(`logit_bias`, `disabled`)
delete req.logit_bias // some providers do not support logit_bias
}
if (!isNaN(req.top_p) && features?.topP === false) {
logVerbose(`top_p: disabled, not supported by ${provider}`)
trace.itemValue(`top_p`, `disabled`)
delete req.top_p
}
if (req.logprobs && features?.logprobs === false) {
logVerbose(`logprobs: disabled, not supported by ${provider}`)
trace.itemValue(`logprobs`, `disabled`)
delete req.logprobs
delete req.top_logprobs
}
if (
req.top_logprobs &&
(features?.logprobs === false || features?.topLogprobs === false)
) {
logVerbose(`top_logprobs: disabled, not supported by ${provider}`)
trace.itemValue(`top_logprobs`, `disabled`)
delete req.top_logprobs
}
if (/^o1/i.test(model) && !req.max_completion_tokens) {
req.max_completion_tokens = req.max_tokens
delete req.max_tokens
}

deleteUndefinedValues(req)
}

export function tracePromptResult(trace: MarkdownTrace, resp: RunPromptResult) {
const { json, text } = resp

Expand Down
14 changes: 4 additions & 10 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,12 @@ export const MODEL_PROVIDERS = Object.freeze<
{
id: string
detail: string
/**
* Supports seed
*/
seed?: boolean
/**
* Supports logit_bias (choices)
*/
logit_bias?: boolean
/**
* Supports tools. Set to false to enable fallbackTools
*/
logitBias?: boolean
tools?: boolean
logprobs?: boolean
topLogprobs?: boolean
topP?: boolean
}[]
>(CONFIGURATION_DATA.providers)
export const MODEL_PRICINGS = Object.freeze<
Expand Down
5 changes: 3 additions & 2 deletions packages/core/src/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
{
"id": "github",
"detail": "GitHub Models",
"logprobs": false
"logprobs": false,
"topLogprobs": false
},
{
"id": "azure",
Expand Down Expand Up @@ -47,7 +48,7 @@
{
"id": "ollama",
"detail": "Ollama local model",
"logit_bias": false
"logitBias": false
},
{
"id": "lmstudio",
Expand Down
18 changes: 0 additions & 18 deletions packages/core/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async (
const { provider, model } = parseModelIdentifier(req.model)
const { encode: encoder } = await resolveTokenEncoder(model)

const features = MODEL_PROVIDERS.find(({ id }) => id === provider)

const cache = !!cacheOrName || !!cacheName
const cacheStore = getChatCompletionCache(
typeof cacheOrName === "string" ? cacheOrName : cacheName
Expand Down Expand Up @@ -140,22 +138,6 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async (
model,
} satisfies CreateChatCompletionRequest)

if (!isNaN(postReq.seed) && features?.seed === false) {
logVerbose(`seed: disabled, not supported by ${provider}`)
trace.itemValue(`seed`, `disabled`)
delete postReq.seed // some providers do not support seed
}
if (postReq.logit_bias && features?.logit_bias === false) {
logVerbose(`logit_bias: disabled, not supported by ${provider}`)
trace.itemValue(`logit_bias`, `disabled`)
delete postReq.logit_bias // some providers do not support logit_bias
}
if (!isNaN(postReq.top_p) && features?.top_p === false) {
logVerbose(`top_p: disabled, not supported by ${provider}`)
trace.itemValue(`top_p`, `disabled`)
delete postReq.top_p
}

// stream_options fails in some cases
if (model === "gpt-4-turbo-v" || /mistral/i.test(model)) {
delete postReq.stream_options
Expand Down

0 comments on commit 41dc99c

Please sign in to comment.