Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TriePagedAttentionCache #632

Merged
merged 23 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a1946b8
add unit tests and fix numerous small problems
renxida Nov 28, 2024
8437a7f
all tests passing
renxida Nov 28, 2024
24a4313
all but publishing working
renxida Nov 28, 2024
5c87b2b
naming consistency Page -> Paged
renxida Nov 28, 2024
07d3040
passing all tests now
renxida Nov 28, 2024
8f83118
name and documentation update publish_pages -> publish_pages_for_tokens
renxida Nov 28, 2024
5c37666
undo accidental edit of messages.py
renxida Nov 28, 2024
802c4e2
remove not-very-useful test case
renxida Nov 28, 2024
3c7cc6c
fix missing rename
renxida Dec 2, 2024
f59485e
use trie by default
renxida Dec 2, 2024
646c167
add errmsg for unknown kvcache type
renxida Dec 2, 2024
5801d27
fix typoes
renxida Dec 2, 2024
7f6e5a3
trie doesn't work yet reverting to base for default
renxida Dec 2, 2024
ca29397
rename cache_type to prefix_sharing_algorithm
renxida Dec 2, 2024
7abaa7c
add another test case for trie and xfail it
renxida Dec 2, 2024
1e35634
fix xpass by fixing a place where i forgot to pass prefix_sharing_alg…
renxida Dec 2, 2024
1a419f1
organize test cases
renxida Dec 2, 2024
85bc5fc
add an extend_allocation method to cache allocations and make decode …
renxida Dec 2, 2024
feb3aa9
remove xfail because we are now PASSING!
renxida Dec 2, 2024
7e07cc2
fix typo; force certain args to be kwargs
renxida Dec 4, 2024
d71826f
use a reference counter class to replace int for ref_count in trie node
renxida Dec 4, 2024
871b622
fixed problem with refcount not having a default
renxida Dec 4, 2024
c32be80
use integer ops instead of float with ceil / floor
renxida Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def model_test_dir(request, tmp_path_factory):
tokenizer_id = request.param["tokenizer_id"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]

tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")
hf_home = os.environ.get("HF_HOME", None)
Expand Down Expand Up @@ -83,7 +84,11 @@ def model_test_dir(request, tmp_path_factory):
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
logger.info(f"Saving edited config to: {edited_config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
Expand Down
16 changes: 14 additions & 2 deletions app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,28 @@ def do_generate(prompt, port):
@pytest.mark.parametrize(
"model_test_dir,llm_server",
[
(
pytest.param(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
"prefix_sharing_algorithm": "trie",
},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
)
),
pytest.param(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
"prefix_sharing_algorithm": "none",
},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
),
],
indirect=True,
)
Expand Down
2 changes: 2 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class PagedKVCacheParams:
# Size of the cache on each device.
device_block_count: int

prefix_sharing_algorithm: str = "none" # currently supporting none and trie


@dataclass_json(undefined=Undefined.RAISE)
@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,28 @@ def pages(self) -> List[PageInfo]:
pass

@abstractmethod
def publish_pages(self, up_to_page_index) -> None:
"""Makes pages[0:up_to_page_index] available to other requests."""
def publish_pages_for_tokens(
self, tokens, *, publish_incomplete_page=False
) -> None:
"""
Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py.
"""
pass

@abstractmethod
def release_pages(self) -> None:
"""Releases the allocation's reference to pages."""
pass

@abstractmethod
def extend_allocation(self, tokens, *, extra_token_slots=0) -> None:
"""
Extends the allocation to include additional tokens. For details, reference the derived class in trie_attention_cache.py.
"""
pass


class BasePageAttentionCacheAllocation(PageAllocation):
class BasePagedAttentionCacheAllocation(PageAllocation):
"""Represents a page allocation in the cache."""

def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
Expand All @@ -56,18 +67,33 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
def pages(self) -> List[PageInfo]:
return list(self._pages)

def publish_pages(self, up_to_page_index) -> None:
def publish_pages_for_tokens(
self, tokens, *, publish_incomplete_page=False
) -> None:
pass

def release_pages(self) -> None:
if self._is_released:
logger.warning("Releasing already-released allocation")
return
self._cache.page_pool.release_pages(self._pages)
self._cache.page_pool.free_pages(self._pages)
self._is_released = True

def extend_allocation(self, tokens, *, extra_token_slots=0) -> None:
# assert old tokens are a prefix of incoming tokens
# if we don't have enough pages to hold the tokens, we need to allocate more pages
token_count = len(tokens) + extra_token_slots
pages_needed = math.ceil(token_count / self._cache.tokens_per_page)
if pages_needed > len(self._pages):
new_pages = self._cache.page_pool.acquire_free_pages(
pages_needed - len(self._pages)
)
if new_pages is None:
raise CacheAllocationFailure()
self._pages += tuple(new_pages)

def __rerp__(self) -> str:
return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"
return f"BasePagedAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"


class BasePagedAttentionCache:
Expand Down Expand Up @@ -117,4 +143,4 @@ def acquire_pages_for_tokens(
if pages is None:
raise CacheAllocationFailure()

return BasePageAttentionCacheAllocation(pages, cache=self)
return BasePagedAttentionCacheAllocation(pages, cache=self)
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
for i in range(self.config.alloc_page_count)
]

self.attn_page_free = list(self.attn_page_entries)
self.available_pages = list(self.attn_page_entries)

# Initialize a page table on each device.
page_table_shape = [
Expand All @@ -108,14 +108,14 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig

def acquire_free_pages(self, count: int) -> list[PageInfo] | None:
with self._lock:
available = len(self.attn_page_free)
available = len(self.available_pages)
if count > available:
return None
return [self.attn_page_free.pop() for _ in range(count)]
return [self.available_pages.pop() for _ in range(count)]

def release_pages(self, pages: list[PageInfo]):
def free_pages(self, pages: list[PageInfo]):
with self._lock:
self.attn_page_free.extend(pages)
self.available_pages.extend(pages)

def copy_page(self, src_page: PageInfo) -> PageInfo:
"""
Expand Down Expand Up @@ -148,7 +148,7 @@ def copy_page(self, src_page: PageInfo) -> PageInfo:

def __repr__(self):
# No need to lock for repr (list is internally synchronized).
free_pages = len(self.attn_page_free)
free_pages = len(self.available_pages)
total_pages = len(self.attn_page_entries)
return (
f"PagePool({total_pages - free_pages}/{total_pages} pages in use: "
Expand Down
Loading
Loading