Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #122 from philpax/buffer-utf8
Browse files Browse the repository at this point in the history
fix(llama): buffer tokens until valid UTF-8
  • Loading branch information
philpax authored Apr 13, 2023
2 parents 0e553a0 + 6b1488f commit 7dd6748
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 53 deletions.
2 changes: 1 addition & 1 deletion ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ impl Tensor {
}
}

fn with_alive_ctx<U>(&self, f: impl Fn() -> U) -> U {
fn with_alive_ctx<U>(&self, mut f: impl FnMut() -> U) -> U {
if let Some(_ctx) = self.ctx.upgrade() {
f()
} else {
Expand Down
3 changes: 0 additions & 3 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ impl ModelLoad {
LoadProgress::HyperparametersLoaded(hparams) => {
log::debug!("Loaded hyperparameters {hparams:#?}")
}
LoadProgress::BadToken { index } => {
log::info!("Warning: Bad token in vocab at index {index}")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
Expand Down
23 changes: 14 additions & 9 deletions llama-rs/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ fn load_vocabulary(path: &Path) -> Vocabulary {
let mut token_to_id = HashMap::new();
let mut max_token_length = 0;

// TODO: Does the original model use valid UTF-8 for its tokens? This seems a little suspect to me.
for (idx, piece) in proto.get_pieces().iter().enumerate() {
let word = piece.get_piece().to_string();
let word = piece.get_piece().as_bytes();
max_token_length = max_token_length.max(word.len());
id_to_token.push(word.clone());
token_to_id.insert(word, idx as i32);
id_to_token.push(word.to_owned());
token_to_id.insert(word.to_owned(), idx as i32);
id_to_token_score.push(piece.get_score());
}
Vocabulary {
Expand Down Expand Up @@ -128,13 +129,17 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String
fn write_tokens(file: &mut File, vocab: &Vocabulary) -> Result<(), String> {
let mut values: Vec<u8> = vec![];
for (i, token) in vocab.id_to_token.iter().enumerate() {
let text = match token {
_ if token.contains("<unk>") => " \u{2047} ".as_bytes().to_vec(),
_ if token.contains("s>") => vec![],
_ if token.len() == 6 && token.contains("<0x") => {
vec![u8::from_str_radix(&token[3..5], 16).unwrap()]
let text = if let Ok(token) = std::str::from_utf8(token) {
match token {
_ if token.contains("<unk>") => " \u{2047} ".as_bytes().to_vec(),
_ if token.contains("s>") => vec![],
_ if token.len() == 6 && token.contains("<0x") => {
vec![u8::from_str_radix(&token[3..5], 16).unwrap()]
}
_ => token.replace('\u{2581}', " ").as_bytes().to_vec(),
}
_ => token.replace('\u{2581}', " ").as_bytes().to_vec(),
} else {
token.clone()
};
values.extend((text.len() as i32).to_le_bytes());
values.extend(&text);
Expand Down
166 changes: 126 additions & 40 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ impl Display for InferenceStats {
}

type TokenId = i32;
type Token = String;
type Token = Vec<u8>;
type TokenScore = f32;

/// The vocabulary used by a model.
Expand All @@ -329,7 +329,7 @@ pub struct Vocabulary {
max_token_length: usize,
}
impl Vocabulary {
fn token(&self, idx: usize) -> &str {
fn token(&self, idx: usize) -> &[u8] {
&self.id_to_token[idx]
}
}
Expand Down Expand Up @@ -406,14 +406,6 @@ impl std::fmt::Display for TokenBias {
pub enum LoadProgress<'a> {
/// The hyperparameters have been loaded from the model.
HyperparametersLoaded(&'a Hyperparameters),
/// A bad token was encountered during the loading process.
///
/// This can be ignored, but invalid tokens will be replaced with
/// the `�` character.
BadToken {
/// The index within the vocabulary.
index: usize,
},
/// The context has been created.
ContextSize {
/// The size of the context.
Expand Down Expand Up @@ -595,7 +587,7 @@ impl Model {
pub fn load(
path: impl AsRef<Path>,
n_context_tokens: usize,
load_progress_callback: impl Fn(LoadProgress),
mut load_progress_callback: impl FnMut(LoadProgress),
) -> Result<(Model, Vocabulary), LoadError> {
use std::fs::File;
use std::io::BufReader;
Expand All @@ -621,6 +613,20 @@ impl Model {
Ok(bytes)
}

fn read_bytes_with_len(
reader: &mut impl BufRead,
len: usize,
) -> Result<Vec<u8>, LoadError> {
let mut bytes = vec![0u8; len];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: len,
})?;
Ok(bytes)
}

fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}
Expand All @@ -635,15 +641,7 @@ impl Model {

/// Helper function. Reads a string from the buffer and returns it.
fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
let s = String::from_utf8(buf)?;
Ok(s)
Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?)
}

// Verify magic
Expand Down Expand Up @@ -699,14 +697,10 @@ impl Model {

for i in 0..hparams.n_vocab {
let len = read_i32(&mut reader)?;
if let Ok(word) = read_string(&mut reader, len as usize) {
max_token_length = max_token_length.max(word.len());
id_to_token.push(word.clone());
token_to_id.insert(word, TokenId::try_from(i)?);
} else {
load_progress_callback(LoadProgress::BadToken { index: i });
id_to_token.push("�".to_string());
}
let token = read_bytes_with_len(&mut reader, len as usize)?;
max_token_length = max_token_length.max(token.len());
id_to_token.push(token.clone());
token_to_id.insert(token, TokenId::try_from(i)?);

// Token score, currently unused
if !is_legacy_model {
Expand Down Expand Up @@ -1444,7 +1438,7 @@ impl InferenceSession {
vocab: &Vocabulary,
params: &InferenceParameters,
prompt: &str,
callback: impl Fn(&str) -> Result<(), E>,
mut callback: impl FnMut(&[u8]) -> Result<(), E>,
) -> Result<(), InferenceError> {
let beginning_of_sentence = self.n_past == 0;
let prompt_tokens: Vec<TokenId> = vocab
Expand Down Expand Up @@ -1481,7 +1475,7 @@ impl InferenceSession {
vocab: &'v Vocabulary,
params: &InferenceParameters,
rng: &mut impl rand::Rng,
) -> Result<&'v str, InferenceError> {
) -> Result<&'v [u8], InferenceError> {
if self.n_past + 1 >= model.hparams.n_ctx {
return Err(InferenceError::ContextFull);
}
Expand Down Expand Up @@ -1522,15 +1516,19 @@ impl InferenceSession {
prompt: &str,
maximum_token_count: Option<usize>,
rng: &mut impl rand::Rng,
callback: impl Fn(&str) -> Result<(), E>,
mut callback: impl FnMut(&str) -> Result<(), E>,
) -> Result<InferenceStats, InferenceError> {
let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX);
if params.play_back_previous_tokens {
// "Play back" the existing tokens, so that loading from an inference snapshot works
// as expected.
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in &self.tokens {
if let Err(e) = callback(vocab.token(*token_id as usize)) {
return Err(InferenceError::UserCallback(Box::new(e)));
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(vocab.token(*token_id as usize)) {
if let Err(e) = callback(&tokens) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
}
}
Expand All @@ -1541,7 +1539,13 @@ impl InferenceSession {

// Feed the initial prompt through the transformer, to update its
// context window with new data.
self.feed_prompt(model, vocab, params, prompt, |tk| callback(tk))?;
self.feed_prompt(
model,
vocab,
params,
prompt,
TokenUtf8Buffer::adapt_callback(&mut callback),
)?;
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

Expand All @@ -1550,15 +1554,19 @@ impl InferenceSession {
// EndOfText token, or we run out of space in the context window,
// or we reach the specified limit.
let mut tokens_processed = 0;
let mut token_utf8_buf = TokenUtf8Buffer::new();
while tokens_processed < maximum_token_count {
let token = match self.infer_next_token(model, vocab, params, rng) {
Ok(token) => token,
Err(InferenceError::EndOfText) => break,
Err(e) => return Err(e),
};

if let Err(e) = callback(token) {
return Err(InferenceError::UserCallback(Box::new(e)));
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(token) {
if let Err(e) = callback(&tokens) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}

tokens_processed += 1;
Expand Down Expand Up @@ -1691,7 +1699,7 @@ impl Vocabulary {
&'a self,
text: &str,
bos: bool,
) -> Result<Vec<(&'a str, TokenId)>, InferenceError> {
) -> Result<Vec<(&'a [u8], TokenId)>, InferenceError> {
let len = text.len();

let mut score = vec![0usize; len + 1];
Expand All @@ -1701,7 +1709,6 @@ impl Vocabulary {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let Ok(sub) = std::str::from_utf8(sub) else { continue; };
let token = self.token_to_id.get(sub);

if let Some(token) = token {
Expand All @@ -1725,14 +1732,14 @@ impl Vocabulary {
if token_id == 0 {
return Err(InferenceError::TokenizationFailed);
}
let token = self.id_to_token[token_id as usize].as_str();
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token, token_id));
i -= token.len();
}

if bos {
// TODO: replace with vocab.bos
res.push(("", 1));
res.push((&[], 1));
}

// Pieces are in reverse order so correct that
Expand All @@ -1741,3 +1748,82 @@ impl Vocabulary {
Ok(res)
}
}

/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text.
///
/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8
/// from multiple tokens. This helps alleviate that issue.
#[derive(Clone, PartialEq, Default)]
pub struct TokenUtf8Buffer(Vec<u8>);
impl TokenUtf8Buffer {
/// Create a new buffer.
pub const fn new() -> Self {
Self(vec![])
}

/// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text,
/// it is returned and the buffer is cleared for next use.
pub fn push(&mut self, token: &[u8]) -> Option<String> {
self.0.extend_from_slice(token);
match std::str::from_utf8(&self.0) {
Ok(s) => {
let out = s.to_owned();
self.0 = vec![];
Some(out)
}
Err(..) => {
for i in 1..self.0.len() {
let slice = &self.0[i..];
if slice.is_empty() {
break;
}

if let Ok(s) = std::str::from_utf8(slice) {
let out = s.to_owned();
self.0 = vec![];
return Some(out);
}
}
None
}
}
}

/// Adapt a `&str` callback so that it can be used in a `&[u8]` context.
fn adapt_callback<'a, E: std::error::Error + 'static>(
mut callback: impl FnMut(&str) -> Result<(), E> + 'a,
) -> impl FnMut(&[u8]) -> Result<(), E> + 'a {
let mut buffer = Self::new();
move |token| match buffer.push(token) {
Some(tokens) => callback(&tokens),
None => Ok(()),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_valid_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(b"hello").as_deref(), Some("hello"));
assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€"));
}

#[test]
fn test_partial_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
}

#[test]
fn test_invalid_prelude_for_valid_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(&[0xD8]).as_deref(), None);
assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
}
}

0 comments on commit 7dd6748

Please sign in to comment.