Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possessive quantifiers to avoid catastrophic backtracking #258

Merged
merged 5 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.20.0", features = ["extension-module"] }

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
fancy-regex = "0.13.0"
regex = "1.10.3"
Comment on lines +15 to +16
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not absolutely necessary, but adds a tiny speed increase

rustc-hash = "1.1.0"
bstr = "1.5.0"
16 changes: 15 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::num::NonZeroU64;
use std::thread;

use fancy_regex::Regex;
use fancy_regex::RegexBuilder;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::pyclass;
Expand Down Expand Up @@ -417,7 +418,7 @@ impl CoreBPE {
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
let regex = RegexBuilder::new(pattern).backtrack_limit(10_000).build()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this change we should never backtract catastrophically - and if we do, this will warn us early

.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;

let special_regex = {
Expand Down Expand Up @@ -572,6 +573,7 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {

#[cfg(test)]
mod tests {
use fancy_regex::RegexBuilder;
use rustc_hash::FxHashMap as HashMap;

use crate::{byte_pair_split, Rank};
Expand All @@ -596,4 +598,16 @@ mod tests {
let res = byte_pair_split(b"abab", &ranks);
assert_eq!(res, vec![b"ab", b"ab"]);
}

#[test]
fn test_effect_of_backtrack_limit() {
let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)")
.backtrack_limit(10)
.build()
.expect("Failed to build regex")
.clone();

let input = "ab".repeat(100) + "c";
assert!(regex.is_match(&input).is_err(), "Should throw");
}
}
16 changes: 16 additions & 0 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES


@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]):
enc = make_enc()
for c in ["^", "0", "a", "'s", " ", "\n"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stressing different parts of the regex, makin sure none have catastrophic backtracking

print(f"Validating `{c}`")

big_value = c * 10_000
assert big_value == enc.decode(enc.encode(big_value))

big_value = " " + big_value
Copy link
Contributor Author

@l0rinc l0rinc Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space is often optional at the beginning, this way the backtracking can reach the space - let's test that as well

assert big_value == enc.decode(enc.encode(big_value))

big_value = big_value + "\n"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some groups require a newline at the end, stress those paths as well

assert big_value == enc.decode(enc.encode(big_value))

Copy link
Contributor Author

@l0rinc l0rinc Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

big_value = big_value + "x" would still fail for whitespaces, i.e "        x".
Seems less typical than the other cases which are fixed here, not yet sure how to fix this one, though, the fancy-regex seems pretty basic in this regard...


def test_simple():
enc = tiktoken.get_encoding("gpt2")
assert enc.encode("hello world") == [31373, 995]
Expand Down
18 changes: 10 additions & 8 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
FIM_SUFFIX = "<|fim_suffix|>"
ENDOFPROMPT = "<|endofprompt|>"

# The pattern in the original GPT-2 release is:
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# This is equivalent, but executes faster:
_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whitespaces can't be possessive (it needs to step back when encountering a non-whitespace), but we can rule out the offending bactracking case by adding a possessive trailing whitespace check.



def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
Expand All @@ -17,10 +22,7 @@ def gpt2():
return {
"name": "gpt2",
"explicit_n_vocab": 50257,
# The pattern in the original GPT-2 release is:
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# This is equivalent, but executes faster:
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -34,7 +36,7 @@ def r50k_base():
return {
"name": "r50k_base",
"explicit_n_vocab": 50257,
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -48,7 +50,7 @@ def p50k_base():
return {
"name": "p50k_base",
"explicit_n_vocab": 50281,
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": {ENDOFTEXT: 50256},
}
Expand All @@ -62,7 +64,7 @@ def p50k_edit():
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return {
"name": "p50k_edit",
"pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
"pat_str": _legacy_splitter_regex,
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
Expand All @@ -82,7 +84,7 @@ def cl100k_base():
}
return {
"name": "cl100k_base",
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems the cl100k also had some backtracking problems, these possessives improve the situation considerably (e.g. in Java these aren't necessary, see knuddelsgmbh/jtokkit#87)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we collapse the \s+(?!\S) too? it should be equivalent to \s+$ no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds reasonable, but check the tests to understand why they're not the same

"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
Expand Down
Loading