Skip to content

Commit

Permalink
Merge branch 'main' into users/stbaione/sgl-benchmark-add-baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
stbaione authored Dec 2, 2024
2 parents ed37ef1 + 8cd3f85 commit d1e434f
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 81 deletions.
4 changes: 3 additions & 1 deletion sharktank/tests/types/quantizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def testPerAxisRoundtrip(self):
)
ssq = self._roundtrip(ssq, "_ssq")
self.assertEqual(ssq.axis, 1)
torch.testing.assert_close(ssq.scale, torch.tensor([0.2, 0.4, 0.8]))
torch.testing.assert_close(
ssq.scale, torch.tensor([0.2, 0.4, 0.8], dtype=torch.float32)
)
torch.testing.assert_close(ssq.reciprocal_scale, torch.tensor([5.0, 2.5, 1.25]))
self.assertIs(ssq.dtype, torch.float16)

Expand Down
7 changes: 2 additions & 5 deletions shortfin/build_tools/build_linux_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,8 @@ REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)"
SCRIPT_NAME="$(basename $0)"
ARCH="$(uname -m)"

# TODO(#130): Update to manylinux_2_28, upstream or a fork
# * upstream uses a version of gcc that has build warnings/errors
# * https://github.com/nod-ai/base-docker-images is a bit out of date but can include a recent clang
# MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux_2_28_${ARCH}:latest}"
MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux2014_${ARCH}:latest}"
# Note: we can switch to https://github.com/nod-ai/base-docker-images as needed for extra deps.
MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux_2_28_${ARCH}:latest}"
PYTHON_VERSIONS="${OVERRIDE_PYTHON_VERSIONS:-cp311-cp311 cp312-cp312 cp313-cp313}"
OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}"

Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,66 @@
Base class for kv caches.
"""

from typing import List
from typing import List, Iterable, Protocol
from .page_pool import PageInfo
import math
from abc import ABC, abstractmethod
from .page_pool import PagePool

# logging
import logging

logger = logging.getLogger(__name__)

# exception for when cache allocation failed
class CacheAllocationFailure(Exception):
pass


class PageAllocation(ABC):
"""Abstract base class for page allocations in the cache."""

@property
@abstractmethod
def pages(self) -> List[PageInfo]:
"""Returns the list of pages that were allocated."""
pass

@abstractmethod
def publish_pages(self, up_to_page_index) -> None:
"""Makes pages[0:up_to_page_index] available to other requests."""
pass

@abstractmethod
def release_pages(self) -> None:
"""Releases the allocation's reference to pages."""
pass


class BasePageAttentionCacheAllocation(PageAllocation):
"""Represents a page allocation in the cache."""

def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
self._pages = tuple(pages)
self._cache = cache
self._is_released = False

@property
def pages(self) -> List[PageInfo]:
return list(self._pages)

def publish_pages(self, up_to_page_index) -> None:
pass

def release_pages(self) -> None:
if self._is_released:
logger.warning("Releasing already-released allocation")
return
self._cache.page_pool.release_pages(self._pages)
self._is_released = True

def __rerp__(self) -> str:
return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"


class BasePagedAttentionCache:
Expand All @@ -33,13 +90,13 @@ class BasePagedAttentionCache:
- Reference counting prevents eviction of in-use pages
"""

def __init__(self, page_pool, tokens_per_page):
def __init__(self, page_pool: PagePool, tokens_per_page: int):
self.page_pool = page_pool
self.tokens_per_page = tokens_per_page

def acquire_pages_for_tokens(
self, tokens: List[int], extra_token_slots: int = 1
) -> tuple[list[PageInfo], int]:
) -> PageAllocation:
"""
Given a list of tokens, return a list of pages and a start position to continue generation from.
Expand All @@ -57,24 +114,7 @@ def acquire_pages_for_tokens(
pages_needed = math.ceil(token_count / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0

return pages, n_cached_tokens

def publish_pages(self, tokens, pages) -> None:
"""
Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests.
Associates the tokens with the pages, and mark them as done writing.
It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
"""

pass # the base implementation doesn't cache unfinished requests.
if pages is None:
raise CacheAllocationFailure()

def release_pages(self, tokens, pages):
"""
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
"""
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
self.page_pool.release_pages(pages)
return BasePageAttentionCacheAllocation(pages, cache=self)
37 changes: 12 additions & 25 deletions shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shortfin as sf
import shortfin.array as sfnp

from .kvcache.base_attention_cache import BasePagedAttentionCache
from .kvcache.base_attention_cache import BasePagedAttentionCache, PageAllocation
from .kvcache.page_pool import PageInfo


Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]):

