From 73390bf0867aafa5693568e9539e623eb5181170 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 30 Jul 2024 17:24:24 +0300 Subject: [PATCH 01/25] Remove redundant torch.device --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index cbe9ebf35f4dd..bbe49655020da 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -276,7 +276,7 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): - with torch.device(torch.device(device_config.device)): + with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) From 1546cd5c41ec0d39b07f34b8ee0e0cb89d2d6d76 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 31 Jul 2024 19:22:49 +0300 Subject: [PATCH 02/25] Add multiprocessing HPU executor --- vllm/engine/llm_engine.py | 4 + vllm/executor/habana_executor.py | 9 + vllm/executor/multiproc_hpu_executor.py | 272 ++++++++++++++++++++++++ 3 files changed, 285 insertions(+) create mode 100644 vllm/executor/multiproc_hpu_executor.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f7e0a7a4dc53..208b89e02de72 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -410,6 +410,10 @@ def _get_executor_cls(cls, initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_habana_executor import RayHabanaExecutor executor_class = RayHabanaExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_hpu_executor import ( + MultiprocessingHPUExecutor) + executor_class = MultiprocessingHPUExecutor else: from vllm.executor.habana_executor import HabanaExecutor executor_class = HabanaExecutor diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index f5cf26b687053..50be225d0a9be 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -18,6 +18,15 @@ logger = init_logger(__name__) +def create_worker(worker_module_name, worker_class_name, **kwargs): + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + ) + wrapper.init_worker(**kwargs) + return wrapper.worker + + class HabanaExecutor(ExecutorBase): uses_ray: bool = False diff --git a/vllm/executor/multiproc_hpu_executor.py b/vllm/executor/multiproc_hpu_executor.py new file mode 100644 index 0000000000000..b1bd49e279652 --- /dev/null +++ b/vllm/executor/multiproc_hpu_executor.py @@ -0,0 +1,272 @@ +import asyncio +import os +import signal +import threading +import weakref +from functools import partial +from typing import Any, List, Optional + +import torch + +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.habana_executor import create_worker +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.triton_utils import maybe_set_triton_cache_manager +from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, + error_on_invalid_device_count_status, + get_distributed_init_method, get_open_port, + get_vllm_instance_id, make_async, + update_environment_variables) + +logger = init_logger(__name__) + + +class MultiprocessingHPUExecutor(DistributedGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + uses_ray: bool = False + + def _init_executor(self) -> None: + # Create the parallel GPU workers. + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) + + # workaround for https://github.com/vllm-project/vllm/issues/6103 + if world_size > 1: + maybe_set_triton_cache_manager() + + cuda_device_count = torch.hpu.device_count() + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= cuda_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + assert world_size <= cuda_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + + error_on_invalid_device_count_status() + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + self.workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[ProcessWorkerWrapper] = [] + + if world_size == 1: + self.worker_monitor = None + else: + result_handler = ResultHandler() + for rank in range(1, world_size): + worker = ProcessWorkerWrapper( + result_handler, + partial( + create_worker, + **self._get_create_worker_kwargs( + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + ))) + self.workers.append(worker) + if rank % tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + # Set up signal handlers to shutdown the executor cleanly + # sometimes gc does not work well + + # Use weakref to avoid holding a reference to self + ref = weakref.ref(self) + + def shutdown(signum, frame): + if executor := ref(): + executor.shutdown() + + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGINT, shutdown) + signal.signal(signal.SIGTERM, shutdown) + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """z + return self.driver_worker.execute_model(execute_model_req) + + def _get_create_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + worker_kwargs = self._get_worker_kwargs(local_rank, rank, + distributed_init_method) + worker_kwargs.update(worker_module_name="vllm.worker.habana_worker", + worker_class_name="HabanaWorker") + return worker_kwargs + + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + if async_run_tensor_parallel_workers_only: + # Run only non-driver workers and just return futures. + return [ + worker.execute_method(method, *args, **kwargs) + for worker in self.non_driver_workers + ] + + # Start all remote workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + +class MultiprocessingHPUExecutorAsync(MultiprocessingHPUExecutor, + DistributedGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) + self.pp_locks: Optional[List[asyncio.Lock]] = None + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if not self.tp_driver_workers: + return await self.driver_exec_model(execute_model_req) + + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_model, self.pp_locks[0], + execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method_async, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) From 5b5dab1371a18783ac274b650fd85d2ad3e0d702 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 31 Jul 2024 19:25:12 +0300 Subject: [PATCH 03/25] typo --- vllm/executor/multiproc_hpu_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/executor/multiproc_hpu_executor.py b/vllm/executor/multiproc_hpu_executor.py index b1bd49e279652..1ac623a7bc1af 100644 --- a/vllm/executor/multiproc_hpu_executor.py +++ b/vllm/executor/multiproc_hpu_executor.py @@ -154,7 +154,7 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. - """z + """ return self.driver_worker.execute_model(execute_model_req) def _get_create_worker_kwargs( From b0c4d4cf5bbce8c7a9d37b9a45a8f5a26c5ef989 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Thu, 1 Aug 2024 22:51:00 +0300 Subject: [PATCH 04/25] Add Gaudi documentation for 1.17 --- .../getting_started/gaudi-installation.rst | 142 +++++++++++++++++- 1 file changed, 140 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index a9f3ebdf274f6..e450db6c3b1a1 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -112,22 +112,122 @@ Gaudi2 devices. Configurations that are not listed may or may not work. - `meta-llama/Meta-Llama-3-8B-Instruct `__ on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B-Instruct `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling - `meta-llama/Llama-2-70b `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Llama-2-70b-chat-hf `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Meta-Llama-3-70B `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- `meta-llama/Meta-Llama-3-70B-Instruct `__ +- `meta-llama/Meta-Llama-3-70B-Instruct `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B-Instruct `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `mistralai/Mistral-7B-Instruct-v0.3 `__ on single HPU or with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling - `mistralai/Mixtral-8x7B-Instruct-v0.1 `__ with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling -Performance Tips +Performance Tuning ================ +Execution modes +------------ + +TODO: t.compile, hpugraphs, lazy and eager + +Bucketing mechanism +------------ + +Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler `__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. +In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occuring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. + +.. note:: + Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. + +Whenever executing vLLM on HPU, the following log can be observed: +.. code-block:: + + INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + +In this scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. + +.. warning:: + If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. + +As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket. + +.. note:: + Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. + + +TODO: how buckets are determined, here are some notes: + + """Read bucketing configuration from env variables. + + phase is either 'prompt' or 'decode' + dim is either 'bs' or 'block' + param is either 'min', 'step' or 'max' + example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 + """ + + """Generate a warmup range. + + Start from bmin and multiply by 2 until you reach bstep. + Then, increase the values in the range by the value of bstep until you + reach bmax. + + Example: + bmin = 2, bstep = 32, bmax = 64 + => ramp_up = (2, 4, 8, 16) + => stable = (32, 64) + => return ramp_up + stable => (2, 4, 8, 16, 32, 64) + """ + + +Warmup +------------ + +Warmup is an optional, but highly recommended step occuring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundries during server runtime. Each warmup step is logged during vLLM startup: + +.. code-block:: + + INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB + INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB + INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB + ... + INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB + INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB + INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB + ... + INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + +This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. Whenever bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. + +.. tip:: + Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. + +HPUGraph capture +------------ + +TODO: VLLM_GRAPH_MEM_MARGIN, how and why, couple of sentences about mem allocations + + +Recommended vLLM Parameters +------------ + - We recommend running inference on Gaudi 2 with ``block_size`` of 128 for BF16 data type. Using default values (16, 32) might lead to sub-optimal performance due to Matrix Multiplication Engine @@ -137,6 +237,44 @@ Performance Tips of 128 or 256 and max context length of 2048 with HPU Graphs enabled. If you encounter out-of-memory issues, see troubleshooting section. +Environment variables +------------ + +vLLM for HPU supports following environment variables for performance tuning: + +- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default +- ``VLLM_GRAPH_MEM_MARGIN``: TODO +- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default +- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default +- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism + + - ``{phase}`` is either ``PROMPT`` or ``DECODE`` + - ``{dim}`` is either ``BS`` or ``SEQ`` + - ``{param}`` is either ``MIN``, ``STEP`` or ``MAX`` + - Default values: + + - Prompt: + - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` + - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` + - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` + + - Decode: + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` + - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` + - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048`` + + +Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: + +- ``PT_HPU_LAZY_MODE``: if ``0``, PyTorch Eager backend for Gaudi will be used, if ``1`` PyTorch Lazy backend for Gaudi will be used, ``1`` is default +- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPUGraphs + Troubleshooting: Tweaking HPU Graphs ==================================== From 739067dd0401343aa0bbb2b285fbdc0ea608f967 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 6 Aug 2024 18:00:07 +0300 Subject: [PATCH 05/25] update docs --- .../getting_started/gaudi-installation.rst | 100 ++++++++++++++---- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index e450db6c3b1a1..382a7229605a3 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -141,8 +141,31 @@ Performance Tuning Execution modes ------------ -TODO: t.compile, hpugraphs, lazy and eager +Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag. + +.. list-table:: Title + :widths: 25 25 50 + :header-rows: 1 + + * - ``PT_HPU_LAZY_MODE`` + - ``enforce_eager`` + - execution mode + * - 0 + - 0 + - torch.compile + * - 0 + - 1 + - PyTorch eager mode + * - 1 + - 0 + - HPU Graphs + * - 1 + - 1 + - PyTorch lazy mode +.. warning:: + Currently all modes utilizing +``PT_HPU_LAZY_MODE`` Bucketing mechanism ------------ @@ -170,29 +193,34 @@ As an example, if a request of 3 sequences, with max sequence length of 412 come .. note:: Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. +Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: -TODO: how buckets are determined, here are some notes: - - """Read bucketing configuration from env variables. +.. code-block:: - phase is either 'prompt' or 'decode' - dim is either 'bs' or 'block' - param is either 'min', 'step' or 'max' - example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 - """ + INFO 08-02 15:30:53 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-02 15:30:53 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-02 15:30:53 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-02 15:30:53 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - """Generate a warmup range. +``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes. - Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you - reach bmax. +Example (with ramp-up) - Example: - bmin = 2, bstep = 32, bmax = 64 +.. code-block:: + + min = 2, step = 32, max = 64 => ramp_up = (2, 4, 8, 16) => stable = (32, 64) - => return ramp_up + stable => (2, 4, 8, 16, 32, 64) - """ + => buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) + +Example (without ramp-up) + +.. code-block:: + + min = 128, step = 128, max = 512 + => ramp_up = () + => stable = (128, 256, 384, 512) + => buckets = ramp_up + stable => (128, 256, 384, 512) Warmup @@ -219,11 +247,41 @@ This example uses the same buckets as in *Bucketing mechanism* section. Each out .. tip:: Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. -HPUGraph capture +TODO: HPU Graph capture ------------ -TODO: VLLM_GRAPH_MEM_MARGIN, how and why, couple of sentences about mem allocations - +`HPU Graphs ` are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. +TODO +.. code-block:: + INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-02 17:37:44 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:37:52 habana_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 habana_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 habana_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache + INFO 08-02 17:37:54 habana_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 + INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB + ... + INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5) + INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + ... + INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB + INFO 08-02 17:38:27 habana_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB + ... + INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB + INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB + INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB + INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB + INFO 08-02 17:38:43 habana_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB + INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] + INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory + INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) Recommended vLLM Parameters ------------ @@ -243,7 +301,7 @@ Environment variables vLLM for HPU supports following environment variables for performance tuning: - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default -- ``VLLM_GRAPH_MEM_MARGIN``: TODO +- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture - ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default - ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default - ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism From 625796be276561156c7c3bb9fdd7d33a27d5f54b Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 6 Aug 2024 18:46:34 +0300 Subject: [PATCH 06/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 382a7229605a3..3b9163f4a94ce 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -164,8 +164,9 @@ Currently in vLLM for HPU we support four execution modes, depending on selected - PyTorch lazy mode .. warning:: - Currently all modes utilizing -``PT_HPU_LAZY_MODE`` + In 1.17.0, all modes utilizing ``PT_HPU_LAZY_MODE=0`` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.17.0, please use HPUGraphs, or PyTorch lazy mode. + + Bucketing mechanism ------------ From 254dab32c6a71a8e98d6d51819fef29560e660c0 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 6 Aug 2024 18:47:53 +0300 Subject: [PATCH 07/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 3b9163f4a94ce..50ebe5f98bbaa 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -143,7 +143,7 @@ Execution modes Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag. -.. list-table:: Title +.. list-table:: vLLM execution modes :widths: 25 25 50 :header-rows: 1 From 0ccb294129107bfe471655a830814f8f640ece42 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 6 Aug 2024 18:49:39 +0300 Subject: [PATCH 08/25] update link --- docs/source/getting_started/gaudi-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 50ebe5f98bbaa..c4048643de8c1 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -251,7 +251,7 @@ This example uses the same buckets as in *Bucketing mechanism* section. Each out TODO: HPU Graph capture ------------ -`HPU Graphs ` are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. +`HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. TODO .. code-block:: INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] From f998014b97e5d3b199e1c9f6cb172ccb583d3934 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:06:59 +0300 Subject: [PATCH 09/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index c4048643de8c1..e1c3f2b265432 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -248,11 +248,21 @@ This example uses the same buckets as in *Bucketing mechanism* section. Each out .. tip:: Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. -TODO: HPU Graph capture +HPU Graph capture ------------ `HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. -TODO + + +Whenever HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. + + +.. note:: + ``VLLM_GRAPH_PROMPT_RATIO`` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * ``VLLM_GRAPH_PROMPT_RATIO``) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. + + +Each described step is logged by vLLM server, as follows: + .. code-block:: INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] From 26e485d4389d4b4f47afcd1f78d842cee6bf113a Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:07:49 +0300 Subject: [PATCH 10/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index e1c3f2b265432..c7a604d6f72da 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -263,7 +263,7 @@ Whenever HPU Graphs are being used, they share the common memory pool ("usable m Each described step is logged by vLLM server, as follows: -.. code-block:: +.. code-block:: INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] @@ -294,6 +294,7 @@ Each described step is logged by vLLM server, as follows: INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) + Recommended vLLM Parameters ------------ From 2ef373ddf7427d10d97890dd8c6d284fc39f543e Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:08:33 +0300 Subject: [PATCH 11/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index c7a604d6f72da..0f546410bfa30 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -257,13 +257,14 @@ HPU Graph capture Whenever HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. -.. note:: +.. note:: ``VLLM_GRAPH_PROMPT_RATIO`` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * ``VLLM_GRAPH_PROMPT_RATIO``) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. Each described step is logged by vLLM server, as follows: .. code-block:: + INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] From ee01a08055e6c652b08418ef7848ce38b1a23636 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:11:10 +0300 Subject: [PATCH 12/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 0f546410bfa30..7a3c4e3c1a6fa 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -164,7 +164,7 @@ Currently in vLLM for HPU we support four execution modes, depending on selected - PyTorch lazy mode .. warning:: - In 1.17.0, all modes utilizing ``PT_HPU_LAZY_MODE=0`` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.17.0, please use HPUGraphs, or PyTorch lazy mode. + In 1.17.0, all modes utilizing ``PT_HPU_LAZY_MODE=0`` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.17.0, please use HPU Graphs, or PyTorch lazy mode. Bucketing mechanism @@ -177,6 +177,7 @@ In a dynamic inference serving scenario, there is a need to minimize the number Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. Whenever executing vLLM on HPU, the following log can be observed: + .. code-block:: INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] @@ -254,7 +255,7 @@ HPU Graph capture `HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. -Whenever HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. +Whenever HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. .. note:: @@ -344,7 +345,7 @@ vLLM for HPU supports following environment variables for performance tuning: Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: - ``PT_HPU_LAZY_MODE``: if ``0``, PyTorch Eager backend for Gaudi will be used, if ``1`` PyTorch Lazy backend for Gaudi will be used, ``1`` is default -- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPUGraphs +- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPU Graphs Troubleshooting: Tweaking HPU Graphs ==================================== From ae2409398a41aacd51d558fce4cbce6aab43f9f3 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:14:15 +0300 Subject: [PATCH 13/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 7a3c4e3c1a6fa..e095f2acb8fd9 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -312,9 +312,10 @@ Recommended vLLM Parameters Environment variables ------------ -vLLM for HPU supports following environment variables for performance tuning: +vLLM for HPU supports following environment variables: - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default +- ``VLLM_PROFILED_ENABLED``: if ``true``, high level profiler will be enabled, ``false``. Resulting JSON traces can be viewed in `perfetto.habana.ai `__ - ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture - ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default - ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default From b1f9b0a219a9c3d86aded4c6949b32622e0bfd4f Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:19:25 +0300 Subject: [PATCH 14/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index e095f2acb8fd9..4290d635f82c3 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -314,8 +314,16 @@ Environment variables vLLM for HPU supports following environment variables: +Diagnostic knobs: + +- ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai `__. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS``. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. +Performance knobs: + - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default -- ``VLLM_PROFILED_ENABLED``: if ``true``, high level profiler will be enabled, ``false``. Resulting JSON traces can be viewed in `perfetto.habana.ai `__ - ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture - ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default - ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default From fa6c5caf612b7fc806f712da0d4914e00073e3fc Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:20:40 +0300 Subject: [PATCH 15/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 4290d635f82c3..bfc7c831fae5c 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -312,16 +312,15 @@ Recommended vLLM Parameters Environment variables ------------ -vLLM for HPU supports following environment variables: - -Diagnostic knobs: +Diagnostic and profiling knobs: - ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai `__. Disabled by default. - ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS``. Disabled by default. - ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. - ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. - ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. -Performance knobs: + +Performance tuning knobs: - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default - ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture @@ -333,7 +332,7 @@ Performance knobs: - ``{dim}`` is either ``BS`` or ``SEQ`` - ``{param}`` is either ``MIN``, ``STEP`` or ``MAX`` - Default values: - + - Prompt: - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` @@ -341,7 +340,7 @@ Performance knobs: - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` - + - Decode: - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` From cdd2839dcecf7a9e60b5fccc1e01c1a0da8e4763 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:21:49 +0300 Subject: [PATCH 16/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index bfc7c831fae5c..93a164e9ecd39 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -312,7 +312,7 @@ Recommended vLLM Parameters Environment variables ------------ -Diagnostic and profiling knobs: +**Diagnostic and profiling knobs:** - ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai `__. Disabled by default. - ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS``. Disabled by default. @@ -320,7 +320,7 @@ Diagnostic and profiling knobs: - ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. - ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. -Performance tuning knobs: +**Performance tuning knobs:** - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default - ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture From 9dd2457e549404ba7118a2b801dabdbdb97892a3 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:25:59 +0300 Subject: [PATCH 17/25] Revert "Add multiprocessing HPU executor" This reverts commit 1546cd5c41ec0d39b07f34b8ee0e0cb89d2d6d76. --- vllm/engine/llm_engine.py | 4 - vllm/executor/habana_executor.py | 9 - vllm/executor/multiproc_hpu_executor.py | 272 ------------------------ 3 files changed, 285 deletions(-) delete mode 100644 vllm/executor/multiproc_hpu_executor.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 208b89e02de72..3f7e0a7a4dc53 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -410,10 +410,6 @@ def _get_executor_cls(cls, initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_habana_executor import RayHabanaExecutor executor_class = RayHabanaExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_hpu_executor import ( - MultiprocessingHPUExecutor) - executor_class = MultiprocessingHPUExecutor else: from vllm.executor.habana_executor import HabanaExecutor executor_class = HabanaExecutor diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index 50be225d0a9be..f5cf26b687053 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -18,15 +18,6 @@ logger = init_logger(__name__) -def create_worker(worker_module_name, worker_class_name, **kwargs): - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - ) - wrapper.init_worker(**kwargs) - return wrapper.worker - - class HabanaExecutor(ExecutorBase): uses_ray: bool = False diff --git a/vllm/executor/multiproc_hpu_executor.py b/vllm/executor/multiproc_hpu_executor.py deleted file mode 100644 index 1ac623a7bc1af..0000000000000 --- a/vllm/executor/multiproc_hpu_executor.py +++ /dev/null @@ -1,272 +0,0 @@ -import asyncio -import os -import signal -import threading -import weakref -from functools import partial -from typing import Any, List, Optional - -import torch - -from vllm.executor.distributed_gpu_executor import ( # yapf: disable - DistributedGPUExecutor, DistributedGPUExecutorAsync) -from vllm.executor.habana_executor import create_worker -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) -from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.triton_utils import maybe_set_triton_cache_manager -from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, - error_on_invalid_device_count_status, - get_distributed_init_method, get_open_port, - get_vllm_instance_id, make_async, - update_environment_variables) - -logger = init_logger(__name__) - - -class MultiprocessingHPUExecutor(DistributedGPUExecutor): - """Python multiprocessing-based multi-GPU executor""" - - uses_ray: bool = False - - def _init_executor(self) -> None: - # Create the parallel GPU workers. - world_size = self.parallel_config.world_size - tensor_parallel_size = self.parallel_config.tensor_parallel_size - - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - - # Disable torch async compiling which won't work with daemonic processes - os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - - # Configure thread parallelism if OMP_NUM_THREADS isn't set - # - # Helps to avoid CPU contention. The default of spawning a thread per - # core combined with multiprocessing for each GPU can have a negative - # impact on performance. The contention is amplified when running in a - # container where CPU limits can cause throttling. - default_omp_num_threads = 1 - if "OMP_NUM_THREADS" not in os.environ and ( - current_parallelism := - torch.get_num_threads()) > default_omp_num_threads: - logger.warning( - "Reducing Torch parallelism from %d threads to %d to avoid " - "unnecessary CPU contention. Set OMP_NUM_THREADS in the " - "external environment to tune this value as needed.", - current_parallelism, default_omp_num_threads) - os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) - torch.set_num_threads(default_omp_num_threads) - - # workaround for https://github.com/vllm-project/vllm/issues/6103 - if world_size > 1: - maybe_set_triton_cache_manager() - - cuda_device_count = torch.hpu.device_count() - # Use confusing message for more common TP-only case. - assert tensor_parallel_size <= cuda_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - assert world_size <= cuda_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") - - error_on_invalid_device_count_status() - - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # 127.0.0.1 for communication. - distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - - self.workers: List[ProcessWorkerWrapper] = [] - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: List[ProcessWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: List[ProcessWorkerWrapper] = [] - - if world_size == 1: - self.worker_monitor = None - else: - result_handler = ResultHandler() - for rank in range(1, world_size): - worker = ProcessWorkerWrapper( - result_handler, - partial( - create_worker, - **self._get_create_worker_kwargs( - rank=rank, - local_rank=rank, - distributed_init_method=distributed_init_method, - ))) - self.workers.append(worker) - if rank % tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) - else: - self.non_driver_workers.append(worker) - - self.worker_monitor = WorkerMonitor(self.workers, result_handler) - result_handler.start() - self.worker_monitor.start() - - # Set up signal handlers to shutdown the executor cleanly - # sometimes gc does not work well - - # Use weakref to avoid holding a reference to self - ref = weakref.ref(self) - - def shutdown(signum, frame): - if executor := ref(): - executor.shutdown() - - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGINT, shutdown) - signal.signal(signal.SIGTERM, shutdown) - - self.driver_worker = self._create_worker( - distributed_init_method=distributed_init_method) - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - - def shutdown(self): - if (worker_monitor := getattr(self, "worker_monitor", - None)) is not None: - worker_monitor.close() - - def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - return self.driver_worker.execute_model(execute_model_req) - - def _get_create_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None): - worker_kwargs = self._get_worker_kwargs(local_rank, rank, - distributed_init_method) - worker_kwargs.update(worker_module_name="vllm.worker.habana_worker", - worker_class_name="HabanaWorker") - return worker_kwargs - - def _run_workers( - self, - method: str, - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. - - Args: - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - """ - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - if async_run_tensor_parallel_workers_only: - # Run only non-driver workers and just return futures. - return [ - worker.execute_method(method, *args, **kwargs) - for worker in self.non_driver_workers - ] - - # Start all remote workers first. - worker_outputs = [ - worker.execute_method(method, *args, **kwargs) - for worker in self.workers - ] - - driver_worker_method = getattr(self.driver_worker, method) - driver_worker_output = driver_worker_method(*args, **kwargs) - - # Get the results of the workers. - return [driver_worker_output - ] + [output.get() for output in worker_outputs] - - def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - if self.worker_monitor is not None and not self.worker_monitor.is_alive( - ): - raise RuntimeError("Worker processes are not running") - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - for result in parallel_worker_tasks: - result.get() - - -class MultiprocessingHPUExecutorAsync(MultiprocessingHPUExecutor, - DistributedGPUExecutorAsync): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.driver_exec_model = make_async(self.driver_worker.execute_model) - self.pp_locks: Optional[List[asyncio.Lock]] = None - - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - if not self.tp_driver_workers: - return await self.driver_exec_model(execute_model_req) - - if self.pp_locks is None: - # This locks each pipeline parallel stage so multiple virtual - # engines can't execute on the same stage at the same time - # We create the locks here to avoid creating them in the constructor - # which uses a different asyncio loop. - self.pp_locks = [ - asyncio.Lock() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - tasks = [ - asyncio.create_task( - _run_task_with_lock(self.driver_exec_model, self.pp_locks[0], - execute_model_req)) - ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, - start=1): - tasks.append( - asyncio.create_task( - _run_task_with_lock(driver_worker.execute_method_async, - self.pp_locks[pp_rank], - "execute_model", execute_model_req))) - results = await asyncio.gather(*tasks) - - # Only the last PP stage has the final results. - return results[-1] - - async def _start_worker_execution_loop(self): - coros = [ - worker.execute_method_async("start_worker_execution_loop") - for worker in self.non_driver_workers - ] - return await asyncio.gather(*coros) From f4de97bdac18959bb4a77073ce6fa5f04c0e7bfb Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:27:40 +0300 Subject: [PATCH 18/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 93a164e9ecd39..a26f11a0c45ee 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -176,7 +176,7 @@ In a dynamic inference serving scenario, there is a need to minimize the number .. note:: Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. -Whenever executing vLLM on HPU, the following log can be observed: +When executing vLLM on HPU, the following log can be observed: .. code-block:: From 6a3a0c3af6388f2d8a9748308251ec6588e34007 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 12 Aug 2024 18:30:13 +0300 Subject: [PATCH 19/25] update docs --- docs/source/getting_started/gaudi-installation.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index a26f11a0c45ee..56c7e643bbc70 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -18,7 +18,7 @@ Requirements - OS: Ubuntu 22.04 LTS - Python: 3.10 - Intel Gaudi accelerator -- Intel Gaudi software version 1.16.0 or newer +- Intel Gaudi software version 1.17.0 To verify that the Intel Gaudi software was correctly installed, run: @@ -44,8 +44,8 @@ Use the following commands to run a Docker image: .. code:: console - $ docker pull vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + $ docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest + $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest Build and Install vLLM --------------------------- From 64262a2dbff798e3d70f8529508dfcb54670c290 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 14:56:59 +0300 Subject: [PATCH 20/25] address cr --- .../getting_started/gaudi-installation.rst | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 56c7e643bbc70..9c946376e51b8 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -124,9 +124,9 @@ Gaudi2 devices. Configurations that are not listed may or may not work. with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Meta-Llama-3-70B `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- `meta-llama/Meta-Llama-3-70B-Instruct `__ +- `meta-llama/Meta-Llama-3-70B-Instruct `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- `meta-llama/Meta-Llama-3.1-70B `__ +- `meta-llama/Meta-Llama-3.1-70B `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Meta-Llama-3.1-70B-Instruct `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling @@ -176,7 +176,7 @@ In a dynamic inference serving scenario, there is a need to minimize the number .. note:: Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. -When executing vLLM on HPU, the following log can be observed: +Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: .. code-block:: @@ -185,25 +185,6 @@ When executing vLLM on HPU, the following log can be observed: INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] -In this scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. - -.. warning:: - If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. - -As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket. - -.. note:: - Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. - -Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: - -.. code-block:: - - INFO 08-02 15:30:53 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] - INFO 08-02 15:30:53 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] - INFO 08-02 15:30:53 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] - INFO 08-02 15:30:53 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - ``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes. Example (with ramp-up) @@ -225,6 +206,16 @@ Example (without ramp-up) => buckets = ramp_up + stable => (128, 256, 384, 512) +In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. + +.. warning:: + If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. + +As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket. + +.. note:: + Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. + Warmup ------------ @@ -315,7 +306,7 @@ Environment variables **Diagnostic and profiling knobs:** - ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai `__. Disabled by default. -- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS``. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS=1``. Disabled by default. - ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. - ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. - ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. @@ -323,8 +314,9 @@ Environment variables **Performance tuning knobs:** - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default -- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture +- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.4`` by default - ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default +- ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default - ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default - ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism From 32c50f2eb0922e2c0ad5d21434d5af16faaf19e4 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 15:20:08 +0300 Subject: [PATCH 21/25] document strategies --- docs/source/getting_started/gaudi-installation.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 9c946376e51b8..ef36e8da03004 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -246,7 +246,13 @@ HPU Graph capture `HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. -Whenever HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. +Whenever HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. + +User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: +- ``max_bs`` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. ``(64, 128)``, ``(64, 256)``, ``(32, 128)``, ``(32, 256)``, ``(1, 128)``, ``(1,256)``), default strategy for decode +- ``min_tokens`` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (``batch_size*sequence_length``), default strategy for prompt + +When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. Whenever a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by ``max_bs`` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in ``min_tokens`` strategy. .. note:: From 340bc239851a21ff63bb6fd1b4ef7457c14c00a9 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 15:34:53 +0300 Subject: [PATCH 22/25] clarify how gpu_mem_utilization works --- .../getting_started/gaudi-installation.rst | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index ef36e8da03004..256c6b401ec10 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -235,7 +235,7 @@ Warmup is an optional, but highly recommended step occuring before vLLM server s INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB -This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. Whenever bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. +This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. .. tip:: Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. @@ -246,20 +246,30 @@ HPU Graph capture `HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. -Whenever HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. +When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). +Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. +Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, it will mark 90% of free device memory at that point as usable. +Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. +Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. +With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. +Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. +Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. + +.. note:: + ``gpu_memory_utilization`` does not correspond to the absolute memory usage across HPU. It describes the memory margin after loading the model and performing a profile run. User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: - ``max_bs`` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. ``(64, 128)``, ``(64, 256)``, ``(32, 128)``, ``(32, 256)``, ``(1, 128)``, ``(1,256)``), default strategy for decode - ``min_tokens`` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (``batch_size*sequence_length``), default strategy for prompt -When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. Whenever a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by ``max_bs`` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in ``min_tokens`` strategy. +When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by ``max_bs`` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in ``min_tokens`` strategy. .. note:: ``VLLM_GRAPH_PROMPT_RATIO`` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * ``VLLM_GRAPH_PROMPT_RATIO``) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. -Each described step is logged by vLLM server, as follows: +Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): .. code-block:: From 312abe4d63c598edc0e2cb83f68b59a58b6abaa0 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 15:39:36 +0300 Subject: [PATCH 23/25] clarify how gpu_mem_utilization works --- docs/source/getting_started/gaudi-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 256c6b401ec10..95e9d6db9331a 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -256,7 +256,7 @@ Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. .. note:: - ``gpu_memory_utilization`` does not correspond to the absolute memory usage across HPU. It describes the memory margin after loading the model and performing a profile run. + ``gpu_memory_utilization`` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, ``gpu_memory_utilization`` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: - ``max_bs`` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. ``(64, 128)``, ``(64, 256)``, ``(32, 128)``, ``(32, 256)``, ``(1, 128)``, ``(1,256)``), default strategy for decode From f10e161047a4df7f28a951b56eb86525eddb4ba5 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 15:42:29 +0300 Subject: [PATCH 24/25] clarify how gpu_mem_utilization works --- docs/source/getting_started/gaudi-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 95e9d6db9331a..7e73b0a2bdffc 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -248,7 +248,7 @@ HPU Graph capture When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. -Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, it will mark 90% of free device memory at that point as usable. +Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, 90% of free device memory will be marked at that point as usable. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. From eac13856b80fc1a7a3529ff21baf0df39ff2d0d3 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 15:47:59 +0300 Subject: [PATCH 25/25] fix typos --- docs/source/getting_started/gaudi-installation.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 7e73b0a2bdffc..7af291d62efc6 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -171,7 +171,7 @@ Bucketing mechanism ------------ Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler `__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. -In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occuring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. +In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. .. note:: Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. @@ -219,7 +219,7 @@ As an example, if a request of 3 sequences, with max sequence length of 412 come Warmup ------------ -Warmup is an optional, but highly recommended step occuring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundries during server runtime. Each warmup step is logged during vLLM startup: +Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: .. code-block:: @@ -248,7 +248,7 @@ HPU Graph capture When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. -Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, 90% of free device memory will be marked at that point as usable. +Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache.