diff --git a/README_GAUDI.md b/README_GAUDI.md index 1a1b2d9cc6e36..91bcbe49405eb 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -1,69 +1,79 @@ -# vLLM with Intel® Gaudi® 2 AI Accelerators +vLLM with Intel® Gaudi® AI Accelerators +======================================= -This README provides instructions on running vLLM with Intel Gaudi devices. +This README provides instructions on running vLLM with Intel Gaudi +devices. Requirements and Installation -============================== +============================= -Please follow the instructions provided in the [Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) -to set up the environment. To achieve the best performance, please follow the methods outlined in the -[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). - -> [!NOTE] -> In this release (1.16.0), we are only targeting functionality and -> accuracy. Performance will be improved in next releases. +Please follow the instructions provided in the [Gaudi Installation +Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) +to set up the environment. To achieve the best performance, please +follow the methods outlined in the [Optimizing Training Platform +Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). Requirements -------------- +------------ - OS: Ubuntu 22.04 LTS - Python: 3.10 -- Intel Gaudi 2 accelerator -- Intel Gaudi software version 1.16.0 +- Intel Gaudi accelerator +- Intel Gaudi software version 1.17.0 To verify that the Intel Gaudi software was correctly installed, run: ``` {.console} $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed -$ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed +$ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed +$ pip list | grep neural # verify that neural-compressor is installed ``` -Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) for more details. +Refer to [Intel Gaudi Software Stack +Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) +for more details. Run Docker Image ------------------- +---------------- -It is highly recommended to use the latest Docker image from Intel -Gaudi vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) for more details. +It is highly recommended to use the latest Docker image from Intel Gaudi +vault. Refer to the [Intel Gaudi +documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) +for more details. Use the following commands to run a Docker image: ``` {.console} -$ docker pull vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest -$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.16.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest - ``` +$ docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest +$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest +``` -Build and Install vLLM-fork ------------------------------ +Build and Install vLLM +---------------------- -To build and install vLLM-fork from source, run: +Currently, the latest features and performance optimizations are +developed in Gaudi\'s [vLLM-fork](https://github.com/HabanaAI/vllm-fork) +and we periodically upstream them to vLLM main repo. To install latest +[HabanaAI/vLLM-fork](https://github.com/HabanaAI/vllm-fork), run the +following: ``` {.console} $ git clone https://github.com/HabanaAI/vllm-fork.git $ cd vllm-fork -# git checkout v0.4.2-Gaudi-1.16.0 -$ pip install -e . # This may take 5-10 minutes. +$ git checkout habana_main +$ pip install -e . ``` Supported Features ================== -- [Offline batched inference](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#offline-batched-inference) -- Online inference via [OpenAI-Compatible Server](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server) +- [Offline batched + inference](https://github.com/HabanaAI/vllm-fork/blob/habana_main/docs/source/getting_started/quickstart.rst#offline-batched-inference) +- Online inference via [OpenAI-Compatible + Server](https://github.com/HabanaAI/vllm-fork/blob/habana_main/docs/source/getting_started/quickstart.rst#openai-compatible-server) - HPU autodetection - no need to manually select device within vLLM -- Paged KV cache with algorithms enabled for Intel Gaudi 2 - accelerators +- Paged KV cache with algorithms enabled for Intel Gaudi accelerators - Custom Intel Gaudi implementations of Paged Attention, KV cache ops, prefill attention, Root Mean Square Layer Normalization, Rotary Positional Encoding @@ -72,7 +82,6 @@ Supported Features Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) for accelerating low-batch latency and throughput - Unsupported Features ==================== @@ -82,11 +91,11 @@ Unsupported Features - Quantization (AWQ, FP8 E5M2, FP8 E4M3) - Prefill chunking (mixed-batch inferencing) - Supported Configurations ======================== -The following configurations have been validated to be function with Gaudi devices. Configurations that are not listed may or may not work. +The following configurations have been validated to be function with +Gaudi2 devices. Configurations that are not listed may or may not work. - [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 @@ -94,47 +103,412 @@ The following configurations have been validated to be function with Gaudi devic - [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling - [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) - with tensor parallelism on 8x HPU, BF16 datatype with random - or greedy sampling + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling - [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) - with tensor parallelism 8x HPU, BF16 datatype with random - or greedy sampling + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling +- [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) + with tensor parallelism on 8x HPU, BF16 datatype with random or + greedy sampling - [mistralai/Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) - on single HPU or with tensor parallelism 2x HPU, BF16 datatype with random or greedy sampling + on single HPU or with tensor parallelism on 2x HPU, BF16 datatype + with random or greedy sampling - [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - with tensor parallelism 2x HPU, BF16 datatype with random or greedy sampling + with tensor parallelism on 2x HPU, BF16 datatype with random or + greedy sampling + +Performance Tuning +================ +Execution modes +----------------------------- +Currently in vLLM for HPU we support four execution modes, depending on +selected HPU PyTorch Bridge backend (via `PT_HPU_LAZY_MODE` environment +variable), and `--enforce-eager` flag. -Performance Tips -================ +| `PT_HPU_LAZY_MODE` | `enforce_eager` | execution mode | +|--- |--- |--- | +| 0 | 0 | torch.compile | +| 0 | 1 | PyTorch eager mode | +| 1 | 0 | HPU Graphs | +| 1 | 1 | PyTorch lazy mode | + + +> [!WARNING] +> In 1.17.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly +> experimental and should be only used for validating functional +> correctness. Their performance will be improved in the next releases. +> For obtaining the best performance in 1.17.0, please use HPU Graphs, or +> PyTorch lazy mode. + +Bucketing mechanism +----------------------------- + +Intel Gaudi accelerators work best when operating on models with fixed +tensor shapes. [Intel Gaudi Graph +Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime) +is responsible for generating optimized binary code that implements the +given model topology on Gaudi. In its default configuration, the +produced binary code may be heavily dependent on input and output tensor +shapes, and can require graph recompilation when encountering +differently shaped tensors within the same topology. While the resulting +binaries utilize Gaudi efficiently, the compilation itself may introduce +a noticeable overhead in end-to-end execution. In a dynamic inference +serving scenario, there is a need to minimize the number of graph +compilations and reduce the risk of graph compilation occurring during +server runtime. Currently it is achieved by \"bucketing\" model\'s +forward pass across two dimensions - `batch_size` and `sequence_length`. + +> [!NOTE] +> Bucketing allows us to reduce the number of required graphs +> significantly, but it does not handle any graph compilation and device +> code generation - this is done in warmup and HPUGraph capture phase. + +Bucketing ranges are determined with 3 parameters - `min`, `step` and +`max`. They can be set separately for prompt and decode phase, and for +batch size and sequence length dimension. These parameters can be +observed in logs during vLLM startup: + +``` {.} +INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] +INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] +INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] +INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +``` -- We recommend running inference on Gaudi 2 with - `block_size` of 128 for BF16 data type. Using default - values (16, 32) might lead to sub-optimal performance due to Matrix - Multiplication Engine under-utilization (see [Gaudi +`min` determines the lowest value of the bucket. `step` determines the +interval between buckets, and `max` determines the upper bound of the +bucket. Furthermore, interval between `min` and `step` has special +handling - `min` gets multiplied by consecutive powers of two, until +`step` gets reached. We call this the ramp-up phase and it is used for +handling lower batch sizes with minimum wastage, while allowing larger +padding on larger batch sizes. + +Example (with ramp-up) + +``` {.} +min = 2, step = 32, max = 64 +=> ramp_up = (2, 4, 8, 16) +=> stable = (32, 64) +=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) +``` + +Example (without ramp-up) + +``` {.} +min = 128, step = 128, max = 512 +=> ramp_up = () +=> stable = (128, 256, 384, 512) +=> buckets = ramp_up + stable => (128, 256, 384, 512) +``` + +In the logged scenario, 24 buckets were generated for prompt (prefill) +runs, and 48 buckets for decode runs. Each bucket corresponds to a +separate optimized device binary for a given model with specified tensor +shapes. Whenever a batch of requests is processed, it is padded across +batch and sequence length dimension to the smallest possible bucket. + +> [!WARNING] +> If a request exceeds maximum bucket size in any dimension, it will be +> processed without padding, and its processing may require a graph +> compilation, potentially significantly increasing end-to-end latency. +> The boundaries of the buckets are user-configurable via environment +> variables, and upper bucket boundaries can be increased to avoid such +> scenario. + +As an example, if a request of 3 sequences, with max sequence length of +412 comes in to an idle vLLM server, it will be padded executed as +`(4, 512)` prefill bucket, as `batch_size` (number of sequences) will be +padded to 4 (closest batch\_size dimension higher than 3), and max +sequence length will be padded to 512 (closest sequence length dimension +higher than 412). After prefill stage, it will be executed as `(4, 512)` +decode bucket and will continue as that bucket until either batch +dimension changes (due to request being finished) - in which case it +will become a `(2, 512)` bucket, or context length increases above 512 +tokens, in which case it will become `(4, 640)` bucket. + +> [!NOTE] +> Bucketing is transparent to a client - padding in sequence length +> dimension is never returned to the client, and padding in batch +> dimension does not create new requests. + +Warmup +------ + +Warmup is an optional, but highly recommended step occurring before vLLM +server starts listening. It executes a forward pass for each bucket with +dummy data. The goal is to pre-compile all graphs and not incur any +graph compilation overheads within bucket boundaries during server +runtime. Each warmup step is logged during vLLM startup: + +``` {.} +INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB +INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB +INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB +... +INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB +INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB +INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB +... +INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB +INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB +``` + +This example uses the same buckets as in *Bucketing mechanism* section. +Each output line corresponds to execution of a single bucket. When +bucket is executed for the first time, its graph is compiled and can be +reused later on, skipping further graph compilations. + +> [!TIP] +> Compiling all the buckets might take some time and can be turned off +> with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if +> you do that, you may face graph compilations once executing a given +> bucket for the first time. It is fine to disable warmup for development, +> but it\'s highly recommended to enable it in deployment. + +HPU Graph capture +----------------------------- + +[HPU +Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) +are currently the most performant execution method of vLLM on Intel +Gaudi. When HPU Graphs are enabled, execution graphs will be traced +(recorded) ahead of time (after performing warmup), to be later replayed +during inference, significantly reducing host overheads. Recording can +take large amounts of memory, which needs to be taken into account when +allocating KV cache. Enabling HPU Graphs will impact the number of +available KV cache blocks, but vLLM provides user-configurable variables +to control memory management. + +When HPU Graphs are being used, they share the common memory pool +(\"usable memory\") as KV cache, determined by `gpu_memory_utilization` +flag (`0.9` by default). Before KV cache gets allocated, model weights +are loaded onto the device, and a forward pass of the model is executed +on dummy data, to estimate memory usage. Only after that, +`gpu_memory_utilization` flag is utilized - at its default value, will +mark 90% of free device memory at that point as usable. Next, KV cache +gets allocated, model is warmed up, and HPU Graphs are captured. +Environment variable `VLLM_GRAPH_RESERVED_MEM` defines the ratio of +memory reserved for HPU Graphs capture. With its default value +(`VLLM_GRAPH_RESERVED_MEM=0.4`), 40% of usable memory will be reserved +for graph capture (later referred to as \"usable graph memory\"), and +the remaining 60% will be utilized for KV cache. Environment variable +`VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory +reserved for prefill and decode graphs. By default +(`VLLM_GRAPH_PROMPT_RATIO=0.5`), both stages have equal memory +constraints. Lower value corresponds to less usable graph memory +reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will +reserve 20% of usable graph memory for prefill graphs, and 80% of usable +graph memory for decode graphs. + +> [!NOTE] +> `gpu_memory_utilization` does not correspond to the absolute memory +> usage across HPU. It specifies the memory margin after loading the model +> and performing a profile run. If device has 100 GiB of total memory, and +> 50 GiB of free memory after loading model weights and executing +> profiling run, `gpu_memory_utilization` at its default value will mark +> 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total +> device memory. + +User can also configure the strategy for capturing HPU Graphs for prompt +and decode stages separately. Strategy affects the order of capturing +graphs. There are two strategies implemented: - `max_bs` - graph capture +queue will sorted in descending order by their batch sizes. Buckets with +equal batch sizes are sorted by sequence length in ascending order (e.g. +`(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, +`(1,256)`), default strategy for decode - `min_tokens` - graph capture +queue will be sorted in ascending order by the number of tokens each +graph processes (`batch_size*sequence_length`), default strategy for +prompt + +When there\'s large amount of requests pending, vLLM scheduler will +attempt to fill the maximum batch size for decode as soon as possible. +When a request is finished, decode batch size decreases. When that +happens, vLLM will attempt to schedule a prefill iteration for requests +in the waiting queue, to fill the decode batch size to its previous +state. This means that in a full load scenario, decode batch size is +often at its maximum, which makes large batch size HPU Graphs crucial to +capture, as reflected by `max_bs` strategy. On the other hand, prefills +will be executed most frequently with very low batch sizes (1-4), which +is reflected in `min_tokens` strategy. + +> [!NOTE] +> `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by +> graphs for each stage (prefill and decode). vLLM will first attempt to +> use up entirety of usable prefill graph memory (usable graph memory \* +> `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it +> will attempt do the same for decode graphs and usable decode graph +> memory pool. If one stage is fully captured, and there is unused memory +> left within usable graph memory pool, vLLM will attempt further graph +> capture for the other stage, until no more HPU Graphs can be captured +> without exceeding reserved memory pool. The behavior on that mechanism +> can be observed in the example below. + +Each described step is logged by vLLM server, as follows (negative +values correspond to memory being released): + +``` {.} +INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] +INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] +INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] +INFO 08-02 17:37:44 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +INFO 08-02 17:37:52 habana_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:52 habana_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:52 habana_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:54 habana_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) +INFO 08-02 17:37:54 habana_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache +INFO 08-02 17:37:54 habana_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 +INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) +INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB +... +INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5) +INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB +... +INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB +INFO 08-02 17:38:27 habana_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB +... +INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB +INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB +INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB +INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB +INFO 08-02 17:38:43 habana_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB +INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] +INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory +INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) +``` + +Recommended vLLM Parameters +----------------------------- + +- We recommend running inference on Gaudi 2 with `block_size` of 128 + for BF16 data type. Using default values (16, 32) might lead to + sub-optimal performance due to Matrix Multiplication Engine + under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). - For max throughput on Llama 7B, we recommend running with batch size - of 128 or 256 and max context length of 2048 with HPU Graphs enabled. - If you encounter out-of-memory issues, see troubleshooting section. + of 128 or 256 and max context length of 2048 with HPU Graphs + enabled. If you encounter out-of-memory issues, see troubleshooting + section. + +Environment variables +----------------------------- + +**Diagnostic and profiling knobs:** + +- `VLLM_PROFILER_ENABLED`: if `true`, high level profiler will be + enabled. Resulting JSON traces can be viewed in + [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). Disabled + by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: if `true`, will log graph + compilations per each vLLM engine step, only when there was any - + highly recommended to use alongside `PT_HPU_METRICS_GC_DETAILS=1`. + Disabled by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: if `true`, will log graph + compilations per each vLLM engine step, always, even if there were + none. Disabled by default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: if `true`, will log cpu fallbacks + per each vLLM engine step, only when there was any. Disabled by + default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, will log cpu + fallbacks per each vLLM engine step, always, even if there were + none. Disabled by default. + +**Performance tuning knobs:** + +- `VLLM_SKIP_WARMUP`: if `true`, warmup will be skipped, `false` by + default +- `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for + HPUGraph capture, `0.4` by default +- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory + dedicated for prompt graphs, `0.5` by default +- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt + graph capture, `min_tokens` or `max_bs`, `min_tokens` by default +- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode + graph capture, `min_tokens` or `max_bs`, `max_bs` by default +- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment + variables configuring ranges of bucketing mechanism + - `{phase}` is either `PROMPT` or `DECODE` + - `{dim}` is either `BS` or `SEQ` + - `{param}` is either `MIN`, `STEP` or `MAX` + - Default values: + - Prompt: + - batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1` + - batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32` + - batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`): + `min(max_num_seqs, 64)` + - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): + `block_size` + - sequence length step + (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` + - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): + `1024` + + - Decode: + - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` + - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): + `128` + - batch size max (`VLLM_DECODE_BS_BUCKET_MAX`): + `max_num_seqs` + - sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`): + `block_size` + - sequence length step + (`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size` + - sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`): + `2048` + +Additionally, there are HPU PyTorch Bridge environment variables +impacting vLLM execution: + +- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be + used, if `1` PyTorch Lazy backend for Gaudi will be used, `1` is + default +- `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor + parallel inference with HPU Graphs Troubleshooting: Tweaking HPU Graphs ==================================== -If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following the below: - -- Tweak `gpu_memory_utilization` knob. It - will decrease the allocation of KV cache, leaving some headroom for - capturing graphs with larger batch size. By default `gpu_memory_utilization` is set to 0.9. - It attempts to allocate \~90% of HBM left for KV cache after short - profiling run. Note that decreasing reduces the number of KV - cache blocks you have available, and therefore reduces the effective - maximum number of tokens you can handle at a given time. - -- If this method is not efficient, you can disable `HPUGraph` completely. With - HPU Graphs disabled, you are trading latency and throughput at lower - batches for potentially higher throughput on higher batches. You can do - that by adding `--enforce-eager` flag to server (for - online inference), or by passing `enforce_eager=True` - argument to LLM constructor (for offline inference). +If you experience device out-of-memory issues or want to attempt +inference at higher batch sizes, try tweaking HPU Graphs by following +the below: + +- Tweak `gpu_memory_utilization` knob. It will decrease the allocation + of KV cache, leaving some headroom for capturing graphs with larger + batch size. By default `gpu_memory_utilization` is set to 0.9. It + attempts to allocate \~90% of HBM left for KV cache after short + profiling run. Note that decreasing reduces the number of KV cache + blocks you have available, and therefore reduces the effective + maximum number of tokens you can handle at a given time. +- If this method is not efficient, you can disable `HPUGraph` + completely. With HPU Graphs disabled, you are trading latency and + throughput at lower batches for potentially higher throughput on + higher batches. You can do that by adding `--enforce-eager` flag to + server (for online inference), or by passing `enforce_eager=True` + argument to LLM constructor (for offline inference). diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index a9f3ebdf274f6..b3234d10b3115 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -18,7 +18,7 @@ Requirements - OS: Ubuntu 22.04 LTS - Python: 3.10 - Intel Gaudi accelerator -- Intel Gaudi software version 1.16.0 or newer +- Intel Gaudi software version 1.17.0 To verify that the Intel Gaudi software was correctly installed, run: @@ -26,10 +26,11 @@ To verify that the Intel Gaudi software was correctly installed, run: $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed - $ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed + $ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed + $ pip list | grep neural # verify that neural_compressor is installed Refer to `Intel Gaudi Software Stack -Verification `__ +Verification `__ for more details. Run Docker Image @@ -44,21 +45,12 @@ Use the following commands to run a Docker image: .. code:: console - $ docker pull vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest - $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + $ docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest + $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest Build and Install vLLM --------------------------- -To build and install vLLM from source, run: - -.. code:: console - - $ git clone https://github.com/vllm-project/vllm.git - $ cd vllm - $ python setup.py develop - - Currently, the latest features and performance optimizations are developed in Gaudi's `vLLM-fork `__ and we periodically upstream them to vLLM main repo. To install latest `HabanaAI/vLLM-fork `__, run the following: .. code:: console @@ -66,16 +58,16 @@ Currently, the latest features and performance optimizations are developed in Ga $ git clone https://github.com/HabanaAI/vllm-fork.git $ cd vllm-fork $ git checkout habana_main - $ python setup.py develop + $ pip install -e . Supported Features ================== - `Offline batched - inference `__ + inference `__ - Online inference via `OpenAI-Compatible - Server `__ + Server `__ - HPU autodetection - no need to manually select device within vLLM - Paged KV cache with algorithms enabled for Intel Gaudi accelerators - Custom Intel Gaudi implementations of Paged Attention, KV cache ops, @@ -112,6 +104,12 @@ Gaudi2 devices. Configurations that are not listed may or may not work. - `meta-llama/Meta-Llama-3-8B-Instruct `__ on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B-Instruct `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling - `meta-llama/Llama-2-70b `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Llama-2-70b-chat-hf `__ @@ -120,14 +118,187 @@ Gaudi2 devices. Configurations that are not listed may or may not work. with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `meta-llama/Meta-Llama-3-70B-Instruct `__ with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B-Instruct `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling - `mistralai/Mistral-7B-Instruct-v0.3 `__ on single HPU or with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling - `mistralai/Mixtral-8x7B-Instruct-v0.1 `__ with tensor parallelism on 2x HPU, BF16 datatype with random or greedy sampling -Performance Tips +Performance Tuning ================ +Execution modes +------------ + +Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag. + +.. list-table:: vLLM execution modes + :widths: 25 25 50 + :header-rows: 1 + + * - ``PT_HPU_LAZY_MODE`` + - ``enforce_eager`` + - execution mode + * - 0 + - 0 + - torch.compile + * - 0 + - 1 + - PyTorch eager mode + * - 1 + - 0 + - HPU Graphs + * - 1 + - 1 + - PyTorch lazy mode + +.. warning:: + In 1.17.0, all modes utilizing ``PT_HPU_LAZY_MODE=0`` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.17.0, please use HPU Graphs, or PyTorch lazy mode. + + +Bucketing mechanism +------------ + +Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler `__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. +In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. + +.. note:: + Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. + +Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: + +.. code-block:: + + INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + +``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes. + +Example (with ramp-up) + +.. code-block:: + + min = 2, step = 32, max = 64 + => ramp_up = (2, 4, 8, 16) + => stable = (32, 64) + => buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) + +Example (without ramp-up) + +.. code-block:: + + min = 128, step = 128, max = 512 + => ramp_up = () + => stable = (128, 256, 384, 512) + => buckets = ramp_up + stable => (128, 256, 384, 512) + + +In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. + +.. warning:: + If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. + +As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket. + +.. note:: + Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. + +Warmup +------------ + +Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: + +.. code-block:: + + INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB + INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB + INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB + ... + INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB + INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB + INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB + ... + INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + +This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. + +.. tip:: + Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. + +HPU Graph capture +------------ + +`HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. + + +When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). +Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. +Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. +Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. +Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. +With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. +Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. +Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. + +.. note:: + ``gpu_memory_utilization`` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, ``gpu_memory_utilization`` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. + +User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: +- ``max_bs`` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. ``(64, 128)``, ``(64, 256)``, ``(32, 128)``, ``(32, 256)``, ``(1, 128)``, ``(1,256)``), default strategy for decode +- ``min_tokens`` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (``batch_size*sequence_length``), default strategy for prompt + +When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by ``max_bs`` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in ``min_tokens`` strategy. + + +.. note:: + ``VLLM_GRAPH_PROMPT_RATIO`` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * ``VLLM_GRAPH_PROMPT_RATIO``) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. + + +Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): + +.. code-block:: + + INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-02 17:37:44 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:37:52 habana_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 habana_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 habana_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache + INFO 08-02 17:37:54 habana_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 + INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB + ... + INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5) + INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + ... + INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB + INFO 08-02 17:38:27 habana_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB + ... + INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB + INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB + INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB + INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB + INFO 08-02 17:38:43 habana_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB + INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] + INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory + INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) + + +Recommended vLLM Parameters +------------ + - We recommend running inference on Gaudi 2 with ``block_size`` of 128 for BF16 data type. Using default values (16, 32) might lead to sub-optimal performance due to Matrix Multiplication Engine @@ -137,6 +308,53 @@ Performance Tips of 128 or 256 and max context length of 2048 with HPU Graphs enabled. If you encounter out-of-memory issues, see troubleshooting section. +Environment variables +------------ + +**Diagnostic and profiling knobs:** + +- ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai `__. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS=1``. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. + +**Performance tuning knobs:** + +- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default +- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.4`` by default +- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default +- ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default +- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default +- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism + + - ``{phase}`` is either ``PROMPT`` or ``DECODE`` + - ``{dim}`` is either ``BS`` or ``SEQ`` + - ``{param}`` is either ``MIN``, ``STEP`` or ``MAX`` + - Default values: + + - Prompt: + - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32`` + - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` + - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024`` + + - Decode: + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128`` + - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` + - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048`` + + +Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: + +- ``PT_HPU_LAZY_MODE``: if ``0``, PyTorch Eager backend for Gaudi will be used, if ``1`` PyTorch Lazy backend for Gaudi will be used, ``1`` is default +- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPU Graphs + Troubleshooting: Tweaking HPU Graphs ==================================== diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 89bdc247c5e8e..8cfde76adf5fa 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -9,7 +9,7 @@ This guide shows how to use vLLM to: * build an API server for a large language model; * start an OpenAI-compatible API server. -Be sure to complete the :ref:`installation instructions ` before continuing with this guide. +Be sure to complete the `Gaudi installation instructions `_ before continuing with this guide. .. note:: diff --git a/examples/lora_inference_hpu.py b/examples/lora_inference_hpu.py new file mode 100644 index 0000000000000..b8154a29a82bb --- /dev/null +++ b/examples/lora_inference_hpu.py @@ -0,0 +1,47 @@ +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + +llm = LLM(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_num_seqs=2, + dtype='bfloat16') + +sampling_params = SamplingParams(temperature=0, + max_tokens=1024, + stop=["[/assistant]"]) + +prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 +] + +expected_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501 + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 +] + +outputs = llm.generate(prompts, + sampling_params, + lora_request=LoRARequest("sql_adapter", 1, + sql_lora_path)) + +for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + match = expected_output[i] == generated_text + if not match: + print( + f"Comparison failed for request_id::{i}\n\t[PROMPT]{prompt!r}\n\t[GENERATED]{generated_text!r}\n\t[EXPECTED]{expected_output[i]!r}" # noqa: E501 + ) diff --git a/tests/conftest.py b/tests/conftest.py index 59510075b0063..cfb7cf56b519a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -590,9 +590,17 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog +def is_hpu(): + from importlib import util + return util.find_spec('habana_frameworks') is not None + + @pytest.fixture(scope="session") def num_gpus_available(): """Get number of GPUs without initializing the CUDA context in current process.""" + if is_hpu(): + return torch.hpu.device_count() + return cuda_device_count_stateless() diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 0bcae5b0c96dc..3e4c8be6dbaa3 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -48,13 +48,19 @@ class ContextInfo(TypedDict): }] +def is_hpu(): + from importlib import util + return util.find_spec('habana_frameworks') is not None + + def cleanup(): destroy_model_parallel() destroy_distributed_environment() with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() - torch.cuda.empty_cache() + if not is_hpu(): + torch.cuda.empty_cache() ray.shutdown() diff --git a/tests/lora/test_llama_hpu.py b/tests/lora/test_llama_hpu.py new file mode 100644 index 0000000000000..dfd551f2ca043 --- /dev/null +++ b/tests/lora/test_llama_hpu.py @@ -0,0 +1,100 @@ +from multiprocessing import Process +from typing import List + +from conftest import cleanup + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def _test_llama_lora(sql_lora_files, tp_size): + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + dtype='float32', + tensor_parallel_size=tp_size) + + expected_no_lora_output = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 + ] + expected_lora_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 + ] + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output + + print("removing lora") + cleanup() + + +def test_llama_lora_1x(sql_lora_files): + p = Process(target=_test_llama_lora, args=(sql_lora_files, 1)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_lora_2x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_lora, args=(sql_lora_files, 2)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_lora_4x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_lora, args=(sql_lora_files, 4)) + p.start() + p.join() + assert p.exitcode == 0 diff --git a/tests/lora/test_lora_hpu.py b/tests/lora/test_lora_hpu.py new file mode 100644 index 0000000000000..ddbab66e166b3 --- /dev/null +++ b/tests/lora/test_lora_hpu.py @@ -0,0 +1,221 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.bfloat16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} +MAX_LORAS = 8 + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="hpu", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(MAX_LORAS + 1, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="hpu", + dtype=dtype) + lora_b_stack = torch.zeros(MAX_LORAS + 1, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="hpu", + dtype=dtype) + for i in range(MAX_LORAS): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="hpu", dtype=dtype) + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), + output) + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="hpu"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="hpu", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + for i in range(MAX_LORAS): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="hpu", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, + (m // 2, m // 2)) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="hpu"), + output, (m // 2, m // 2)) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="hpu", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="hpu", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="hpu", + dtype=dtype) + ] + [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="hpu", + dtype=dtype) + ] + [ + torch.zeros(MAX_LORAS + 1, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + for i in range(MAX_LORAS): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, + (qkv[0], qkv[1], qkv[2])) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="hpu"), + output, (qkv[0], qkv[1], qkv[2])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_multilora_hpu.py b/tests/lora/test_multilora_hpu.py new file mode 100644 index 0000000000000..edca64fd5a2ae --- /dev/null +++ b/tests/lora/test_multilora_hpu.py @@ -0,0 +1,130 @@ +from multiprocessing import Process +from typing import List, Optional, Tuple + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + + +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + """ + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + presence_penalty=0.2, + max_tokens=128), None), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + result = {} + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + result[ + request_output.request_id] = request_output.outputs[0].text + return result + + +expected_output = [ + " or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501 + " that is the question.\nIt is the most famous line in all of Shakespeare's plays and one of the most famous in English literature. The question is not whether or not to be, but rather the question of who to be.\nIn Hamlet's case, the question is whether or not to be a good person. He is torn between the goodness of his father and the evil of his mother.\nThe question is a difficult one, and one that has been asked many times before. In Hamlet's case, the question is whether or not to be a good person, and he is torn between the", # noqa: E501 + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' " # noqa: E501 +] + + +def _test_llama_multilora(sql_lora_files, tp_size): + """Main function that sets up and runs the prompt processing.""" + engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_loras=2, + max_lora_rank=8, + max_num_seqs=16, + dtype='float32', + tensor_parallel_size=tp_size) + engine = LLMEngine.from_engine_args(engine_args) + test_prompts = create_test_prompts(sql_lora_files) + results = process_requests(engine, test_prompts) + generated_texts = [results[key] for key in sorted(results)] + assert generated_texts == expected_output + + +def test_llama_multilora_1x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_multilora, args=(sql_lora_files, 1)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_multilora_2x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_multilora, args=(sql_lora_files, 2)) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_llama_multilora_4x(sql_lora_files): + # Work-around to resolve stalling issue in multi-card scenario + p = Process(target=_test_llama_multilora, args=(sql_lora_files, 4)) + p.start() + p.join() + assert p.exitcode == 0 diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b73cf5bf55324..6ed985e72e6b3 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -3,6 +3,7 @@ import torch from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.utils import get_device class DummyLoRAManager: @@ -28,16 +29,16 @@ def init_random_lora(self, lora_alpha=1, lora_a=torch.rand([weight.shape[1], rank], dtype=weight.dtype, - device="cuda"), + device=get_device()), lora_b=torch.rand([rank, weight.shape[0]], dtype=weight.dtype, - device="cuda"), + device=get_device()), ) if generate_embeddings_tensor: lora.embeddings_tensor = torch.rand(5, generate_embeddings_tensor, dtype=weight.dtype, - device="cuda") + device=get_device()) self.set_module_lora(module_name, lora) return lora @@ -53,8 +54,8 @@ def init_lora(self, module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([input_dim, rank], device="cuda"), - lora_b=torch.rand([rank, output_dim], device="cuda"), + lora_a=torch.rand([input_dim, rank], device=get_device()), + lora_b=torch.rand([rank, output_dim], device=get_device()), embeddings_tensor=embeddings_tensor, ) self.set_module_lora(module_name, lora) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 33b6e2e538b13..2259630fa10b7 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -2,6 +2,7 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### +import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type @@ -12,6 +13,8 @@ AttentionMetadata, AttentionType) from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) +from vllm.hpu import cache_ops +from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.logger import init_logger logger = init_logger(__name__) @@ -108,7 +111,7 @@ def __post_init__(self): self.attn_bias: Optional[torch.Tensor] = None -class HabanaAttentionImpl(AttentionImpl): +class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -137,10 +140,16 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, ) -> None: + super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.position_bias = None @@ -158,6 +167,12 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + if self.prefill_usefusedsdpa: + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + suppored_head_sizes = HabanaPagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( @@ -204,22 +219,29 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - HabanaPagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, attn_metadata.is_prompt) + num_kv_cache_passes, num_slots_available, indices, offsets = \ + cache_ops.prepare_to_cache(key_cache, + attn_metadata.slot_mapping) + key_cache = self.k_cache(key, key_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) + value_cache = self.v_cache(value, value_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) if attn_metadata.is_prompt: # Prompt run. if kv_cache is None or attn_metadata.block_tables.numel() == 0: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ - 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ + 'attn_bias must be set before calling model.forward!' + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None and \ + self.position_bias is not None: + attn_bias.add_(self.position_bias[:, :, + -attn_bias.size(2):, + -attn_bias.size(3):]) + else: + attn_bias = None query_shape = (batch_size, seq_len, self.num_heads, self.head_size) @@ -232,6 +254,10 @@ def forward( attn_bias=attn_bias, p=0.0, scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + valid_seq_lengths=attn_metadata.seq_lens_tensor, ) output = out.reshape(batch_size, seq_len, hidden_size) else: @@ -255,7 +281,8 @@ def forward( query, key_cache, value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, self.kv_cache_dtype, self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale) + v_scale, self.matmul_qk, self.softmax, self.matmul_av, + self.k_cache, self.v_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 7dd701c7a0cdf..9602886299c47 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -75,6 +75,11 @@ def forward_decode( alibi_slopes: Optional[torch.Tensor], k_scale: float, v_scale: float, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) -> torch.Tensor: block_size = value_cache.shape[1] return ops.paged_attention_v1( @@ -88,6 +93,11 @@ def forward_decode( block_size, alibi_slopes, kv_cache_dtype, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) @staticmethod diff --git a/vllm/config.py b/vllm/config.py index f16bea16fe646..6acb70ad047b2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -474,12 +474,13 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor") + "scaling factor. " + "Intel Gaudi (HPU) supports fp8 (using fp8_inc).") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -600,11 +601,12 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - + device: Device on which weights are loaded. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO download_dir: Optional[str] = None + device: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e4b223a1b505f..d6c544750afea 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -38,6 +38,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + weights_load_device: Optional[str] = None dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -205,6 +206,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument("--weights-load-device", + type=str, + default=EngineArgs.weights_load_device, + choices=["cuda", "neuron", "hpu", "cpu"], + help='Device on which weights are loaded.') parser.add_argument( '--dtype', type=str, @@ -223,11 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). ' + 'Intel Gaudi (HPU) supports fp8 (using fp8_inc).') parser.add_argument( '--quantization-param-path', type=nullable_str, @@ -835,9 +842,12 @@ def create_engine_config(self, ) -> EngineConfig: self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + device = device_config.device if self.weights_load_device is None else \ + self.weights_load_device load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + device=device, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f7e0a7a4dc53..f8b9c48bc9589 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -182,7 +182,7 @@ def __init__( "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "pipeline_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " + "weights_load_device=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " @@ -206,6 +206,7 @@ def __init__( parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, model_config.quantization, + load_config.device, model_config.enforce_eager, cache_config.cache_dtype, model_config.quantization_param_path, @@ -853,6 +854,9 @@ def _process_model_outputs( request_outputs.append(request_output) return request_outputs + def finish_measurements(self): + self.model_executor.finish_measurements() + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1d..fc9f118ff14b2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -173,6 +173,9 @@ def set_tokenizer( self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( tokenizer) + def finish_measurements(self): + self.llm_engine.finish_measurements() + @overload # LEGACY: single (prompt + optional token ids) def generate( self, diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index f5cf26b687053..baeaec5afa371 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -90,6 +90,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: msg = f"init_cache_engine took {cache_init_m.get_summary_string()}" logger.info(msg) + def finish_measurements(self): + self.driver_worker.finish_measurements() + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: @@ -151,35 +154,48 @@ def execute_model( return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.driver_worker.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") - - def list_loras(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. return + def shutdown(self) -> None: + self.driver_worker.shutdown_inc() + + def __del__(self): + self.shutdown() + class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index c45513e3e5c91..2a8e2df37f031 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -239,6 +239,9 @@ def _driver_execute_model( return self.driver_worker.execute_method("execute_model", execute_model_req) + def finish_measurements(self): + self._run_workers("finish_measurements") + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 481da6403d73d..e59ab02bd0b45 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -46,6 +46,37 @@ def reshape_and_cache(key, value[start_idx:end_idx]) +def prepare_to_cache(cache, slot_mapping): + num_blocks = cache.size(0) + block_size = cache.size(1) + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + num_slots_requested = slot_mapping.size(0) + num_slots_available = num_blocks * block_size + # NOTE(kzawora): HPU PT bridge crashes with + # RuntimeError: Invalid inputs for scatter_nd_onnx + # on index_put when num_slots_requested > num_slots_available. + # This case might occur when we have little kv cache blocks and + # lots of padding, or are doing warmup. + # This loop is a workaround for this issue. Please remove it + # once key_cache.index_put_(indices, offsets), key) works. + num_kv_cache_passes = torch.div(num_slots_requested, + num_slots_available).ceil().int().item() + + return num_kv_cache_passes, num_slots_available, indices, offsets + + +def insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, block_offsets): + for i in range(num_kv_cache_passes): + start_idx = i * num_slots_available + end_idx = (i + 1) * num_slots_available + cache.index_put_((block_indices[start_idx:end_idx], + block_offsets[start_idx:end_idx]), + input[start_idx:end_idx]) + + def swap_blocks(src, dst, block_mapping): index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device) index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 1dcc3e64d05c2..be082a5a4f2a4 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -15,7 +15,6 @@ import torch import torch.nn.functional as F -import vllm.hpu.utils as hpu_utils from vllm.logger import init_logger logger = init_logger(__name__) @@ -26,6 +25,13 @@ except ImportError: logger.warning("Could not import HPU FusedRMSNorm kernel. " "vLLM will use forward_native implementation of RMSNorm.") +HPUFusedSDPA = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + HPUFusedSDPA = FusedSDPA +except ImportError: + logger.warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') @@ -37,7 +43,6 @@ def fetch_from_cache(cache, blocks, permutations): ] -@hpu_utils.with_mark_steps def paged_attention_v1(query, key_cache, value_cache, @@ -47,7 +52,12 @@ def paged_attention_v1(query, context_lens, block_size, alibi_slopes=None, - kv_cache_dtype=None) -> None: + kv_cache_dtype=None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + k_cache_cls=None, + v_cache_cls=None) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape _, _, kv_heads, _ = key_cache.shape @@ -60,19 +70,23 @@ def paged_attention_v1(query, batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) + fetch_keys = fetch_from_cache if k_cache_cls is None else \ + k_cache_cls.fetch_from_cache + keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1) + attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) - attn_weights = (attn_weights.masked_fill(mask, min_inf).softmax(dim=-1)) + attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) + fetch_values = fetch_from_cache if v_cache_cls is None else \ + v_cache_cls.fetch_from_cache + values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) if PA_SPLIT_VALUE: attn_weights = attn_weights.split(block_size, dim=-1) else: @@ -80,7 +94,7 @@ def paged_attention_v1(query, attn_weights = [attn_weights] if query_heads != kv_heads: values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)] + attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] if query_heads != kv_heads: attn_weights = [a.flatten(1, 2) for a in attn_weights] attn_weights = sum(attn_weights) @@ -123,7 +137,21 @@ def static_fused_moe(hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) -@hpu_utils.with_mark_steps +#TODO: remove after fusedsdpa fix for query_head != kv_head +def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The kv go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = kv.shape + if n_rep == 1: + return kv + kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, + head_dim) + return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -131,24 +159,114 @@ def prompt_attention( attn_bias: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + valid_seq_lengths: Optional[torch.Tensor] = None, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) query_heads = query.size(1) kv_heads = key.size(1) - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - key = key.unflatten(1, (kv_heads, 1)) - value = value.unflatten(1, (kv_heads, 1)) + if attn_bias is not None or HPUFusedSDPA is None: + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + if attn_bias is not None: + attn_bias = attn_bias.unsqueeze(2) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: - attn_bias = attn_bias.unsqueeze(2) - attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) - if attn_bias is not None: - attn_weights.add_(attn_bias) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = torch.matmul(attn_weights, value) - if query_heads != kv_heads: - attn_weights = attn_weights.flatten(1, 2) + attn_weights.add_(attn_bias) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = matmul_av_op(attn_weights, value) + if query_heads != kv_heads: + attn_weights = attn_weights.flatten(1, 2) + else: + #TODO: remove after fusedsdpa fix for query_heads != kv_heads + if query_heads != kv_heads: + key = repeat_kv(key, int(query_heads // kv_heads)) + value = repeat_kv(value, int(query_heads // kv_heads)) + softmax_mode = 'fast' + recompute_mode = True + attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True, + scale, softmax_mode, recompute_mode, + valid_seq_lengths, 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights + + +def dispatch_bgmv_linear( + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indices: torch.LongTensor, + layer_idx: int, + scale: float, +): + """ + `wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices + stacked into single tensors, assuming same rank. HPU handles no-LoRA + requests using zero valued A and B tensors. These zero valued tensors are + appended at the end of `wa_t_all` and `wb_t_all` during initialization. For + custom BGMV, the corresponding `wa` and `wb` for each batch is created + based on the lora_index of each sample. + + For example: + `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, + hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles + no-LoRA case. The `wa` tensor for a batch of size batch_Size will have + a shape of (batch_size, num_layers, hidden_dim, lora_rank) + + This method avoids for-loop as well as graph breaks. + """ + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + max_loras = wa_t_all.size(0) + # Wrap-around for negative indices + indices = indices % max_loras + wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out @ wb + out = out.squeeze(1) + y += out * scale + + +def dispatch_bgmv_embedding( + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + indices: torch.LongTensor, + layer_idx: int, + scale: float, +): + """ + `wa_t_all` contains all LoRA A weight matrices stacked into a single tensor + assuming same rank. HPU handles no-LoRA requests using zero valued A + tensor. This zero valued tensor is appended at the end of `wa_t_all` during + initialization. For custom BGMV, the corresponding wa for each batch is + created based on the lora_index of the sample. + + For example: + `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, + hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles + no-LoRA case. The wa tensor for a batch of size batch_Size will have a + shape of (batch_size, num_layers, lora_rank, hidden_dim) + + + This method avoids for-loop as well as graph breaks. + """ + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + max_loras = wa_t_all.size(0) + # Wrap-around for negative indices + indices = indices % max_loras + wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out.squeeze(1) + y += out * scale \ No newline at end of file diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 39c66d2e7e824..790cb864f329b 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -11,6 +11,10 @@ if current_platform.is_hpu(): import habana_frameworks.torch as htorch + +import torch + +from vllm.hpu.cache_ops import insert_or_update_cache def with_mark_steps(fn): @@ -25,3 +29,40 @@ def wrapped(*args, **kwargs): return result return wrapped + + +class Matmul(torch.nn.Module): + + def __init__(self): + super(Matmul, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class Softmax(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, inv_head=None): + return torch.softmax(x, dim) + + +class VLLMKVCache(torch.nn.Module): + + def __init__(self): + super(VLLMKVCache, self).__init__() + + def forward(self, input, cache, num_kv_cache_passes, num_slots_available, + block_indices, block_offset): + insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, + block_offset) + return cache + + def fetch_from_cache(self, cache, blocks, permutations): + return [ + cache.index_select(0, blocks[:, i]).permute(permutations) + for i in range(blocks.size(1)) + ] diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 87de285a373a2..4a45f3fda88f1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -27,6 +27,10 @@ LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.utils import is_hpu + +if is_hpu(): + from vllm.hpu.ops import dispatch_bgmv_embedding, dispatch_bgmv_linear if TYPE_CHECKING: pass @@ -89,7 +93,11 @@ def _apply_lora( x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + if is_hpu(): + dispatch_bgmv_linear(output, x, lora_a_stacked, lora_b_stacked, + indices, 0, 1.0) + else: + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) return output.view_as(org_output) @@ -127,9 +135,15 @@ def _apply_lora_packed_nslice( indices = indices.view(-1) offset_left = 0 for slice_idx in range(len(output_slices)): - add_lora_slice(output, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, - output_slices[slice_idx]) + if is_hpu(): + dispatch_bgmv_linear( + output[:, offset_left:offset_left + output_slices[slice_idx]], + x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], + indices, 0, 1.0) + else: + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, + offset_left, output_slices[slice_idx]) offset_left += output_slices[slice_idx] return output.view_as(org_output) @@ -330,8 +344,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], -1) - bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + if is_hpu(): + dispatch_bgmv_embedding(full_output, full_lora_a_embeddings, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + else: + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) return full_output.view_as(full_output_org) @classmethod diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e1ede7d4d710a..30d2fd9502977 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,7 +24,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.utils import is_pin_memory_available +from vllm.utils import get_device, is_hpu, is_pin_memory_available logger = init_logger(__name__) @@ -93,7 +93,7 @@ def convert_mapping( long_lora_offsets: Optional[torch.Tensor] = None if long_lora_context: long_lora_offsets = torch.zeros(len(index_mapping_indices), - device="cuda", + device=get_device(), dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 @@ -118,9 +118,9 @@ def convert_mapping( if long_lora_context: assert long_lora_offsets is not None indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + indices = torch.tensor(indices_list, dtype=torch.long, device=get_device()) prompt_mapping_tensor = torch.tensor(prompt_mapping, - device="cuda", + device=get_device(), dtype=torch.long) embeddings_indices = torch.stack([ indices[2] * extra_vocab_size, @@ -131,10 +131,10 @@ def convert_mapping( sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = ( - torch.arange( - 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + - (sampler_indices_padded * len(sampler_indices_padded))) + sampler_indices_padded = (torch.arange( + 0, len(sampler_indices_padded), device=get_device(), dtype=torch.long) + + (sampler_indices_padded * + len(sampler_indices_padded))) long_lora_indices = None long_lora_indices_len: Optional[int] = None if long_lora_context: @@ -424,20 +424,20 @@ def __init__( self.long_lora_context: Optional[LongContextLoRAContext] = None self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.sampler_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.embeddings_indices = torch.empty(2, self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) self.long_lora_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, - device="cuda") + device=get_device()) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} @@ -465,11 +465,25 @@ def __init__( @property def capacity(self) -> int: - return self.lora_config.max_cpu_loras + if is_hpu(): + # HPU handles no LoRA requests using zero valued A and B tensors. + # These zero valued tensors are appended at the end of A and B, + # making total number of loras to be lora_config.max_cpu_loras + 1. + # This demands the total number of max_cpu_loras to be + # lora_config.max_cpu_loras + 1 + return self.lora_config.max_cpu_loras + 1 + else: + return self.lora_config.max_cpu_loras @property def lora_slots(self) -> int: - return self.lora_config.max_loras + if is_hpu(): + # HPU handles no LoRA requests using zero valued A and B tensors. + # These zero valued tensors are appended at the end of A and B, + # making total number of loras to be lora_config.max_cpu_loras + 1. + return self.lora_config.max_loras + 1 + else: + return self.lora_config.max_loras @property def adapter_slots(self) -> int: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 55cbbabd7da44..c12668c14887d 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -79,18 +79,15 @@ def forward_hpu( if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: - orig_dtype = x.dtype orig_shape = x.shape residual += x.view(residual.shape) # Note: HPUFusedRMSNorm requires 3D tensors as inputs - x = HPUFusedRMSNorm.apply(residual.float(), self.weight.float(), + x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon) - return x.to(orig_dtype).view(orig_shape), residual + return x.view(orig_shape), residual - orig_dtype = x.dtype - x = HPUFusedRMSNorm.apply(x.float(), self.weight.float(), - self.variance_epsilon) - return x.to(orig_dtype) + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x def forward_xpu( self, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b6e280ae65049..10c8a95f838da 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -273,6 +273,7 @@ def __init__(self, quant_config, prefix) self.gather_output = gather_output + self.collective_func = tensor_model_parallel_all_gather # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() @@ -334,7 +335,7 @@ def forward(self, input_): output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) + output = self.collective_func(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None @@ -723,6 +724,7 @@ def __init__(self, self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results + self.collective_func = tensor_model_parallel_all_reduce # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() @@ -770,7 +772,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def forward(self, input_): + def resolve_input(self, input_): if self.input_is_parallel: input_parallel = input_ else: @@ -778,6 +780,10 @@ def forward(self, input_): splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[tp_rank].contiguous() + return input_parallel + + def forward(self, input_): + input_parallel = self.resolve_input(input_) # Matrix multiply. assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index bd574512e3431..7590d3e980275 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,6 +18,7 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.inc import INCConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -37,6 +38,7 @@ "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "inc": INCConfig, } diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py new file mode 100644 index 0000000000000..f6718ec2ac9e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/inc.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class INCConfig(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "INCConfig": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["INCLinearMethod"]: + if isinstance(layer, LinearBase): + return INCLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + +class INCLinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, + quant_config: INCConfig, + separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bbe49655020da..06048d97088e1 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -37,7 +37,7 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_tpu +from vllm.utils import is_hpu, is_tpu logger = init_logger(__name__) @@ -48,14 +48,15 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + if not is_hpu(): + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( @@ -276,10 +277,11 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with torch.device(self.load_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + logger.info("Loading weights on %s ...", self.load_config.device) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index fc4e13bbb0e68..3ae3c8c8f712c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,6 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -224,9 +225,14 @@ def forward( position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(len(self.h)): layer = self.h[i] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + if current_platform.is_hpu(): + htorch.core.mark_step() hidden_states = self.ln_f(hidden_states) return hidden_states diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1d..676a51ce67f96 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip @@ -317,6 +318,9 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( @@ -326,6 +330,8 @@ def forward( attn_metadata, residual, ) + if current_platform.is_hpu(): + htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/utils.py b/vllm/utils.py index 21f1b39d4c3dd..b6b0ac8285842 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -39,6 +39,7 @@ "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = { @@ -995,6 +996,12 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) +def get_device() -> str: + if is_hpu(): + return "hpu" + return "cuda" + + def error_on_invalid_device_count_status(): cache_entries = 0 with contextlib.suppress(Exception): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 950b896c3b1b6..f678d44f71dd3 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -91,9 +91,11 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. + dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else \ + self.dtype kv_cache.append( torch.zeros(kv_cache_shape, - dtype=self.dtype, + dtype=dtype, pin_memory=pin_memory, device=device)) return kv_cache diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 10473017f5334..4fa495769d844 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -98,10 +98,46 @@ def warmup_range(config: Tuple[int, int, int]): return list(ramp_up_tw) + list(stable) -def warmup_buckets(bs_bucket_config, seq_bucket_config): - buckets = itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config)) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) +def warmup_buckets(bs_bucket_config, seq_bucket_config, + max_num_batched_tokens): + buckets = list( + itertools.product(warmup_range(bs_bucket_config), + warmup_range(seq_bucket_config))) + if len(buckets) == 0: + msg = ("No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}") + raise ValueError(msg) + + # Remove buckets exceeding batch token budget + filtered_buckets = list( + filter(lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, + buckets)) + + if len(filtered_buckets) == 0: + # legacy case - we can handle this if we ignore max_num_batched_tokens + min_bucket_bs, min_bucket_seq = min(buckets, + key=lambda b: (b[0] * b[1])) + min_reqd_budget = min_bucket_bs * min_bucket_seq + msg = ( + "The current bucketing configuration " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config} cannot be used with specified " + f"max_num_batched_tokens ({max_num_batched_tokens}), as the " + f"smallest bucket ({min_reqd_budget}) would exceed token budget. " + "Please increase max_num_batched_tokens or decrease bucket minimum " + "Ignoring max_num_batched_tokens at risk of out-of-memory errors.") + logger.error(msg) + return list(sorted(buckets, key=lambda b: + (b[0] * b[1], b[1], b[0]))), [] + + captured_buckets = list( + sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + omitted_buckets = list( + sorted([x for x in buckets if x not in filtered_buckets])) + return captured_buckets, omitted_buckets def next_pow2(value: int): @@ -155,6 +191,9 @@ class HpuModelAdapter(): def __init__(self, model, enforce_eager): self.model = model + self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( ) and not enforce_eager: self.model = torch.compile(self.model, @@ -164,7 +203,7 @@ def __init__(self, model, enforce_eager): def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): prefill_metadata = attn_metadata - if prefill_metadata is None: + if prefill_metadata is None or self.prefill_use_fusedsdpa: return attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor @@ -187,8 +226,8 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') - if 'bypass_hpu_graphs' in kwargs: - kwargs.pop('bypass_hpu_graphs') # required for PT eager + if 'warmup_mode' in kwargs: + kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], input_ids.size(0), @@ -416,10 +455,14 @@ def __init__( # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() + self.seen_configs: set = set() self._mem_margin: Optional[int] = None self._setup_buckets() def load_model(self) -> None: + import habana_frameworks.torch.core as htcore + if self.model_config.quantization == 'inc': + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( @@ -436,6 +479,43 @@ def load_model(self) -> None: f"took {m_getmodel.get_summary_string()}") logger.info(msg) + if self.lora_config: + assert hasattr(self.model, "supported_lora_modules" + ) and self.model.supported_lora_modules, ( + "Model does not support LoRA") + assert hasattr(self.model, "embedding_modules" + ), "Model does not have embedding_modules" + assert hasattr( + self.model, "embedding_padding_modules" + ), "Model does not have embedding_padding_modules" + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, self.lora_config, self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.model_config.quantization == 'inc': + logger.info("Preparing model with INC..") + with HabanaMemoryProfiler() as m_inc: + from neural_compressor.torch.quantization import ( + FP8Config, convert, prepare) + config = FP8Config.from_json_file( + os.getenv("QUANT_CONFIG", "")) + if config.measure: + self.model = prepare(self.model, config) + elif config.quantize: + self.model = convert(self.model, config) + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) + else: + self.model = self.model.to("hpu") + htcore.mark_step() + torch.hpu.synchronize() + # FIXME: Running with disable_tensor_cache=True causes # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: @@ -448,35 +528,26 @@ def load_model(self) -> None: msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) - if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") - assert hasattr( - self.model, - "embedding_modules"), "Model does not have embedding_modules" - assert hasattr(self.model, "embedding_padding_modules" - ), "Model does not have embedding_padding_modules" - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.vocab_size, - self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) - self.model = self.lora_manager.create_lora_manager(self.model) - def _use_graphs(self, batch_size, seq_len, is_prompt): if self.enforce_eager: return False return (batch_size, seq_len, is_prompt) in self.graphed_buckets + def _is_valid_bucket(self, bucket): + return bucket[0] * bucket[1] <= self.max_num_batched_tokens + def _setup_buckets(self) -> None: + max_bucket_cfg = 64 + if self.lora_config and \ + max_bucket_cfg > self.max_num_batched_tokens // self.block_size: + max_bucket_cfg = self.max_num_batched_tokens // self.block_size self.prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, step=32, max=min( self.max_num_seqs, - 64)) + max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=1, @@ -498,23 +569,52 @@ def _setup_buckets(self) -> None: f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg) + self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + + if self.lora_config: + self.prompt_buckets[:] = [ + bucket for bucket in self.prompt_buckets + if self._is_valid_bucket(bucket) + ] msg = (f"Generated {len(self.prompt_buckets)} " f"prompt buckets: {list(sorted(self.prompt_buckets))}") logger.info(msg) + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) - self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) + self.decode_buckets, decode_omitted_buckets = warmup_buckets( + self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg, + self.max_num_batched_tokens) + if self.lora_config: + self.decode_buckets[:] = [ + bucket for bucket in self.decode_buckets + if self._is_valid_bucket(bucket) + ] msg = (f"Generated {len(self.decode_buckets)} decode buckets: " f"{list(sorted(self.decode_buckets))}") logger.info(msg) + msg = (f"Omitted {len(decode_omitted_buckets)} " + "decode buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted decode buckets: {list(sorted(decode_omitted_buckets))}" + logger.debug(msg) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -583,21 +683,10 @@ def _prepare_prompt( # actual prompt lens context_lens.append(context_len) query_lens.append(seq_len - context_len) - input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.append(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.append( - [lora_id] * - (seq_len - context_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( @@ -657,6 +746,19 @@ def _prepare_prompt( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) + for seq_group_metadata, context_len in zip(seq_group_metadata_list, + context_lens): + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (max_prompt_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (max_prompt_len - context_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0, @@ -918,8 +1020,13 @@ def prepare_input_tensors( paddings = [max_len - s for s in seq_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( - paddings, + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings) @@ -1010,7 +1117,11 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: ]) return attention_metadata - def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt): + def create_dummy_seq_group_metadata(self, + group_id, + seq_len, + is_prompt, + lora_request=None): sampling_params = SamplingParams(temperature=0) num_blocks = math.ceil(seq_len / self.block_size) if is_prompt: @@ -1025,44 +1136,119 @@ def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt): output_token_ids = [1] * output_len seq_data = SequenceData(prompt_token_ids) seq_data.output_token_ids = output_token_ids - return SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=(output_len == 0), - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=block_tables, - ) + return SequenceGroupMetadata(request_id=str(group_id), + is_prompt=(output_len == 0), + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=lora_request) def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] max_seq_len = self.prompt_seq_bucket_cfg[-1] - - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches) - - def warmup_scenario(self, batch_size, seq_len, is_prompt, - kv_caches) -> None: + if self.lora_config: + max_seq_len = self.max_num_batched_tokens // max_batch_size + + self.warmup_scenario(max_batch_size, + max_seq_len, + True, + kv_caches, + is_profile_run=True) + + def warmup_scenario(self, + batch_size, + seq_len, + is_prompt, + kv_caches, + is_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" f"seq{seq_len}_" f"graphs{'T' if use_graphs else 'F'}") + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config and is_profile_run: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs else 1 + if self.lora_config and not is_profile_run: + lora_mapping = LoRAMapping( + [0] * batch_size * seq_len, + [0] * batch_size * seq_len, + ) + self.set_active_loras(set(), lora_mapping) seqs = [ - self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) for i in range(batch_size) ] torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches) + self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() self.profiler.end() gc.collect() + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) @@ -1325,6 +1511,15 @@ def get_counter_dict(self, cache_config, duration, seq_len, return counters +def unwrap_model(model): + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + return unwrap_model(model._orig_mod) + else: + model = list(vars(model)['_modules'].values())[0] + modules = list(vars(model)['_modules'].values()) + return modules + + class HabanaModelRunner( HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): """ @@ -1372,6 +1567,19 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) + def finish_measurements(self): + from neural_compressor.torch.quantization import finalize_calibration + finalize_calibration(self.model.model) + + def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): + cfg = (batch_size, seq_len, is_prompt) + seen = cfg in self.seen_configs + self.seen_configs.add(cfg) + if not seen and not warmup_mode: + phase = 'prompt' if is_prompt else 'decode' + logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", + phase, batch_size, seq_len) + @torch.inference_mode() def execute_model( self, @@ -1379,14 +1587,17 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + warmup_mode=False, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( "num_steps > 1 is not supported in HabanaModelRunner") - # NOTE(kzawora): Need to restore this after adding LoRA - # if self.lora_config: - # self.set_active_loras(lora_requests, lora_mapping) + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata @@ -1403,6 +1614,7 @@ def execute_model( batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + self._check_config(batch_size, seq_len, is_prompt, warmup_mode) execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, @@ -1412,6 +1624,11 @@ def execute_model( } if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({ + "bypass_hpu_graphs": not use_graphs, + "warmup_mode": warmup_mode + }) htorch.core.mark_step() if self.is_driver_worker: @@ -1425,9 +1642,18 @@ def execute_model( with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata. - selected_token_indices, - bypass_hpu_graphs=not use_graphs) + selected_token_indices=sampling_metadata.selected_token_indices + ) + + if self.lora_config: + from vllm.lora.layers import VocabParallelEmbeddingWithLoRA + modules = unwrap_model(self.model.model) + for module in modules: + if isinstance(module, VocabParallelEmbeddingWithLoRA): + for i in range(0, len(module.indices_len)): + module.indices_len[ + i] = sampling_metadata.selected_token_indices.numel( + ) # Compute the logits. with self.profiler.record_event( @@ -1469,3 +1695,16 @@ def execute_model( is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) return [output] + + def shutdown_inc(self): + print('inc shutdown') + if (model_config := getattr(self, "model_config", None)) and \ + getattr(model_config, "quantization", None) == 'inc': + print('inc shutdown start') + from neural_compressor.torch.quantization import ( + finalize_calibration) + finalize_calibration(self.model.model) + print('inc shutdown') + + def __del__(self): + self.shutdown_inc() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 2e815b4fe1579..984d18eb2da3f 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -95,6 +95,16 @@ def __init__( # Initialize gpu_cache as embedding models don't initialize kv_caches self.hpu_cache: Optional[List[List[torch.tensor]]] = None + def _set_env_vars(self): + local_rank = self.local_rank + if self.parallel_config.world_size == 1: + local_rank = -1 + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["ID"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) + os.environ["RANK"] = str(self.rank) + def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") @@ -105,6 +115,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. + if self.model_config.quantization == 'inc': + self._set_env_vars() init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) @@ -173,9 +185,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_hpu_blocks = max(num_hpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - # NOTE(kzawora): Restore this once LoRA support is added - # if self.model_runner.lora_manager: - # self.model_runner.remove_all_loras() + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() gc.collect() return num_hpu_blocks, num_cpu_blocks @@ -222,6 +233,9 @@ def _warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def finish_measurements(self): + self.model_runner.finish_measurements() + @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @@ -275,29 +289,39 @@ def execute_worker(self, worker_input: WorkerInput) -> None: self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + return self.model_runner.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") - - def list_loras(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + return self.model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for HPU backend.") + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def shutdown_inc(self): + self.model_runner.shutdown_inc() + + def __del__(self): + self.shutdown_inc() @property def max_model_len(self) -> int: