Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Transformers and TikToken #4

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@ pip install chonkie[all]
Here's a basic example to get you started:

```python
# First import the chunker you want from Chonkie
from chonkie import TokenChunker

# Import your favorite tokenizer library
# Also supports AutoTokenizers, TikToken and AutoTikTokenizer
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_pretrained("gpt2)

# Initialize the chunker
chunker = TokenChunker()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Homepage = "https://github.com/bhavnicksm/chonkie"
sentence = ["spacy>=3.0.0"]
semantic = ["sentence-transformers>=2.0.0", "numpy>=1.23.0"]
all = ["spacy>=3.0.0", "sentence-transformers>=2.0.0", "numpy>=1.23.0"]
dev = ["pytest>=6.2.0"]
dev = ["pytest>=6.2.0", "tranformers>=4.0.0", "tiktoken>=0.2.0"]

[tool.setuptools]
package-dir = {"" = "src"}
Expand Down
41 changes: 41 additions & 0 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,48 @@ class BaseChunker(ABC):
All chunker implementations should inherit from this class and implement
the chunk() method according to their specific chunking strategy.
"""
def __init__(self, tokenizer):
"""Initialize the chunker with a tokenizer.

Args:
tokenizer: Tokenizer object to be used for tokenizing text
"""
self.tokenizer = tokenizer
self._tokenizer_backend = self._get_tokenizer_backend()

def _get_tokenizer_backend(self):
"""Return the backend tokenizer object."""
if "transformers" in str(type(self.tokenizer)):
return "transformers"
elif "tokenizers" in str(type(self.tokenizer)):
return "tokenizers"
elif "tiktoken" in str(type(self.tokenizer)):
return "tiktoken"
else:
raise ValueError("Tokenizer backend not supported")

def _encode(self, text: str):
"""Encode text using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
return self.tokenizer.encode(text)
elif self._tokenizer_backend == "tokenizers":
return self.tokenizer.encode(text).ids
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode(text)
else:
raise ValueError("Tokenizer backend not supported.")

def _encode_batch(self, texts: List[str]):
"""Encode a batch of texts using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
return self.tokenizer.batch_encode_plus(texts)['input_ids']
elif self._tokenizer_backend == "tokenizers":
return self.tokenizer.encode_batch(texts)
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode_batch(texts)
else:
raise ValueError("Tokenizer backend not supported.")

@abstractmethod
def chunk(self, text: str) -> List[Chunk]:
"""Split text into chunks according to the implementation strategy.
Expand Down
7 changes: 4 additions & 3 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(
self,
tokenizer: Tokenizer,
sentence_transformer_model: str,
max_chunk_size: int,
similarity_threshold: Optional[float] = None,
similarity_percentile: Optional[float] = None,
max_chunk_size: int = 512,
initial_sentences: int = 1,
sentence_mode: str = "heuristic",
spacy_model: str = "en_core_web_sm"
Expand All @@ -52,6 +52,8 @@ def __init__(
ValueError: If parameters are invalid
ImportError: If required dependencies aren't installed
"""
super().__init__(tokenizer)

if max_chunk_size <= 0:
raise ValueError("max_chunk_size must be positive")
if similarity_threshold is not None and (similarity_threshold < 0 or similarity_threshold > 1):
Expand All @@ -65,7 +67,6 @@ def __init__(
if sentence_mode not in ["heuristic", "spacy"]:
raise ValueError("sentence_mode must be 'heuristic' or 'spacy'")

self.tokenizer = tokenizer
self.max_chunk_size = max_chunk_size
self.similarity_threshold = similarity_threshold
self.similarity_percentile = similarity_percentile
Expand Down Expand Up @@ -159,7 +160,7 @@ def _prepare_sentences(self, text: str) -> List[Sentence]:
embeddings = self.sentence_transformer.encode(raw_sentences, convert_to_numpy=True)

# Batch compute token counts
token_counts = [len(encoding.ids) for encoding in self.tokenizer.encode_batch(raw_sentences)]
token_counts = [len(encoding) for encoding in self._encode_batch(raw_sentences)]

# Create Sentence objects with all precomputed information
sentences = [
Expand Down
11 changes: 6 additions & 5 deletions src/chonkie/chunker/sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class SentenceChunker(BaseChunker):
def __init__(
self,
tokenizer: Tokenizer,
chunk_size: int,
chunk_overlap: int,
chunk_size: int = 512,
chunk_overlap: int = 128,
mode: str = "simple",
min_sentences_per_chunk: int = 1,
spacy_model: str = "en_core_web_sm"
Expand All @@ -59,6 +59,8 @@ def __init__(
ValueError: If parameters are invalid
Warning: If spacy mode is requested but spacy is not available
"""
super().__init__(tokenizer)

if chunk_size <= 0:
raise ValueError("chunk_size must be positive")
if chunk_overlap >= chunk_size:
Expand All @@ -68,7 +70,6 @@ def __init__(
if min_sentences_per_chunk < 1:
raise ValueError("min_sentences_per_chunk must be at least 1")

self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_sentences_per_chunk = min_sentences_per_chunk
Expand Down Expand Up @@ -191,8 +192,8 @@ def _get_token_counts(self, sentences: List[str]) -> List[int]:
List of token counts for each sentence
"""
# Batch encode all sentences at once
encoded_sentences = self.tokenizer.encode_batch(sentences)
return [len(encoded.ids) for encoded in encoded_sentences]
encoded_sentences = self._encode_batch(sentences)
return [len(encoded) for encoded in encoded_sentences]

def _create_chunk(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/chonkie/chunker/spdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ def __init__(
self,
tokenizer,
sentence_transformer_model: str,
max_chunk_size: int,
similarity_threshold: float = None,
similarity_percentile: float = None,
max_chunk_size: int = 512,
initial_sentences: int = 1,
sentence_mode: str = "heuristic",
spacy_model: str = "en_core_web_sm",
Expand Down
8 changes: 4 additions & 4 deletions src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .base import Chunk, BaseChunker
class TokenChunker(BaseChunker):
def __init__(self, tokenizer: Tokenizer, chunk_size: int, chunk_overlap: int):
def __init__(self, tokenizer: Tokenizer, chunk_size: int = 512, chunk_overlap: int = 128):
"""Initialize the TokenChunker with configuration parameters.

Args:
Expand All @@ -14,12 +14,12 @@ def __init__(self, tokenizer: Tokenizer, chunk_size: int, chunk_overlap: int):
Raises:
ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size
"""
super().__init__(tokenizer)
if chunk_size <= 0:
raise ValueError("chunk_size must be positive")
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")

self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

Expand All @@ -36,8 +36,8 @@ def chunk(self, text: str) -> List[Chunk]:
return []

# Encode full text
encoding = self.tokenizer.encode(text)
text_tokens = encoding.ids
encoding = self._encode(text)
text_tokens = encoding
chunks = []

# Calculate chunk positions
Expand Down
11 changes: 6 additions & 5 deletions src/chonkie/chunker/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .base import Chunk, BaseChunker

class WordChunker(BaseChunker):
def __init__(self, tokenizer: Tokenizer, chunk_size: int, chunk_overlap: int, mode: str = "simple"):
def __init__(self, tokenizer: Tokenizer, chunk_size: int = 512, chunk_overlap: int = 128, mode: str = "simple"):
"""Initialize the WordChunker with configuration parameters.

Args:
Expand All @@ -16,14 +16,15 @@ def __init__(self, tokenizer: Tokenizer, chunk_size: int, chunk_overlap: int, mo
Raises:
ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size or invalid mode
"""
super().__init__(tokenizer)

if chunk_size <= 0:
raise ValueError("chunk_size must be positive")
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
if mode not in ["simple", "advanced"]:
raise ValueError("mode must be either 'heuristic' or 'advanced'")

self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.mode = mode
Expand Down Expand Up @@ -139,7 +140,7 @@ def _get_token_count(self, text: str) -> int:
Returns:
Number of tokens
"""
return len(self.tokenizer.encode(text).ids)
return len(self._encode(text))

def _create_chunk(self, words: List[str], start_idx: int, end_idx: int) -> Tuple[Chunk, int]:
"""Create a chunk from a list of words.
Expand Down Expand Up @@ -170,8 +171,8 @@ def _get_word_list_token_counts(self, words: List[str]) -> List[int]:
Returns:
List of token counts for each word
"""
encodings = self.tokenizer.encode_batch(words)
return [len(encoding.ids) for encoding in encodings]
encodings = self._encode_batch(words)
return [len(encoding) for encoding in encodings]

def chunk(self, text: str) -> List[Chunk]:
"""Split text into overlapping chunks based on words while respecting token limits.
Expand Down
66 changes: 65 additions & 1 deletion tests/chunker/test_token_chunker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import pytest

import tiktoken
from transformers import AutoTokenizer
from tokenizers import Tokenizer

from chonkie import Chunk
from chonkie import TokenChunker

@pytest.fixture
def transformers_tokenizer():
return AutoTokenizer.from_pretrained("gpt2")

@pytest.fixture
def tiktokenizer():
return tiktoken.get_encoding("gpt2")

@pytest.fixture
def tokenizer():
return Tokenizer.from_pretrained("gpt2")
Expand All @@ -12,7 +24,7 @@ def sample_text():
text = """The process of text chunking in RAG applications represents a delicate balance between competing requirements. On one side, we have the need for semantic coherence – ensuring that each chunk maintains meaningful context that can be understood and processed independently. On the other, we must optimize for information density, ensuring that each chunk carries sufficient signal without excessive noise that might impede retrieval accuracy. In this post, we explore the challenges of text chunking in RAG applications and propose a novel approach that leverages recent advances in transformer-based language models to achieve a more effective balance between these competing requirements."""
return text

def test_token_chunker_initialization(tokenizer):
def test_token_chunker_initialization_tok(tokenizer):
"""
Test that the TokenChunker can be initialized with a tokenizer.
"""
Expand All @@ -23,6 +35,29 @@ def test_token_chunker_initialization(tokenizer):
assert chunker.chunk_size == 512
assert chunker.chunk_overlap == 128

def test_token_chunker_initialization_hftok(transformers_tokenizer):
"""
Test that the TokenChunker can be initialized with a tokenizer.
"""
chunker = TokenChunker(tokenizer=transformers_tokenizer, chunk_size=512, chunk_overlap=128)

assert chunker is not None
assert chunker.tokenizer == transformers_tokenizer
assert chunker.chunk_size == 512
assert chunker.chunk_overlap == 128


def test_token_chunker_initialization_tik(tiktokenizer):
"""
Test that the TokenChunker can be initialized with a tokenizer.
"""
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)

assert chunker is not None
assert chunker.tokenizer == tiktokenizer
assert chunker.chunk_size == 512
assert chunker.chunk_overlap == 128

def test_token_chunker_chunking(tokenizer, sample_text):
"""
Test that the TokenChunker can chunk a sample text into tokens.
Expand All @@ -38,6 +73,35 @@ def test_token_chunker_chunking(tokenizer, sample_text):
assert all([chunk.start_index is not None for chunk in chunks])
assert all([chunk.end_index is not None for chunk in chunks])

def test_token_chunker_chunking_hf(transformers_tokenizer, sample_text):
"""
Test that the TokenChunker can chunk a sample text into tokens.
"""
chunker = TokenChunker(tokenizer=transformers_tokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk(sample_text)

assert len(chunks) > 0
assert type(chunks[0]) is Chunk
assert all([chunk.token_count <= 512 for chunk in chunks])
assert all([chunk.token_count > 0 for chunk in chunks])
assert all([chunk.text is not None for chunk in chunks])
assert all([chunk.start_index is not None for chunk in chunks])
assert all([chunk.end_index is not None for chunk in chunks])

def test_token_chunker_chunking_tik(tiktokenizer, sample_text):
"""
Test that the TokenChunker can chunk a sample text into tokens.
"""
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk(sample_text)

assert len(chunks) > 0
assert type(chunks[0]) is Chunk
assert all([chunk.token_count <= 512 for chunk in chunks])
assert all([chunk.token_count > 0 for chunk in chunks])
assert all([chunk.text is not None for chunk in chunks])
assert all([chunk.start_index is not None for chunk in chunks])
assert all([chunk.end_index is not None for chunk in chunks])

def test_token_chunker_empty_text(tokenizer):
"""
Expand Down
Loading