From 66eae9e75e6e70a69eeefbe24e8a1f0499524a3b Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 13 Aug 2024 13:46:19 +0200 Subject: [PATCH 1/7] remove reminder_comment.yml (#179) --- .github/workflows/reminder_comment.yml | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 .github/workflows/reminder_comment.yml diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml deleted file mode 100644 index 390c88bb6530..000000000000 --- a/.github/workflows/reminder_comment.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: PR Reminder Comment Bot -on: - pull_request_target: - types: [opened] - -jobs: - pr_reminder: - runs-on: ubuntu-latest - steps: - - name: Remind to run full CI on PR - uses: actions/github-script@v6 - with: - script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: 'šŸ‘‹ Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\nšŸš€' - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} From c0984334c495762b10ee37dc817afad9fec0ef57 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 13 Aug 2024 13:46:33 +0200 Subject: [PATCH 2/7] Fix logger initialization in ops.py (#178) --- vllm/hpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 7a40e6e72025..c8f00c1cbd59 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -14,7 +14,7 @@ import vllm.hpu.utils as hpu_utils from vllm.logger import init_logger -logger = init_logger() +logger = init_logger(__name__) HPUFusedRMSNorm = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm From 6f047d864ba3f7b409eeaedfd1e92f61389d31da Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 14:53:48 +0200 Subject: [PATCH 3/7] 1.17 documentation update (#172) --- .../getting_started/gaudi-installation.rst | 234 +++++++++++++++++- 1 file changed, 230 insertions(+), 4 deletions(-) diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index a9f3ebdf274f..7af291d62efc 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -18,7 +18,7 @@ Requirements - OS: Ubuntu 22.04 LTS - Python: 3.10 - Intel Gaudi accelerator -- Intel Gaudi software version 1.16.0 or newer +- Intel Gaudi software version 1.17.0 To verify that the Intel Gaudi software was correctly installed, run: @@ -44,8 +44,8 @@ Use the following commands to run a Docker image: .. code:: console - $ docker pull vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + $ docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest + $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest Build and Install vLLM --------------------------- @@ -112,6 +112,12 @@ Gaudi2 devices. Configurations that are not listed may or may not work. - `meta-llama/Meta-Llama-3-8B-Instruct `__ on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B-Instruct `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling - `meta-llama/Llama-2-70b `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Llama-2-70b-chat-hf `__ @@ -120,14 +126,187 @@ Gaudi2 devices. Configurations that are not listed may or may not work. with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Meta-Llama-3-70B-Instruct `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B-Instruct `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `mistralai/Mistral-7B-Instruct-v0.3 `__ on single HPU or with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling - `mistralai/Mixtral-8x7B-Instruct-v0.1 `__ with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling -Performance Tips +Performance Tuning ================ +Execution modes +------------ + +Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag. + +.. list-table:: vLLM execution modes + :widths: 25 25 50 + :header-rows: 1 + + * - ``PT_HPU_LAZY_MODE`` + - ``enforce_eager`` + - execution mode + * - 0 + - 0 + - torch.compile + * - 0 + - 1 + - PyTorch eager mode + * - 1 + - 0 + - HPU Graphs + * - 1 + - 1 + - PyTorch lazy mode + +.. warning:: + In 1.17.0, all modes utilizing ``PT_HPU_LAZY_MODE=0`` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.17.0, please use HPU Graphs, or PyTorch lazy mode. + + +Bucketing mechanism +------------ + +Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler `__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. +In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. + +.. note:: + Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. + +Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: + +.. code-block:: + + INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + +``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes. + +Example (with ramp-up) + +.. code-block:: + + min = 2, step = 32, max = 64 + => ramp_up = (2, 4, 8, 16) + => stable = (32, 64) + => buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) + +Example (without ramp-up) + +.. code-block:: + + min = 128, step = 128, max = 512 + => ramp_up = () + => stable = (128, 256, 384, 512) + => buckets = ramp_up + stable => (128, 256, 384, 512) + + +In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. + +.. warning:: + If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. + +As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket. + +.. note:: + Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. + +Warmup +------------ + +Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: + +.. code-block:: + + INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB + INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB + INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB + ... + INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB + INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB + INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB + ... + INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + +This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. + +.. tip:: + Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. + +HPU Graph capture +------------ + +`HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. + + +When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). +Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. +Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. +Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. +Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. +With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. +Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. +Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. + +.. note:: + ``gpu_memory_utilization`` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, ``gpu_memory_utilization`` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. + +User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: +- ``max_bs`` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. ``(64, 128)``, ``(64, 256)``, ``(32, 128)``, ``(32, 256)``, ``(1, 128)``, ``(1,256)``), default strategy for decode +- ``min_tokens`` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (``batch_size*sequence_length``), default strategy for prompt + +When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by ``max_bs`` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in ``min_tokens`` strategy. + + +.. note:: + ``VLLM_GRAPH_PROMPT_RATIO`` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * ``VLLM_GRAPH_PROMPT_RATIO``) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. + + +Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): + +.. code-block:: + + INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-02 17:37:44 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:37:52 habana_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 habana_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 habana_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache + INFO 08-02 17:37:54 habana_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 + INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB + ... + INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5) + INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + ... + INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB + INFO 08-02 17:38:27 habana_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB + ... + INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB + INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB + INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB + INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB + INFO 08-02 17:38:43 habana_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB + INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] + INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory + INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) + + +Recommended vLLM Parameters +------------ + - We recommend running inference on Gaudi 2 with ``block_size`` of 128 for BF16 data type. Using default values (16, 32) might lead to sub-optimal performance due to Matrix Multiplication Engine @@ -137,6 +316,53 @@ Performance Tips of 128 or 256 and max context length of 2048 with HPU Graphs enabled. If you encounter out-of-memory issues, see troubleshooting section. +Environment variables +------------ + +**Diagnostic and profiling knobs:** + +- ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai `__. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS=1``. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. + +**Performance tuning knobs:** + +- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default +- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.4`` by default +- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default +- ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default +- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default +- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism + + - ``{phase}`` is either ``PROMPT`` or ``DECODE`` + - ``{dim}`` is either ``BS`` or ``SEQ`` + - ``{param}`` is either ``MIN``, ``STEP`` or ``MAX`` + - Default values: + + - Prompt: + - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` + - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` + - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` + + - Decode: + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` + - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` + - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048`` + + +Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: + +- ``PT_HPU_LAZY_MODE``: if ``0``, PyTorch Eager backend for Gaudi will be used, if ``1`` PyTorch Lazy backend for Gaudi will be used, ``1`` is default +- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPU Graphs + Troubleshooting: Tweaking HPU Graphs ==================================== From 1e0e492e1400114f9156d61ffdd73585181ed119 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Aug 2024 15:06:19 +0200 Subject: [PATCH 4/7] Readme 1.17 update (#186) FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) **BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--- README_GAUDI.md | 497 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 435 insertions(+), 62 deletions(-) diff --git a/README_GAUDI.md b/README_GAUDI.md index 1a1b2d9cc6e3..a569d6314acf 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -1,25 +1,25 @@ -# vLLM with IntelĀ® GaudiĀ® 2 AI Accelerators +vLLM with IntelĀ® GaudiĀ® AI Accelerators +======================================= -This README provides instructions on running vLLM with Intel Gaudi devices. +This README provides instructions on running vLLM with Intel Gaudi +devices. Requirements and Installation -============================== +============================= -Please follow the instructions provided in the [Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) -to set up the environment. To achieve the best performance, please follow the methods outlined in the -[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). - -> [!NOTE] -> In this release (1.16.0), we are only targeting functionality and -> accuracy. Performance will be improved in next releases. +Please follow the instructions provided in the [Gaudi Installation +Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) +to set up the environment. To achieve the best performance, please +follow the methods outlined in the [Optimizing Training Platform +Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). Requirements -------------- +------------ - OS: Ubuntu 22.04 LTS - Python: 3.10 -- Intel Gaudi 2 accelerator -- Intel Gaudi software version 1.16.0 +- Intel Gaudi accelerator +- Intel Gaudi software version 1.17.0 To verify that the Intel Gaudi software was correctly installed, run: @@ -29,41 +29,50 @@ $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, ha $ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed ``` -Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) for more details. +Refer to [Intel Gaudi Software Stack +Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) +for more details. Run Docker Image ------------------- +---------------- -It is highly recommended to use the latest Docker image from Intel -Gaudi vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) for more details. +It is highly recommended to use the latest Docker image from Intel Gaudi +vault. Refer to the [Intel Gaudi +documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) +for more details. Use the following commands to run a Docker image: ``` {.console} -$ docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest -$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest - ``` +$ docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest +$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest +``` -Build and Install vLLM-fork ------------------------------ +Build and Install vLLM +---------------------- -To build and install vLLM-fork from source, run: +Currently, the latest features and performance optimizations are +developed in Gaudi\'s [vLLM-fork](https://github.com/HabanaAI/vllm-fork) +and we periodically upstream them to vLLM main repo. To install latest +[HabanaAI/vLLM-fork](https://github.com/HabanaAI/vllm-fork), run the +following: ``` {.console} $ git clone https://github.com/HabanaAI/vllm-fork.git $ cd vllm-fork -# git checkout v0.4.2-Gaudi-1.16.0 -$ pip install -e . # This may take 5-10 minutes. +$ git checkout habana_main +$ python setup.py develop ``` Supported Features ================== -- [Offline batched inference](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#offline-batched-inference) -- Online inference via [OpenAI-Compatible Server](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server) +- [Offline batched + inference](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#offline-batched-inference) +- Online inference via [OpenAI-Compatible + Server](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server) - HPU autodetection - no need to manually select device within vLLM -- Paged KV cache with algorithms enabled for Intel Gaudi 2 - accelerators +- Paged KV cache with algorithms enabled for Intel Gaudi accelerators - Custom Intel Gaudi implementations of Paged Attention, KV cache ops, prefill attention, Root Mean Square Layer Normalization, Rotary Positional Encoding @@ -72,7 +81,6 @@ Supported Features Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) for accelerating low-batch latency and throughput - Unsupported Features ==================== @@ -82,11 +90,11 @@ Unsupported Features - Quantization (AWQ, FP8 E5M2, FP8 E4M3) - Prefill chunking (mixed-batch inferencing) - Supported Configurations ======================== -The following configurations have been validated to be function with Gaudi devices. Configurations that are not listed may or may not work. +The following configurations have been validated to be function with +Gaudi2 devices. Configurations that are not listed may or may not work. - [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 @@ -94,47 +102,412 @@ The following configurations have been validated to be function with Gaudi devic - [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling - [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) - with tensor parallelism on 8x HPU, BF16 datatype with random - or greedy sampling + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling - [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) - with tensor parallelism 8x HPU, BF16 datatype with random - or greedy sampling + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling - [mistralai/Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) - on single HPU or with tensor parallelism 2x HPU, BF16 datatype with random or greedy sampling + on single HPU or with tensor parallelism on 2x HPU, BF16 datatype + with random or greedy sampling - [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - with tensor parallelism 2x HPU, BF16 datatype with random or greedy sampling + with tensor parallelism on 2x HPU, BF16 datatype with random or + greedy sampling + +Performance Tuning +================ +Execution modes +----------------------------- +Currently in vLLM for HPU we support four execution modes, depending on +selected HPU PyTorch Bridge backend (via `PT_HPU_LAZY_MODE` environment +variable), and `--enforce-eager` flag. -Performance Tips -================ +| `PT_HPU_LAZY_MODE` | `enforce_eager` | execution mode | +|--- |--- |--- | +| 0 | 0 | torch.compile | +| 0 | 1 | PyTorch eager mode | +| 1 | 0 | HPU Graphs | +| 1 | 1 | PyTorch lazy mode | + + +> [!WARNING] +> In 1.17.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly +> experimental and should be only used for validating functional +> correctness. Their performance will be improved in the next releases. +> For obtaining the best performance in 1.17.0, please use HPU Graphs, or +> PyTorch lazy mode. + +Bucketing mechanism +----------------------------- + +Intel Gaudi accelerators work best when operating on models with fixed +tensor shapes. [Intel Gaudi Graph +Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime) +is responsible for generating optimized binary code that implements the +given model topology on Gaudi. In its default configuration, the +produced binary code may be heavily dependent on input and output tensor +shapes, and can require graph recompilation when encountering +differently shaped tensors within the same topology. While the resulting +binaries utilize Gaudi efficiently, the compilation itself may introduce +a noticeable overhead in end-to-end execution. In a dynamic inference +serving scenario, there is a need to minimize the number of graph +compilations and reduce the risk of graph compilation occurring during +server runtime. Currently it is achieved by \"bucketing\" model\'s +forward pass across two dimensions - `batch_size` and `sequence_length`. + +> [!NOTE] +> Bucketing allows us to reduce the number of required graphs +> significantly, but it does not handle any graph compilation and device +> code generation - this is done in warmup and HPUGraph capture phase. + +Bucketing ranges are determined with 3 parameters - `min`, `step` and +`max`. They can be set separately for prompt and decode phase, and for +batch size and sequence length dimension. These parameters can be +observed in logs during vLLM startup: + +``` {.} +INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] +INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] +INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] +INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +``` -- We recommend running inference on Gaudi 2 with - `block_size` of 128 for BF16 data type. Using default - values (16, 32) might lead to sub-optimal performance due to Matrix - Multiplication Engine under-utilization (see [Gaudi +`min` determines the lowest value of the bucket. `step` determines the +interval between buckets, and `max` determines the upper bound of the +bucket. Furthermore, interval between `min` and `step` has special +handling - `min` gets multiplied by consecutive powers of two, until +`step` gets reached. We call this the ramp-up phase and it is used for +handling lower batch sizes with minimum wastage, while allowing larger +padding on larger batch sizes. + +Example (with ramp-up) + +``` {.} +min = 2, step = 32, max = 64 +=> ramp_up = (2, 4, 8, 16) +=> stable = (32, 64) +=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) +``` + +Example (without ramp-up) + +``` {.} +min = 128, step = 128, max = 512 +=> ramp_up = () +=> stable = (128, 256, 384, 512) +=> buckets = ramp_up + stable => (128, 256, 384, 512) +``` + +In the logged scenario, 24 buckets were generated for prompt (prefill) +runs, and 48 buckets for decode runs. Each bucket corresponds to a +separate optimized device binary for a given model with specified tensor +shapes. Whenever a batch of requests is processed, it is padded across +batch and sequence length dimension to the smallest possible bucket. + +> [!WARNING] +> If a request exceeds maximum bucket size in any dimension, it will be +> processed without padding, and its processing may require a graph +> compilation, potentially significantly increasing end-to-end latency. +> The boundaries of the buckets are user-configurable via environment +> variables, and upper bucket boundaries can be increased to avoid such +> scenario. + +As an example, if a request of 3 sequences, with max sequence length of +412 comes in to an idle vLLM server, it will be padded executed as +`(4, 512)` prefill bucket, as `batch_size` (number of sequences) will be +padded to 4 (closest batch\_size dimension higher than 3), and max +sequence length will be padded to 512 (closest sequence length dimension +higher than 412). After prefill stage, it will be executed as `(4, 512)` +decode bucket and will continue as that bucket until either batch +dimension changes (due to request being finished) - in which case it +will become a `(2, 512)` bucket, or context length increases above 512 +tokens, in which case it will become `(4, 640)` bucket. + +> [!NOTE] +> Bucketing is transparent to a client - padding in sequence length +> dimension is never returned to the client, and padding in batch +> dimension does not create new requests. + +Warmup +------ + +Warmup is an optional, but highly recommended step occurring before vLLM +server starts listening. It executes a forward pass for each bucket with +dummy data. The goal is to pre-compile all graphs and not incur any +graph compilation overheads within bucket boundaries during server +runtime. Each warmup step is logged during vLLM startup: + +``` {.} +INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB +INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB +INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB +... +INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB +INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB +INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB +... +INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB +INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB +``` + +This example uses the same buckets as in *Bucketing mechanism* section. +Each output line corresponds to execution of a single bucket. When +bucket is executed for the first time, its graph is compiled and can be +reused later on, skipping further graph compilations. + +> [!TIP] +> Compiling all the buckets might take some time and can be turned off +> with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if +> you do that, you may face graph compilations once executing a given +> bucket for the first time. It is fine to disable warmup for development, +> but it\'s highly recommended to enable it in deployment. + +HPU Graph capture +----------------------------- + +[HPU +Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) +are currently the most performant execution method of vLLM on Intel +Gaudi. When HPU Graphs are enabled, execution graphs will be traced +(recorded) ahead of time (after performing warmup), to be later replayed +during inference, significantly reducing host overheads. Recording can +take large amounts of memory, which needs to be taken into account when +allocating KV cache. Enabling HPU Graphs will impact the number of +available KV cache blocks, but vLLM provides user-configurable variables +to control memory management. + +When HPU Graphs are being used, they share the common memory pool +(\"usable memory\") as KV cache, determined by `gpu_memory_utilization` +flag (`0.9` by default). Before KV cache gets allocated, model weights +are loaded onto the device, and a forward pass of the model is executed +on dummy data, to estimate memory usage. Only after that, +`gpu_memory_utilization` flag is utilized - at its default value, will +mark 90% of free device memory at that point as usable. Next, KV cache +gets allocated, model is warmed up, and HPU Graphs are captured. +Environment variable `VLLM_GRAPH_RESERVED_MEM` defines the ratio of +memory reserved for HPU Graphs capture. With its default value +(`VLLM_GRAPH_RESERVED_MEM=0.4`), 40% of usable memory will be reserved +for graph capture (later referred to as \"usable graph memory\"), and +the remaining 60% will be utilized for KV cache. Environment variable +`VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory +reserved for prefill and decode graphs. By default +(`VLLM_GRAPH_PROMPT_RATIO=0.5`), both stages have equal memory +constraints. Lower value corresponds to less usable graph memory +reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will +reserve 20% of usable graph memory for prefill graphs, and 80% of usable +graph memory for decode graphs. + +> [!NOTE] +> `gpu_memory_utilization` does not correspond to the absolute memory +> usage across HPU. It specifies the memory margin after loading the model +> and performing a profile run. If device has 100 GiB of total memory, and +> 50 GiB of free memory after loading model weights and executing +> profiling run, `gpu_memory_utilization` at its default value will mark +> 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total +> device memory. + +User can also configure the strategy for capturing HPU Graphs for prompt +and decode stages separately. Strategy affects the order of capturing +graphs. There are two strategies implemented: - `max_bs` - graph capture +queue will sorted in descending order by their batch sizes. Buckets with +equal batch sizes are sorted by sequence length in ascending order (e.g. +`(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, +`(1,256)`), default strategy for decode - `min_tokens` - graph capture +queue will be sorted in ascending order by the number of tokens each +graph processes (`batch_size*sequence_length`), default strategy for +prompt + +When there\'s large amount of requests pending, vLLM scheduler will +attempt to fill the maximum batch size for decode as soon as possible. +When a request is finished, decode batch size decreases. When that +happens, vLLM will attempt to schedule a prefill iteration for requests +in the waiting queue, to fill the decode batch size to its previous +state. This means that in a full load scenario, decode batch size is +often at its maximum, which makes large batch size HPU Graphs crucial to +capture, as reflected by `max_bs` strategy. On the other hand, prefills +will be executed most frequently with very low batch sizes (1-4), which +is reflected in `min_tokens` strategy. + +> [!NOTE] +> `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by +> graphs for each stage (prefill and decode). vLLM will first attempt to +> use up entirety of usable prefill graph memory (usable graph memory \* +> `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it +> will attempt do the same for decode graphs and usable decode graph +> memory pool. If one stage is fully captured, and there is unused memory +> left within usable graph memory pool, vLLM will attempt further graph +> capture for the other stage, until no more HPU Graphs can be captured +> without exceeding reserved memory pool. The behavior on that mechanism +> can be observed in the example below. + +Each described step is logged by vLLM server, as follows (negative +values correspond to memory being released): + +``` {.} +INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] +INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] +INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] +INFO 08-02 17:37:44 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +INFO 08-02 17:37:52 habana_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:52 habana_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:52 habana_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:54 habana_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) +INFO 08-02 17:37:54 habana_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache +INFO 08-02 17:37:54 habana_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 +INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) +INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB +... +INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5) +INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB +... +INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB +INFO 08-02 17:38:27 habana_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB +... +INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB +INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB +INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB +INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB +INFO 08-02 17:38:43 habana_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB +INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] +INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory +INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) +``` + +Recommended vLLM Parameters +----------------------------- + +- We recommend running inference on Gaudi 2 with `block_size` of 128 + for BF16 data type. Using default values (16, 32) might lead to + sub-optimal performance due to Matrix Multiplication Engine + under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). - For max throughput on Llama 7B, we recommend running with batch size - of 128 or 256 and max context length of 2048 with HPU Graphs enabled. - If you encounter out-of-memory issues, see troubleshooting section. + of 128 or 256 and max context length of 2048 with HPU Graphs + enabled. If you encounter out-of-memory issues, see troubleshooting + section. + +Environment variables +----------------------------- + +**Diagnostic and profiling knobs:** + +- `VLLM_PROFILER_ENABLED`: if `true`, high level profiler will be + enabled. Resulting JSON traces can be viewed in + [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). Disabled + by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: if `true`, will log graph + compilations per each vLLM engine step, only when there was any - + highly recommended to use alongside `PT_HPU_METRICS_GC_DETAILS=1`. + Disabled by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: if `true`, will log graph + compilations per each vLLM engine step, always, even if there were + none. Disabled by default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: if `true`, will log cpu fallbacks + per each vLLM engine step, only when there was any. Disabled by + default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, will log cpu + fallbacks per each vLLM engine step, always, even if there were + none. Disabled by default. + +**Performance tuning knobs:** + +- `VLLM_SKIP_WARMUP`: if `true`, warmup will be skipped, `false` by + default +- `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for + HPUGraph capture, `0.4` by default +- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory + dedicated for prompt graphs, `0.5` by default +- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt + graph capture, `min_tokens` or `max_bs`, `min_tokens` by default +- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode + graph capture, `min_tokens` or `max_bs`, `max_bs` by default +- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment + variables configuring ranges of bucketing mechanism + - `{phase}` is either `PROMPT` or `DECODE` + - `{dim}` is either `BS` or `SEQ` + - `{param}` is either `MIN`, `STEP` or `MAX` + - Default values: + - Prompt: + - batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1` + - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32` + - batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`): + `min(max_num_seqs, 64)` + - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): + `block_size` + - sequence length step + (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` + - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): + `1024` + + - Decode: + - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` + - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): + `128` + - batch size max (`VLLM_DECODE_BS_BUCKET_MAX`): + `max_num_seqs` + - sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`): + `block_size` + - sequence length step + (`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size` + - sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`): + `2048` + +Additionally, there are HPU PyTorch Bridge environment variables +impacting vLLM execution: + +- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be + used, if `1` PyTorch Lazy backend for Gaudi will be used, `1` is + default +- `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor + parallel inference with HPU Graphs Troubleshooting: Tweaking HPU Graphs ==================================== -If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following the below: - -- Tweak `gpu_memory_utilization` knob. It - will decrease the allocation of KV cache, leaving some headroom for - capturing graphs with larger batch size. By default `gpu_memory_utilization` is set to 0.9. - It attempts to allocate \~90% of HBM left for KV cache after short - profiling run. Note that decreasing reduces the number of KV - cache blocks you have available, and therefore reduces the effective - maximum number of tokens you can handle at a given time. - -- If this method is not efficient, you can disable `HPUGraph` completely. With - HPU Graphs disabled, you are trading latency and throughput at lower - batches for potentially higher throughput on higher batches. You can do - that by adding `--enforce-eager` flag to server (for - online inference), or by passing `enforce_eager=True` - argument to LLM constructor (for offline inference). +If you experience device out-of-memory issues or want to attempt +inference at higher batch sizes, try tweaking HPU Graphs by following +the below: + +- Tweak `gpu_memory_utilization` knob. It will decrease the allocation + of KV cache, leaving some headroom for capturing graphs with larger + batch size. By default `gpu_memory_utilization` is set to 0.9. It + attempts to allocate \~90% of HBM left for KV cache after short + profiling run. Note that decreasing reduces the number of KV cache + blocks you have available, and therefore reduces the effective + maximum number of tokens you can handle at a given time. +- If this method is not efficient, you can disable `HPUGraph` + completely. With HPU Graphs disabled, you are trading latency and + throughput at lower batches for potentially higher throughput on + higher batches. You can do that by adding `--enforce-eager` flag to + server (for online inference), or by passing `enforce_eager=True` + argument to LLM constructor (for offline inference). From b0112c3a9a075e83f5bb98127586d925402f3614 Mon Sep 17 00:00:00 2001 From: Nir David <124874956+nirda7@users.noreply.github.com> Date: Wed, 14 Aug 2024 19:34:25 +0300 Subject: [PATCH 5/7] Support FP8 INC in vLLM (#144) FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) **BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--- README_GAUDI.md | 3 +- .../getting_started/gaudi-installation.rst | 3 +- vllm/attention/backends/habana_attn.py | 26 +++- vllm/attention/ops/habana_paged_attn.py | 10 ++ vllm/config.py | 8 +- vllm/engine/arg_utils.py | 14 ++- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/llm.py | 3 + vllm/executor/habana_executor.py | 9 ++ vllm/executor/ray_habana_executor.py | 3 + vllm/hpu/cache_ops.py | 31 +++++ vllm/hpu/ops.py | 33 +++-- vllm/hpu/utils.py | 40 ++++++ vllm/model_executor/layers/layernorm.py | 11 +- vllm/model_executor/layers/linear.py | 10 +- .../layers/quantization/__init__.py | 2 + .../model_executor/layers/quantization/inc.py | 115 ++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 22 ++-- vllm/model_executor/models/llama.py | 6 + vllm/utils.py | 1 + vllm/worker/cache_engine.py | 4 +- vllm/worker/habana_model_runner.py | 57 ++++++++- vllm/worker/habana_worker.py | 21 ++++ 23 files changed, 387 insertions(+), 51 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/inc.py diff --git a/README_GAUDI.md b/README_GAUDI.md index a569d6314acf..9ea30a2e43f6 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -26,7 +26,8 @@ To verify that the Intel Gaudi software was correctly installed, run: ``` {.console} $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed -$ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed +$ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed +$ pip list | grep neural # verify that neural-compressor is installed ``` Refer to [Intel Gaudi Software Stack diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 7af291d62efc..ddbac022a8d9 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -26,7 +26,8 @@ To verify that the Intel Gaudi software was correctly installed, run: $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed - $ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed + $ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed + $ pip list | grep neural # verify that neural_compressor is installed Refer to `Intel Gaudi Software Stack Verification `__ diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 33b6e2e538b1..7a867e79b203 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -12,6 +12,8 @@ AttentionMetadata, AttentionType) from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) +from vllm.hpu import cache_ops +from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.logger import init_logger logger = init_logger(__name__) @@ -108,7 +110,7 @@ def __post_init__(self): self.attn_bias: Optional[torch.Tensor] = None -class HabanaAttentionImpl(AttentionImpl): +class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -137,10 +139,16 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, ) -> None: + super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.position_bias = None @@ -204,9 +212,13 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - HabanaPagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, attn_metadata.is_prompt) + num_kv_cache_passes, num_slots_available, indices, offsets = \ + cache_ops.prepare_to_cache(key_cache, + attn_metadata.slot_mapping) + key_cache = self.k_cache(key, key_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) + value_cache = self.v_cache(value, value_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) if attn_metadata.is_prompt: # Prompt run. @@ -232,6 +244,9 @@ def forward( attn_bias=attn_bias, p=0.0, scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, ) output = out.reshape(batch_size, seq_len, hidden_size) else: @@ -255,7 +270,8 @@ def forward( query, key_cache, value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, self.kv_cache_dtype, self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale) + v_scale, self.matmul_qk, self.softmax, self.matmul_av, + self.k_cache, self.v_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 7dd701c7a0cd..9602886299c4 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -75,6 +75,11 @@ def forward_decode( alibi_slopes: Optional[torch.Tensor], k_scale: float, v_scale: float, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) -> torch.Tensor: block_size = value_cache.shape[1] return ops.paged_attention_v1( @@ -88,6 +93,11 @@ def forward_decode( block_size, alibi_slopes, kv_cache_dtype, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) @staticmethod diff --git a/vllm/config.py b/vllm/config.py index f16bea16fe64..6acb70ad047b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -474,12 +474,13 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor") + "scaling factor. " + "Intel Gaudi (HPU) supports fp8 (using fp8_inc).") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -600,11 +601,12 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - + device: Device on which weights are loaded. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO download_dir: Optional[str] = None + device: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e4b223a1b505..d6c544750afe 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -38,6 +38,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + weights_load_device: Optional[str] = None dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -205,6 +206,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument("--weights-load-device", + type=str, + default=EngineArgs.weights_load_device, + choices=["cuda", "neuron", "hpu", "cpu"], + help='Device on which weights are loaded.') parser.add_argument( '--dtype', type=str, @@ -223,11 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). ' + 'Intel Gaudi (HPU) supports fp8 (using fp8_inc).') parser.add_argument( '--quantization-param-path', type=nullable_str, @@ -835,9 +842,12 @@ def create_engine_config(self, ) -> EngineConfig: self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + device = device_config.device if self.weights_load_device is None else \ + self.weights_load_device load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + device=device, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f7e0a7a4dc5..f8b9c48bc958 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -182,7 +182,7 @@ def __init__( "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "pipeline_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " + "weights_load_device=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " @@ -206,6 +206,7 @@ def __init__( parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, model_config.quantization, + load_config.device, model_config.enforce_eager, cache_config.cache_dtype, model_config.quantization_param_path, @@ -853,6 +854,9 @@ def _process_model_outputs( request_outputs.append(request_output) return request_outputs + def finish_measurements(self): + self.model_executor.finish_measurements() + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1..fc9f118ff14b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -173,6 +173,9 @@ def set_tokenizer( self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( tokenizer) + def finish_measurements(self): + self.llm_engine.finish_measurements() + @overload # LEGACY: single (prompt + optional token ids) def generate( self, diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index f5cf26b68705..80f8037a2d04 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -90,6 +90,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: msg = f"init_cache_engine took {cache_init_m.get_summary_string()}" logger.info(msg) + def finish_measurements(self): + self.driver_worker.finish_measurements() + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: @@ -180,6 +183,12 @@ def check_health(self) -> None: # it's running. return + def shutdown(self) -> None: + self.driver_worker.shutdown_inc() + + def __del__(self): + self.shutdown() + class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index 9e0a89cbeb8a..17e3414a96b5 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -237,6 +237,9 @@ def _driver_execute_model( return self.driver_worker.execute_method("execute_model", execute_model_req) + def finish_measurements(self): + self._run_workers("finish_measurements") + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 14824945aa53..98f109accea0 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -43,6 +43,37 @@ def reshape_and_cache(key, value[start_idx:end_idx]) +def prepare_to_cache(cache, slot_mapping): + num_blocks = cache.size(0) + block_size = cache.size(1) + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + num_slots_requested = slot_mapping.size(0) + num_slots_available = num_blocks * block_size + # NOTE(kzawora): HPU PT bridge crashes with + # RuntimeError: Invalid inputs for scatter_nd_onnx + # on index_put when num_slots_requested > num_slots_available. + # This case might occur when we have little kv cache blocks and + # lots of padding, or are doing warmup. + # This loop is a workaround for this issue. Please remove it + # once key_cache.index_put_(indices, offsets), key) works. + num_kv_cache_passes = torch.div(num_slots_requested, + num_slots_available).ceil().int().item() + + return num_kv_cache_passes, num_slots_available, indices, offsets + + +def insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, block_offsets): + for i in range(num_kv_cache_passes): + start_idx = i * num_slots_available + end_idx = (i + 1) * num_slots_available + cache.index_put_((block_indices[start_idx:end_idx], + block_offsets[start_idx:end_idx]), + input[start_idx:end_idx]) + + def swap_blocks(src, dst, block_mapping): index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device) index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index c8f00c1cbd59..23f6964723d3 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -11,7 +11,6 @@ import torch import torch.nn.functional as F -import vllm.hpu.utils as hpu_utils from vllm.logger import init_logger logger = init_logger(__name__) @@ -33,7 +32,6 @@ def fetch_from_cache(cache, blocks, permutations): ] -@hpu_utils.with_mark_steps def paged_attention_v1(query, key_cache, value_cache, @@ -43,7 +41,12 @@ def paged_attention_v1(query, context_lens, block_size, alibi_slopes=None, - kv_cache_dtype=None) -> None: + kv_cache_dtype=None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + k_cache_cls=None, + v_cache_cls=None) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape _, _, kv_heads, _ = key_cache.shape @@ -56,19 +59,23 @@ def paged_attention_v1(query, batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) + fetch_keys = fetch_from_cache if k_cache_cls is None else \ + k_cache_cls.fetch_from_cache + keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1) + attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) - attn_weights = (attn_weights.masked_fill(mask, min_inf).softmax(dim=-1)) + attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) + fetch_values = fetch_from_cache if v_cache_cls is None else \ + v_cache_cls.fetch_from_cache + values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) if PA_SPLIT_VALUE: attn_weights = attn_weights.split(block_size, dim=-1) else: @@ -76,7 +83,7 @@ def paged_attention_v1(query, attn_weights = [attn_weights] if query_heads != kv_heads: values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)] + attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] if query_heads != kv_heads: attn_weights = [a.flatten(1, 2) for a in attn_weights] attn_weights = sum(attn_weights) @@ -119,7 +126,6 @@ def static_fused_moe(hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) -@hpu_utils.with_mark_steps def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -127,6 +133,9 @@ def prompt_attention( attn_bias: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -139,11 +148,11 @@ def prompt_attention( value = value.unflatten(1, (kv_heads, 1)) if attn_bias is not None: attn_bias = attn_bias.unsqueeze(2) - attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = torch.matmul(attn_weights, value) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = matmul_av_op(attn_weights, value) if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) attn_weights = attn_weights.transpose(1, 2) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index b7b435c50c29..3d9c7cb1c4c2 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -8,6 +8,9 @@ from functools import wraps import habana_frameworks.torch as htorch +import torch + +from vllm.hpu.cache_ops import insert_or_update_cache def with_mark_steps(fn): @@ -22,3 +25,40 @@ def wrapped(*args, **kwargs): return result return wrapped + + +class Matmul(torch.nn.Module): + + def __init__(self): + super(Matmul, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class Softmax(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, inv_head=None): + return torch.softmax(x, dim) + + +class VLLMKVCache(torch.nn.Module): + + def __init__(self): + super(VLLMKVCache, self).__init__() + + def forward(self, input, cache, num_kv_cache_passes, num_slots_available, + block_indices, block_offset): + insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, + block_offset) + return cache + + def fetch_from_cache(self, cache, blocks, permutations): + return [ + cache.index_select(0, blocks[:, i]).permute(permutations) + for i in range(blocks.size(1)) + ] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 55cbbabd7da4..c12668c14887 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -79,18 +79,15 @@ def forward_hpu( if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: - orig_dtype = x.dtype orig_shape = x.shape residual += x.view(residual.shape) # Note: HPUFusedRMSNorm requires 3D tensors as inputs - x = HPUFusedRMSNorm.apply(residual.float(), self.weight.float(), + x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon) - return x.to(orig_dtype).view(orig_shape), residual + return x.view(orig_shape), residual - orig_dtype = x.dtype - x = HPUFusedRMSNorm.apply(x.float(), self.weight.float(), - self.variance_epsilon) - return x.to(orig_dtype) + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x def forward_xpu( self, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b6e280ae6504..10c8a95f838d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -273,6 +273,7 @@ def __init__(self, quant_config, prefix) self.gather_output = gather_output + self.collective_func = tensor_model_parallel_all_gather # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() @@ -334,7 +335,7 @@ def forward(self, input_): output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) + output = self.collective_func(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None @@ -723,6 +724,7 @@ def __init__(self, self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results + self.collective_func = tensor_model_parallel_all_reduce # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() @@ -770,7 +772,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def forward(self, input_): + def resolve_input(self, input_): if self.input_is_parallel: input_parallel = input_ else: @@ -778,6 +780,10 @@ def forward(self, input_): splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[tp_rank].contiguous() + return input_parallel + + def forward(self, input_): + input_parallel = self.resolve_input(input_) # Matrix multiply. assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index bd574512e343..7590d3e98027 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,6 +18,7 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.inc import INCConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -37,6 +38,7 @@ "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "inc": INCConfig, } diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py new file mode 100644 index 000000000000..f6718ec2ac9e --- /dev/null +++ b/vllm/model_executor/layers/quantization/inc.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class INCConfig(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "INCConfig": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["INCLinearMethod"]: + if isinstance(layer, LinearBase): + return INCLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + +class INCLinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, + quant_config: INCConfig, + separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bbe49655020d..06048d97088e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -37,7 +37,7 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_tpu +from vllm.utils import is_hpu, is_tpu logger = init_logger(__name__) @@ -48,14 +48,15 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + if not is_hpu(): + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( @@ -276,10 +277,11 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with torch.device(self.load_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + logger.info("Loading weights on %s ...", self.load_config.device) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1..676a51ce67f9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip @@ -317,6 +318,9 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( @@ -326,6 +330,8 @@ def forward( attn_metadata, residual, ) + if current_platform.is_hpu(): + htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/utils.py b/vllm/utils.py index 8a1bc5de03eb..fe84253feb17 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -39,6 +39,7 @@ "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 93be2f4c321f..ec0b8c236921 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -91,9 +91,11 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. + dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else \ + self.dtype kv_cache.append( torch.zeros(kv_cache_shape, - dtype=self.dtype, + dtype=dtype, pin_memory=pin_memory, device=device)) return kv_cache diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index cf91c69069ed..72aba42ae855 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -182,8 +182,8 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') - if 'bypass_hpu_graphs' in kwargs: - kwargs.pop('bypass_hpu_graphs') # required for PT eager + if 'warmup_mode' in kwargs: + kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], input_ids.size(0), @@ -413,6 +413,9 @@ def __init__( self._setup_buckets() def load_model(self) -> None: + import habana_frameworks.torch.core as htcore + if self.model_config.quantization == 'inc': + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( @@ -429,6 +432,26 @@ def load_model(self) -> None: f"took {m_getmodel.get_summary_string()}") logger.info(msg) + if self.model_config.quantization == 'inc': + logger.info("Preparing model with INC..") + with HabanaMemoryProfiler() as m_inc: + from neural_compressor.torch.quantization import ( + FP8Config, convert, prepare) + config = FP8Config.from_json_file( + os.getenv("QUANT_CONFIG", "")) + if config.measure: + self.model = prepare(self.model, config) + elif config.quantize: + self.model = convert(self.model, config) + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) + else: + self.model = self.model.to("hpu") + htcore.mark_step() + torch.hpu.synchronize() + # FIXME: Running with disable_tensor_cache=True causes # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: @@ -1051,7 +1074,7 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches) + self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() self.profiler.end() gc.collect() @@ -1362,6 +1385,10 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) + def finish_measurements(self): + from neural_compressor.torch.quantization import finalize_calibration + finalize_calibration(self.model.model) + @torch.inference_mode() def execute_model( self, @@ -1369,6 +1396,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + warmup_mode=False, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( @@ -1402,6 +1430,11 @@ def execute_model( } if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({ + "bypass_hpu_graphs": not use_graphs, + "warmup_mode": warmup_mode + }) htorch.core.mark_step() if self.is_driver_worker: @@ -1415,9 +1448,8 @@ def execute_model( with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata. - selected_token_indices, - bypass_hpu_graphs=not use_graphs) + selected_token_indices=sampling_metadata.selected_token_indices + ) # Compute the logits. with self.profiler.record_event( @@ -1459,3 +1491,16 @@ def execute_model( is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) return [output] + + def shutdown_inc(self): + print('inc shutdown') + if (model_config := getattr(self, "model_config", None)) and \ + getattr(model_config, "quantization", None) == 'inc': + print('inc shutdown start') + from neural_compressor.torch.quantization import ( + finalize_calibration) + finalize_calibration(self.model.model) + print('inc shutdown') + + def __del__(self): + self.shutdown_inc() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index f3fdc4dcc63c..87122c03d3c8 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -91,6 +91,16 @@ def __init__( # Initialize gpu_cache as embedding models don't initialize kv_caches self.hpu_cache: Optional[List[List[torch.tensor]]] = None + def _set_env_vars(self): + local_rank = self.local_rank + if self.parallel_config.world_size == 1: + local_rank = -1 + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["ID"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) + os.environ["RANK"] = str(self.rank) + def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") @@ -99,6 +109,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. + if self.model_config.quantization == 'inc': + self._set_env_vars() init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) @@ -211,6 +223,9 @@ def _warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def finish_measurements(self): + self.model_runner.finish_measurements() + @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @@ -288,6 +303,12 @@ def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> Set[int]: raise NotImplementedError("LoRA is not implemented for HPU backend.") + def shutdown_inc(self): + self.model_runner.shutdown_inc() + + def __del__(self): + self.shutdown_inc() + @property def max_model_len(self) -> int: return self.model_config.max_model_len From 8185d760325a7699c5c07f7cd0e28d443a36051b Mon Sep 17 00:00:00 2001 From: Mohit Deopujari Date: Sun, 18 Aug 2024 23:30:38 -0700 Subject: [PATCH 6/7] [Doc][BugFix] Update setup instructions and reference links (#191) 1. Replaced the non-working setup instruction with the correct command. 2. Fixed broken links and updated references in documentation. --- README_GAUDI.md | 6 +++--- .../getting_started/gaudi-installation.rst | 17 ++++------------- docs/source/getting_started/quickstart.rst | 2 +- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/README_GAUDI.md b/README_GAUDI.md index 9ea30a2e43f6..91bcbe49405e 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -62,16 +62,16 @@ following: $ git clone https://github.com/HabanaAI/vllm-fork.git $ cd vllm-fork $ git checkout habana_main -$ python setup.py develop +$ pip install -e . ``` Supported Features ================== - [Offline batched - inference](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#offline-batched-inference) + inference](https://github.com/HabanaAI/vllm-fork/blob/habana_main/docs/source/getting_started/quickstart.rst#offline-batched-inference) - Online inference via [OpenAI-Compatible - Server](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server) + Server](https://github.com/HabanaAI/vllm-fork/blob/habana_main/docs/source/getting_started/quickstart.rst#openai-compatible-server) - HPU autodetection - no need to manually select device within vLLM - Paged KV cache with algorithms enabled for Intel Gaudi accelerators - Custom Intel Gaudi implementations of Paged Attention, KV cache ops, diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index ddbac022a8d9..b3234d10b311 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -30,7 +30,7 @@ To verify that the Intel Gaudi software was correctly installed, run: $ pip list | grep neural # verify that neural_compressor is installed Refer to `Intel Gaudi Software Stack -Verification `__ +Verification `__ for more details. Run Docker Image @@ -51,15 +51,6 @@ Use the following commands to run a Docker image: Build and Install vLLM --------------------------- -To build and install vLLM from source, run: - -.. code:: console - - $ git clone https://github.com/vllm-project/vllm.git - $ cd vllm - $ python setup.py develop - - Currently, the latest features and performance optimizations are developed in Gaudi's `vLLM-fork `__ and we periodically upstream them to vLLM main repo. To install latest `HabanaAI/vLLM-fork `__, run the following: .. code:: console @@ -67,16 +58,16 @@ Currently, the latest features and performance optimizations are developed in Ga $ git clone https://github.com/HabanaAI/vllm-fork.git $ cd vllm-fork $ git checkout habana_main - $ python setup.py develop + $ pip install -e . Supported Features ================== - `Offline batched - inference `__ + inference `__ - Online inference via `OpenAI-Compatible - Server `__ + Server `__ - HPU autodetection - no need to manually select device within vLLM - Paged KV cache with algorithms enabled for Intel Gaudi accelerators - Custom Intel Gaudi implementations of Paged Attention, KV cache ops, diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 89bdc247c5e8..8cfde76adf5f 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -9,7 +9,7 @@ This guide shows how to use vLLM to: * build an API server for a large language model; * start an OpenAI-compatible API server. -Be sure to complete the :ref:`installation instructions ` before continuing with this guide. +Be sure to complete the `Gaudi installation instructions `_ before continuing with this guide. .. note:: From f7dd91d88e6b9e68479af0817431949f665507a7 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Mon, 19 Aug 2024 00:46:21 -0700 Subject: [PATCH 7/7] split gptbigcode forward (#194) --- vllm/model_executor/models/gpt_bigcode.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index fc4e13bbb0e6..3ae3c8c8f712 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,6 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -224,9 +225,14 @@ def forward( position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(len(self.h)): layer = self.h[i] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + if current_platform.is_hpu(): + htorch.core.mark_step() hidden_states = self.ln_f(hidden_states) return hidden_states