Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Update to latest vLLM upstream and Support vLLM on CPU #149

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions models/continuous_batching/cpu/meta-llama--Llama-2-7b-chat-hf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
deployment_config:
autoscaling_config:
min_replicas: 1
initial_replicas: 1
max_replicas: 8
target_num_ongoing_requests_per_replica: 24
metrics_interval_s: 10.0
look_back_period_s: 30.0
smoothing_factor: 0.6
downscale_delay_s: 300.0
upscale_delay_s: 15.0
max_concurrent_queries: 64
ray_actor_options: {}
engine_config:
model_id: meta-llama/Llama-2-7b-chat-hf
hf_model_id: meta-llama/Llama-2-7b-chat-hf
type: VLLMEngine
engine_kwargs:
device: "cpu"
dtype: "float32"
trust_remote_code: true
max_num_batched_tokens: 4096
max_num_seqs: 64
gpu_memory_utilization: 0.95
max_total_tokens: 4096
generation:
prompt_format:
system: "<<SYS>>\n{instruction}\n<</SYS>>\n\n"
assistant: " {instruction} </s><s>"
trailing_assistant: ""
user: "[INST] {system}{instruction} [/INST]"
system_in_user: true
default_system_message: ""
stopping_sequences: []
scaling_config:
num_workers: 1
num_gpus_per_worker: 0
num_cpus_per_worker: 32
placement_strategy: "STRICT_PACK"
resources_per_worker: {}

10 changes: 4 additions & 6 deletions models/continuous_batching/meta-llama--Llama-2-7b-chat-hf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ deployment_config:
downscale_delay_s: 300.0
upscale_delay_s: 15.0
max_concurrent_queries: 64
ray_actor_options:
resources:
accelerator_type_a10: 0.01
ray_actor_options: {}
engine_config:
model_id: meta-llama/Llama-2-7b-chat-hf
hf_model_id: meta-llama/Llama-2-7b-chat-hf
Expand All @@ -34,8 +32,8 @@ engine_config:
stopping_sequences: []
scaling_config:
num_workers: 1
num_gpus_per_worker: 1
num_gpus_per_worker: 0
num_cpus_per_worker: 8
placement_strategy: "STRICT_PACK"
resources_per_worker:
accelerator_type_a10: 0.01
resources_per_worker: {}

4 changes: 2 additions & 2 deletions rayllm/backend/llm/embedding/embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ModelType,
S3MirrorConfig,
)
from pydantic import ConfigDict

logger = logging.getLogger(__name__)

Expand All @@ -20,8 +21,7 @@ class EmbeddingOptimize(str, Enum):


class EmbeddingEngineConfig(EngineConfig):
class Config:
use_enum_values = True
model_config = ConfigDict(use_enum_values=True)

type: EngineType = EngineType.EmbeddingEngine
model_type: ModelType = ModelType.embedding
Expand Down
5 changes: 2 additions & 3 deletions rayllm/backend/llm/trtllm/trtllm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
S3MirrorConfig,
SamplingParams,
)
from pydantic import ConfigDict

try:
from tensorrt_llm.libs import trt_llm_engine_py as trt_py
Expand All @@ -23,9 +24,7 @@ class TRTLLMGPTServeConfig(BaseModelExtended):
max_tokens_in_paged_kv_cache: int = None
kv_cache_free_gpu_mem_fraction: float = None
enable_trt_overlap: bool = None

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

