Skip to content

Commit

Permalink
feat: 🎉 add custom weights and pre-encoded tokens support
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 14, 2024
1 parent f7a9321 commit 0f75f5d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
31 changes: 26 additions & 5 deletions docs/src/content/docs/reference/scripts/choices.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ sidebar:

You can specify a list of preferred words (choices) in the script metadata. It will increase the probability of the model generating the specified words.

- Each word should match a single token for the desired model!
- For some models, GenAIScript does not have a token encoder so it won't be able to compute the logit bias for the choices
- Each word should match a single token for the desired model!
- For some models, GenAIScript does not have a token encoder so it won't be able to compute the logit bias for the choices

```js
script({
Expand All @@ -22,12 +22,33 @@ script({
ERR
```

## Custom weights

You can tune the probability of each choice by providing a weight for each choice.
The default weight is `5`.

```js '{ token: "ERR", weight: 10 }'
script({
choices: ["OK", { token: "ERR", weight: 10 }],
})
```

## Pre-encoded tokens

For models where GenAIScript does not have a token encoder, you can provide the pre-encoded tokens.

```js
script({
choices: [{ token: 12345, weight: 10 }],
})
```

## Logit Bias

Internally, GenAIScript tokenizes the word and build the [logit_bias](https://help.openai.com/en/articles/5247780-using-logit-bias-to-alter-token-probability-with-the-openai-api) for each token.

- choices: `OK`, `ERR`
- logit bias: `{"5175":5,"5392":5}`
- choices: `OK`, `ERR`
- logit bias: `{"5175":5,"5392":5}`

## Logprobs

Expand All @@ -38,4 +59,4 @@ You can enable [logprobs](/genaiscript/reference/scripts/logprobs) to visualize
<span class="logprobs" title="100% (-0.000003)" style="background: rgb(0, 0, 180); color: white; white-space: pre; font-family: monospace;">ERR</span>
<span class="logprobs" title="32.07% (-1.14)" style="background: rgb(122, 0, 58); color: white; white-space: pre; font-family: monospace;">.</span>

---
---
13 changes: 8 additions & 5 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,9 @@ export function mergeGenerationOptions(
async function choicesToLogitBias(
trace: MarkdownTrace,
model: string,
choices: ElementOrArray<string>
choices: ElementOrArray<
string | { token: string | number; weight?: number }
>
) {
choices = arrayify(choices)
if (!choices?.length) return undefined
Expand All @@ -764,12 +766,13 @@ async function choicesToLogitBias(
}
const res = Object.fromEntries(
choices.map((c) => {
const tokens = encode(c)
if (tokens.length !== 1)
const { token, weight } = typeof c === "string" ? { token: c } : c
const encoded = typeof token === "number" ? [token] : encode(token)
if (encoded.length !== 1)
trace.warn(
`choice ${c} tokenizes to ${tokens.join(", ")} (expected one token)`
`choice ${c} tokenizes to ${encoded.join(", ")} (expected one token)`
)
return [tokens[0], CHOICE_LOGIT_BIAS]
return [encoded[0], isNaN(weight) ? CHOICE_LOGIT_BIAS : weight]
})
)
trace.itemValue("choices", choices.join(", "))
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ interface ModelOptions extends ModelConnectionOptions, ModelTemplateOptions {
/**
* A list of keywords that should be found in the output.
*/
choices?: ElementOrArray<string>
choices?: ElementOrArray<string | { token: string | number; weight?: number }>

/**
* Returns the log probabilities of the each tokens. Not supported in all models.
Expand Down
2 changes: 1 addition & 1 deletion packages/sample/genaisrc/choices.genai.mjs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
script({
choices: ["OK", "ERR"],
choices: ["OK", { token: "ERR", weight: 0.2 }],
})
// tests logit_bias
const res = await runPrompt(
Expand Down

0 comments on commit 0f75f5d

Please sign in to comment.