Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1] EngineArgs for better config handling for v1 #10382

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ steps:
- vllm/
- tests/v1
commands:
- pytest -v -s v1
- VLLM_USE_V1=1 pytest -v -s v1
rickyyx marked this conversation as resolved.
Show resolved Hide resolved

- label: Examples Test # 15min
working_dir: "/vllm-workspace/examples"
Expand Down
3 changes: 3 additions & 0 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ async def generate(engine: AsyncLLM, request_id: str,

@pytest.mark.asyncio
async def test_load(monkeypatch):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

Expand Down
42 changes: 42 additions & 0 deletions tests/v1/engine/test_engine_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from vllm import envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext

if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)


def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")

# Assert V1 defaults
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default"


def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(
UsageContext.LLM_CLASS)

assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 8192

engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048


def test_prefix_cache_disabled_with_multimodel():
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")

vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
assert not vllm_config.cache_config.enable_prefix_caching
3 changes: 2 additions & 1 deletion tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_engine_core(monkeypatch):
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
Expand Down
6 changes: 4 additions & 2 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
Expand Down Expand Up @@ -153,7 +154,8 @@ async def test_engine_core_client_asyncio(monkeypatch):
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
Expand Down
53 changes: 50 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean

if TYPE_CHECKING:
Expand Down Expand Up @@ -113,7 +114,7 @@ class EngineArgs:
# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
block_size: int = 16 if not current_platform.is_hpu() else 128
enable_prefix_caching: bool = False
enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
swap_space: float = 4 # GiB
Expand Down Expand Up @@ -196,6 +197,11 @@ def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model

# Override the default value of enable_prefix_caching if it's not set
# by user.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)

# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
Expand Down Expand Up @@ -936,7 +942,12 @@ def create_load_config(self) -> LoadConfig:
ignore_patterns=self.ignore_patterns,
)

def create_engine_config(self) -> VllmConfig:
def create_engine_config(self,
usage_context: Optional[UsageContext] = None
) -> VllmConfig:
if envs.VLLM_USE_V1:
self._override_v1_engine_args(usage_context)

# gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
Expand Down Expand Up @@ -1146,7 +1157,7 @@ def create_engine_config(self) -> VllmConfig:
or "all" in detailed_trace_modules,
)

return VllmConfig(
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
Expand All @@ -1161,6 +1172,42 @@ def create_engine_config(self) -> VllmConfig:
compilation_config=self.compilation_config,
)

if envs.VLLM_USE_V1:
self._override_v1_engine_config(config)
return config

def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
"""
Override the EngineArgs's args based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"

if self.max_num_batched_tokens is None:
# When no user override, set the default values based on the
# usage context.
if usage_context == UsageContext.LLM_CLASS:
logger.warning("Setting max_num_batched_tokens to 8192 "
"for LLM_CLASS usage context.")
self.max_num_seqs = 1024
self.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
logger.warning("Setting max_num_batched_tokens to 2048 "
"for OPENAI_API_SERVER usage context.")
self.max_num_seqs = 1024
self.max_num_batched_tokens = 2048

def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
"""
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
# TODO (ywang96): Enable APC by default when VLM supports it.
if engine_config.model_config.is_multimodal_model:
logger.warning(
"Prefix caching is currently not supported for multimodal "
"models and has been disabled.")
engine_config.cache_config.enable_prefix_caching = False


@dataclass
class AsyncEngineArgs(EngineArgs):
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def from_engine_args(
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
if engine_config is None:
engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(usage_context)

executor_class = cls._get_executor_cls(engine_config)

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def from_engine_args(
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,
from vllm.plugins import load_general_plugins
load_general_plugins()

engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(usage_context)
executor_class = LLMEngine._get_executor_cls(engine_config)

use_async_sockets = engine_config.model_config.use_async_output_proc
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ async def build_async_engine_client_from_engine_args(
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):

engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def from_engine_args(

# Create the engine configs.
if engine_config is None:
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(usage_context)
else:
vllm_config = engine_config

Expand Down
13 changes: 0 additions & 13 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,6 @@ def __init__(
executor_class: Type[GPUExecutor],
usage_context: UsageContext,
):
# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 2048

# TODO (ywang96): Enable APC by default when VLM supports it.
if not vllm_config.model_config.is_multimodal_model:
vllm_config.cache_config.enable_prefix_caching = True

assert vllm_config.model_config.task != "embedding"

logger.info("Initializing an LLM engine (v%s) with config: %s",
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_engine_args(
"""Creates an LLM engine from the engine arguments."""

# Create the engine configs.
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config)

if VLLM_ENABLE_V1_MULTIPROCESSING:
Expand Down