Skip to content

Commit

Permalink
Implement PageAllocation as a handle into a PagedAttentionCache, allo…
Browse files Browse the repository at this point in the history
…wing publishing and releasing an allocation via handle rather than cache (#608)

Deinitialization looks wonky for now. Will test extensively to get
deinit right once I merge #600

Closes #607
  • Loading branch information
renxida authored Dec 2, 2024
1 parent b89814e commit 8cd3f85
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 75 deletions.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,66 @@
Base class for kv caches.
"""

from typing import List
from typing import List, Iterable, Protocol
from .page_pool import PageInfo
import math
from abc import ABC, abstractmethod
from .page_pool import PagePool

# logging
import logging

logger = logging.getLogger(__name__)

# exception for when cache allocation failed
class CacheAllocationFailure(Exception):
pass


class PageAllocation(ABC):
"""Abstract base class for page allocations in the cache."""

@property
@abstractmethod
def pages(self) -> List[PageInfo]:
"""Returns the list of pages that were allocated."""
pass

@abstractmethod
def publish_pages(self, up_to_page_index) -> None:
"""Makes pages[0:up_to_page_index] available to other requests."""
pass

@abstractmethod
def release_pages(self) -> None:
"""Releases the allocation's reference to pages."""
pass


class BasePageAttentionCacheAllocation(PageAllocation):
"""Represents a page allocation in the cache."""

def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
self._pages = tuple(pages)
self._cache = cache
self._is_released = False

@property
def pages(self) -> List[PageInfo]:
return list(self._pages)

def publish_pages(self, up_to_page_index) -> None:
pass

def release_pages(self) -> None:
if self._is_released:
logger.warning("Releasing already-released allocation")
return
self._cache.page_pool.release_pages(self._pages)
self._is_released = True

def __rerp__(self) -> str:
return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"


class BasePagedAttentionCache:
Expand All @@ -33,13 +90,13 @@ class BasePagedAttentionCache:
- Reference counting prevents eviction of in-use pages
"""

def __init__(self, page_pool, tokens_per_page):
def __init__(self, page_pool: PagePool, tokens_per_page: int):
self.page_pool = page_pool
self.tokens_per_page = tokens_per_page

def acquire_pages_for_tokens(
self, tokens: List[int], extra_token_slots: int = 1
) -> tuple[list[PageInfo], int]:
) -> PageAllocation:
"""
Given a list of tokens, return a list of pages and a start position to continue generation from.
Expand All @@ -57,24 +114,7 @@ def acquire_pages_for_tokens(
pages_needed = math.ceil(token_count / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0

return pages, n_cached_tokens

def publish_pages(self, tokens, pages) -> None:
"""
Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests.
Associates the tokens with the pages, and mark them as done writing.
It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
"""

pass # the base implementation doesn't cache unfinished requests.
if pages is None:
raise CacheAllocationFailure()

def release_pages(self, tokens, pages):
"""
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
"""
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
self.page_pool.release_pages(pages)
return BasePageAttentionCacheAllocation(pages, cache=self)
37 changes: 12 additions & 25 deletions shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shortfin as sf
import shortfin.array as sfnp

from .kvcache.base_attention_cache import BasePagedAttentionCache
from .kvcache.base_attention_cache import BasePagedAttentionCache, PageAllocation
from .kvcache.page_pool import PageInfo


Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]):

# Cache pages that have been locked for this request.
self._cache: BasePagedAttentionCache | None = None
self.locked_pages: list[PageInfo] | None = None
self.allocation: PageAllocation | None = None

def reset(self, phase: InferencePhase):
"""Resets all per request state in preparation for an subsequent execution."""
Expand All @@ -52,35 +52,22 @@ def reset(self, phase: InferencePhase):
self.return_all_logits = False
self.return_host_array = True
self.result_logits = None
self.allocation.release_pages()
self.allocation = None

def cache_page_indices(self, max_len: int) -> list[int]:
if not self.locked_pages:
if not self.allocation:
return []
indices = [p.index for p in self.locked_pages]
if len(indices) > max_len:
return indices[0:max_len]
indices = [p.index for p in self.allocation.pages[:max_len]]
return indices

def publish_allocated_pages(self, up_to_page_index: int):
assert self.allocation
self.allocation.publish_pages(up_to_page_index)

def free_cache_pages(self):
cache = self._cache
if cache:
pages = self.locked_pages
self._cache = None
self.locked_pages = None
cache.release_pages(self.input_token_ids, pages)

def lock_initial_cache_pages(
self, cache: BasePagedAttentionCache, pages: list[PageInfo]
):
assert not self._cache
self._cache = cache
self.locked_pages = pages

def lock_new_cache_pages(
self, cache: BasePagedAttentionCache, pages: list[PageInfo]
):
assert self._cache is cache
self.locked_pages.extend(pages)
if self.allocation:
self.allocation.release_pages()


class StrobeMessage(sf.Message):
Expand Down
59 changes: 32 additions & 27 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import shortfin as sf
import shortfin.array as sfnp

from .kvcache.base_attention_cache import BasePagedAttentionCache
from .kvcache.page_pool import PagePoolConfig, PagePool
from .kvcache.base_attention_cache import (
BasePagedAttentionCache,
CacheAllocationFailure,
PageAllocation,
)
from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo
from .config_struct import ModelParams
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
Expand Down Expand Up @@ -229,16 +233,17 @@ def board_prefills(self, cache: BasePagedAttentionCache):
len(prefill_request.input_token_ids) / self.page_seq_stride
)
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
if pages is None:
try:
allocation = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
except CacheAllocationFailure:
logger.debug("Cannot fulfill request for %d pages", needed_pages)
continue
else:
logger.debug("Allocated %d cache pages to request", len(pages))
prefill_request.lock_initial_cache_pages(cache, pages)
logger.debug(f"Successfully acquired allocation: {allocation}")
prefill_request.free_cache_pages()
prefill_request.allocation = allocation

# Can flight this request.
exec_process.exec_requests.append(prefill_request)
Expand Down Expand Up @@ -266,26 +271,20 @@ def board_decodes(self, cache: BasePagedAttentionCache):
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
incoming_token_count = len(decode_request.input_token_ids)
needed_pages = math.ceil(
(decode_request.start_position + incoming_token_count)
/ self.page_seq_stride
)
if needed_pages > len(decode_request.locked_pages):
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(

try:
allocation = cache.acquire_pages_for_tokens(
decode_request.input_token_ids,
extra_token_slots=1, # need 1 extra slot to write result.
)
if pages is None:
logger.debug(
"Cannot fulfill decode request for %d pages", needed_pages
)
continue
else:
logger.debug(
"Allocated %d cache pages to decode request", len(pages)
)
decode_request.lock_new_cache_pages(cache, pages)
except CacheAllocationFailure:
logger.debug(
"Cannot fulfill request for %d tokens",
len(decode_request.input_token_ids),
)

decode_request.free_cache_pages()
decode_request.allocation = allocation

# Can flight this request.
exec_process.exec_requests.append(decode_request)
Expand Down Expand Up @@ -438,6 +437,12 @@ async def run(self):
# Invoke. Logits are of shape [bs, bsl, d].
(logits,) = await fn(*args, fiber=self.fiber)

# publish cache pages
for r in self.exec_requests:
total_tokens = r.start_position + len(r.input_token_ids)
number_of_complete_pages = total_tokens // seq_stride
r.publish_allocated_pages(number_of_complete_pages)

# Return results.
for i in range(req_count):
req = self.exec_requests[i]
Expand Down
Loading

0 comments on commit 8cd3f85

Please sign in to comment.