From 74e42f5c17fd5c65b836f1be80c832f79e678437 Mon Sep 17 00:00:00 2001 From: Sebastian Schoennenbeck Date: Tue, 1 Oct 2024 11:58:06 +0200 Subject: [PATCH] [Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (#8965) Signed-off-by: Alvant --- vllm/engine/async_llm_engine.py | 4 ++++ vllm/engine/multiprocessing/__init__.py | 5 +++++ vllm/engine/multiprocessing/client.py | 20 +++++++++++++---- vllm/engine/protocol.py | 4 +++- vllm/entrypoints/openai/protocol.py | 22 +++++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 1 + vllm/entrypoints/openai/serving_completion.py | 1 + vllm/entrypoints/openai/serving_embedding.py | 1 + 8 files changed, 53 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8d7f6e857d284..a5f27998ecba7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1085,6 +1085,7 @@ async def encode( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model. @@ -1099,6 +1100,8 @@ async def encode( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. Yields: The output `EmbeddingRequestOutput` objects from the LLMEngine @@ -1151,6 +1154,7 @@ async def encode( pooling_params, lora_request=lora_request, trace_headers=trace_headers, + priority=priority, ): yield LLMEngine.validate_output(output, EmbeddingRequestOutput) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 6d6d7895b2101..34c161e9395ae 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -30,6 +30,7 @@ class RPCProcessRequest: lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None + priority: int = 0 @overload # DEPRECATED def __init__( @@ -41,6 +42,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -53,6 +55,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -68,6 +71,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -84,6 +88,7 @@ def __init__( self.lora_request = lora_request self.trace_headers = trace_headers self.prompt_adapter_request = prompt_adapter_request + self.priority = priority @dataclass diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 79da0be97fdbf..b0d061dbab4a1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -380,6 +380,7 @@ def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: ... @@ -392,6 +393,7 @@ def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: ... @@ -407,6 +409,7 @@ def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: @@ -425,6 +428,9 @@ def generate( trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the + scheduling policy is not "priority". """ if inputs is not None: prompt = inputs @@ -433,7 +439,7 @@ def generate( return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, - prompt_adapter_request) + prompt_adapter_request, priority) @overload # DEPRECATED def encode( @@ -444,6 +450,7 @@ def encode( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> AsyncGenerator[EmbeddingRequestOutput, None]: ... @@ -455,6 +462,7 @@ def encode( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> AsyncGenerator[EmbeddingRequestOutput, None]: ... @@ -469,6 +477,7 @@ def encode( request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[EmbeddingRequestOutput, None]: @@ -496,7 +505,7 @@ def encode( and request_id is not None) return self._process_request(prompt, pooling_params, request_id, - lora_request, trace_headers) + lora_request, trace_headers, priority) async def _process_request( self, @@ -505,7 +514,8 @@ async def _process_request( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ EmbeddingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" @@ -550,7 +560,9 @@ async def _process_request( request_id=request_id, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)) + prompt_adapter_request=prompt_adapter_request, + priority=priority, + )) # 3) Send the RPCGenerateRequest to the MQLLMEngine. parts = (request_bytes, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d0bbeb357b506..d7ff743e0ada6 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -40,7 +40,8 @@ def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" ... @@ -52,6 +53,7 @@ def encode( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model.""" ... diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c3101ca2b6900..623f1180bb443 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "If specified, will override the default whitespace pattern " "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) # doc: end-chat-completion-extra-params @@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel): description=( "If specified, will override the default whitespace pattern " "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) # doc: end-completion-extra-params @@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel): # doc: end-embedding-pooling-params + # doc: begin-embedding-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-embedding-extra-params + def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 29a5b11b595c7..41f131f56b51f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -235,6 +235,7 @@ async def create_chat_completion( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, + priority=request.priority, ) except ValueError as e: # TODO: Use a vllm-specific Validation Error diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a0161611288de..59e69121deb9e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -148,6 +148,7 @@ async def create_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=request.priority, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 5d95e1369b884..d6f337a7236d6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -148,6 +148,7 @@ async def create_embedding( pooling_params, request_id_item, lora_request=lora_request, + priority=request.priority, ) generators.append(generator)