# Cache pages that have been locked for this request.
self._cache: BasePagedAttentionCache | None = None
self.locked_pages: list[PageInfo] | None = None
self.allocation: PageAllocation | None = None

def reset(self, phase: InferencePhase):
"""Resets all per request state in preparation for an subsequent execution."""
Expand All @@ -52,35 +52,22 @@ def reset(self, phase: InferencePhase):
self.return_all_logits = False
self.return_host_array = True
self.result_logits = None
self.allocation.release_pages()
self.allocation = None

def cache_page_indices(self, max_len: int) -> list[int]:
if not self.locked_pages:
if not self.allocation:
return []
indices = [p.index for p in self.locked_pages]
if len(indices) > max_len:
return indices[0:max_len]
indices = [p.index for p in self.allocation.pages[:max_len]]
return indices

def publish_allocated_pages(self, up_to_page_index: int):
assert self.allocation
self.allocation.publish_pages(up_to_page_index)

def free_cache_pages(self):
cache = self._cache
if cache:
pages = self.locked_pages
self._cache = None
self.locked_pages = None
cache.release_pages(self.input_token_ids, pages)

def lock_initial_cache_pages(
self, cache: BasePagedAttentionCache, pages: list[PageInfo]
):
assert not self._cache
self._cache = cache
self.locked_pages = pages

def lock_new_cache_pages(
self, cache: BasePagedAttentionCache, pages: list[PageInfo]
):
assert self._cache is cache
self.locked_pages.extend(pages)
if self.allocation:
self.allocation.release_pages()


class StrobeMessage(sf.Message):
Expand Down
59 changes: 32 additions & 27 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import shortfin as sf
import shortfin.array as sfnp

from .kvcache.base_attention_cache import BasePagedAttentionCache
from .kvcache.page_pool import PagePoolConfig, PagePool
from .kvcache.base_attention_cache import (
BasePagedAttentionCache,
CacheAllocationFailure,
PageAllocation,
)
from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo
from .config_struct import ModelParams
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
Expand Down Expand Up @@ -229,16 +233,17 @@ def board_prefills(self, cache: BasePagedAttentionCache):
len(prefill_request.input_token_ids) / self.page_seq_stride
)
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
if pages is None:
try:
allocation = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
except CacheAllocationFailure:
logger.debug("Cannot fulfill request for %d pages", needed_pages)
continue
else:
logger.debug("Allocated %d cache pages to request", len(pages))
prefill_request.lock_initial_cache_pages(cache, pages)
logger.debug(f"Successfully acquired allocation: {allocation}")
prefill_request.free_cache_pages()
prefill_request.allocation = allocation

# Can flight this request.
exec_process.exec_requests.append(prefill_request)
Expand Down Expand Up @@ -266,26 +271,20 @@ def board_decodes(self, cache: BasePagedAttentionCache):
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
incoming_token_count = len(decode_request.input_token_ids)
needed_pages = math.ceil(
(decode_request.start_position + incoming_token_count)
/ self.page_seq_stride
)
if needed_pages > len(decode_request.locked_pages):
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(

try:
allocation = cache.acquire_pages_for_tokens(
decode_request.input_token_ids,
extra_token_slots=1, # need 1 extra slot to write result.
)
if pages is None:
logger.debug(
"Cannot fulfill decode request for %d pages", needed_pages
)
continue
else:
logger.debug(
"Allocated %d cache pages to decode request", len(pages)
)
decode_request.lock_new_cache_pages(cache, pages)
except CacheAllocationFailure:
logger.debug(
"Cannot fulfill request for %d tokens",
len(decode_request.input_token_ids),
)

decode_request.free_cache_pages()
decode_request.allocation = allocation

# Can flight this request.
exec_process.exec_requests.append(decode_request)
Expand Down Expand Up @@ -438,6 +437,12 @@ async def run(self):
# Invoke. Logits are of shape [bs, bsl, d].
(logits,) = await fn(*args, fiber=self.fiber)

# publish cache pages
for r in self.exec_requests:
total_tokens = r.start_position + len(r.input_token_ids)
number_of_complete_pages = total_tokens // seq_stride
r.publish_allocated_pages(number_of_complete_pages)

# Return results.
for i in range(req_count):
req = self.exec_requests[i]
Expand Down
Loading

0 comments on commit d1e434f

Please sign in to comment.