Skip to content

Commit

Permalink
rename cache_type to prefix_sharing_algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Dec 2, 2024
1 parent c774c0d commit 926d483
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
6 changes: 5 additions & 1 deletion app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ def model_test_dir(request, tmp_path_factory):
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": "none",
},
}
logger.info(f"Saving edited config to: {edited_config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class PagedKVCacheParams:
# Size of the cache on each device.
device_block_count: int

cache_type: str = "base" # currently supporting base and trie
prefix_sharing_algorithm: str = "none" # currently supporting none and trie


@dataclass_json(undefined=Undefined.RAISE)
Expand Down
6 changes: 3 additions & 3 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ def __init__(
page_pool = PagePool(
devices=self.main_fiber.devices_dict.values(), config=page_pool_config
)
if model_params.paged_kv_cache.cache_type == "trie":
if model_params.paged_kv_cache.prefix_sharing_algorithm == "trie":
self.page_cache = TriePagedAttentionCache(
page_pool=page_pool,
tokens_per_page=model_params.paged_kv_cache.block_seq_stride,
)
elif model_params.paged_kv_cache.cache_type == "base":
elif model_params.paged_kv_cache.prefix_sharing_algorithm == "none":
self.page_cache = BasePagedAttentionCache(
page_pool=page_pool,
tokens_per_page=model_params.paged_kv_cache.block_seq_stride,
)
else:
raise ValueError(
f"Unknown model_params.paged_kv_cache.cache_type {model_params.paged_kv_cache.cache_type}. Currently only supporting 'trie' and 'base'."
f"Unknown model_params.paged_kv_cache.prefix_sharing_algorithm {model_params.paged_kv_cache.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'."
)

self.program_isolation = PROG_ISOLATIONS[program_isolation]
Expand Down

0 comments on commit 926d483

Please sign in to comment.