From b0112c3a9a075e83f5bb98127586d925402f3614 Mon Sep 17 00:00:00 2001 From: Nir David <124874956+nirda7@users.noreply.github.com> Date: Wed, 14 Aug 2024 19:34:25 +0300 Subject: [PATCH] Support FP8 INC in vLLM (#144) FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) **BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--- README_GAUDI.md | 3 +- .../getting_started/gaudi-installation.rst | 3 +- vllm/attention/backends/habana_attn.py | 26 +++- vllm/attention/ops/habana_paged_attn.py | 10 ++ vllm/config.py | 8 +- vllm/engine/arg_utils.py | 14 ++- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/llm.py | 3 + vllm/executor/habana_executor.py | 9 ++ vllm/executor/ray_habana_executor.py | 3 + vllm/hpu/cache_ops.py | 31 +++++ vllm/hpu/ops.py | 33 +++-- vllm/hpu/utils.py | 40 ++++++ vllm/model_executor/layers/layernorm.py | 11 +- vllm/model_executor/layers/linear.py | 10 +- .../layers/quantization/__init__.py | 2 + .../model_executor/layers/quantization/inc.py | 115 ++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 22 ++-- vllm/model_executor/models/llama.py | 6 + vllm/utils.py | 1 + vllm/worker/cache_engine.py | 4 +- vllm/worker/habana_model_runner.py | 57 ++++++++- vllm/worker/habana_worker.py | 21 ++++ 23 files changed, 387 insertions(+), 51 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/inc.py diff --git a/README_GAUDI.md b/README_GAUDI.md index a569d6314acf8..9ea30a2e43f69 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -26,7 +26,8 @@ To verify that the Intel Gaudi software was correctly installed, run: ``` {.console} $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed -$ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed +$ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed +$ pip list | grep neural # verify that neural-compressor is installed ``` Refer to [Intel Gaudi Software Stack diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 7af291d62efc6..ddbac022a8d9d 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -26,7 +26,8 @@ To verify that the Intel Gaudi software was correctly installed, run: $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed - $ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed + $ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed + $ pip list | grep neural # verify that neural_compressor is installed Refer to `Intel Gaudi Software Stack Verification `__ diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 33b6e2e538b13..7a867e79b203d 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -12,6 +12,8 @@ AttentionMetadata, AttentionType) from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) +from vllm.hpu import cache_ops +from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.logger import init_logger logger = init_logger(__name__) @@ -108,7 +110,7 @@ def __post_init__(self): self.attn_bias: Optional[torch.Tensor] = None -class HabanaAttentionImpl(AttentionImpl): +class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -137,10 +139,16 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, ) -> None: + super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.position_bias = None @@ -204,9 +212,13 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - HabanaPagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, attn_metadata.is_prompt) + num_kv_cache_passes, num_slots_available, indices, offsets = \ + cache_ops.prepare_to_cache(key_cache, + attn_metadata.slot_mapping) + key_cache = self.k_cache(key, key_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) + value_cache = self.v_cache(value, value_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) if attn_metadata.is_prompt: # Prompt run. @@ -232,6 +244,9 @@ def forward( attn_bias=attn_bias, p=0.0, scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, ) output = out.reshape(batch_size, seq_len, hidden_size) else: @@ -255,7 +270,8 @@ def forward( query, key_cache, value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, self.kv_cache_dtype, self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale) + v_scale, self.matmul_qk, self.softmax, self.matmul_av, + self.k_cache, self.v_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 7dd701c7a0cdf..9602886299c47 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -75,6 +75,11 @@ def forward_decode( alibi_slopes: Optional[torch.Tensor], k_scale: float, v_scale: float, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) -> torch.Tensor: block_size = value_cache.shape[1] return ops.paged_attention_v1( @@ -88,6 +93,11 @@ def forward_decode( block_size, alibi_slopes, kv_cache_dtype, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) @staticmethod diff --git a/vllm/config.py b/vllm/config.py index f16bea16fe646..6acb70ad047b2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -474,12 +474,13 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor") + "scaling factor. " + "Intel Gaudi (HPU) supports fp8 (using fp8_inc).") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -600,11 +601,12 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - + device: Device on which weights are loaded. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO download_dir: Optional[str] = None + device: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e4b223a1b505f..d6c544750afea 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -38,6 +38,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + weights_load_device: Optional[str] = None dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -205,6 +206,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument("--weights-load-device", + type=str, + default=EngineArgs.weights_load_device, + choices=["cuda", "neuron", "hpu", "cpu"], + help='Device on which weights are loaded.') parser.add_argument( '--dtype', type=str, @@ -223,11 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). ' + 'Intel Gaudi (HPU) supports fp8 (using fp8_inc).') parser.add_argument( '--quantization-param-path', type=nullable_str, @@ -835,9 +842,12 @@ def create_engine_config(self, ) -> EngineConfig: self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + device = device_config.device if self.weights_load_device is None else \ + self.weights_load_device load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + device=device, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f7e0a7a4dc53..f8b9c48bc9589 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -182,7 +182,7 @@ def __init__( "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "pipeline_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " + "weights_load_device=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " @@ -206,6 +206,7 @@ def __init__( parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, model_config.quantization, + load_config.device, model_config.enforce_eager, cache_config.cache_dtype, model_config.quantization_param_path, @@ -853,6 +854,9 @@ def _process_model_outputs( request_outputs.append(request_output) return request_outputs + def finish_measurements(self): + self.model_executor.finish_measurements() + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1d..fc9f118ff14b2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -173,6 +173,9 @@ def set_tokenizer( self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( tokenizer) + def finish_measurements(self): + self.llm_engine.finish_measurements() + @overload # LEGACY: single (prompt + optional token ids) def generate( self, diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index f5cf26b687053..80f8037a2d043 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -90,6 +90,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: msg = f"init_cache_engine took {cache_init_m.get_summary_string()}" logger.info(msg) + def finish_measurements(self): + self.driver_worker.finish_measurements() + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: @@ -180,6 +183,12 @@ def check_health(self) -> None: # it's running. return + def shutdown(self) -> None: + self.driver_worker.shutdown_inc() + + def __del__(self): + self.shutdown() + class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index 9e0a89cbeb8aa..17e3414a96b57 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -237,6 +237,9 @@ def _driver_execute_model( return self.driver_worker.execute_method("execute_model", execute_model_req) + def finish_measurements(self): + self._run_workers("finish_measurements") + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 14824945aa53a..98f109accea06 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -43,6 +43,37 @@ def reshape_and_cache(key, value[start_idx:end_idx]) +def prepare_to_cache(cache, slot_mapping): + num_blocks = cache.size(0) + block_size = cache.size(1) + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + num_slots_requested = slot_mapping.size(0) + num_slots_available = num_blocks * block_size + # NOTE(kzawora): HPU PT bridge crashes with + # RuntimeError: Invalid inputs for scatter_nd_onnx + # on index_put when num_slots_requested > num_slots_available. + # This case might occur when we have little kv cache blocks and + # lots of padding, or are doing warmup. + # This loop is a workaround for this issue. Please remove it + # once key_cache.index_put_(indices, offsets), key) works. + num_kv_cache_passes = torch.div(num_slots_requested, + num_slots_available).ceil().int().item() + + return num_kv_cache_passes, num_slots_available, indices, offsets + + +def insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, block_offsets): + for i in range(num_kv_cache_passes): + start_idx = i * num_slots_available + end_idx = (i + 1) * num_slots_available + cache.index_put_((block_indices[start_idx:end_idx], + block_offsets[start_idx:end_idx]), + input[start_idx:end_idx]) + + def swap_blocks(src, dst, block_mapping): index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device) index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index c8f00c1cbd59d..23f6964723d3f 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -11,7 +11,6 @@ import torch import torch.nn.functional as F -import vllm.hpu.utils as hpu_utils from vllm.logger import init_logger logger = init_logger(__name__) @@ -33,7 +32,6 @@ def fetch_from_cache(cache, blocks, permutations): ] -@hpu_utils.with_mark_steps def paged_attention_v1(query, key_cache, value_cache, @@ -43,7 +41,12 @@ def paged_attention_v1(query, context_lens, block_size, alibi_slopes=None, - kv_cache_dtype=None) -> None: + kv_cache_dtype=None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + k_cache_cls=None, + v_cache_cls=None) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape _, _, kv_heads, _ = key_cache.shape @@ -56,19 +59,23 @@ def paged_attention_v1(query, batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) + fetch_keys = fetch_from_cache if k_cache_cls is None else \ + k_cache_cls.fetch_from_cache + keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1) + attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) - attn_weights = (attn_weights.masked_fill(mask, min_inf).softmax(dim=-1)) + attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) + fetch_values = fetch_from_cache if v_cache_cls is None else \ + v_cache_cls.fetch_from_cache + values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) if PA_SPLIT_VALUE: attn_weights = attn_weights.split(block_size, dim=-1) else: @@ -76,7 +83,7 @@ def paged_attention_v1(query, attn_weights = [attn_weights] if query_heads != kv_heads: values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)] + attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] if query_heads != kv_heads: attn_weights = [a.flatten(1, 2) for a in attn_weights] attn_weights = sum(attn_weights) @@ -119,7 +126,6 @@ def static_fused_moe(hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) -@hpu_utils.with_mark_steps def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -127,6 +133,9 @@ def prompt_attention( attn_bias: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -139,11 +148,11 @@ def prompt_attention( value = value.unflatten(1, (kv_heads, 1)) if attn_bias is not None: attn_bias = attn_bias.unsqueeze(2) - attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = torch.matmul(attn_weights, value) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = matmul_av_op(attn_weights, value) if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) attn_weights = attn_weights.transpose(1, 2) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index b7b435c50c295..3d9c7cb1c4c22 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -8,6 +8,9 @@ from functools import wraps import habana_frameworks.torch as htorch +import torch + +from vllm.hpu.cache_ops import insert_or_update_cache def with_mark_steps(fn): @@ -22,3 +25,40 @@ def wrapped(*args, **kwargs): return result return wrapped + + +class Matmul(torch.nn.Module): + + def __init__(self): + super(Matmul, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class Softmax(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, inv_head=None): + return torch.softmax(x, dim) + + +class VLLMKVCache(torch.nn.Module): + + def __init__(self): + super(VLLMKVCache, self).__init__() + + def forward(self, input, cache, num_kv_cache_passes, num_slots_available, + block_indices, block_offset): + insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, + block_offset) + return cache + + def fetch_from_cache(self, cache, blocks, permutations): + return [ + cache.index_select(0, blocks[:, i]).permute(permutations) + for i in range(blocks.size(1)) + ] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 55cbbabd7da44..c12668c14887d 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -79,18 +79,15 @@ def forward_hpu( if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: - orig_dtype = x.dtype orig_shape = x.shape residual += x.view(residual.shape) # Note: HPUFusedRMSNorm requires 3D tensors as inputs - x = HPUFusedRMSNorm.apply(residual.float(), self.weight.float(), + x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon) - return x.to(orig_dtype).view(orig_shape), residual + return x.view(orig_shape), residual - orig_dtype = x.dtype - x = HPUFusedRMSNorm.apply(x.float(), self.weight.float(), - self.variance_epsilon) - return x.to(orig_dtype) + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x def forward_xpu( self, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b6e280ae65049..10c8a95f838da 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -273,6 +273,7 @@ def __init__(self, quant_config, prefix) self.gather_output = gather_output + self.collective_func = tensor_model_parallel_all_gather # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() @@ -334,7 +335,7 @@ def forward(self, input_): output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) + output = self.collective_func(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None @@ -723,6 +724,7 @@ def __init__(self, self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results + self.collective_func = tensor_model_parallel_all_reduce # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() @@ -770,7 +772,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def forward(self, input_): + def resolve_input(self, input_): if self.input_is_parallel: input_parallel = input_ else: @@ -778,6 +780,10 @@ def forward(self, input_): splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[tp_rank].contiguous() + return input_parallel + + def forward(self, input_): + input_parallel = self.resolve_input(input_) # Matrix multiply. assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index bd574512e3431..7590d3e980275 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,6 +18,7 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.inc import INCConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -37,6 +38,7 @@ "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "inc": INCConfig, } diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py new file mode 100644 index 0000000000000..f6718ec2ac9e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/inc.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class INCConfig(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "INCConfig": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["INCLinearMethod"]: + if isinstance(layer, LinearBase): + return INCLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + +class INCLinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, + quant_config: INCConfig, + separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bbe49655020da..06048d97088e1 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -37,7 +37,7 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_tpu +from vllm.utils import is_hpu, is_tpu logger = init_logger(__name__) @@ -48,14 +48,15 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + if not is_hpu(): + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( @@ -276,10 +277,11 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with torch.device(self.load_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + logger.info("Loading weights on %s ...", self.load_config.device) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1d..676a51ce67f96 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip @@ -317,6 +318,9 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( @@ -326,6 +330,8 @@ def forward( attn_metadata, residual, ) + if current_platform.is_hpu(): + htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/utils.py b/vllm/utils.py index 8a1bc5de03eb7..fe84253feb172 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -39,6 +39,7 @@ "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 93be2f4c321fe..ec0b8c2369210 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -91,9 +91,11 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. + dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else \ + self.dtype kv_cache.append( torch.zeros(kv_cache_shape, - dtype=self.dtype, + dtype=dtype, pin_memory=pin_memory, device=device)) return kv_cache diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index cf91c69069ed6..72aba42ae8553 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -182,8 +182,8 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') - if 'bypass_hpu_graphs' in kwargs: - kwargs.pop('bypass_hpu_graphs') # required for PT eager + if 'warmup_mode' in kwargs: + kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], input_ids.size(0), @@ -413,6 +413,9 @@ def __init__( self._setup_buckets() def load_model(self) -> None: + import habana_frameworks.torch.core as htcore + if self.model_config.quantization == 'inc': + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( @@ -429,6 +432,26 @@ def load_model(self) -> None: f"took {m_getmodel.get_summary_string()}") logger.info(msg) + if self.model_config.quantization == 'inc': + logger.info("Preparing model with INC..") + with HabanaMemoryProfiler() as m_inc: + from neural_compressor.torch.quantization import ( + FP8Config, convert, prepare) + config = FP8Config.from_json_file( + os.getenv("QUANT_CONFIG", "")) + if config.measure: + self.model = prepare(self.model, config) + elif config.quantize: + self.model = convert(self.model, config) + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) + else: + self.model = self.model.to("hpu") + htcore.mark_step() + torch.hpu.synchronize() + # FIXME: Running with disable_tensor_cache=True causes # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: @@ -1051,7 +1074,7 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches) + self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() self.profiler.end() gc.collect() @@ -1362,6 +1385,10 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) + def finish_measurements(self): + from neural_compressor.torch.quantization import finalize_calibration + finalize_calibration(self.model.model) + @torch.inference_mode() def execute_model( self, @@ -1369,6 +1396,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + warmup_mode=False, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( @@ -1402,6 +1430,11 @@ def execute_model( } if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({ + "bypass_hpu_graphs": not use_graphs, + "warmup_mode": warmup_mode + }) htorch.core.mark_step() if self.is_driver_worker: @@ -1415,9 +1448,8 @@ def execute_model( with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata. - selected_token_indices, - bypass_hpu_graphs=not use_graphs) + selected_token_indices=sampling_metadata.selected_token_indices + ) # Compute the logits. with self.profiler.record_event( @@ -1459,3 +1491,16 @@ def execute_model( is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) return [output] + + def shutdown_inc(self): + print('inc shutdown') + if (model_config := getattr(self, "model_config", None)) and \ + getattr(model_config, "quantization", None) == 'inc': + print('inc shutdown start') + from neural_compressor.torch.quantization import ( + finalize_calibration) + finalize_calibration(self.model.model) + print('inc shutdown') + + def __del__(self): + self.shutdown_inc() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index f3fdc4dcc63c6..87122c03d3c8f 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -91,6 +91,16 @@ def __init__( # Initialize gpu_cache as embedding models don't initialize kv_caches self.hpu_cache: Optional[List[List[torch.tensor]]] = None + def _set_env_vars(self): + local_rank = self.local_rank + if self.parallel_config.world_size == 1: + local_rank = -1 + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["ID"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) + os.environ["RANK"] = str(self.rank) + def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") @@ -99,6 +109,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. + if self.model_config.quantization == 'inc': + self._set_env_vars() init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) @@ -211,6 +223,9 @@ def _warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def finish_measurements(self): + self.model_runner.finish_measurements() + @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @@ -288,6 +303,12 @@ def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> Set[int]: raise NotImplementedError("LoRA is not implemented for HPU backend.") + def shutdown_inc(self): + self.model_runner.shutdown_inc() + + def __del__(self): + self.shutdown_inc() + @property def max_model_len(self) -> int: return self.model_config.max_model_len