diff --git a/src/chonkie/refinery/base.py b/src/chonkie/refinery/base.py index a26c62d..7896f6a 100644 --- a/src/chonkie/refinery/base.py +++ b/src/chonkie/refinery/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Union from chonkie.types import Chunk @@ -23,6 +23,10 @@ def __init__(self, context_size: int = 0) -> None: def refine(self, chunks: List[Chunk]) -> List[Chunk]: """Refine the given list of chunks and return the refined list.""" pass + + def refine_batch(self, chunks_batch: List[List[Chunk]]) -> List[List[Chunk]]: + """Refine the given list of chunks and return the refined list.""" + return [self.refine(chunks) for chunks in chunks_batch] @classmethod @abstractmethod @@ -34,6 +38,29 @@ def __repr__(self) -> str: """Representation of the Refinery.""" return f"{self.__class__.__name__}(context_size={self.context_size})" - def __call__(self, chunks: List[Chunk]) -> List[Chunk]: - """Call the Refinery.""" - return self.refine(chunks) + def __call__(self, chunks: Union[List[Chunk], List[List[Chunk]]]) -> Union[List[Chunk], List[List[Chunk]]]: + """Call the Refinery. + + Args: + chunks: Either a list of Chunks or a list of lists of Chunks + + Returns: + Refined chunks in the same format as input + + Raises: + ValueError: If input type is not a list of Chunks or list of lists of Chunks + + """ + # If chunks is not a list or is empty, return chunks + if not isinstance(chunks, list) or not chunks: + return chunks + + # Check if it's a list of Chunks + if isinstance(chunks[0], Chunk): + return self.refine(chunks) + + # Check if it's a list of lists of Chunks + if isinstance(chunks[0], list) and chunks[0] and isinstance(chunks[0][0], Chunk): + return self.refine_batch(chunks) + + raise ValueError("Invalid input type for Refinery: must be List[Chunk] or List[List[Chunk]]") diff --git a/src/chonkie/refinery/overlap.py b/src/chonkie/refinery/overlap.py index b823ef2..7b4d248 100644 --- a/src/chonkie/refinery/overlap.py +++ b/src/chonkie/refinery/overlap.py @@ -50,6 +50,9 @@ def __init__( else: # Without tokenizer, must use approximate method self.approximate = True + + # Average number of characters per token + self._AVG_CHAR_PER_TOKEN = 7 def _get_refined_chunks( self, chunks: List[Chunk], inplace: bool = True @@ -168,7 +171,7 @@ def _suffix_overlap_token_exact(self, chunk: Chunk) -> Optional[Context]: return None # Take 6x context_size characters to ensure enough tokens - char_window = min(len(chunk.text), self.context_size * 6) + char_window = min(len(chunk.text), self.context_size * self._AVG_CHAR_PER_TOKEN) text_portion = chunk.text[:char_window] # Get exact token boundaries