Skip to content

Commit

Permalink
address cr
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 14, 2024
1 parent 6a3a0c3 commit 64262a2
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions docs/source/getting_started/gaudi-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ Gaudi2 devices. Configurations that are not listed may or may not work.
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-70B <https://huggingface.co/meta-llama/Meta-Llama-3-70B>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct>`__
- `meta-llama/Meta-Llama-3-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-70B <https://huggingface.co/meta-llama/Meta-Llama-3-70B>`__
- `meta-llama/Meta-Llama-3.1-70B <https://huggingface.co/meta-llama/Meta-Llama-3.1-70B>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
Expand Down Expand Up @@ -176,7 +176,7 @@ In a dynamic inference serving scenario, there is a need to minimize the number
.. note::
Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase.

When executing vLLM on HPU, the following log can be observed:
Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup:

.. code-block::
Expand All @@ -185,25 +185,6 @@ When executing vLLM on HPU, the following log can be observed:
INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
In this scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket.

.. warning::
If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario.

As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket.

.. note::
Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests.

Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup:

.. code-block::
INFO 08-02 15:30:53 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-02 15:30:53 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-02 15:30:53 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-02 15:30:53 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes.

Example (with ramp-up)
Expand All @@ -225,6 +206,16 @@ Example (without ramp-up)
=> buckets = ramp_up + stable => (128, 256, 384, 512)
In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket.

.. warning::
If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario.

As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket.

.. note::
Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests.

Warmup
------------

Expand Down Expand Up @@ -315,16 +306,17 @@ Environment variables
**Diagnostic and profiling knobs:**

- ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai <https://perfetto.habana.ai/#!/viewer>`__. Disabled by default.
- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS``. Disabled by default.
- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS=1``. Disabled by default.
- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default.
- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default.
- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default.

**Performance tuning knobs:**

- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default
- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture
- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.4`` by default
- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default
- ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default
- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default
- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism

Expand Down

0 comments on commit 64262a2

Please sign in to comment.