diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 0d40119c7..f6a23960a 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -51,6 +51,7 @@ def model_test_dir(request, tmp_path_factory): tokenizer_id = request.param["tokenizer_id"] settings = request.param["settings"] batch_sizes = request.param["batch_sizes"] + prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"] tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") hf_home = os.environ.get("HF_HOME", None) @@ -83,7 +84,11 @@ def model_test_dir(request, tmp_path_factory): "prefill_batch_sizes": batch_sizes, "decode_batch_sizes": batch_sizes, "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + "paged_kv_cache": { + "block_seq_stride": 16, + "device_block_count": 256, + "prefix_sharing_algorithm": prefix_sharing_algorithm, + }, } logger.info(f"Saving edited config to: {edited_config_path}\n") logger.info(f"Config: {json.dumps(config, indent=2)}") diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py index c4da9e4eb..99ce2d802 100644 --- a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py +++ b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py @@ -65,16 +65,28 @@ def do_generate(prompt, port): @pytest.mark.parametrize( "model_test_dir,llm_server", [ - ( + pytest.param( { "repo_id": "SlyEcho/open_llama_3b_v2_gguf", "model_file": "open-llama-3b-v2-f16.gguf", "tokenizer_id": "openlm-research/open_llama_3b_v2", "settings": CPU_SETTINGS, "batch_sizes": [1, 4], + "prefix_sharing_algorithm": "trie", }, {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - ) + ), + pytest.param( + { + "repo_id": "SlyEcho/open_llama_3b_v2_gguf", + "model_file": "open-llama-3b-v2-f16.gguf", + "tokenizer_id": "openlm-research/open_llama_3b_v2", + "settings": CPU_SETTINGS, + "batch_sizes": [1, 4], + "prefix_sharing_algorithm": "none", + }, + {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, + ), ], indirect=True, ) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 141c7a7eb..7caed5d07 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -86,6 +86,8 @@ class PagedKVCacheParams: # Size of the cache on each device. device_block_count: int + prefix_sharing_algorithm: str = "none" # currently supporting none and trie + @dataclass_json(undefined=Undefined.RAISE) @dataclass diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 73134903c..39e19ebea 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -34,8 +34,12 @@ def pages(self) -> List[PageInfo]: pass @abstractmethod - def publish_pages(self, up_to_page_index) -> None: - """Makes pages[0:up_to_page_index] available to other requests.""" + def publish_pages_for_tokens( + self, tokens, *, publish_incomplete_page=False + ) -> None: + """ + Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py. + """ pass @abstractmethod @@ -43,8 +47,15 @@ def release_pages(self) -> None: """Releases the allocation's reference to pages.""" pass + @abstractmethod + def extend_allocation(self, tokens, *, extra_token_slots=0) -> None: + """ + Extends the allocation to include additional tokens. For details, reference the derived class in trie_attention_cache.py. + """ + pass + -class BasePageAttentionCacheAllocation(PageAllocation): +class BasePagedAttentionCacheAllocation(PageAllocation): """Represents a page allocation in the cache.""" def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): @@ -56,18 +67,33 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): def pages(self) -> List[PageInfo]: return list(self._pages) - def publish_pages(self, up_to_page_index) -> None: + def publish_pages_for_tokens( + self, tokens, *, publish_incomplete_page=False + ) -> None: pass def release_pages(self) -> None: if self._is_released: logger.warning("Releasing already-released allocation") return - self._cache.page_pool.release_pages(self._pages) + self._cache.page_pool.free_pages(self._pages) self._is_released = True + def extend_allocation(self, tokens, *, extra_token_slots=0) -> None: + # assert old tokens are a prefix of incoming tokens + # if we don't have enough pages to hold the tokens, we need to allocate more pages + token_count = len(tokens) + extra_token_slots + pages_needed = math.ceil(token_count / self._cache.tokens_per_page) + if pages_needed > len(self._pages): + new_pages = self._cache.page_pool.acquire_free_pages( + pages_needed - len(self._pages) + ) + if new_pages is None: + raise CacheAllocationFailure() + self._pages += tuple(new_pages) + def __rerp__(self) -> str: - return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" + return f"BasePagedAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" class BasePagedAttentionCache: @@ -117,4 +143,4 @@ def acquire_pages_for_tokens( if pages is None: raise CacheAllocationFailure() - return BasePageAttentionCacheAllocation(pages, cache=self) + return BasePagedAttentionCacheAllocation(pages, cache=self) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 1686370c0..0c2cb3f67 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -86,7 +86,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig for i in range(self.config.alloc_page_count) ] - self.attn_page_free = list(self.attn_page_entries) + self.available_pages = list(self.attn_page_entries) # Initialize a page table on each device. page_table_shape = [ @@ -108,14 +108,14 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig def acquire_free_pages(self, count: int) -> list[PageInfo] | None: with self._lock: - available = len(self.attn_page_free) + available = len(self.available_pages) if count > available: return None - return [self.attn_page_free.pop() for _ in range(count)] + return [self.available_pages.pop() for _ in range(count)] - def release_pages(self, pages: list[PageInfo]): + def free_pages(self, pages: list[PageInfo]): with self._lock: - self.attn_page_free.extend(pages) + self.available_pages.extend(pages) def copy_page(self, src_page: PageInfo) -> PageInfo: """ @@ -148,7 +148,7 @@ def copy_page(self, src_page: PageInfo) -> PageInfo: def __repr__(self): # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) + free_pages = len(self.available_pages) total_pages = len(self.attn_page_entries) return ( f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py new file mode 100644 index 000000000..fbb008005 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -0,0 +1,426 @@ +from typing import Dict, Set, List, Tuple, Optional +from dataclasses import dataclass +import time +import math +import heapq +from .page_pool import PagePool, PageInfo +from .base_attention_cache import ( + BasePagedAttentionCache, + PageAllocation, + CacheAllocationFailure, +) + + +@dataclass +class RefCount: + """ + A reference counter to replace simple int. + """ + + count: int = 0 + + def increment(self) -> int: + self.count += 1 + return self.count + + def decrement(self) -> int: + self.count -= 1 + return self.count + + def is_empty(self) -> bool: + return self.count <= 0 + + +@dataclass +class TrieNode: + """Node of the block trie for paged attention cache. + + Each node represents a page of tokens in the cache, with edges representing + token sequences that can follow. This allows prefix sharing between sequences + that have common prefixes. + + Attributes: + tokens: Tuple of tokens stored in this node's page + page: PageInfo object containing the actual cache page + children: Dict mapping token sequences to child nodes + parent: Parent node in the trie (None for root) + ref_count: Number of active references to this node + access_time: Last access timestamp for LRU eviction + """ + + tokens: Tuple[int, ...] + page: PageInfo + children: Optional[Dict[Tuple[int, ...], "TrieNode"]] = None + parent: Optional["TrieNode"] = None + ref_count: RefCount = None + access_time: float = 0.0 + + def __post_init__(self) -> None: + """Initialize children dict and access time if not provided.""" + if self.children is None: + self.children = {} + self.access_time = time.monotonic() + self.ref_count = RefCount() + + def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode": + """Create a new child node with the given tokens and page. + + Args: + tokens: Sequence of tokens for the new node + page: PageInfo for the new node's cache page + + Returns: + The newly created child node + """ + new_node = TrieNode(tokens=tokens, page=page, parent=self) + self.children[tokens] = new_node + return new_node + + def unlink(self) -> None: + """Remove this node from its parent's children.""" + if self.parent is not None: + del self.parent.children[self.tokens] + self.parent = None + + def __hash__(self) -> int: + """Nodes are uniquely identified by their memory address.""" + return id(self) + + def __eq__(self, other: object) -> bool: + """Nodes are equal only if they are the same object.""" + return self is other + + +class TriePagedAttentionCacheAllocation(PageAllocation): + """Represents a page allocation in the trie-based cache. + + Tracks sequence of pages and which ones are already published to the cache, + implementing the PageAllocation protocol for the trie cache. + + Attributes: + cache: The parent cache this allocation belongs to + tokens: Complete sequence of tokens this allocation represents + last_cached_node: Last matched node in the trie + pages: List of all pages in allocation + number_of_published_pages: Number of pages that are published to the cache + """ + + def __init__( + self, + cache: "TriePagedAttentionCache", + tokens: List[int], + last_cached_node: TrieNode, + cached_pages: List[PageInfo], + newly_acquired_pages: List[PageInfo], + ): + self.cache = cache + self.tokens = tokens + self.last_cached_node = last_cached_node + self._pages = cached_pages + newly_acquired_pages + self.number_of_published_pages = len(cached_pages) + self._is_released = False + + @property + def pages(self) -> List[PageInfo]: + return self._pages + + def publish_pages_for_tokens( + self, tokens, *, publish_incomplete_page=False + ) -> None: + """Make pages available in the cache for the specified tokens. + + Args: + tokens_to_publish: Tokens to publish to the cache + + Raises: + ValueError: If tokens don't match allocation or exceed available pages + """ + # If we have more tokens, publish pages up to the incoming tokens. + # If incoming has more tokens, replace our tokens with incoming tokens and publish pages up to the incoming tokens. + + def has_common_prefix(tokens1, tokens2): + for t1, t2 in zip(tokens1, tokens2): + if t1 != t2: + return False + return True + + if not has_common_prefix(self.tokens, tokens): + raise ValueError( + "Tokens provided in publish_pages do not match tokens in allocation" + ) + + if len(tokens) > len(self.tokens): + self.tokens = tokens + + tokens_per_page = self.cache.tokens_per_page + + if publish_incomplete_page: + number_of_pages_to_publish = -( + len(tokens) // -tokens_per_page + ) # ceil division + else: + number_of_pages_to_publish = len(tokens) // tokens_per_page + + # Create token blocks for unpublished pages + start_token_index = self.number_of_published_pages * tokens_per_page + unpublished_tokens = [ + tuple(self.tokens[i : i + tokens_per_page]) + for i in range(start_token_index, len(self.tokens), tokens_per_page) + ] + + unpublished_pages = self._pages[ + self.number_of_published_pages : number_of_pages_to_publish + ] + + # Add unpublished pages to trie + if publish_incomplete_page: + raise NotImplementedError( + "Additional work needed here to support publishing incomplete pages to ensure that we finish up a page before attaching child nodes to it." + ) + cur_node = self.last_cached_node + for token_block, page in zip(unpublished_tokens, unpublished_pages): + new_node = cur_node.create_child(token_block, page) + cur_node = new_node + + if cur_node is not self.cache.root: + self.cache.leaves.add(cur_node) + + # Update reference counts + if unpublished_tokens: + cur_node.ref_count.increment() + self.last_cached_node.ref_count.decrement() + self.last_cached_node = cur_node + + self.number_of_published_pages = number_of_pages_to_publish + + def release_pages(self) -> None: + """Release the allocation's reference to its pages. + + Decrements reference count of the last cached node. When count + reaches zero, the node becomes eligible for eviction. + """ + if self._is_released: + return + + self.last_cached_node.ref_count.decrement() + self._is_released = True + + def extend_allocation(self, tokens: List[int], *, extra_token_slots=0) -> None: + """Extend the current allocation to accommodate additional tokens. + + Args: + tokens: New token sequence to extend the allocation to + + Raises: + ValueError: If new tokens don't extend current allocation's tokens + """ + # Verify new tokens extend current tokens + if len(tokens) < len(self.tokens): + raise ValueError("New tokens must be longer than current tokens") + + # Check that current tokens are a prefix of new tokens + for old_token, new_token in zip(self.tokens, tokens): + if old_token != new_token: + raise ValueError("New tokens must extend current token sequence") + + # If tokens are identical, no extension needed + if len(tokens) == len(self.tokens): + return + + # Calculate how many new pages we need + tokens_per_page = self.cache.tokens_per_page + current_pages = len(self._pages) + total_tokens = len(tokens) + extra_token_slots + total_pages_needed = math.ceil(total_tokens / tokens_per_page) + new_pages_needed = total_pages_needed - current_pages + + if new_pages_needed <= 0: + self.tokens = tokens + return + + # Acquire new pages + new_pages = self.cache.page_pool.acquire_free_pages(new_pages_needed) + + if new_pages is None: + # Try eviction if initial allocation fails + self.cache._evict_pages( + new_pages_needed - len(self.cache.page_pool.available_pages) + ) + new_pages = self.cache.page_pool.acquire_free_pages(new_pages_needed) + + if new_pages is None: + raise CacheAllocationFailure( + "Failed to acquire pages for allocation extension even after attempting eviction" + ) + + # Extend our page list + self._pages.extend(new_pages) + + # Update tokens + self.tokens = tokens + + +class TriePagedAttentionCache(BasePagedAttentionCache): + """Trie-based paged attention cache implementation. + + Implements prefix sharing through a trie structure where each node + represents a page of tokens. Common prefixes between sequences share + the same nodes/pages, reducing memory usage. + + Attributes: + root: Root node of the trie + leaves: Set of leaf nodes for efficient eviction + page_pool: Pool providing page allocations + tokens_per_page: Number of tokens that fit in each page + """ + + def __init__(self, page_pool: PagePool, tokens_per_page: int): + """Initialize the trie cache. + + Args: + page_pool: Pool to allocate pages from + tokens_per_page: Number of tokens per page + + Raises: + ValueError: If tokens_per_page <= 0 + """ + if tokens_per_page <= 0: + raise ValueError("tokens_per_page must be positive") + + super().__init__(page_pool, tokens_per_page) + + # Create root node with dummy page + dummy_page = PageInfo( + index=0, # Root uses reserved index 0 + pool=self.page_pool, + token_offset=0, + token_count=0, + ) + self.root = TrieNode(tokens=tuple(), page=dummy_page) + self.leaves: Set[TrieNode] = set() + + def _match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]: + """ + Find the longest prefix match in the trie. + + Walks the trie following the token sequence as far as possible, + collecting matched pages along the way. + + Args: + tokens: Sequence of tokens to match + + Returns: + Tuple of (last matched node, list of matched pages) + """ + tokens = tuple(tokens) + matched_pages = [] + cur = self.root + + for i in range(0, len(tokens), self.tokens_per_page): + token_block = tokens[i : i + self.tokens_per_page] + + if token_block not in cur.children: + break + cur = cur.children[token_block] + cur.access_time = time.monotonic() + matched_pages.append(cur.page) + + return cur, matched_pages + + def acquire_pages_for_tokens( + self, + tokens: List[int], + extra_token_slots: int = 0, + ) -> PageAllocation: + """Acquire pages for a sequence of tokens. + + Attempts to reuse existing cached pages where possible through + prefix matching, allocating new pages only for the uncached suffix. + + Args: + tokens: Sequence of tokens needing pages + extra_token_slots: Additional token slots to allocate beyond tokens + + Returns: + PageAllocation containing both cached and newly allocated pages + + Raises: + CacheAllocationFailure: If unable to allocate required pages + """ + tokens = tuple(tokens) + + cur_node, matched_pages = self._match(tokens) + cur_node.ref_count.increment() + + n_cached_tokens = len(matched_pages) * self.tokens_per_page + remaining_length = len(tokens) - n_cached_tokens + extra_token_slots + n_empty_pages = math.ceil(remaining_length / self.tokens_per_page) + + new_pages = self.page_pool.acquire_free_pages(n_empty_pages) + + if new_pages is not None: + return TriePagedAttentionCacheAllocation( + cache=self, + tokens=tokens, + last_cached_node=cur_node, + cached_pages=matched_pages, + newly_acquired_pages=new_pages, + ) + + # Try eviction + self._evict_pages(n_empty_pages - len(self.page_pool.available_pages)) + new_pages = self.page_pool.acquire_free_pages(n_empty_pages) + + if new_pages is None: + raise CacheAllocationFailure( + "Failed to acquire pages even after attempting eviction from LRU leaves" + ) + + return TriePagedAttentionCacheAllocation( + cache=self, + tokens=tokens, + last_cached_node=cur_node, + cached_pages=matched_pages, + newly_acquired_pages=new_pages, + ) + + def _evict_pages(self, max_pages: int) -> int: + """Evict up to max_pages pages using LRU strategy. + + Evicts from unreferenced leaf nodes first, working up the trie + as nodes become childless. + + Args: + max_pages: Maximum number of pages to evict + + Returns: + Number of pages actually evicted + """ + pages_to_evict = [] + + # Initialize heap with unreferenced leaves + unused_leaf_heap = [ + (leaf.access_time, leaf) + for leaf in self.leaves + if leaf.ref_count.is_empty() + ] + heapq.heapify(unused_leaf_heap) + + # Evict least recently used nodes + while unused_leaf_heap and len(pages_to_evict) < max_pages: + _, leaf = heapq.heappop(unused_leaf_heap) + pages_to_evict.append(leaf.page) + parent = leaf.parent + leaf.unlink() + self.leaves.remove(leaf) + + # If parent becomes childless, it becomes a leaf + if parent is not self.root and not parent.children: + self.leaves.add(parent) + if parent.ref_count.is_empty(): + heapq.heappush(unused_leaf_heap, (parent.access_time, parent)) + + if pages_to_evict: + self.page_pool.free_pages(pages_to_evict) + + return len(pages_to_evict) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index c03900782..9e2ab7179 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -52,8 +52,6 @@ def reset(self, phase: InferencePhase): self.return_all_logits = False self.return_host_array = True self.result_logits = None - self.allocation.release_pages() - self.allocation = None def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: @@ -63,11 +61,14 @@ def cache_page_indices(self, max_len: int) -> list[int]: def publish_allocated_pages(self, up_to_page_index: int): assert self.allocation - self.allocation.publish_pages(up_to_page_index) + self.allocation.publish_pages_for_tokens( + self.input_token_ids, publish_incomplete_page=False + ) def free_cache_pages(self): if self.allocation: self.allocation.release_pages() + self.allocation = None class StrobeMessage(sf.Message): diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 2f942aec7..5b43c1310 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -16,6 +16,7 @@ CacheAllocationFailure, PageAllocation, ) +from .kvcache.trie_attention_cache import TriePagedAttentionCache from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo from .config_struct import ModelParams from .manager import SystemManager @@ -67,10 +68,20 @@ def __init__( page_pool = PagePool( devices=self.main_fiber.devices_dict.values(), config=page_pool_config ) - self.page_cache = BasePagedAttentionCache( - page_pool=page_pool, - tokens_per_page=model_params.paged_kv_cache.block_seq_stride, - ) + if model_params.paged_kv_cache.prefix_sharing_algorithm == "trie": + self.page_cache = TriePagedAttentionCache( + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, + ) + elif model_params.paged_kv_cache.prefix_sharing_algorithm == "none": + self.page_cache = BasePagedAttentionCache( + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, + ) + else: + raise ValueError( + f"Unknown model_params.paged_kv_cache.prefix_sharing_algorithm {model_params.paged_kv_cache.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'." + ) self.program_isolation = PROG_ISOLATIONS[program_isolation] @@ -270,21 +281,9 @@ def board_decodes(self, cache: BasePagedAttentionCache): assert decode_request.phase == InferencePhase.DECODE if len(exec_process.exec_requests) >= self.ideal_batch_size: break - incoming_token_count = len(decode_request.input_token_ids) - - try: - allocation = cache.acquire_pages_for_tokens( - decode_request.input_token_ids, - extra_token_slots=1, # need 1 extra slot to write result. - ) - except CacheAllocationFailure: - logger.debug( - "Cannot fulfill request for %d tokens", - len(decode_request.input_token_ids), - ) - - decode_request.free_cache_pages() - decode_request.allocation = allocation + decode_request.allocation.extend_allocation( + decode_request.input_token_ids, extra_token_slots=1 + ) # Can flight this request. exec_process.exec_requests.append(decode_request) diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py index 113da6912..8f2c4c060 100644 --- a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -10,7 +10,7 @@ from shortfin_apps.llm.components.kvcache.base_attention_cache import ( BasePagedAttentionCache, - BasePageAttentionCacheAllocation, + BasePagedAttentionCacheAllocation, CacheAllocationFailure, ) from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PageInfo @@ -32,7 +32,7 @@ def acquire_free_pages(self, count: int) -> List[PageInfo]: except queue.Empty: return None - def release_pages(self, pages): + def free_pages(self, pages): for page in pages: self._queue.put(page) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py new file mode 100644 index 000000000..0f49efda8 --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -0,0 +1,432 @@ +import pytest +from typing import List, Tuple +import shortfin as sf +import shortfin.array as sfnp +from unittest.mock import Mock, MagicMock +import threading +import time +from dataclasses import dataclass + +from shortfin_apps.llm.components.kvcache.trie_attention_cache import ( + TriePagedAttentionCache, +) +from shortfin_apps.llm.components.kvcache.base_attention_cache import ( + CacheAllocationFailure, +) +from shortfin_apps.llm.components.kvcache.page_pool import ( + PagePool, + PageInfo, + PagePoolConfig, +) + + +# Test constants +TEST_PAGE_SIZE = 16 # Tokens per page +TEST_POOL_CAPACITY = 10 + + +@dataclass +class TokenSequence: + """Helper class for test parameterization""" + + tokens: List[int] + description: str + expected_pages: int + expected_cached: int = 0 + + def __str__(self): + return self.description + + +class MockScopedDevice: + """A proper mock for ScopedDevice that implements required interface""" + + def __init__(self): + self._mock = Mock(spec=sf.ScopedDevice) + # Add any necessary attributes/methods the real ScopedDevice has + self._mock.device_id = 0 + self._mock.device_type = "CPU" + + def __repr__(self): + return f"MockScopedDevice(device_id={self._mock.device_id})" + + +@pytest.fixture +def mock_device_array(): + """Create mock device array with proper interface implementation""" + + class MockDeviceArray: + def __init__(self): + self.shape = None + self.dtype = None + + def view(self, *args): + return Mock() + + def copy_from(self, src): + pass + + return MockDeviceArray() + + +@pytest.fixture +def mock_device(): + """Create properly structured mock device""" + return MockScopedDevice() + + +@pytest.fixture +def page_pool(mock_device, mock_device_array): + """Create PagePool with properly structured mock components""" + # Mock the device array creation + original_for_device = sf.array.device_array.for_device + + def mock_for_device(device, shape, dtype): + mock_array = mock_device_array + mock_array.shape = shape + mock_array.dtype = dtype + return mock_array + + sf.array.device_array.for_device = mock_for_device + + try: + config = PagePoolConfig( + dtype=sfnp.float16, + alloc_page_count=TEST_POOL_CAPACITY, + paged_kv_block_size_elements=128, + ) + + pool = PagePool(devices=[mock_device], config=config) + pool.page_tables = [mock_device_array] + return pool + finally: + # Restore original function + sf.array.device_array.for_device = original_for_device + + +@pytest.fixture +def trie_cache(page_pool): + """Create TriePagedAttentionCache instance""" + return TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) + + +@pytest.fixture +def published_sequence(trie_cache): + """Helper fixture that returns a function to publish token sequences""" + + def _publish_sequence(tokens: List[int]) -> None: + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages_for_tokens(alloc.tokens) + alloc.release_pages() + + return _publish_sequence + + +def print_tree_state(cache, prefix=""): + """Helper function to print current tree state in a readable format""" + if not hasattr(cache, "root"): + print(f"{prefix}Unable to access trie structure") + return + + def node_info(node): + token_str = f"tokens={list(node.tokens) if node.tokens else 'root'}" + return f"{token_str}, ref_count={node.ref_count}, page_index={node.page.index}" + + def print_node(node, depth=0): + indent = " " * depth + print(f"{prefix}{indent}- {node_info(node)}") + if node.children: + for child in node.children.values(): + print_node(child, depth + 1) + + print(f"{prefix}Tree state:") + print_node(cache.root) + + +basic_sequences = [ + {"tokens": [], "description": "empty_sequence", "expected_pages": 0}, + { + "tokens": list(range(TEST_PAGE_SIZE // 2)), + "description": "partial_page", + "expected_pages": 1, + }, + { + "tokens": list(range(TEST_PAGE_SIZE)), + "description": "exact_page", + "expected_pages": 1, + }, + { + "tokens": list(range(TEST_PAGE_SIZE + 1)), + "description": "overflow_page", + "expected_pages": 2, + }, + { + "tokens": list(range(TEST_PAGE_SIZE * 2)), + "description": "multiple_pages", + "expected_pages": 2, + }, +] + + +@pytest.mark.parametrize("test_sequence", basic_sequences) +def test_basic_allocation(trie_cache, test_sequence): + """Test basic page allocation without reuse""" + allocation = trie_cache.acquire_pages_for_tokens( + test_sequence["tokens"], extra_token_slots=0 + ) + assert len(allocation.pages) == test_sequence["expected_pages"] + assert allocation.number_of_published_pages == 0 + assert ( + len(allocation.pages) - allocation.number_of_published_pages + == test_sequence["expected_pages"] + ) + allocation.publish_pages_for_tokens(allocation.tokens) + allocation.release_pages() + + +reuse_sequences = [ + { + "initial_tokens": list(range(TEST_PAGE_SIZE)), + "reuse_tokens": list(range(TEST_PAGE_SIZE)), + "description": "exact_match", + "total_pages": 1, + "expected_cached": 1, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE * 2)), + "reuse_tokens": list(range(TEST_PAGE_SIZE * 2)), + "description": "multi_page_match", + "total_pages": 2, + "expected_cached": 2, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE * 2)), + "reuse_tokens": list(range(TEST_PAGE_SIZE)) + + list(range(100, 100 + TEST_PAGE_SIZE)), + "description": "prefix_match", + "total_pages": 2, + "expected_cached": 1, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE)), + "reuse_tokens": list(range(50, 50 + TEST_PAGE_SIZE)), + "description": "no_match", + "total_pages": 1, + "expected_cached": 0, + }, +] + + +@pytest.mark.parametrize("test_sequences", reuse_sequences) +def test_page_reuse(trie_cache, published_sequence, test_sequences): + """Test page reuse scenarios""" + # Publish initial sequence + published_sequence(test_sequences["initial_tokens"]) + + # Try to reuse + allocation = trie_cache.acquire_pages_for_tokens( + test_sequences["reuse_tokens"], extra_token_slots=0 + ) + assert len(allocation.pages) == test_sequences["total_pages"] + assert allocation.number_of_published_pages == test_sequences["expected_cached"] + assert ( + len(allocation.pages) - allocation.number_of_published_pages + == test_sequences["total_pages"] - test_sequences["expected_cached"] + ) + allocation.publish_pages_for_tokens(allocation.tokens) + allocation.release_pages() + + +@pytest.fixture +def filled_cache(trie_cache, published_sequence): + """Fixture that fills cache with numbered sequences""" + sequences = [] + for i in range(TEST_POOL_CAPACITY): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + published_sequence(tokens) + sequences.append(tokens) + return sequences + + +@pytest.mark.parametrize( + "access_count", [1, TEST_POOL_CAPACITY // 2, TEST_POOL_CAPACITY - 1] +) +def test_lru_eviction(trie_cache, access_count): + """Test LRU eviction with different access patterns""" + print(f"\nStarting test_lru_eviction with access_count={access_count}") + + # Create mix of published and unpublished sequences + keep_published = 3 # Number of sequences to keep published + sequences = [] + + # First add some sequences we'll keep published + print("\nPublishing sequences to keep active:") + for i in range(keep_published): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) + sequences.append(tokens) + print(f"Published sequence {i} (keeping active)") + print_tree_state(trie_cache, " ") + + # Then add sequences we'll publish but release (evictable) + print("\nAdding releasable sequences:") + for i in range(keep_published, TEST_POOL_CAPACITY): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) + alloc.release_pages() # These can be evicted + sequences.append(tokens) + print(f"Added releasable sequence {i}") + print_tree_state(trie_cache, " ") + + print("\nCache state before accessing sequences:") + print_tree_state(trie_cache, " ") + + # Access some sequences to update their LRU status + print(f"\nAccessing {access_count} sequences to update LRU order:") + for i in range(access_count): + print(f"\nAccessing sequence {i}:") + alloc = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) + print_tree_state(trie_cache, " ") + alloc.release_pages() + print(f"After releasing allocation {i}:") + print_tree_state(trie_cache, " ") + + print("\nCache state before attempting new allocation:") + print_tree_state(trie_cache, " ") + print("\nAvailable pages in pool:", len(trie_cache.page_pool.available_pages)) + + # Try to allocate new sequence - should evict least recently used unpublished sequence + new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) + print(f"\nAttempting to allocate new sequence: {new_tokens}") + new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) + print("\nNew allocation succeeded:") + print("\nCache state after new allocation:") + print_tree_state(trie_cache, " ") + new_alloc.release_pages() + + # Verify recently accessed sequences AND published sequences weren't evicted + print("\nVerifying preserved sequences:") + for i in range(max(access_count, keep_published)): + print(f"\nChecking sequence {i}:") + recheck = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) + cached_pages = recheck.number_of_published_pages + print(f"- Cached pages found: {cached_pages}") + assert ( + cached_pages == 1 + ), f"Sequence {i} was evicted but should have been preserved" + recheck.release_pages() + + +@pytest.mark.parametrize("publish_steps", [1, 2, 3]) +def test_progressive_publish(trie_cache, publish_steps): + """Test publishing pages progressively""" + print(f"\nStarting test_progressive_publish with publish_steps={publish_steps}") + + tokens = tuple(range(TEST_PAGE_SIZE * 3)) # Three pages + print(f"\nInitial tokens: {tokens}") + print(f"Tokens per page: {TEST_PAGE_SIZE}") + print( + f"Expected total pages: {len(tokens) // TEST_PAGE_SIZE + (1 if len(tokens) % TEST_PAGE_SIZE else 0)}" + ) + + print("\nInitial cache state:") + print_tree_state(trie_cache) + + print("\nAcquiring initial allocation...") + alloc = trie_cache.acquire_pages_for_tokens(tokens) + print(f"Initial allocation pages: {[p.index for p in alloc.pages]}") + print("\nCache state after initial allocation:") + print_tree_state(trie_cache) + + for step in range(1, publish_steps + 1): + print(f"\n--- Step {step} of {publish_steps} ---") + + # Publish next page + print(f"Publishing up to page {step}") + # Replace publishing with tokens + alloc.publish_pages_for_tokens(alloc.tokens[: (step) * TEST_PAGE_SIZE]) + print("\nCache state after publish:") + print_tree_state(trie_cache) + + # Verify reuse up to published point + reuse_tokens = tokens[: (step) * TEST_PAGE_SIZE] + print(f"\nAttempting to reuse tokens: {reuse_tokens}") + print(f"Expected cached pages: {step}") + + reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens) + print(f"Reuse allocation total pages: {len(reuse_alloc.pages)}") + print(f"Reuse allocation cached pages: {reuse_alloc.number_of_published_pages}") + + print("\nCache state after reuse attempt:") + print_tree_state(trie_cache) + + try: + assert reuse_alloc.number_of_published_pages == step + except AssertionError: + print("\nASSERTION FAILED!") + print( + f"Expected {step} cached pages but got {reuse_alloc.number_of_published_pages}" + ) + raise + + reuse_alloc.release_pages() + print("\nCache state after releasing reuse allocation:") + print_tree_state(trie_cache) + + alloc.release_pages() + print("\nFinal cache state after releasing initial allocation:") + print_tree_state(trie_cache) + + +@pytest.mark.parametrize("ref_count", [1, 2, 5]) +def test_reference_counting(trie_cache, ref_count): + """Test reference counting with different counts""" + tokens = list(range(TEST_PAGE_SIZE)) + allocations = [] + + # Create initial allocation and publish + first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + # Replace publishing with tokens + first_alloc.publish_pages_for_tokens(first_alloc.tokens) + allocations.append(first_alloc) + print("\nInitial allocation created") + print_tree_state(trie_cache, " ") + + # Create additional references + for i in range(ref_count - 1): + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + allocations.append(alloc) + print(f"\nCreated reference {i+1}") + print_tree_state(trie_cache, " ") + + # Fill remaining cache + remaining = TEST_POOL_CAPACITY - 1 + fill_allocations = [] + for i in range(remaining): + fill_tokens = list( + range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) + ) + alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) + fill_allocations.append(alloc) + print(f"\nFilled cache slot {i+1}/{remaining}") + print_tree_state(trie_cache, " ") + + print("\nAttempting allocation that should fail...") + try: + new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) + new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) + print("ERROR: Allocation succeeded when it should have failed!") + print("\nPost-allocation state:") + print_tree_state(trie_cache, " ") + new_alloc.release_pages() + pytest.fail("Expected CacheAllocationFailure was not raised") + except CacheAllocationFailure: + print("Success: CacheAllocationFailure raised as expected") + + # Cleanup + print("\nCleaning up allocations...") + for alloc in allocations + fill_allocations: + alloc.release_pages()