@classmethod
def from_engine_config(
Expand Down
42 changes: 12 additions & 30 deletions rayllm/backend/llm/vllm/vllm_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import logging
import time
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union

import ray
from ray.util.placement_group import PlacementGroup
from transformers.dynamic_module_utils import init_hf_modules
from vllm.config import CacheConfig as VllmCacheConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.config import ParallelConfig as VllmParallelConfig
from vllm.config import SchedulerConfig as VllmSchedulerConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine, AsyncStream, _AsyncLLMEngine

Expand All @@ -32,13 +28,8 @@

logger = logging.getLogger(__name__)

VllmConfigs = Tuple[
VllmCacheConfig, VllmModelConfig, VllmParallelConfig, VllmSchedulerConfig
]


class AviaryLLMEngine(_AsyncLLMEngine):
def __init__(self, *args, runtime_env: dict, **kwargs):
def __init__(self, *args, runtime_env: dict = {}, **kwargs):
self.runtime_env = runtime_env
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -106,15 +97,17 @@ def _record_system_stats(self, *args, **kwargs):
return last_stats


def _get_vllm_engine_config(vllm_app) -> Tuple[AsyncEngineArgs, VllmConfigs]:
# Generate engine arguements and engine configs
def _get_vllm_engine_config(vllm_app) -> AsyncEngineArgs:
device = vllm_app.engine_config.engine_kwargs.get("device", "gpu")

# Generate engine arguements and engine configs
async_engine_args = AsyncEngineArgs(
# This is the local path on disk, or the hf model id
# If it is the hf_model_id, vllm automatically downloads the correct model.
**dict(
model=vllm_app.engine_config.actual_hf_model_id,
worker_use_ray=True,
# vLLM for CPU doesn't support Ray workers for model parallelism yet
worker_use_ray=True if device == "gpu" else False,
engine_use_ray=False,
tensor_parallel_size=vllm_app.placement_config.world_size,
max_model_len=vllm_app.engine_config.max_total_tokens,
Expand All @@ -123,8 +116,8 @@ def _get_vllm_engine_config(vllm_app) -> Tuple[AsyncEngineArgs, VllmConfigs]:
**vllm_app.engine_config.get_initialization_kwargs(),
)
)
configs = async_engine_args.create_engine_configs()
return async_engine_args, configs

return async_engine_args


class AviaryAsyncLLMEngine(AsyncLLMEngine):
Expand Down Expand Up @@ -167,23 +160,12 @@ def from_llm_app(
# torch to have access to CUDA devices. We use a remote task
# with `num_gpus` set here, so the type check happens in an environment
# with `CUDA_VISIBLE_DEVICES` set.
engine_args, engine_configs = ray.get(
engine_args = ray.get(
ray.remote(_get_vllm_engine_config)
.options(**scaling_config)
.remote(vllm_app)
)

# Create the async LLM engine.
engine = cls(
engine_args.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
None,
placement_group,
runtime_env=runtime_env,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=True,
)
engine = cls.from_engine_args(engine_args, start_engine_loop=True)

return engine
4 changes: 3 additions & 1 deletion rayllm/backend/llm/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def __init__(
self.llm_app = llm_app.copy(deep=True)
self.engine_config = llm_app.engine_config
self.placement_config = llm_app.placement_config
if not (self.placement_config.scaling_config.num_gpus_per_worker > 0):

device = self.engine_config.engine_kwargs.get("device", "gpu")
if device == "gpu" and not (self.placement_config.scaling_config.num_gpus_per_worker > 0):
raise ValueError("The VLLM Engine Requires > 0 GPUs to run.")

self.node_initializer = node_initializer or LLMNodeInitializer(
Expand Down
86 changes: 86 additions & 0 deletions rayllm/backend/server/config_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional
from pydantic import (
BaseModel,
Field,
NonNegativeFloat,
NonNegativeInt,
PositiveFloat,
PositiveInt,
field_validator,
)

# Adapted from ray.serve.config.AutoscalingConfig
# Port it here as the original AutoscalingConfig model is pydantic V1
class AutoscalingConfig(BaseModel):
"""Config for the Serve Autoscaler."""

# Please keep these options in sync with those in
# `src/ray/protobuf/serve.proto`.

# Publicly exposed options
min_replicas: NonNegativeInt = 1
initial_replicas: Optional[NonNegativeInt] = None
max_replicas: PositiveInt = 1

# DEPRECATED: replaced by target_ongoing_requests
target_num_ongoing_requests_per_replica: PositiveFloat = Field(
default=1.0,
description="[DEPRECATED] Please use `target_ongoing_requests` instead.",
)
# Will default to 1.0 in the future.
target_ongoing_requests: Optional[PositiveFloat] = None

# How often to scrape for metrics
metrics_interval_s: PositiveFloat = 10.0
# Time window to average over for metrics.
look_back_period_s: PositiveFloat = 30.0

# DEPRECATED
smoothing_factor: PositiveFloat = 1.0
# DEPRECATED: replaced by `downscaling_factor`
upscale_smoothing_factor: Optional[PositiveFloat] = Field(
default=None, description="[DEPRECATED] Please use `upscaling_factor` instead."
)
# DEPRECATED: replaced by `upscaling_factor`
downscale_smoothing_factor: Optional[PositiveFloat] = Field(
default=None,
description="[DEPRECATED] Please use `downscaling_factor` instead.",
)

# Multiplicative "gain" factor to limit scaling decisions
upscaling_factor: Optional[PositiveFloat] = None
downscaling_factor: Optional[PositiveFloat] = None

# How frequently to make autoscaling decisions
# loop_period_s: float = CONTROL_LOOP_PERIOD_S
# How long to wait before scaling down replicas
downscale_delay_s: NonNegativeFloat = 600.0
# How long to wait before scaling up replicas
upscale_delay_s: NonNegativeFloat = 30.0

@field_validator("max_replicas")
def replicas_settings_valid(cls, max_replicas, values):
min_replicas = values.data.get("min_replicas")
initial_replicas = values.data.get("initial_replicas")
if min_replicas is not None and max_replicas < min_replicas:
raise ValueError(
f"max_replicas ({max_replicas}) must be greater than "
f"or equal to min_replicas ({min_replicas})!"
)

if initial_replicas is not None:
if initial_replicas < min_replicas:
raise ValueError(
f"min_replicas ({min_replicas}) must be less than "
f"or equal to initial_replicas ({initial_replicas})!"
)
elif initial_replicas > max_replicas:
raise ValueError(
f"max_replicas ({max_replicas}) must be greater than "
f"or equal to initial_replicas ({initial_replicas})!"
)

return max_replicas

def __init__(self, **kwargs):
super().__init__(**kwargs)
Loading