Skip to content

Commit

Permalink
Add Gaudi documentation for 1.17
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 1, 2024
1 parent 5b5dab1 commit b0c4d4c
Showing 1 changed file with 140 additions and 2 deletions.
142 changes: 140 additions & 2 deletions docs/source/getting_started/gaudi-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,122 @@ Gaudi2 devices. Configurations that are not listed may or may not work.
- `meta-llama/Meta-Llama-3-8B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-8B <https://huggingface.co/meta-llama/Meta-Llama-3.1-8B>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-8B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Llama-2-70b <https://huggingface.co/meta-llama/Llama-2-70b>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Llama-2-70b-chat-hf <https://huggingface.co/meta-llama/Llama-2-70b-chat-hf>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-70B <https://huggingface.co/meta-llama/Meta-Llama-3-70B>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct>`__
- `meta-llama/Meta-Llama-3-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-70B <https://huggingface.co/meta-llama/Meta-Llama-3-70B>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `mistralai/Mistral-7B-Instruct-v0.3 <https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3>`__
on single HPU or with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling
- `mistralai/Mixtral-8x7B-Instruct-v0.1 <https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1>`__
with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling

Performance Tips
Performance Tuning
================

Execution modes
------------

TODO: t.compile, hpugraphs, lazy and eager

Bucketing mechanism
------------

Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler <https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime>`__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution.
In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occuring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``.

.. note::
Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase.

Whenever executing vLLM on HPU, the following log can be observed:

This comment has been minimized.

Copy link
@Sdaher21

Sdaher21 Aug 5, 2024

When

.. code-block::
INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
In this scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket.

.. warning::
If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario.

As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket.

.. note::
Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests.


TODO: how buckets are determined, here are some notes:

"""Read bucketing configuration from env variables.

phase is either 'prompt' or 'decode'
dim is either 'bs' or 'block'
param is either 'min', 'step' or 'max'
example env variable: VLLM_DECODE_BS_BUCKET_STEP=128
"""

"""Generate a warmup range.

Start from bmin and multiply by 2 until you reach bstep.
Then, increase the values in the range by the value of bstep until you
reach bmax.

Example:
bmin = 2, bstep = 32, bmax = 64
=> ramp_up = (2, 4, 8, 16)
=> stable = (32, 64)
=> return ramp_up + stable => (2, 4, 8, 16, 32, 64)
"""


Warmup
------------

Warmup is an optional, but highly recommended step occuring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundries during server runtime. Each warmup step is logged during vLLM startup:

.. code-block::
INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB
INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB
INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB
...
INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB
INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB
INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB
INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB
...
INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB
INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB
This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. Whenever bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations.

.. tip::
Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment.

HPUGraph capture
------------

TODO: VLLM_GRAPH_MEM_MARGIN, how and why, couple of sentences about mem allocations


Recommended vLLM Parameters
------------

- We recommend running inference on Gaudi 2 with ``block_size`` of 128
for BF16 data type. Using default values (16, 32) might lead to
sub-optimal performance due to Matrix Multiplication Engine
Expand All @@ -137,6 +237,44 @@ Performance Tips
of 128 or 256 and max context length of 2048 with HPU Graphs enabled.
If you encounter out-of-memory issues, see troubleshooting section.

Environment variables
------------

vLLM for HPU supports following environment variables for performance tuning:

- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default
- ``VLLM_GRAPH_MEM_MARGIN``: TODO
- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default
- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default
- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism

- ``{phase}`` is either ``PROMPT`` or ``DECODE``
- ``{dim}`` is either ``BS`` or ``SEQ``
- ``{param}`` is either ``MIN``, ``STEP`` or ``MAX``
- Default values:

- Prompt:
- batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32``
- batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)``
- sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024``

- Decode:
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128``
- batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs``
- sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048``


Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:

- ``PT_HPU_LAZY_MODE``: if ``0``, PyTorch Eager backend for Gaudi will be used, if ``1`` PyTorch Lazy backend for Gaudi will be used, ``1`` is default
- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPUGraphs

Troubleshooting: Tweaking HPU Graphs
====================================

Expand Down

0 comments on commit b0c4d4c

Please sign in to comment.