From a1946b8fe2410fcfa060ff4a1a9d0f0724a26d57 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:13:33 -0800 Subject: [PATCH 01/23] add unit tests and fix numerous small problems --- .../kvcache/base_attention_cache.py | 2 +- .../llm/components/kvcache/page_pool.py | 12 +- .../kvcache/trie_attention_cache.py | 325 ++++++++++++++ .../kvcache/trie_attention_cache_test.py | 395 ++++++++++++++++++ 4 files changed, 727 insertions(+), 7 deletions(-) create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py create mode 100644 shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 73134903c..c86379368 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -63,7 +63,7 @@ def release_pages(self) -> None: if self._is_released: logger.warning("Releasing already-released allocation") return - self._cache.page_pool.release_pages(self._pages) + self._cache.page_pool.free_pages(self._pages) self._is_released = True def __rerp__(self) -> str: diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 1686370c0..0c2cb3f67 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -86,7 +86,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig for i in range(self.config.alloc_page_count) ] - self.attn_page_free = list(self.attn_page_entries) + self.available_pages = list(self.attn_page_entries) # Initialize a page table on each device. page_table_shape = [ @@ -108,14 +108,14 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig def acquire_free_pages(self, count: int) -> list[PageInfo] | None: with self._lock: - available = len(self.attn_page_free) + available = len(self.available_pages) if count > available: return None - return [self.attn_page_free.pop() for _ in range(count)] + return [self.available_pages.pop() for _ in range(count)] - def release_pages(self, pages: list[PageInfo]): + def free_pages(self, pages: list[PageInfo]): with self._lock: - self.attn_page_free.extend(pages) + self.available_pages.extend(pages) def copy_page(self, src_page: PageInfo) -> PageInfo: """ @@ -148,7 +148,7 @@ def copy_page(self, src_page: PageInfo) -> PageInfo: def __repr__(self): # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) + free_pages = len(self.available_pages) total_pages = len(self.attn_page_entries) return ( f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py new file mode 100644 index 000000000..f4614c8d8 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -0,0 +1,325 @@ +from typing import Dict, Set, List, Tuple, Optional +from dataclasses import dataclass +import time +import math +import heapq +from .page_pool import PagePool, PageInfo +from .base_attention_cache import ( + BasePagedAttentionCache, + PageAllocation, + CacheAllocationFailure, +) + + +@dataclass +class TrieNode: + """Node of the block trie for paged attention cache. + + Each node represents a page of tokens in the cache, with edges representing + token sequences that can follow. This allows prefix sharing between sequences + that have common prefixes. + + Attributes: + tokens: Tuple of tokens stored in this node's page + page: PageInfo object containing the actual cache page + children: Dict mapping token sequences to child nodes + parent: Parent node in the trie (None for root) + ref_count: Number of active references to this node + access_time: Last access timestamp for LRU eviction + """ + + tokens: Tuple[int, ...] + page: PageInfo + children: Optional[Dict[Tuple[int, ...], "TrieNode"]] = None + parent: Optional["TrieNode"] = None + ref_count: int = 0 + access_time: float = 0.0 + + def __post_init__(self) -> None: + """Initialize children dict and access time if not provided.""" + if self.children is None: + self.children = {} + self.access_time = time.monotonic() + + def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode": + """Create a new child node with the given tokens and page. + + Args: + tokens: Sequence of tokens for the new node + page: PageInfo for the new node's cache page + + Returns: + The newly created child node + """ + new_node = TrieNode(tokens=tokens, page=page, parent=self) + self.children[tokens] = new_node + return new_node + + def unlink(self) -> None: + """Remove this node from its parent's children.""" + if self.parent is not None: + del self.parent.children[self.tokens] + self.parent = None + + def __hash__(self) -> int: + """Nodes are uniquely identified by their memory address.""" + return id(self) + + def __eq__(self, other: object) -> bool: + """Nodes are equal only if they are the same object.""" + return self is other + + +class TriePageAttentionCacheAllocation(PageAllocation): + """Represents a page allocation in the trie-based cache. + + Tracks both previously cached pages and newly allocated pages, + implementing the PageAllocation protocol for the trie cache. + + Attributes: + cache: The parent cache this allocation belongs to + tokens: Complete sequence of tokens this allocation represents + last_cached_node: Last matched node in the trie + cached_pages: List of pages already in cache + newly_acquired_pages: List of newly allocated pages + start_index: Index where cached tokens end and new tokens begin + """ + + def __init__( + self, + cache: "TriePagedAttentionCache", + tokens: List[int], + last_cached_node: TrieNode, + cached_pages: List[PageInfo], + newly_acquired_pages: List[PageInfo], + start_index: int, + ): + self.cache = cache + self.tokens = tokens + self.last_cached_node = last_cached_node + self.cached_pages = cached_pages + self.newly_acquired_pages = newly_acquired_pages + self.start_index = start_index + self._is_released = False + + @property + def pages(self) -> List[PageInfo]: + """List all pages in this allocation, both cached and new. + + Returns: + Combined list of cached and newly acquired pages + """ + return self.cached_pages + self.newly_acquired_pages + + def publish_pages(self, up_to_page_index: int) -> None: + """Make pages available in the cache up to the specified index. + + Args: + up_to_page_index: Number of pages to publish, starting from the beginning + """ + tokens_per_page = self.cache.tokens_per_page + + publish_token_count = min(len(self.tokens), up_to_page_index * tokens_per_page) + + cur_node = self.last_cached_node + first_uncached_page_index = self.start_index // tokens_per_page + + uncached_tokens = [ + tuple(self.tokens[i : i + tokens_per_page]) + for i in range( + first_uncached_page_index * tokens_per_page, + publish_token_count, + tokens_per_page, + ) + ] + + uncached_pages = self.newly_acquired_pages[: len(uncached_tokens)] + + for token_block, page in zip(uncached_tokens, uncached_pages): + new_node = cur_node.create_child(token_block, page) + cur_node = new_node + + if cur_node is not self.cache.root: + self.cache.leaves.add(cur_node) + + cur_node.ref_count += 1 + self.last_cached_node.ref_count -= 1 + self.last_cached_node = cur_node + + def release_pages(self) -> None: + """Release the allocation's reference to its pages. + + Decrements reference count of the last cached node. When count + reaches zero, the node becomes eligible for eviction. + """ + if self._is_released: + return + + self.last_cached_node.ref_count -= 1 + self._is_released = True + + +class TriePagedAttentionCache(BasePagedAttentionCache): + """Trie-based paged attention cache implementation. + + Implements prefix sharing through a trie structure where each node + represents a page of tokens. Common prefixes between sequences share + the same nodes/pages, reducing memory usage. + + Attributes: + root: Root node of the trie + leaves: Set of leaf nodes for efficient eviction + page_pool: Pool providing page allocations + tokens_per_page: Number of tokens that fit in each page + """ + + def __init__(self, page_pool: PagePool, tokens_per_page: int): + """Initialize the trie cache. + + Args: + page_pool: Pool to allocate pages from + tokens_per_page: Number of tokens per page + + Raises: + ValueError: If tokens_per_page <= 0 + """ + if tokens_per_page <= 0: + raise ValueError("tokens_per_page must be positive") + + super().__init__(page_pool, tokens_per_page) + + # Create root node with dummy page + dummy_page = PageInfo( + index=0, # Root uses reserved index 0 + pool=self.page_pool, + token_offset=0, + token_count=0, + ) + self.root = TrieNode(tokens=tuple(), page=dummy_page) + self.leaves: Set[TrieNode] = set() + + def _match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]: + """ + Find the longest prefix match in the trie. + + Walks the trie following the token sequence as far as possible, + collecting matched pages along the way. + + Args: + tokens: Sequence of tokens to match + + Returns: + Tuple of (last matched node, list of matched pages) + """ + tokens = tuple(tokens) + matched_pages = [] + cur = self.root + + for i in range(0, len(tokens), self.tokens_per_page): + token_block = tokens[i : i + self.tokens_per_page] + + if token_block not in cur.children: + break + cur = cur.children[token_block] + cur.access_time = time.monotonic() + matched_pages.append(cur.page) + + return cur, matched_pages + + def acquire_pages_for_tokens( + self, + tokens: List[int], + extra_token_slots: int = 0, + ) -> PageAllocation: + """Acquire pages for a sequence of tokens. + + Attempts to reuse existing cached pages where possible through + prefix matching, allocating new pages only for the uncached suffix. + + Args: + tokens: Sequence of tokens needing pages + extra_token_slots: Additional token slots to allocate beyond tokens + + Returns: + PageAllocation containing both cached and newly allocated pages + + Raises: + CacheAllocationFailure: If unable to allocate required pages + """ + tokens = tuple(tokens) + + cur_node, matched_pages = self._match(tokens) + cur_node.ref_count += 1 + + n_cached_tokens = len(matched_pages) * self.tokens_per_page + remaining_length = len(tokens) - n_cached_tokens + extra_token_slots + n_empty_pages = math.ceil(remaining_length / self.tokens_per_page) + + new_pages = self.page_pool.acquire_free_pages(n_empty_pages) + + if new_pages is not None: + return TriePageAttentionCacheAllocation( + cache=self, + tokens=tokens, + last_cached_node=cur_node, + cached_pages=matched_pages, + newly_acquired_pages=new_pages, + start_index=n_cached_tokens, + ) + + # Try eviction + self._evict_pages(n_empty_pages - len(self.page_pool.available_pages)) + new_pages = self.page_pool.acquire_free_pages(n_empty_pages) + + if new_pages is None: + raise CacheAllocationFailure( + "Failed to acquire pages even after attempting eviction from LRU leaves" + ) + + return TriePageAttentionCacheAllocation( + cache=self, + tokens=tokens, + last_cached_node=cur_node, + cached_pages=matched_pages, + newly_acquired_pages=new_pages, + start_index=n_cached_tokens, + ) + + def _evict_pages(self, max_pages: int) -> int: + """Evict up to max_pages pages using LRU strategy. + + Evicts from unreferenced leaf nodes first, working up the trie + as nodes become childless. + + Args: + max_pages: Maximum number of pages to evict + + Returns: + Number of pages actually evicted + """ + pages_to_evict = [] + + # Initialize heap with unreferenced leaves + unused_leaf_heap = [ + (leaf.access_time, leaf) for leaf in self.leaves if leaf.ref_count == 0 + ] + heapq.heapify(unused_leaf_heap) + + # Evict least recently used nodes + while unused_leaf_heap and len(pages_to_evict) < max_pages: + _, leaf = heapq.heappop(unused_leaf_heap) + pages_to_evict.append(leaf.page) + parent = leaf.parent + leaf.unlink() + self.leaves.remove(leaf) + + # If parent becomes childless, it becomes a leaf + if parent is not self.root and not parent.children: + self.leaves.add(parent) + if parent.ref_count == 0: + heapq.heappush(unused_leaf_heap, (parent.access_time, parent)) + + if pages_to_evict: + self.page_pool.free_pages(pages_to_evict) + + return len(pages_to_evict) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py new file mode 100644 index 000000000..72f1d666e --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -0,0 +1,395 @@ +import pytest +from typing import List, Tuple +import shortfin as sf +import shortfin.array as sfnp +from unittest.mock import Mock, MagicMock +import threading +import time +from dataclasses import dataclass + +from shortfin_apps.llm.components.kvcache.trie_attention_cache import ( + TriePagedAttentionCache, +) +from shortfin_apps.llm.components.kvcache.base_attention_cache import ( + CacheAllocationFailure, +) +from shortfin_apps.llm.components.kvcache.page_pool import ( + PagePool, + PageInfo, + PagePoolConfig, +) + + +# Test constants +TEST_PAGE_SIZE = 16 # Tokens per page +TEST_POOL_CAPACITY = 10 + + +@dataclass +class TokenSequence: + """Helper class for test parameterization""" + + tokens: List[int] + description: str + expected_pages: int + expected_cached: int = 0 + + def __str__(self): + return self.description + + +class MockScopedDevice: + """A proper mock for ScopedDevice that implements required interface""" + + def __init__(self): + self._mock = Mock(spec=sf.ScopedDevice) + # Add any necessary attributes/methods the real ScopedDevice has + self._mock.device_id = 0 + self._mock.device_type = "CPU" + + def __repr__(self): + return f"MockScopedDevice(device_id={self._mock.device_id})" + + +@pytest.fixture +def mock_device_array(): + """Create mock device array with proper interface implementation""" + + class MockDeviceArray: + def __init__(self): + self.shape = None + self.dtype = None + + def view(self, *args): + return Mock() + + def copy_from(self, src): + pass + + return MockDeviceArray() + + +@pytest.fixture +def mock_device(): + """Create properly structured mock device""" + return MockScopedDevice() + + +@pytest.fixture +def page_pool(mock_device, mock_device_array): + """Create PagePool with properly structured mock components""" + # Mock the device array creation + original_for_device = sf.array.device_array.for_device + + def mock_for_device(device, shape, dtype): + mock_array = mock_device_array + mock_array.shape = shape + mock_array.dtype = dtype + return mock_array + + sf.array.device_array.for_device = mock_for_device + + try: + config = PagePoolConfig( + dtype=sfnp.float16, + alloc_page_count=TEST_POOL_CAPACITY, + paged_kv_block_size_elements=128, + ) + + pool = PagePool(devices=[mock_device], config=config) + pool.page_tables = [mock_device_array] + return pool + finally: + # Restore original function + sf.array.device_array.for_device = original_for_device + + +@pytest.fixture +def trie_cache(page_pool): + """Create TriePagedAttentionCache instance""" + return TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) + + +@pytest.fixture +def published_sequence(trie_cache): + """Helper fixture that returns a function to publish token sequences""" + + def _publish_sequence(tokens: List[int]) -> None: + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages(len(alloc.pages)) + alloc.release_pages() + + return _publish_sequence + + +def print_tree_state(cache, prefix=""): + """Helper function to print current tree state in a readable format""" + if not hasattr(cache, "root"): + print(f"{prefix}Unable to access trie structure") + return + + def node_info(node): + token_str = f"tokens={list(node.tokens) if node.tokens else 'root'}" + return f"{token_str}, ref_count={node.ref_count}, page_index={node.page.index}" + + def print_node(node, depth=0): + indent = " " * depth + print(f"{prefix}{indent}- {node_info(node)}") + if node.children: + for child in node.children.values(): + print_node(child, depth + 1) + + print(f"{prefix}Tree state:") + print_node(cache.root) + + +# Test sequences for parameterization +basic_sequences = [ + TokenSequence(tokens=[], description="empty_sequence", expected_pages=0), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE // 2)), + description="partial_page", + expected_pages=1, + ), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE)), description="exact_page", expected_pages=1 + ), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE + 1)), + description="overflow_page", + expected_pages=2, + ), + TokenSequence( + tokens=list(range(TEST_PAGE_SIZE * 2)), + description="multiple_pages", + expected_pages=2, + ), +] + +reuse_sequences = [ + (list(range(TEST_PAGE_SIZE)), list(range(TEST_PAGE_SIZE)), "exact_match", 1, 1), + ( + list(range(TEST_PAGE_SIZE * 2)), + list(range(TEST_PAGE_SIZE * 2)), + "multi_page_match", + 2, + 2, + ), + ( + list(range(TEST_PAGE_SIZE * 2)), + list(range(TEST_PAGE_SIZE)) + list(range(100, 100 + TEST_PAGE_SIZE)), + "prefix_match", + 2, + 1, + ), + ( + list(range(TEST_PAGE_SIZE)), + list(range(50, 50 + TEST_PAGE_SIZE)), + "no_match", + 1, + 0, + ), +] + + +@pytest.mark.parametrize("seq", basic_sequences) +def test_basic_allocation(trie_cache, seq): + """Test basic page allocation without reuse""" + allocation = trie_cache.acquire_pages_for_tokens(seq.tokens, extra_token_slots=0) + assert len(allocation.pages) == seq.expected_pages + assert len(allocation.cached_pages) == 0 + assert len(allocation.newly_acquired_pages) == seq.expected_pages + allocation.release_pages() + + +@pytest.mark.parametrize( + "initial_tokens,reuse_tokens,description,total_pages,expected_cached", + reuse_sequences, +) +def test_page_reuse( + trie_cache, + published_sequence, + initial_tokens, + reuse_tokens, + description, + total_pages, + expected_cached, +): + """Test page reuse scenarios""" + # Publish initial sequence + published_sequence(initial_tokens) + + # Try to reuse + allocation = trie_cache.acquire_pages_for_tokens(reuse_tokens, extra_token_slots=0) + assert len(allocation.pages) == total_pages + assert len(allocation.cached_pages) == expected_cached + assert len(allocation.newly_acquired_pages) == total_pages - expected_cached + allocation.release_pages() + + +@pytest.fixture +def filled_cache(trie_cache, published_sequence): + """Fixture that fills cache with numbered sequences""" + sequences = [] + for i in range(TEST_POOL_CAPACITY): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + published_sequence(tokens) + sequences.append(tokens) + return sequences + + +@pytest.mark.parametrize( + "access_count", [1, TEST_POOL_CAPACITY // 2, TEST_POOL_CAPACITY - 1] +) +def test_lru_eviction(trie_cache, access_count): + """Test LRU eviction with different access patterns""" + print(f"\nStarting test_lru_eviction with access_count={access_count}") + + # Create mix of published and unpublished sequences + keep_published = 3 # Number of sequences to keep published + sequences = [] + + # First add some sequences we'll keep published + print("\nPublishing sequences to keep active:") + for i in range(keep_published): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages(1) # Don't release these - they should stay in cache + sequences.append(tokens) + print(f"Published sequence {i} (keeping active)") + print_tree_state(trie_cache, " ") + + # Then add sequences we'll publish but release (evictable) + print("\nAdding releasable sequences:") + for i in range(keep_published, TEST_POOL_CAPACITY): + tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + alloc.publish_pages(1) + alloc.release_pages() # These can be evicted + sequences.append(tokens) + print(f"Added releasable sequence {i}") + print_tree_state(trie_cache, " ") + + print("\nCache state before accessing sequences:") + print_tree_state(trie_cache, " ") + + # Access some sequences to update their LRU status + print(f"\nAccessing {access_count} sequences to update LRU order:") + for i in range(access_count): + print(f"\nAccessing sequence {i}:") + alloc = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) + print_tree_state(trie_cache, " ") + alloc.release_pages() + print(f"After releasing allocation {i}:") + print_tree_state(trie_cache, " ") + + print("\nCache state before attempting new allocation:") + print_tree_state(trie_cache, " ") + print("\nAvailable pages in pool:", len(trie_cache.page_pool.available_pages)) + + # Try to allocate new sequence - should evict least recently used unpublished sequence + new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) + print(f"\nAttempting to allocate new sequence: {new_tokens}") + new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) + print("\nNew allocation succeeded:") + print(f"- Allocated {len(new_alloc.pages)} new pages") + print(f"- Cached pages: {len(new_alloc.cached_pages)}") + print(f"- Newly acquired pages: {len(new_alloc.newly_acquired_pages)}") + print("\nCache state after new allocation:") + print_tree_state(trie_cache, " ") + new_alloc.release_pages() + + # Verify recently accessed sequences AND published sequences weren't evicted + print("\nVerifying preserved sequences:") + for i in range(max(access_count, keep_published)): + print(f"\nChecking sequence {i}:") + recheck = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) + cached_pages = len(recheck.cached_pages) + print(f"- Cached pages found: {cached_pages}") + assert ( + cached_pages == 1 + ), f"Sequence {i} was evicted but should have been preserved" + recheck.release_pages() + + +@pytest.mark.parametrize("publish_steps", [1, 2, 3]) +def test_progressive_publish(trie_cache, publish_steps): + """Test publishing pages progressively""" + tokens = list(range(TEST_PAGE_SIZE * 3)) # Three pages + alloc = trie_cache.acquire_pages_for_tokens(tokens) + + for step in range(publish_steps): + # Publish next page + alloc.publish_pages(step + 1) + + # Verify reuse up to published point + reuse_tokens = tokens[: (step + 1) * TEST_PAGE_SIZE] + reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens) + assert len(reuse_alloc.cached_pages) == step + 1 + reuse_alloc.release_pages() + + alloc.release_pages() + + +@pytest.mark.parametrize("ref_count", [1, 2, 5]) +def test_reference_counting(trie_cache, ref_count): + """Test reference counting with different counts""" + tokens = list(range(TEST_PAGE_SIZE)) + allocations = [] + + # Create initial allocation and publish + first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + first_alloc.publish_pages(1) + allocations.append(first_alloc) + print("\nInitial allocation created") + print_tree_state(trie_cache, " ") + + # Create additional references + for i in range(ref_count - 1): + alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) + allocations.append(alloc) + print(f"\nCreated reference {i+1}") + print_tree_state(trie_cache, " ") + + # Fill remaining cache + remaining = TEST_POOL_CAPACITY - 1 + fill_allocations = [] + for i in range(remaining): + fill_tokens = list( + range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) + ) + alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0) + alloc.publish_pages(1) + fill_allocations.append(alloc) + print(f"\nFilled cache slot {i+1}/{remaining}") + print_tree_state(trie_cache, " ") + + print("\nAttempting allocation that should fail...") + try: + new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) + new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) + print("ERROR: Allocation succeeded when it should have failed!") + print(f"- Allocated {len(new_alloc.pages)} new pages") + print(f"- Cached pages: {len(new_alloc.cached_pages)}") + print( + f"- Number of newly acquired pages: {len(new_alloc.newly_acquired_pages)}" + ) + print(f"- Newly acquired pages: {new_alloc.newly_acquired_pages}") + print("\nPost-allocation state:") + print_tree_state(trie_cache, " ") + new_alloc.release_pages() + pytest.fail("Expected CacheAllocationFailure was not raised") + except CacheAllocationFailure: + print("Success: CacheAllocationFailure raised as expected") + + # Cleanup + print("\nCleaning up allocations...") + for alloc in allocations + fill_allocations: + alloc.release_pages() + + +@pytest.mark.parametrize("tokens_per_page", [0, -1, -100]) +def test_invalid_init(page_pool, tokens_per_page): + """Test validation in __init__""" + with pytest.raises(ValueError): + TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=tokens_per_page) From 8437a7f3f3f67a3186650b7a7c9956c161295ae8 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:26:31 -0800 Subject: [PATCH 02/23] all tests passing --- .../kvcache/trie_attention_cache.py | 5 +- .../kvcache/trie_attention_cache_test.py | 60 +++++++++++++++++-- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index f4614c8d8..ca009ed3a 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -122,7 +122,7 @@ def publish_pages(self, up_to_page_index: int) -> None: publish_token_count = min(len(self.tokens), up_to_page_index * tokens_per_page) cur_node = self.last_cached_node - first_uncached_page_index = self.start_index // tokens_per_page + first_uncached_page_index = len(self.cached_pages) uncached_tokens = [ tuple(self.tokens[i : i + tokens_per_page]) @@ -139,6 +139,9 @@ def publish_pages(self, up_to_page_index: int) -> None: new_node = cur_node.create_child(token_block, page) cur_node = new_node + self.cached_pages.extend(uncached_pages) + self.newly_acquired_pages = self.newly_acquired_pages[len(uncached_pages) :] + if cur_node is not self.cache.root: self.cache.leaves.add(cur_node) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index 72f1d666e..aa271dcec 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -315,20 +315,70 @@ def test_lru_eviction(trie_cache, access_count): @pytest.mark.parametrize("publish_steps", [1, 2, 3]) def test_progressive_publish(trie_cache, publish_steps): """Test publishing pages progressively""" - tokens = list(range(TEST_PAGE_SIZE * 3)) # Three pages + print(f"\nStarting test_progressive_publish with publish_steps={publish_steps}") + + tokens = tuple(range(TEST_PAGE_SIZE * 3)) # Three pages + print(f"\nInitial tokens: {tokens}") + print(f"Tokens per page: {TEST_PAGE_SIZE}") + print( + f"Expected total pages: {len(tokens) // TEST_PAGE_SIZE + (1 if len(tokens) % TEST_PAGE_SIZE else 0)}" + ) + + print("\nInitial cache state:") + print_tree_state(trie_cache) + + print("\nAcquiring initial allocation...") alloc = trie_cache.acquire_pages_for_tokens(tokens) + print(f"Initial allocation pages: {[p.index for p in alloc.pages]}") + print("\nCache state after initial allocation:") + print_tree_state(trie_cache) + + for step in range(1, publish_steps + 1): + print(f"\n--- Step {step} of {publish_steps} ---") - for step in range(publish_steps): # Publish next page - alloc.publish_pages(step + 1) + print(f"Publishing up to page {step}") + alloc.publish_pages(step) + print("\nCache state after publish:") + print_tree_state(trie_cache) # Verify reuse up to published point - reuse_tokens = tokens[: (step + 1) * TEST_PAGE_SIZE] + reuse_tokens = tokens[: (step) * TEST_PAGE_SIZE] + print(f"\nAttempting to reuse tokens: {reuse_tokens}") + print(f"Expected cached pages: {step}") + reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens) - assert len(reuse_alloc.cached_pages) == step + 1 + print(f"Reuse allocation total pages: {len(reuse_alloc.pages)}") + print(f"Reuse allocation cached pages: {len(reuse_alloc.cached_pages)}") + print(f"Cached page indices: {[p.index for p in reuse_alloc.cached_pages]}") + print( + f"New page indices: {[p.index for p in reuse_alloc.newly_acquired_pages]}" + ) + + print("\nCache state after reuse attempt:") + print_tree_state(trie_cache) + + try: + assert len(reuse_alloc.cached_pages) == step + except AssertionError: + print("\nASSERTION FAILED!") + print( + f"Expected {step} cached pages but got {len(reuse_alloc.cached_pages)}" + ) + print("Cached pages details:") + for i, page in enumerate(reuse_alloc.cached_pages): + print( + f"Page {i}: index={page.index}, token_offset={page.token_offset}, token_count={page.token_count}" + ) + raise + reuse_alloc.release_pages() + print("\nCache state after releasing reuse allocation:") + print_tree_state(trie_cache) alloc.release_pages() + print("\nFinal cache state after releasing initial allocation:") + print_tree_state(trie_cache) @pytest.mark.parametrize("ref_count", [1, 2, 5]) From 24a4313e7e307b7a321af7a89fd29d487a248fa0 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:56:05 -0800 Subject: [PATCH 03/23] all but publishing working --- .../kvcache/trie_attention_cache.py | 66 ++++++++----------- .../shortfin_apps/llm/components/messages.py | 2 +- .../kvcache/trie_attention_cache_test.py | 48 ++++++-------- 3 files changed, 51 insertions(+), 65 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index ca009ed3a..bcfe0f04f 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -73,16 +73,15 @@ def __eq__(self, other: object) -> bool: class TriePageAttentionCacheAllocation(PageAllocation): """Represents a page allocation in the trie-based cache. - Tracks both previously cached pages and newly allocated pages, + Tracks sequence of pages and which ones are already published to the cache, implementing the PageAllocation protocol for the trie cache. Attributes: cache: The parent cache this allocation belongs to tokens: Complete sequence of tokens this allocation represents last_cached_node: Last matched node in the trie - cached_pages: List of pages already in cache - newly_acquired_pages: List of newly allocated pages - start_index: Index where cached tokens end and new tokens begin + pages: List of all pages in allocation + number_of_published_pages: Number of pages that are published to the cache """ def __init__( @@ -92,62 +91,57 @@ def __init__( last_cached_node: TrieNode, cached_pages: List[PageInfo], newly_acquired_pages: List[PageInfo], - start_index: int, ): self.cache = cache self.tokens = tokens self.last_cached_node = last_cached_node - self.cached_pages = cached_pages - self.newly_acquired_pages = newly_acquired_pages - self.start_index = start_index + self._pages = cached_pages + newly_acquired_pages + self.number_of_published_pages = len(cached_pages) self._is_released = False @property def pages(self) -> List[PageInfo]: - """List all pages in this allocation, both cached and new. + return self._pages - Returns: - Combined list of cached and newly acquired pages - """ - return self.cached_pages + self.newly_acquired_pages - - def publish_pages(self, up_to_page_index: int) -> None: - """Make pages available in the cache up to the specified index. + def publish_pages(self, up_to_page_index) -> None: + """Make pages available in the cache for the specified tokens. Args: - up_to_page_index: Number of pages to publish, starting from the beginning + tokens_to_publish: Tokens to publish to the cache + + Raises: + ValueError: If tokens don't match allocation or exceed available pages """ tokens_per_page = self.cache.tokens_per_page - publish_token_count = min(len(self.tokens), up_to_page_index * tokens_per_page) - - cur_node = self.last_cached_node - first_uncached_page_index = len(self.cached_pages) + # Create token blocks for unpublished pages + start_token = self.number_of_published_pages * tokens_per_page - uncached_tokens = [ + unpublished_tokens = [ tuple(self.tokens[i : i + tokens_per_page]) - for i in range( - first_uncached_page_index * tokens_per_page, - publish_token_count, - tokens_per_page, - ) + for i in range(start_token, tokens_per_page) ] - uncached_pages = self.newly_acquired_pages[: len(uncached_tokens)] + unpublished_pages = self._pages[ + self.number_of_published_pages : up_to_page_index + ] - for token_block, page in zip(uncached_tokens, uncached_pages): + # Add unpublished pages to trie + cur_node = self.last_cached_node + for token_block, page in zip(unpublished_tokens, unpublished_pages): new_node = cur_node.create_child(token_block, page) cur_node = new_node - self.cached_pages.extend(uncached_pages) - self.newly_acquired_pages = self.newly_acquired_pages[len(uncached_pages) :] - if cur_node is not self.cache.root: self.cache.leaves.add(cur_node) - cur_node.ref_count += 1 - self.last_cached_node.ref_count -= 1 - self.last_cached_node = cur_node + # Update reference counts + if unpublished_tokens: + cur_node.ref_count += 1 + self.last_cached_node.ref_count -= 1 + self.last_cached_node = cur_node + + self.number_of_published_pages = up_to_page_index def release_pages(self) -> None: """Release the allocation's reference to its pages. @@ -267,7 +261,6 @@ def acquire_pages_for_tokens( last_cached_node=cur_node, cached_pages=matched_pages, newly_acquired_pages=new_pages, - start_index=n_cached_tokens, ) # Try eviction @@ -285,7 +278,6 @@ def acquire_pages_for_tokens( last_cached_node=cur_node, cached_pages=matched_pages, newly_acquired_pages=new_pages, - start_index=n_cached_tokens, ) def _evict_pages(self, max_pages: int) -> int: diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index c03900782..bc1f851e2 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -58,7 +58,7 @@ def reset(self, phase: InferencePhase): def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: return [] - indices = [p.index for p in self.allocation.pages[:max_len]] + indices = [p.index for p in self.allocation._pages[:max_len]] return indices def publish_allocated_pages(self, up_to_page_index: int): diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index aa271dcec..4ce71a9a7 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -197,8 +197,13 @@ def test_basic_allocation(trie_cache, seq): """Test basic page allocation without reuse""" allocation = trie_cache.acquire_pages_for_tokens(seq.tokens, extra_token_slots=0) assert len(allocation.pages) == seq.expected_pages - assert len(allocation.cached_pages) == 0 - assert len(allocation.newly_acquired_pages) == seq.expected_pages + assert allocation.number_of_published_pages == 0 + assert ( + len(allocation.pages) - allocation.number_of_published_pages + == seq.expected_pages + ) + # Replace publishing with tokens + allocation.publish_pages(len(allocation.pages)) allocation.release_pages() @@ -222,8 +227,13 @@ def test_page_reuse( # Try to reuse allocation = trie_cache.acquire_pages_for_tokens(reuse_tokens, extra_token_slots=0) assert len(allocation.pages) == total_pages - assert len(allocation.cached_pages) == expected_cached - assert len(allocation.newly_acquired_pages) == total_pages - expected_cached + assert allocation.number_of_published_pages == expected_cached + assert ( + len(allocation.pages) - allocation.number_of_published_pages + == total_pages - expected_cached + ) + # Replace publishing with tokens + allocation.publish_pages(len(allocation.pages)) allocation.release_pages() @@ -292,9 +302,6 @@ def test_lru_eviction(trie_cache, access_count): print(f"\nAttempting to allocate new sequence: {new_tokens}") new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) print("\nNew allocation succeeded:") - print(f"- Allocated {len(new_alloc.pages)} new pages") - print(f"- Cached pages: {len(new_alloc.cached_pages)}") - print(f"- Newly acquired pages: {len(new_alloc.newly_acquired_pages)}") print("\nCache state after new allocation:") print_tree_state(trie_cache, " ") new_alloc.release_pages() @@ -304,7 +311,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(max(access_count, keep_published)): print(f"\nChecking sequence {i}:") recheck = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0) - cached_pages = len(recheck.cached_pages) + cached_pages = recheck.number_of_published_pages print(f"- Cached pages found: {cached_pages}") assert ( cached_pages == 1 @@ -338,6 +345,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page print(f"Publishing up to page {step}") + # Replace publishing with tokens alloc.publish_pages(step) print("\nCache state after publish:") print_tree_state(trie_cache) @@ -349,27 +357,18 @@ def test_progressive_publish(trie_cache, publish_steps): reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens) print(f"Reuse allocation total pages: {len(reuse_alloc.pages)}") - print(f"Reuse allocation cached pages: {len(reuse_alloc.cached_pages)}") - print(f"Cached page indices: {[p.index for p in reuse_alloc.cached_pages]}") - print( - f"New page indices: {[p.index for p in reuse_alloc.newly_acquired_pages]}" - ) + print(f"Reuse allocation cached pages: {reuse_alloc.number_of_published_pages}") print("\nCache state after reuse attempt:") print_tree_state(trie_cache) try: - assert len(reuse_alloc.cached_pages) == step + assert reuse_alloc.number_of_published_pages == step except AssertionError: print("\nASSERTION FAILED!") print( - f"Expected {step} cached pages but got {len(reuse_alloc.cached_pages)}" + f"Expected {step} cached pages but got {reuse_alloc.number_of_published_pages}" ) - print("Cached pages details:") - for i, page in enumerate(reuse_alloc.cached_pages): - print( - f"Page {i}: index={page.index}, token_offset={page.token_offset}, token_count={page.token_count}" - ) raise reuse_alloc.release_pages() @@ -389,7 +388,8 @@ def test_reference_counting(trie_cache, ref_count): # Create initial allocation and publish first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - first_alloc.publish_pages(1) + # Replace publishing with tokens + first_alloc.publish_pages(len(first_alloc.pages)) allocations.append(first_alloc) print("\nInitial allocation created") print_tree_state(trie_cache, " ") @@ -419,12 +419,6 @@ def test_reference_counting(trie_cache, ref_count): new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE)) new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0) print("ERROR: Allocation succeeded when it should have failed!") - print(f"- Allocated {len(new_alloc.pages)} new pages") - print(f"- Cached pages: {len(new_alloc.cached_pages)}") - print( - f"- Number of newly acquired pages: {len(new_alloc.newly_acquired_pages)}" - ) - print(f"- Newly acquired pages: {new_alloc.newly_acquired_pages}") print("\nPost-allocation state:") print_tree_state(trie_cache, " ") new_alloc.release_pages() From 5c87b2bc20f7a2e4fb71cef5b557ee0f9a4ccbd2 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 16:57:44 -0800 Subject: [PATCH 04/23] naming consistency Page -> Paged --- .../llm/components/kvcache/trie_attention_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index bcfe0f04f..2295f6455 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -70,7 +70,7 @@ def __eq__(self, other: object) -> bool: return self is other -class TriePageAttentionCacheAllocation(PageAllocation): +class TriePagedAttentionCacheAllocation(PageAllocation): """Represents a page allocation in the trie-based cache. Tracks sequence of pages and which ones are already published to the cache, @@ -255,7 +255,7 @@ def acquire_pages_for_tokens( new_pages = self.page_pool.acquire_free_pages(n_empty_pages) if new_pages is not None: - return TriePageAttentionCacheAllocation( + return TriePagedAttentionCacheAllocation( cache=self, tokens=tokens, last_cached_node=cur_node, @@ -272,7 +272,7 @@ def acquire_pages_for_tokens( "Failed to acquire pages even after attempting eviction from LRU leaves" ) - return TriePageAttentionCacheAllocation( + return TriePagedAttentionCacheAllocation( cache=self, tokens=tokens, last_cached_node=cur_node, From 07d30408bc8309338060d6ab580ed3895ffab579 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:13:57 -0800 Subject: [PATCH 05/23] passing all tests now --- .../kvcache/base_attention_cache.py | 4 +- .../kvcache/trie_attention_cache.py | 38 ++++++++++++++++--- .../shortfin_apps/llm/components/messages.py | 4 +- .../kvcache/trie_attention_cache_test.py | 18 ++++----- 4 files changed, 45 insertions(+), 19 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index c86379368..f571907de 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -34,7 +34,7 @@ def pages(self) -> List[PageInfo]: pass @abstractmethod - def publish_pages(self, up_to_page_index) -> None: + def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: """Makes pages[0:up_to_page_index] available to other requests.""" pass @@ -56,7 +56,7 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): def pages(self) -> List[PageInfo]: return list(self._pages) - def publish_pages(self, up_to_page_index) -> None: + def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: pass def release_pages(self) -> None: diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 2295f6455..283076c33 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -103,7 +103,7 @@ def __init__( def pages(self) -> List[PageInfo]: return self._pages - def publish_pages(self, up_to_page_index) -> None: + def publish_pages(self, tokens, publish_incomplete_page=False) -> None: """Make pages available in the cache for the specified tokens. Args: @@ -112,21 +112,47 @@ def publish_pages(self, up_to_page_index) -> None: Raises: ValueError: If tokens don't match allocation or exceed available pages """ + # If we have more tokens, publish pages up to the incoming tokens. + # If incoming has more tokens, replace our tokens with incoming tokens and publish pages up to the incoming tokens. + + def has_common_prefix(tokens1, tokens2): + for t1, t2 in zip(tokens1, tokens2): + if t1 != t2: + return False + return True + + if not has_common_prefix(self.tokens, tokens): + raise ValueError( + "Tokens provided in publish_pages do not match tokens in allocation" + ) + + if len(tokens) > len(self.tokens): + self.tokens = tokens + tokens_per_page = self.cache.tokens_per_page - # Create token blocks for unpublished pages - start_token = self.number_of_published_pages * tokens_per_page + number_of_pages_to_publish = len(tokens) / tokens_per_page + if publish_incomplete_page: + number_of_pages_to_publish = math.ceil(number_of_pages_to_publish) + else: + number_of_pages_to_publish = math.floor(number_of_pages_to_publish) + # Create token blocks for unpublished pages + start_token_index = self.number_of_published_pages * tokens_per_page unpublished_tokens = [ tuple(self.tokens[i : i + tokens_per_page]) - for i in range(start_token, tokens_per_page) + for i in range(start_token_index, len(self.tokens), tokens_per_page) ] unpublished_pages = self._pages[ - self.number_of_published_pages : up_to_page_index + self.number_of_published_pages : number_of_pages_to_publish ] # Add unpublished pages to trie + if publish_incomplete_page: + raise NotImplementedError( + "Additional work needed here to support publishing incomplete pages to ensure that we finish up a page before attaching child nodes to it." + ) cur_node = self.last_cached_node for token_block, page in zip(unpublished_tokens, unpublished_pages): new_node = cur_node.create_child(token_block, page) @@ -141,7 +167,7 @@ def publish_pages(self, up_to_page_index) -> None: self.last_cached_node.ref_count -= 1 self.last_cached_node = cur_node - self.number_of_published_pages = up_to_page_index + self.number_of_published_pages = number_of_pages_to_publish def release_pages(self) -> None: """Release the allocation's reference to its pages. diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index bc1f851e2..2f41c9834 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -63,7 +63,9 @@ def cache_page_indices(self, max_len: int) -> list[int]: def publish_allocated_pages(self, up_to_page_index: int): assert self.allocation - self.allocation.publish_pages(up_to_page_index) + self.allocation.publish_pages( + self.input_token_ids, publish_incomplete_pages=False + ) def free_cache_pages(self): if self.allocation: diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index 4ce71a9a7..b06ea5d4e 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -116,7 +116,7 @@ def published_sequence(trie_cache): def _publish_sequence(tokens: List[int]) -> None: alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(len(alloc.pages)) + alloc.publish_pages(alloc.tokens) alloc.release_pages() return _publish_sequence @@ -202,8 +202,7 @@ def test_basic_allocation(trie_cache, seq): len(allocation.pages) - allocation.number_of_published_pages == seq.expected_pages ) - # Replace publishing with tokens - allocation.publish_pages(len(allocation.pages)) + allocation.publish_pages(allocation.tokens) allocation.release_pages() @@ -232,8 +231,7 @@ def test_page_reuse( len(allocation.pages) - allocation.number_of_published_pages == total_pages - expected_cached ) - # Replace publishing with tokens - allocation.publish_pages(len(allocation.pages)) + allocation.publish_pages(allocation.tokens) allocation.release_pages() @@ -264,7 +262,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(1) # Don't release these - they should stay in cache + alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) sequences.append(tokens) print(f"Published sequence {i} (keeping active)") print_tree_state(trie_cache, " ") @@ -274,7 +272,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published, TEST_POOL_CAPACITY): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(1) + alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) alloc.release_pages() # These can be evicted sequences.append(tokens) print(f"Added releasable sequence {i}") @@ -346,7 +344,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page print(f"Publishing up to page {step}") # Replace publishing with tokens - alloc.publish_pages(step) + alloc.publish_pages(alloc.tokens[: (step) * TEST_PAGE_SIZE]) print("\nCache state after publish:") print_tree_state(trie_cache) @@ -389,7 +387,7 @@ def test_reference_counting(trie_cache, ref_count): # Create initial allocation and publish first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) # Replace publishing with tokens - first_alloc.publish_pages(len(first_alloc.pages)) + first_alloc.publish_pages(first_alloc.tokens) allocations.append(first_alloc) print("\nInitial allocation created") print_tree_state(trie_cache, " ") @@ -409,7 +407,7 @@ def test_reference_counting(trie_cache, ref_count): range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) ) alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0) - alloc.publish_pages(1) + alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) fill_allocations.append(alloc) print(f"\nFilled cache slot {i+1}/{remaining}") print_tree_state(trie_cache, " ") From 8f83118db1f78acc70a19e683d9844827d047242 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:18:38 -0800 Subject: [PATCH 06/23] name and documentation update publish_pages -> publish_pages_for_tokens --- .../components/kvcache/base_attention_cache.py | 8 +++++--- .../components/kvcache/trie_attention_cache.py | 2 +- .../shortfin_apps/llm/components/messages.py | 2 +- .../kvcache/trie_attention_cache_test.py | 16 ++++++++-------- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index f571907de..a7ee9c369 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -34,8 +34,10 @@ def pages(self) -> List[PageInfo]: pass @abstractmethod - def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: - """Makes pages[0:up_to_page_index] available to other requests.""" + def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None: + """ + Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py. + """ pass @abstractmethod @@ -56,7 +58,7 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): def pages(self) -> List[PageInfo]: return list(self._pages) - def publish_pages(self, tokens, publish_incomplete_pages=False) -> None: + def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None: pass def release_pages(self) -> None: diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 283076c33..1c9872cda 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -103,7 +103,7 @@ def __init__( def pages(self) -> List[PageInfo]: return self._pages - def publish_pages(self, tokens, publish_incomplete_page=False) -> None: + def publish_pages_for_tokens(self, tokens, publish_incomplete_page=False) -> None: """Make pages available in the cache for the specified tokens. Args: diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 2f41c9834..118ae2225 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -63,7 +63,7 @@ def cache_page_indices(self, max_len: int) -> list[int]: def publish_allocated_pages(self, up_to_page_index: int): assert self.allocation - self.allocation.publish_pages( + self.allocation.publish_pages_for_tokens( self.input_token_ids, publish_incomplete_pages=False ) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index b06ea5d4e..6d216c790 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -116,7 +116,7 @@ def published_sequence(trie_cache): def _publish_sequence(tokens: List[int]) -> None: alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens) + alloc.publish_pages_for_tokens(alloc.tokens) alloc.release_pages() return _publish_sequence @@ -202,7 +202,7 @@ def test_basic_allocation(trie_cache, seq): len(allocation.pages) - allocation.number_of_published_pages == seq.expected_pages ) - allocation.publish_pages(allocation.tokens) + allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages() @@ -231,7 +231,7 @@ def test_page_reuse( len(allocation.pages) - allocation.number_of_published_pages == total_pages - expected_cached ) - allocation.publish_pages(allocation.tokens) + allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages() @@ -262,7 +262,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) sequences.append(tokens) print(f"Published sequence {i} (keeping active)") print_tree_state(trie_cache, " ") @@ -272,7 +272,7 @@ def test_lru_eviction(trie_cache, access_count): for i in range(keep_published, TEST_POOL_CAPACITY): tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE)) alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) alloc.release_pages() # These can be evicted sequences.append(tokens) print(f"Added releasable sequence {i}") @@ -344,7 +344,7 @@ def test_progressive_publish(trie_cache, publish_steps): # Publish next page print(f"Publishing up to page {step}") # Replace publishing with tokens - alloc.publish_pages(alloc.tokens[: (step) * TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[: (step) * TEST_PAGE_SIZE]) print("\nCache state after publish:") print_tree_state(trie_cache) @@ -387,7 +387,7 @@ def test_reference_counting(trie_cache, ref_count): # Create initial allocation and publish first_alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0) # Replace publishing with tokens - first_alloc.publish_pages(first_alloc.tokens) + first_alloc.publish_pages_for_tokens(first_alloc.tokens) allocations.append(first_alloc) print("\nInitial allocation created") print_tree_state(trie_cache, " ") @@ -407,7 +407,7 @@ def test_reference_counting(trie_cache, ref_count): range(100 + i * TEST_PAGE_SIZE, 100 + (i + 1) * TEST_PAGE_SIZE) ) alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0) - alloc.publish_pages(alloc.tokens[:TEST_PAGE_SIZE]) + alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE]) fill_allocations.append(alloc) print(f"\nFilled cache slot {i+1}/{remaining}") print_tree_state(trie_cache, " ") From 5c376667470cc6973c48897ab43a2f9979b6794c Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:23:27 -0800 Subject: [PATCH 07/23] undo accidental edit of messages.py --- shortfin/python/shortfin_apps/llm/components/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 118ae2225..a8d0f871b 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -58,7 +58,7 @@ def reset(self, phase: InferencePhase): def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: return [] - indices = [p.index for p in self.allocation._pages[:max_len]] + indices = [p.index for p in self.allocation.pages[:max_len]] return indices def publish_allocated_pages(self, up_to_page_index: int): From 802c4e28bb9e714677c150760636f95a4fc9d5be Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 27 Nov 2024 17:36:38 -0800 Subject: [PATCH 08/23] remove not-very-useful test case --- .../llm/components/kvcache/trie_attention_cache_test.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index 6d216c790..ce0025419 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -428,10 +428,3 @@ def test_reference_counting(trie_cache, ref_count): print("\nCleaning up allocations...") for alloc in allocations + fill_allocations: alloc.release_pages() - - -@pytest.mark.parametrize("tokens_per_page", [0, -1, -100]) -def test_invalid_init(page_pool, tokens_per_page): - """Test validation in __init__""" - with pytest.raises(ValueError): - TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=tokens_per_page) From 3c7cc6c061c20c236fc439132595b5960ec8a51d Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 11:29:39 -0800 Subject: [PATCH 09/23] fix missing rename --- .../apps/llm/components/kvcache/base_attention_cache_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py index 113da6912..181cc8d9f 100644 --- a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -32,7 +32,7 @@ def acquire_free_pages(self, count: int) -> List[PageInfo]: except queue.Empty: return None - def release_pages(self, pages): + def free_pages(self, pages): for page in pages: self._queue.put(page) From f59485e52dfef620a1d4dcad743b1c45634b4dba Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 11:39:19 -0800 Subject: [PATCH 10/23] use trie by default --- .../shortfin_apps/llm/components/config_struct.py | 2 ++ .../shortfin_apps/llm/components/service.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 141c7a7eb..e791b4be2 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -86,6 +86,8 @@ class PagedKVCacheParams: # Size of the cache on each device. device_block_count: int + cache_type: str = "trie" # currently supporting base and trie + @dataclass_json(undefined=Undefined.RAISE) @dataclass diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 2f942aec7..93b922c65 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -16,6 +16,7 @@ CacheAllocationFailure, PageAllocation, ) +from .kvcache.trie_attention_cache import TriePagedAttentionCache from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo from .config_struct import ModelParams from .manager import SystemManager @@ -67,10 +68,16 @@ def __init__( page_pool = PagePool( devices=self.main_fiber.devices_dict.values(), config=page_pool_config ) - self.page_cache = BasePagedAttentionCache( - page_pool=page_pool, - tokens_per_page=model_params.paged_kv_cache.block_seq_stride, - ) + if model_params.paged_kv_cache.cache_type == "trie": + self.page_cache = TriePagedAttentionCache( + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, + ) + elif model_params.paged_kv_cache.cache_type == "base": + self.page_cache = BasePagedAttentionCache( + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, + ) self.program_isolation = PROG_ISOLATIONS[program_isolation] From 646c167256812c24b286a75df78c2908d32ee1b3 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 11:40:28 -0800 Subject: [PATCH 11/23] add errmsg for unknown kvcache type --- shortfin/python/shortfin_apps/llm/components/service.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 93b922c65..8b9c39c68 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -78,6 +78,10 @@ def __init__( page_pool=page_pool, tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) + else: + raise ValueError( + f"Unknown model_params.paged_kv_cache.cache_type {model_params.paged_kv_cache.cache_type}. Currently only supporting 'trie' and 'base'." + ) self.program_isolation = PROG_ISOLATIONS[program_isolation] From 5801d2765845212d434fdf077d1f8dd34a9cf6b9 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 11:51:13 -0800 Subject: [PATCH 12/23] fix typoes --- .../llm/components/kvcache/base_attention_cache.py | 4 ++-- shortfin/python/shortfin_apps/llm/components/messages.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index a7ee9c369..b189284b8 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -34,7 +34,7 @@ def pages(self) -> List[PageInfo]: pass @abstractmethod - def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None: + def publish_pages_for_tokens(self, tokens, publish_incomplete_page=False) -> None: """ Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py. """ @@ -58,7 +58,7 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): def pages(self) -> List[PageInfo]: return list(self._pages) - def publish_pages_for_tokens(self, tokens, publish_incomplete_pages=False) -> None: + def publish_pages_for_tokens(self, tokens, publish_incomplete_page=False) -> None: pass def release_pages(self) -> None: diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index a8d0f871b..8de450882 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -64,7 +64,7 @@ def cache_page_indices(self, max_len: int) -> list[int]: def publish_allocated_pages(self, up_to_page_index: int): assert self.allocation self.allocation.publish_pages_for_tokens( - self.input_token_ids, publish_incomplete_pages=False + self.input_token_ids, publish_incomplete_page=False ) def free_cache_pages(self): From 7f6e5a32da0aa60b98e585e66d6e6fe7ad04fa4b Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 13:56:55 -0800 Subject: [PATCH 13/23] trie doesn't work yet reverting to base for default --- shortfin/python/shortfin_apps/llm/components/config_struct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index e791b4be2..eb93f017e 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -86,7 +86,7 @@ class PagedKVCacheParams: # Size of the cache on each device. device_block_count: int - cache_type: str = "trie" # currently supporting base and trie + cache_type: str = "base" # currently supporting base and trie @dataclass_json(undefined=Undefined.RAISE) From ca29397a8f70520aeb3a8f8de6d367627a026b6c Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 13:59:32 -0800 Subject: [PATCH 14/23] rename cache_type to prefix_sharing_algorithm --- app_tests/integration_tests/llm/shortfin/conftest.py | 6 +++++- .../python/shortfin_apps/llm/components/config_struct.py | 2 +- shortfin/python/shortfin_apps/llm/components/service.py | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 0d40119c7..9724485cd 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -83,7 +83,11 @@ def model_test_dir(request, tmp_path_factory): "prefill_batch_sizes": batch_sizes, "decode_batch_sizes": batch_sizes, "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + "paged_kv_cache": { + "block_seq_stride": 16, + "device_block_count": 256, + "prefix_sharing_algorithm": "none", + }, } logger.info(f"Saving edited config to: {edited_config_path}\n") logger.info(f"Config: {json.dumps(config, indent=2)}") diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index eb93f017e..7caed5d07 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -86,7 +86,7 @@ class PagedKVCacheParams: # Size of the cache on each device. device_block_count: int - cache_type: str = "base" # currently supporting base and trie + prefix_sharing_algorithm: str = "none" # currently supporting none and trie @dataclass_json(undefined=Undefined.RAISE) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 8b9c39c68..ed1be03db 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -68,19 +68,19 @@ def __init__( page_pool = PagePool( devices=self.main_fiber.devices_dict.values(), config=page_pool_config ) - if model_params.paged_kv_cache.cache_type == "trie": + if model_params.paged_kv_cache.prefix_sharing_algorithm == "trie": self.page_cache = TriePagedAttentionCache( page_pool=page_pool, tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) - elif model_params.paged_kv_cache.cache_type == "base": + elif model_params.paged_kv_cache.prefix_sharing_algorithm == "none": self.page_cache = BasePagedAttentionCache( page_pool=page_pool, tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) else: raise ValueError( - f"Unknown model_params.paged_kv_cache.cache_type {model_params.paged_kv_cache.cache_type}. Currently only supporting 'trie' and 'base'." + f"Unknown model_params.paged_kv_cache.prefix_sharing_algorithm {model_params.paged_kv_cache.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'." ) self.program_isolation = PROG_ISOLATIONS[program_isolation] From 7abaa7c7f8928f9d8d9a8f0acffa0b47e9acfd45 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 14:02:24 -0800 Subject: [PATCH 15/23] add another test case for trie and xfail it --- .../integration_tests/llm/shortfin/conftest.py | 1 + .../llm/shortfin/cpu_llm_server_test.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 9724485cd..438914b31 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -51,6 +51,7 @@ def model_test_dir(request, tmp_path_factory): tokenizer_id = request.param["tokenizer_id"] settings = request.param["settings"] batch_sizes = request.param["batch_sizes"] + prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"] tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") hf_home = os.environ.get("HF_HOME", None) diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py index c4da9e4eb..77de2ad82 100644 --- a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py +++ b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py @@ -72,9 +72,24 @@ def do_generate(prompt, port): "tokenizer_id": "openlm-research/open_llama_3b_v2", "settings": CPU_SETTINGS, "batch_sizes": [1, 4], + "prefix_sharing_algorithm": "none", }, {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - ) + ), + pytest.param( + { + "repo_id": "SlyEcho/open_llama_3b_v2_gguf", + "model_file": "open-llama-3b-v2-f16.gguf", + "tokenizer_id": "openlm-research/open_llama_3b_v2", + "settings": CPU_SETTINGS, + "batch_sizes": [1, 4], + "prefix_sharing_algorithm": "trie", + }, + {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, + marks=pytest.mark.xfail( + reason="Trie-based prefix sharing not yet supported" + ), + ), ], indirect=True, ) From 1e35634b496f157546c0cc9cbc7d894af8f47e97 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 14:10:25 -0800 Subject: [PATCH 16/23] fix xpass by fixing a place where i forgot to pass prefix_sharing_algorithm in conftest.py --- app_tests/integration_tests/llm/shortfin/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 438914b31..f6a23960a 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -87,7 +87,7 @@ def model_test_dir(request, tmp_path_factory): "paged_kv_cache": { "block_seq_stride": 16, "device_block_count": 256, - "prefix_sharing_algorithm": "none", + "prefix_sharing_algorithm": prefix_sharing_algorithm, }, } logger.info(f"Saving edited config to: {edited_config_path}\n") From 1a419f120bd7a0e1ddd8eb2dc9b1d44d6ef74656 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 14:27:08 -0800 Subject: [PATCH 17/23] organize test cases --- .../kvcache/trie_attention_cache_test.py | 138 +++++++++--------- 1 file changed, 70 insertions(+), 68 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index ce0025419..0f49efda8 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -143,93 +143,95 @@ def print_node(node, depth=0): print_node(cache.root) -# Test sequences for parameterization basic_sequences = [ - TokenSequence(tokens=[], description="empty_sequence", expected_pages=0), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE // 2)), - description="partial_page", - expected_pages=1, - ), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE)), description="exact_page", expected_pages=1 - ), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE + 1)), - description="overflow_page", - expected_pages=2, - ), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE * 2)), - description="multiple_pages", - expected_pages=2, - ), -] - -reuse_sequences = [ - (list(range(TEST_PAGE_SIZE)), list(range(TEST_PAGE_SIZE)), "exact_match", 1, 1), - ( - list(range(TEST_PAGE_SIZE * 2)), - list(range(TEST_PAGE_SIZE * 2)), - "multi_page_match", - 2, - 2, - ), - ( - list(range(TEST_PAGE_SIZE * 2)), - list(range(TEST_PAGE_SIZE)) + list(range(100, 100 + TEST_PAGE_SIZE)), - "prefix_match", - 2, - 1, - ), - ( - list(range(TEST_PAGE_SIZE)), - list(range(50, 50 + TEST_PAGE_SIZE)), - "no_match", - 1, - 0, - ), + {"tokens": [], "description": "empty_sequence", "expected_pages": 0}, + { + "tokens": list(range(TEST_PAGE_SIZE // 2)), + "description": "partial_page", + "expected_pages": 1, + }, + { + "tokens": list(range(TEST_PAGE_SIZE)), + "description": "exact_page", + "expected_pages": 1, + }, + { + "tokens": list(range(TEST_PAGE_SIZE + 1)), + "description": "overflow_page", + "expected_pages": 2, + }, + { + "tokens": list(range(TEST_PAGE_SIZE * 2)), + "description": "multiple_pages", + "expected_pages": 2, + }, ] -@pytest.mark.parametrize("seq", basic_sequences) -def test_basic_allocation(trie_cache, seq): +@pytest.mark.parametrize("test_sequence", basic_sequences) +def test_basic_allocation(trie_cache, test_sequence): """Test basic page allocation without reuse""" - allocation = trie_cache.acquire_pages_for_tokens(seq.tokens, extra_token_slots=0) - assert len(allocation.pages) == seq.expected_pages + allocation = trie_cache.acquire_pages_for_tokens( + test_sequence["tokens"], extra_token_slots=0 + ) + assert len(allocation.pages) == test_sequence["expected_pages"] assert allocation.number_of_published_pages == 0 assert ( len(allocation.pages) - allocation.number_of_published_pages - == seq.expected_pages + == test_sequence["expected_pages"] ) allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages() -@pytest.mark.parametrize( - "initial_tokens,reuse_tokens,description,total_pages,expected_cached", - reuse_sequences, -) -def test_page_reuse( - trie_cache, - published_sequence, - initial_tokens, - reuse_tokens, - description, - total_pages, - expected_cached, -): +reuse_sequences = [ + { + "initial_tokens": list(range(TEST_PAGE_SIZE)), + "reuse_tokens": list(range(TEST_PAGE_SIZE)), + "description": "exact_match", + "total_pages": 1, + "expected_cached": 1, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE * 2)), + "reuse_tokens": list(range(TEST_PAGE_SIZE * 2)), + "description": "multi_page_match", + "total_pages": 2, + "expected_cached": 2, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE * 2)), + "reuse_tokens": list(range(TEST_PAGE_SIZE)) + + list(range(100, 100 + TEST_PAGE_SIZE)), + "description": "prefix_match", + "total_pages": 2, + "expected_cached": 1, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE)), + "reuse_tokens": list(range(50, 50 + TEST_PAGE_SIZE)), + "description": "no_match", + "total_pages": 1, + "expected_cached": 0, + }, +] + + +@pytest.mark.parametrize("test_sequences", reuse_sequences) +def test_page_reuse(trie_cache, published_sequence, test_sequences): """Test page reuse scenarios""" # Publish initial sequence - published_sequence(initial_tokens) + published_sequence(test_sequences["initial_tokens"]) # Try to reuse - allocation = trie_cache.acquire_pages_for_tokens(reuse_tokens, extra_token_slots=0) - assert len(allocation.pages) == total_pages - assert allocation.number_of_published_pages == expected_cached + allocation = trie_cache.acquire_pages_for_tokens( + test_sequences["reuse_tokens"], extra_token_slots=0 + ) + assert len(allocation.pages) == test_sequences["total_pages"] + assert allocation.number_of_published_pages == test_sequences["expected_cached"] assert ( len(allocation.pages) - allocation.number_of_published_pages - == total_pages - expected_cached + == test_sequences["total_pages"] - test_sequences["expected_cached"] ) allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages() From 85bc5fc785588c7d7b690dab4e059fd15701a113 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 15:38:53 -0800 Subject: [PATCH 18/23] add an extend_allocation method to cache allocations and make decode requests use that instead --- .../llm/shortfin/cpu_llm_server_test.py | 12 ++--- .../kvcache/base_attention_cache.py | 20 +++++++ .../kvcache/trie_attention_cache.py | 54 +++++++++++++++++++ .../shortfin_apps/llm/components/messages.py | 3 +- .../shortfin_apps/llm/components/service.py | 18 ++----- 5 files changed, 84 insertions(+), 23 deletions(-) diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py index 77de2ad82..3064961cf 100644 --- a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py +++ b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py @@ -65,16 +65,19 @@ def do_generate(prompt, port): @pytest.mark.parametrize( "model_test_dir,llm_server", [ - ( + pytest.param( { "repo_id": "SlyEcho/open_llama_3b_v2_gguf", "model_file": "open-llama-3b-v2-f16.gguf", "tokenizer_id": "openlm-research/open_llama_3b_v2", "settings": CPU_SETTINGS, "batch_sizes": [1, 4], - "prefix_sharing_algorithm": "none", + "prefix_sharing_algorithm": "trie", }, {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, + marks=pytest.mark.xfail( + reason="Trie-based prefix sharing not yet supported" + ), ), pytest.param( { @@ -83,12 +86,9 @@ def do_generate(prompt, port): "tokenizer_id": "openlm-research/open_llama_3b_v2", "settings": CPU_SETTINGS, "batch_sizes": [1, 4], - "prefix_sharing_algorithm": "trie", + "prefix_sharing_algorithm": "none", }, {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - marks=pytest.mark.xfail( - reason="Trie-based prefix sharing not yet supported" - ), ), ], indirect=True, diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index b189284b8..c24234d4d 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -45,6 +45,13 @@ def release_pages(self) -> None: """Releases the allocation's reference to pages.""" pass + @abstractmethod + def extend_allocation(self, tokens, extra_token_slots=0) -> None: + """ + Extends the allocation to include additional tokens. For details, reference the derived class in trie_attention_cache.py. + """ + pass + class BasePageAttentionCacheAllocation(PageAllocation): """Represents a page allocation in the cache.""" @@ -68,6 +75,19 @@ def release_pages(self) -> None: self._cache.page_pool.free_pages(self._pages) self._is_released = True + def extend_allocation(self, tokens, extra_token_slots=0) -> None: + # assert old tokens are a prefix of incoming tokens + # if we don't have enough pages to hold the tokens, we need to allocate more pages + token_count = len(tokens) + extra_token_slots + pages_needed = math.ceil(token_count / self._cache.tokens_per_page) + if pages_needed > len(self._pages): + new_pages = self._cache.page_pool.acquire_free_pages( + pages_needed - len(self._pages) + ) + if new_pages is None: + raise CacheAllocationFailure() + self._pages += tuple(new_pages) + def __rerp__(self) -> str: return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 1c9872cda..aadbca167 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -181,6 +181,60 @@ def release_pages(self) -> None: self.last_cached_node.ref_count -= 1 self._is_released = True + def extend_allocation(self, tokens: List[int], extra_token_slots=0) -> None: + """Extend the current allocation to accommodate additional tokens. + + Args: + tokens: New token sequence to extend the allocation to + + Raises: + ValueError: If new tokens don't extend current allocation's tokens + """ + # Verify new tokens extend current tokens + if len(tokens) < len(self.tokens): + raise ValueError("New tokens must be longer than current tokens") + + # Check that current tokens are a prefix of new tokens + for old_token, new_token in zip(self.tokens, tokens): + if old_token != new_token: + raise ValueError("New tokens must extend current token sequence") + + # If tokens are identical, no extension needed + if len(tokens) == len(self.tokens): + return + + # Calculate how many new pages we need + tokens_per_page = self.cache.tokens_per_page + current_pages = len(self._pages) + total_tokens = len(tokens) + extra_token_slots + total_pages_needed = math.ceil(total_tokens / tokens_per_page) + new_pages_needed = total_pages_needed - current_pages + + if new_pages_needed <= 0: + self.tokens = tokens + return + + # Acquire new pages + new_pages = self.cache.page_pool.acquire_free_pages(new_pages_needed) + + if new_pages is None: + # Try eviction if initial allocation fails + self.cache._evict_pages( + new_pages_needed - len(self.cache.page_pool.available_pages) + ) + new_pages = self.cache.page_pool.acquire_free_pages(new_pages_needed) + + if new_pages is None: + raise CacheAllocationFailure( + "Failed to acquire pages for allocation extension even after attempting eviction" + ) + + # Extend our page list + self._pages.extend(new_pages) + + # Update tokens + self.tokens = tokens + class TriePagedAttentionCache(BasePagedAttentionCache): """Trie-based paged attention cache implementation. diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 8de450882..9e2ab7179 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -52,8 +52,6 @@ def reset(self, phase: InferencePhase): self.return_all_logits = False self.return_host_array = True self.result_logits = None - self.allocation.release_pages() - self.allocation = None def cache_page_indices(self, max_len: int) -> list[int]: if not self.allocation: @@ -70,6 +68,7 @@ def publish_allocated_pages(self, up_to_page_index: int): def free_cache_pages(self): if self.allocation: self.allocation.release_pages() + self.allocation = None class StrobeMessage(sf.Message): diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index ed1be03db..5b43c1310 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -281,21 +281,9 @@ def board_decodes(self, cache: BasePagedAttentionCache): assert decode_request.phase == InferencePhase.DECODE if len(exec_process.exec_requests) >= self.ideal_batch_size: break - incoming_token_count = len(decode_request.input_token_ids) - - try: - allocation = cache.acquire_pages_for_tokens( - decode_request.input_token_ids, - extra_token_slots=1, # need 1 extra slot to write result. - ) - except CacheAllocationFailure: - logger.debug( - "Cannot fulfill request for %d tokens", - len(decode_request.input_token_ids), - ) - - decode_request.free_cache_pages() - decode_request.allocation = allocation + decode_request.allocation.extend_allocation( + decode_request.input_token_ids, extra_token_slots=1 + ) # Can flight this request. exec_process.exec_requests.append(decode_request) From feb3aa9736de09d5c441389039e8b907778d73d5 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 15:39:45 -0800 Subject: [PATCH 19/23] remove xfail because we are now PASSING! --- .../integration_tests/llm/shortfin/cpu_llm_server_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py index 3064961cf..99ce2d802 100644 --- a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py +++ b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py @@ -75,9 +75,6 @@ def do_generate(prompt, port): "prefix_sharing_algorithm": "trie", }, {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - marks=pytest.mark.xfail( - reason="Trie-based prefix sharing not yet supported" - ), ), pytest.param( { From 7e07cc2ba451b8eb7a5f2be52055e197b64f365d Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 3 Dec 2024 16:16:48 -0800 Subject: [PATCH 20/23] fix typo; force certain args to be kwargs --- .../components/kvcache/base_attention_cache.py | 18 +++++++++++------- .../components/kvcache/trie_attention_cache.py | 6 ++++-- .../kvcache/base_attention_cache_test.py | 2 +- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index c24234d4d..39e19ebea 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -34,7 +34,9 @@ def pages(self) -> List[PageInfo]: pass @abstractmethod - def publish_pages_for_tokens(self, tokens, publish_incomplete_page=False) -> None: + def publish_pages_for_tokens( + self, tokens, *, publish_incomplete_page=False + ) -> None: """ Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py. """ @@ -46,14 +48,14 @@ def release_pages(self) -> None: pass @abstractmethod - def extend_allocation(self, tokens, extra_token_slots=0) -> None: + def extend_allocation(self, tokens, *, extra_token_slots=0) -> None: """ Extends the allocation to include additional tokens. For details, reference the derived class in trie_attention_cache.py. """ pass -class BasePageAttentionCacheAllocation(PageAllocation): +class BasePagedAttentionCacheAllocation(PageAllocation): """Represents a page allocation in the cache.""" def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): @@ -65,7 +67,9 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): def pages(self) -> List[PageInfo]: return list(self._pages) - def publish_pages_for_tokens(self, tokens, publish_incomplete_page=False) -> None: + def publish_pages_for_tokens( + self, tokens, *, publish_incomplete_page=False + ) -> None: pass def release_pages(self) -> None: @@ -75,7 +79,7 @@ def release_pages(self) -> None: self._cache.page_pool.free_pages(self._pages) self._is_released = True - def extend_allocation(self, tokens, extra_token_slots=0) -> None: + def extend_allocation(self, tokens, *, extra_token_slots=0) -> None: # assert old tokens are a prefix of incoming tokens # if we don't have enough pages to hold the tokens, we need to allocate more pages token_count = len(tokens) + extra_token_slots @@ -89,7 +93,7 @@ def extend_allocation(self, tokens, extra_token_slots=0) -> None: self._pages += tuple(new_pages) def __rerp__(self) -> str: - return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" + return f"BasePagedAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" class BasePagedAttentionCache: @@ -139,4 +143,4 @@ def acquire_pages_for_tokens( if pages is None: raise CacheAllocationFailure() - return BasePageAttentionCacheAllocation(pages, cache=self) + return BasePagedAttentionCacheAllocation(pages, cache=self) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index aadbca167..d41e35b8b 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -103,7 +103,9 @@ def __init__( def pages(self) -> List[PageInfo]: return self._pages - def publish_pages_for_tokens(self, tokens, publish_incomplete_page=False) -> None: + def publish_pages_for_tokens( + self, tokens, *, publish_incomplete_page=False + ) -> None: """Make pages available in the cache for the specified tokens. Args: @@ -181,7 +183,7 @@ def release_pages(self) -> None: self.last_cached_node.ref_count -= 1 self._is_released = True - def extend_allocation(self, tokens: List[int], extra_token_slots=0) -> None: + def extend_allocation(self, tokens: List[int], *, extra_token_slots=0) -> None: """Extend the current allocation to accommodate additional tokens. Args: diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py index 181cc8d9f..8f2c4c060 100644 --- a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -10,7 +10,7 @@ from shortfin_apps.llm.components.kvcache.base_attention_cache import ( BasePagedAttentionCache, - BasePageAttentionCacheAllocation, + BasePagedAttentionCacheAllocation, CacheAllocationFailure, ) from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PageInfo From d71826fa245c6a146c3e140d4fa4ceb2d1fa6d51 Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 3 Dec 2024 16:29:42 -0800 Subject: [PATCH 21/23] use a reference counter class to replace int for ref_count in trie node --- .../kvcache/trie_attention_cache.py | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index d41e35b8b..7690857ae 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -11,6 +11,26 @@ ) +@dataclass +class RefCount: + """ + A reference counter to replace simple int. + """ + + count: int = 0 + + def increment(self) -> int: + self.count += 1 + return self.count + + def decrement(self) -> int: + self.count -= 1 + return self.count + + def is_empty(self) -> bool: + return self.count <= 0 + + @dataclass class TrieNode: """Node of the block trie for paged attention cache. @@ -32,7 +52,7 @@ class TrieNode: page: PageInfo children: Optional[Dict[Tuple[int, ...], "TrieNode"]] = None parent: Optional["TrieNode"] = None - ref_count: int = 0 + ref_count: RefCount access_time: float = 0.0 def __post_init__(self) -> None: @@ -40,6 +60,7 @@ def __post_init__(self) -> None: if self.children is None: self.children = {} self.access_time = time.monotonic() + self.ref_count = RefCount() def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode": """Create a new child node with the given tokens and page. @@ -165,8 +186,8 @@ def has_common_prefix(tokens1, tokens2): # Update reference counts if unpublished_tokens: - cur_node.ref_count += 1 - self.last_cached_node.ref_count -= 1 + cur_node.ref_count.increment() + self.last_cached_node.ref_count.decrement() self.last_cached_node = cur_node self.number_of_published_pages = number_of_pages_to_publish @@ -180,7 +201,7 @@ def release_pages(self) -> None: if self._is_released: return - self.last_cached_node.ref_count -= 1 + self.last_cached_node.ref_count.decrement() self._is_released = True def extend_allocation(self, tokens: List[int], *, extra_token_slots=0) -> None: @@ -328,7 +349,7 @@ def acquire_pages_for_tokens( tokens = tuple(tokens) cur_node, matched_pages = self._match(tokens) - cur_node.ref_count += 1 + cur_node.ref_count.increment() n_cached_tokens = len(matched_pages) * self.tokens_per_page remaining_length = len(tokens) - n_cached_tokens + extra_token_slots @@ -378,7 +399,9 @@ def _evict_pages(self, max_pages: int) -> int: # Initialize heap with unreferenced leaves unused_leaf_heap = [ - (leaf.access_time, leaf) for leaf in self.leaves if leaf.ref_count == 0 + (leaf.access_time, leaf) + for leaf in self.leaves + if leaf.ref_count.is_empty() ] heapq.heapify(unused_leaf_heap) @@ -393,7 +416,7 @@ def _evict_pages(self, max_pages: int) -> int: # If parent becomes childless, it becomes a leaf if parent is not self.root and not parent.children: self.leaves.add(parent) - if parent.ref_count == 0: + if parent.ref_count.is_empty(): heapq.heappush(unused_leaf_heap, (parent.access_time, parent)) if pages_to_evict: From 871b6225bb579a1f19bbd6c6484fa72ddec70d73 Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 3 Dec 2024 16:49:25 -0800 Subject: [PATCH 22/23] fixed problem with refcount not having a default --- .../llm/components/kvcache/trie_attention_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index 7690857ae..c2a31039d 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -52,7 +52,7 @@ class TrieNode: page: PageInfo children: Optional[Dict[Tuple[int, ...], "TrieNode"]] = None parent: Optional["TrieNode"] = None - ref_count: RefCount + ref_count: RefCount = None access_time: float = 0.0 def __post_init__(self) -> None: From c32be80dae115e39b902b9ada40223161557c2d9 Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 3 Dec 2024 19:42:39 -0800 Subject: [PATCH 23/23] use integer ops instead of float with ceil / floor --- .../llm/components/kvcache/trie_attention_cache.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py index c2a31039d..fbb008005 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/trie_attention_cache.py @@ -154,11 +154,12 @@ def has_common_prefix(tokens1, tokens2): tokens_per_page = self.cache.tokens_per_page - number_of_pages_to_publish = len(tokens) / tokens_per_page if publish_incomplete_page: - number_of_pages_to_publish = math.ceil(number_of_pages_to_publish) + number_of_pages_to_publish = -( + len(tokens) // -tokens_per_page + ) # ceil division else: - number_of_pages_to_publish = math.floor(number_of_pages_to_publish) + number_of_pages_to_publish = len(tokens) // tokens_per_page # Create token blocks for unpublished pages start_token_index = self.number_of_published_pages * tokens_per_page