From c3efa257f5c699f0e45a368c6417a8f9879dbbb2 Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 20 Nov 2024 22:32:58 +0000 Subject: [PATCH 1/6] new Signed-off-by: rickyx --- .buildkite/test-pipeline.yaml | 2 +- tests/v1/engine/test_async_llm.py | 3 ++ tests/v1/engine/test_engine_args.py | 19 ++++++++ tests/v1/engine/test_engine_core.py | 3 +- tests/v1/engine/test_engine_core_client.py | 3 +- vllm/engine/arg_utils.py | 53 +++++++++++++++++++++- vllm/engine/llm_engine.py | 2 +- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/core.py | 13 ------ vllm/v1/engine/llm_engine.py | 2 +- 10 files changed, 82 insertions(+), 20 deletions(-) create mode 100644 tests/v1/engine/test_engine_args.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 24bf223fb12c0..545b253c07db0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -171,7 +171,7 @@ steps: - vllm/ - tests/v1 commands: - - pytest -v -s v1 + - VLLM_USE_V1=1 pytest -v -s v1 - label: Examples Test # 15min working_dir: "/vllm-workspace/examples" diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 1f26fe0fc892f..fffb5b8100ec7 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -32,6 +32,9 @@ async def generate(engine: AsyncLLM, request_id: str, @pytest.mark.asyncio async def test_load(monkeypatch): + # TODO(rickyx): Remove monkeypatch once we have a better way to test V1 + # so that in the future when we switch, we don't have to change all the + # tests. with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py new file mode 100644 index 0000000000000..b9ec030b4a052 --- /dev/null +++ b/tests/v1/engine/test_engine_args.py @@ -0,0 +1,19 @@ +import pytest + +from vllm import envs +from vllm.engine.arg_utils import EngineArgs + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +def test_v1_defaults(): + engine_args = EngineArgs(model="facebook/opt-125m") + + # Assert V1 defaults + assert engine_args.enable_prefix_caching + assert engine_args.max_num_seqs == 1024 + assert engine_args.max_num_batched_tokens is None diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index b3692b594326a..bd11ff1877064 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -43,7 +43,8 @@ def test_engine_core(monkeypatch): m.setenv("VLLM_USE_V1", "1") """Setup the EngineCore.""" engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 7b241bf836a0e..2cf2b786e12f0 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -153,7 +153,8 @@ async def test_engine_core_client_asyncio(monkeypatch): m.setenv("VLLM_USE_V1", "1") engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) client = EngineCoreClient.make_client( vllm_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a3ae1889774f3..5d88052c38fed 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform from vllm.transformers_utils.utils import check_gguf_file +from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, StoreBoolean if TYPE_CHECKING: @@ -936,7 +937,9 @@ def create_load_config(self) -> LoadConfig: ignore_patterns=self.ignore_patterns, ) - def create_engine_config(self) -> VllmConfig: + def create_engine_config(self, + usage_context: Optional[UsageContext] = None + ) -> VllmConfig: # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -1162,6 +1165,54 @@ def create_engine_config(self) -> VllmConfig: ) +@dataclass +class EngineArgsV1(EngineArgs): + """Arguments for vLLM engine v1.""" + + # V1's default values that differ from the default values in EngineArgs. + # This allows to switch between V1 and V0's default behaviour transparently. + enable_prefix_caching: bool = True + max_num_seqs: int = 1024 + max_num_batched_tokens: Optional[int] = None + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser = EngineArgs.add_cli_args(parser) + return parser + + def create_engine_config(self, + usage_context: Optional[UsageContext] = None + ) -> VllmConfig: + assert (usage_context + is not None), "usage_context must be provided for EngineArgsV1" + + if self.max_num_batched_tokens is None: + if usage_context == UsageContext.LLM_CLASS: + logger.warning("Setting max_num_batched_tokens to 8192 " + "for LLM_CLASS usage context.") + self.max_num_batched_tokens = 8192 + elif usage_context == UsageContext.OPENAI_API_SERVER: + logger.warning("Setting max_num_batched_tokens to 2048 " + "for OPENAI_API_SERVER usage context.") + self.max_num_batched_tokens = 2048 + + engine_config = super().create_engine_config(usage_context) + + # TODO (ywang96): Enable APC by default when VLM supports it. + if engine_config.model_config.is_multimodal_model: + logger.warning( + "Prefix caching is currently not supported for multimodal " + "models and has been disabled.") + engine_config.cache_config.enable_prefix_caching = False + return engine_config + + +if envs.VLLM_USE_V1: + # Overwrite EngineArgs to use EngineArgsV1 + # This has to be done before `AsyncEngineArgs` is imported. + EngineArgs = EngineArgsV1 # type: ignore + + @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2a5eaf1340762..8cf5c4c308c47 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -579,7 +579,7 @@ def from_engine_args( ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config(usage_context) executor_class = cls._get_executor_cls(engine_config) # Create the LLM engine. engine = cls( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 09bff9655a882..b5428bc82f742 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -94,7 +94,7 @@ def from_engine_args( # Create the engine configs. if engine_config is None: - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config(usage_context) else: vllm_config = engine_config diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 35ed131d50de9..495c4e3222649 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -39,19 +39,6 @@ def __init__( executor_class: Type[GPUExecutor], usage_context: UsageContext, ): - # Override the configs for V1. - # FIXME - if usage_context == UsageContext.LLM_CLASS: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 2048 - - # TODO (ywang96): Enable APC by default when VLM supports it. - if not vllm_config.model_config.is_multimodal_model: - vllm_config.cache_config.enable_prefix_caching = True - assert vllm_config.model_config.task != "embedding" logger.info("Initializing an LLM engine (v%s) with config: %s", diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 75a77be750acd..7a5482f03b6fa 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,7 +82,7 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config(usage_context) executor_class = cls._get_executor_cls(vllm_config) if VLLM_ENABLE_V1_MULTIPROCESSING: From 376a7d2d10cffbca22227a5dcdf58432472e8a50 Mon Sep 17 00:00:00 2001 From: rickyx Date: Thu, 21 Nov 2024 23:54:53 +0000 Subject: [PATCH 2/6] fix Signed-off-by: rickyx --- tests/v1/engine/test_engine_args.py | 28 +++++++++++++++++++--- tests/v1/engine/test_engine_core_client.py | 3 ++- vllm/engine/arg_utils.py | 14 +++++++---- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index b9ec030b4a052..7876c487fc14c 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -1,7 +1,9 @@ import pytest from vllm import envs +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs +from vllm.usage.usage_lib import UsageContext if not envs.VLLM_USE_V1: pytest.skip( @@ -10,10 +12,30 @@ ) -def test_v1_defaults(): +def test_defaults(): engine_args = EngineArgs(model="facebook/opt-125m") # Assert V1 defaults assert engine_args.enable_prefix_caching - assert engine_args.max_num_seqs == 1024 - assert engine_args.max_num_batched_tokens is None + + +def test_defaults_with_usage_context(): + engine_args = EngineArgs(model="facebook/opt-125m") + vllm_config: VllmConfig = engine_args.create_engine_config( + UsageContext.LLM_CLASS) + + assert vllm_config.scheduler_config.max_num_seqs == 1024 + assert vllm_config.scheduler_config.max_num_batched_tokens == 8192 + + engine_args = EngineArgs(model="facebook/opt-125m") + vllm_config = engine_args.create_engine_config( + UsageContext.OPENAI_API_SERVER) + assert vllm_config.scheduler_config.max_num_seqs == 1024 + assert vllm_config.scheduler_config.max_num_batched_tokens == 2048 + + +def test_prefix_cache_disabled_with_multimodel(): + engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf") + + vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) + assert not vllm_config.cache_config.enable_prefix_caching diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 2cf2b786e12f0..ee7c64138264c 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -82,7 +82,8 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): m.setenv("VLLM_USE_V1", "1") engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config( + UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) client = EngineCoreClient.make_client( vllm_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5d88052c38fed..043370a74104e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -84,7 +84,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: @dataclass -class EngineArgs: +class _EngineArgs: """Arguments for vLLM engine.""" model: str = 'facebook/opt-125m' served_model_name: Optional[Union[str, List[str]]] = None @@ -1166,18 +1166,16 @@ def create_engine_config(self, @dataclass -class EngineArgsV1(EngineArgs): +class EngineArgsV1(_EngineArgs): """Arguments for vLLM engine v1.""" # V1's default values that differ from the default values in EngineArgs. # This allows to switch between V1 and V0's default behaviour transparently. enable_prefix_caching: bool = True - max_num_seqs: int = 1024 - max_num_batched_tokens: Optional[int] = None @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser = EngineArgs.add_cli_args(parser) + parser = _EngineArgs.add_cli_args(parser) return parser def create_engine_config(self, @@ -1187,13 +1185,17 @@ def create_engine_config(self, is not None), "usage_context must be provided for EngineArgsV1" if self.max_num_batched_tokens is None: + # When no user override, set the default values based on the + # usage context. if usage_context == UsageContext.LLM_CLASS: logger.warning("Setting max_num_batched_tokens to 8192 " "for LLM_CLASS usage context.") + self.max_num_seqs = 1024 self.max_num_batched_tokens = 8192 elif usage_context == UsageContext.OPENAI_API_SERVER: logger.warning("Setting max_num_batched_tokens to 2048 " "for OPENAI_API_SERVER usage context.") + self.max_num_seqs = 1024 self.max_num_batched_tokens = 2048 engine_config = super().create_engine_config(usage_context) @@ -1207,6 +1209,8 @@ def create_engine_config(self, return engine_config +EngineArgs = _EngineArgs # type: ignore + if envs.VLLM_USE_V1: # Overwrite EngineArgs to use EngineArgsV1 # This has to be done before `AsyncEngineArgs` is imported. From b15d4f9abdd1295a12dfb284906dd37faf11f299 Mon Sep 17 00:00:00 2001 From: rickyx Date: Fri, 22 Nov 2024 08:25:01 +0000 Subject: [PATCH 3/6] up Signed-off-by: rickyx --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/multiprocessing/engine.py | 2 +- vllm/entrypoints/openai/api_server.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a5388708b1c6..3224577c567f8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -680,7 +680,7 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. if engine_config is None: - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config(usage_context) executor_class = cls._get_executor_cls(engine_config) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 7de23643a2e1c..49a90b321dac4 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -111,7 +111,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, from vllm.plugins import load_general_plugins load_general_plugins() - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config(usage_context) executor_class = LLMEngine._get_executor_cls(engine_config) use_async_sockets = engine_config.model_config.use_async_output_proc diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b0fe061f5db4a..0751a60f524a3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -133,8 +133,8 @@ async def build_async_engine_client_from_engine_args( # TODO: fill out feature matrix. if (MQLLMEngineClient.is_unsupported_config(engine_args) or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): - - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config( + UsageContext.OPENAI_API_SERVER) uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), "uses_ray", False) From a102588e87409b0fd4acd9336bcd2ddcb3a69654 Mon Sep 17 00:00:00 2001 From: rickyx Date: Sun, 24 Nov 2024 01:04:18 +0000 Subject: [PATCH 4/6] comments Signed-off-by: rickyx --- vllm/engine/arg_utils.py | 50 +++++++++++++++------------------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 043370a74104e..e4c6020518e05 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -84,7 +84,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: @dataclass -class _EngineArgs: +class EngineArgs: """Arguments for vLLM engine.""" model: str = 'facebook/opt-125m' served_model_name: Optional[Union[str, List[str]]] = None @@ -114,7 +114,7 @@ class _EngineArgs: # NOTE(kzawora): default block size for Gaudi should be 128 # smaller sizes still work, but very inefficiently block_size: int = 16 if not current_platform.is_hpu() else 128 - enable_prefix_caching: bool = False + enable_prefix_caching: bool = bool(envs.VLLM_USE_V1) disable_sliding_window: bool = False use_v2_block_manager: bool = True swap_space: float = 4 # GiB @@ -940,6 +940,9 @@ def create_load_config(self) -> LoadConfig: def create_engine_config(self, usage_context: Optional[UsageContext] = None ) -> VllmConfig: + if envs.VLLM_USE_V1: + self._override_v1_args(usage_context) + # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -1149,7 +1152,7 @@ def create_engine_config(self, or "all" in detailed_trace_modules, ) - return VllmConfig( + config = VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -1164,25 +1167,15 @@ def create_engine_config(self, compilation_config=self.compilation_config, ) + if envs.VLLM_USE_V1: + config = self._override_v1_configs(config) + return config -@dataclass -class EngineArgsV1(_EngineArgs): - """Arguments for vLLM engine v1.""" - - # V1's default values that differ from the default values in EngineArgs. - # This allows to switch between V1 and V0's default behaviour transparently. - enable_prefix_caching: bool = True - - @staticmethod - def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser = _EngineArgs.add_cli_args(parser) - return parser - - def create_engine_config(self, - usage_context: Optional[UsageContext] = None - ) -> VllmConfig: - assert (usage_context - is not None), "usage_context must be provided for EngineArgsV1" + def _override_v1_args(self, usage_context: UsageContext): + """ + Override the EngineArgs's args based on the usage context for V1. + """ + assert envs.VLLM_USE_V1, "V1 is not enabled" if self.max_num_batched_tokens is None: # When no user override, set the default values based on the @@ -1198,8 +1191,11 @@ def create_engine_config(self, self.max_num_seqs = 1024 self.max_num_batched_tokens = 2048 - engine_config = super().create_engine_config(usage_context) - + def _override_v1_configs(self, engine_config: VllmConfig): + """ + Override the EngineConfig's configs based on the usage context for V1. + """ + assert envs.VLLM_USE_V1, "V1 is not enabled" # TODO (ywang96): Enable APC by default when VLM supports it. if engine_config.model_config.is_multimodal_model: logger.warning( @@ -1209,14 +1205,6 @@ def create_engine_config(self, return engine_config -EngineArgs = _EngineArgs # type: ignore - -if envs.VLLM_USE_V1: - # Overwrite EngineArgs to use EngineArgsV1 - # This has to be done before `AsyncEngineArgs` is imported. - EngineArgs = EngineArgsV1 # type: ignore - - @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" From 26a654071864f1f1c15d7f343cbbb887140cfcf2 Mon Sep 17 00:00:00 2001 From: rickyx Date: Sun, 24 Nov 2024 01:11:13 +0000 Subject: [PATCH 5/6] nits Signed-off-by: rickyx --- vllm/engine/arg_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e4c6020518e05..03043d00ab1be 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -941,7 +941,7 @@ def create_engine_config(self, usage_context: Optional[UsageContext] = None ) -> VllmConfig: if envs.VLLM_USE_V1: - self._override_v1_args(usage_context) + self._override_v1_engine_args(usage_context) # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): @@ -1168,10 +1168,10 @@ def create_engine_config(self, ) if envs.VLLM_USE_V1: - config = self._override_v1_configs(config) + self._override_v1_engine_config(config) return config - def _override_v1_args(self, usage_context: UsageContext): + def _override_v1_engine_args(self, usage_context: UsageContext) -> None: """ Override the EngineArgs's args based on the usage context for V1. """ @@ -1191,7 +1191,7 @@ def _override_v1_args(self, usage_context: UsageContext): self.max_num_seqs = 1024 self.max_num_batched_tokens = 2048 - def _override_v1_configs(self, engine_config: VllmConfig): + def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: """ Override the EngineConfig's configs based on the usage context for V1. """ @@ -1202,7 +1202,6 @@ def _override_v1_configs(self, engine_config: VllmConfig): "Prefix caching is currently not supported for multimodal " "models and has been disabled.") engine_config.cache_config.enable_prefix_caching = False - return engine_config @dataclass From 31927bbb643ca3fac244e2f58bdcb105aa60d7d2 Mon Sep 17 00:00:00 2001 From: rickyx Date: Sun, 24 Nov 2024 01:55:35 +0000 Subject: [PATCH 6/6] comments Signed-off-by: rickyx --- tests/v1/engine/test_engine_args.py | 3 ++- vllm/engine/arg_utils.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index 7876c487fc14c..69cfdf5a395c1 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -16,7 +16,8 @@ def test_defaults(): engine_args = EngineArgs(model="facebook/opt-125m") # Assert V1 defaults - assert engine_args.enable_prefix_caching + assert (engine_args.enable_prefix_caching + ), "V1 turns on prefix caching by default" def test_defaults_with_usage_context(): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 03043d00ab1be..002b67e635bf4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -114,7 +114,7 @@ class EngineArgs: # NOTE(kzawora): default block size for Gaudi should be 128 # smaller sizes still work, but very inefficiently block_size: int = 16 if not current_platform.is_hpu() else 128 - enable_prefix_caching: bool = bool(envs.VLLM_USE_V1) + enable_prefix_caching: Optional[bool] = None disable_sliding_window: bool = False use_v2_block_manager: bool = True swap_space: float = 4 # GiB @@ -197,6 +197,11 @@ def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model + # Override the default value of enable_prefix_caching if it's not set + # by user. + if self.enable_prefix_caching is None: + self.enable_prefix_caching = bool(envs.VLLM_USE_V1) + # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins()