From f3f1f93b6af654771c20b943d556167f9765a8a8 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Fri, 30 Aug 2024 10:35:53 +0300 Subject: [PATCH 1/6] Port not warmed-up configurations log warnings --- vllm/worker/habana_model_runner.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a975dba6f5136..133706c18aed6 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -448,6 +448,7 @@ def __init__( # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() + self.seen_configs = set() self._mem_margin: Optional[int] = None self._setup_buckets() @@ -1560,6 +1561,14 @@ def finish_measurements(self): from neural_compressor.torch.quantization import finalize_calibration finalize_calibration(self.model.model) + def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): + cfg = (batch_size, seq_len, is_prompt) + seen = cfg in self.seen_configs + self.seen_configs.add(cfg) + if not seen and not warmup_mode: + phase = 'prompt' if is_prompt else 'decode' + logger.warning(f'Configuration: ({phase}, {batch_size}, {seq_len}) was not warmed-up!') + @torch.inference_mode() def execute_model( self, @@ -1594,6 +1603,7 @@ def execute_model( batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + self._check_config(batch_size, seq_len, is_prompt, warmup_mode) execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, @@ -1605,8 +1615,7 @@ def execute_model( execute_model_kwargs.update(multi_modal_input) if htorch.utils.internal.is_lazy(): execute_model_kwargs.update({ - "bypass_hpu_graphs": not use_graphs, - "warmup_mode": warmup_mode + "bypass_hpu_graphs": not use_graphs }) htorch.core.mark_step() From fd38e5d2fa7a6fb6f8c11dfb5bf8ee801b90451b Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Fri, 30 Aug 2024 11:47:58 +0300 Subject: [PATCH 2/6] Formating for log warnings --- vllm/worker/habana_model_runner.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 133706c18aed6..0100076aec8e2 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -448,7 +448,7 @@ def __init__( # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() - self.seen_configs = set() + self.seen_configs: set = set() self._mem_margin: Optional[int] = None self._setup_buckets() @@ -1567,7 +1567,8 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): self.seen_configs.add(cfg) if not seen and not warmup_mode: phase = 'prompt' if is_prompt else 'decode' - logger.warning(f'Configuration: ({phase}, {batch_size}, {seq_len}) was not warmed-up!') + logger.warning('Configuration: (', phase, ', ', batch_size, ', ', + seq_len, ') was not warmed-up!') @torch.inference_mode() def execute_model( @@ -1614,9 +1615,7 @@ def execute_model( if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) if htorch.utils.internal.is_lazy(): - execute_model_kwargs.update({ - "bypass_hpu_graphs": not use_graphs - }) + execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) htorch.core.mark_step() if self.is_driver_worker: From a032ea2781583756f1fca8bdaa6284fa2693b841 Mon Sep 17 00:00:00 2001 From: Liran Bachar Date: Sun, 1 Sep 2024 12:23:16 +0300 Subject: [PATCH 3/6] support loading autofp8 checkpoint fix gaudi2 weight range to +=240 avoid cuda code in hpu path replace _scaled_mm with hpu op --- vllm/_custom_ops/__init__.py | 75 +++++ .../_cuda_ops.py} | 0 vllm/_custom_ops/_hpu_ops.py | 317 ++++++++++++++++++ vllm/{ => _custom_ops}/_ipex_ops.py | 0 .../compressed_tensors/compressed_tensors.py | 5 +- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 19 +- .../layers/quantization/utils/w8a8_utils.py | 41 ++- vllm/model_executor/models/llama.py | 7 + vllm/utils.py | 58 +--- vllm/worker/habana_model_runner.py | 3 +- vllm/worker/habana_worker.py | 3 +- 12 files changed, 458 insertions(+), 72 deletions(-) create mode 100644 vllm/_custom_ops/__init__.py rename vllm/{_custom_ops.py => _custom_ops/_cuda_ops.py} (100%) create mode 100644 vllm/_custom_ops/_hpu_ops.py rename vllm/{ => _custom_ops}/_ipex_ops.py (100%) diff --git a/vllm/_custom_ops/__init__.py b/vllm/_custom_ops/__init__.py new file mode 100644 index 0000000000000..2411a1465c187 --- /dev/null +++ b/vllm/_custom_ops/__init__.py @@ -0,0 +1,75 @@ + +from functools import lru_cache + +@lru_cache(maxsize=None) +def is_hip() -> bool: + return torch.version.hip is not None + + +@lru_cache(maxsize=None) +def is_cpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "cpu" in version("vllm") + except PackageNotFoundError: + return False + + +@lru_cache(maxsize=None) +def is_openvino() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "openvino" in version("vllm") + except PackageNotFoundError: + return False + + +@lru_cache(maxsize=None) +def is_neuron() -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + + +@lru_cache(maxsize=None) +def is_hpu() -> bool: + from importlib import util + return util.find_spec('habana_frameworks') is not None + + +@lru_cache(maxsize=None) +def is_tpu() -> bool: + try: + import libtpu + except ImportError: + libtpu = None + return libtpu is not None + + +@lru_cache(maxsize=None) +def is_xpu() -> bool: + from importlib.metadata import version + is_xpu_flag = "xpu" in version("vllm") + # vllm is not build with xpu + if not is_xpu_flag: + return False + try: + import intel_extension_for_pytorch as ipex # noqa: F401 + _import_ipex = True + except ImportError as e: + logger.warning("Import Error for IPEX: %s", e.msg) + _import_ipex = False + # ipex dependency is not ready + if not _import_ipex: + logger.warning("not found ipex lib") + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + +if is_xpu(): + from ._ipex_ops import * +elif is_hpu(): + from ._hpu_ops import * +else: + from ._cuda_ops import * \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops/_cuda_ops.py similarity index 100% rename from vllm/_custom_ops.py rename to vllm/_custom_ops/_cuda_ops.py diff --git a/vllm/_custom_ops/_hpu_ops.py b/vllm/_custom_ops/_hpu_ops.py new file mode 100644 index 0000000000000..d553540f9e25a --- /dev/null +++ b/vllm/_custom_ops/_hpu_ops.py @@ -0,0 +1,317 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +############################################################################### +import os +from typing import Optional, Tuple + +import habana_frameworks.torch as htorch +import torch +import torch.nn.functional as F + +import vllm.hpu.utils as hpu_utils + +PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') + + +def silu_and_mul(output, input): + d = input.shape[-1] // 2 + silu = torch.nn.SiLU().to(input.device) + x, y = torch.split(input, d, dim=-1) + output.copy_(silu(x) * y) + + +def fetch_from_cache(cache, blocks, permutations): + return [ + cache.index_select(0, blocks[:, i]).permute(permutations) + for i in range(blocks.size(1)) + ] + + +def paged_attention_v1(query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + alibi_slopes=None, + kv_cache_dtype=None, + qk_matmul_op=torch.matmul, + softmax_op=torch.softmax, + av_matmul_op=torch.matmul, + k_cache_cls=None, + v_cache_cls=None) -> None: + seq_len = block_tables.size(1) + batch_size, query_heads, _ = query.shape + _, _, kv_heads, _ = key_cache.shape + min_inf = torch.finfo(query.dtype).min + mask = (torch.arange(0, + seq_len * block_size, + dtype=torch.int32, + device=key_cache.device).view(1, -1).expand( + batch_size, -1).ge(context_lens.view(-1, 1)).view( + batch_size, 1, 1, -1)) + query.mul_(scale) + query = query.unsqueeze(-2) + fetch_keys = fetch_from_cache if k_cache_cls is None else k_cache_cls.fetch_from_cache + keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] + mask = mask.unsqueeze(2) + + attn_weights = [qk_matmul_op(query, k) for k in keys] + attn_weights = torch.cat(attn_weights, dim=-1) + if alibi_slopes is not None: + attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, + -attn_weights.size(3):]) + attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) + + fetch_values = fetch_from_cache if v_cache_cls is None else k_cache_cls.fetch_from_cache + values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) + if PA_SPLIT_VALUE: + attn_weights = attn_weights.split(block_size, dim=-1) + else: + values = [torch.cat(values, dim=-2)] + attn_weights = [attn_weights] + if query_heads != kv_heads: + values = [v.unflatten(1, (kv_heads, 1)) for v in values] + attn_weights = [av_matmul_op(a, v) for a, v in zip(attn_weights, values)] + if query_heads != kv_heads: + attn_weights = [a.flatten(1, 2) for a in attn_weights] + attn_weights = sum(attn_weights) + return attn_weights.squeeze(-2) + + +def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + silu_and_mul(out, x) + return out + + +def static_fused_moe(hidden_states, w1, w2, score, topk): + B, D = hidden_states.shape + num_experts = w1.shape[0] + routing_weights = F.softmax(score, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros((1, B, D), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights = torch.zeros((B, num_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, routing_weights) + padded_weights = padded_weights.reshape(-1, B, w1.shape[0]) + padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) + + htorch.core.mark_step() + + for expert_idx in range(num_experts): + padded_weight = padded_weights[expert_idx] + current_state_static = hidden_states.reshape(-1, D) + w_output = silu_and_mul_wrapper( + torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1))) + w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) + current_hidden_states_static = w_output * padded_weight + final_hidden_states += current_hidden_states_static + htorch.core.mark_step() + + return final_hidden_states.view(-1, D) + + +def prompt_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, + scale: Optional[float] = None, + qk_matmul_op = torch.matmul, + softmax_op = torch.softmax, + av_matmul_op = torch.matmul, +) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + query_heads = query.size(1) + kv_heads = key.size(1) + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + attn_bias = attn_bias.unsqueeze(2) + attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2)) + if attn_bias is not None: + attn_weights.add_(attn_bias) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = av_matmul_op(attn_weights, value) + if query_heads != kv_heads: + attn_weights = attn_weights.flatten(1, 2) + attn_weights = attn_weights.transpose(1, 2) + return attn_weights + + + + +def reshape_and_cache(key, + value, + key_cache, + value_cache, + slot_mapping, + dtype, + is_prompt=False): + num_blocks = key_cache.size(0) + block_size = key_cache.size(1) + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + num_slots_requested = slot_mapping.size(0) + num_slots_available = num_blocks * block_size + # NOTE(kzawora): HPU PT bridge crashes with + # RuntimeError: Invalid inputs for scatter_nd_onnx + # on index_put when num_slots_requested > num_slots_available. + # This case might occur when we have little kv cache blocks and + # lots of padding, or are doing warmup. + # This loop is a workaround for this issue. Please remove it + # once key_cache.index_put_(indices, offsets), key) works. + num_kv_cache_passes = torch.div(num_slots_requested, + num_slots_available).ceil().int().item() + for i in range(num_kv_cache_passes): + start_idx = i * num_slots_available + end_idx = (i + 1) * num_slots_available + key_cache.index_put_( + (indices[start_idx:end_idx], offsets[start_idx:end_idx]), + key[start_idx:end_idx]) + value_cache.index_put_( + (indices[start_idx:end_idx], offsets[start_idx:end_idx]), + value[start_idx:end_idx]) + + +def prepare_to_cache(cache, slot_mapping): + num_blocks = cache.size(0) + block_size = cache.size(1) + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + num_slots_requested = slot_mapping.size(0) + num_slots_available = num_blocks * block_size + # NOTE(kzawora): HPU PT bridge crashes with + # RuntimeError: Invalid inputs for scatter_nd_onnx + # on index_put when num_slots_requested > num_slots_available. + # This case might occur when we have little kv cache blocks and + # lots of padding, or are doing warmup. + # This loop is a workaround for this issue. Please remove it + # once key_cache.index_put_(indices, offsets), key) works. + num_kv_cache_passes = torch.div(num_slots_requested, + num_slots_available).ceil().int().item() + + return num_kv_cache_passes, num_slots_available, indices, offsets + + +def insert_or_update_cache(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offsets): + for i in range(num_kv_cache_passes): + start_idx = i * num_slots_available + end_idx = (i + 1) * num_slots_available + cache.index_put_( + (block_indices[start_idx:end_idx], block_offsets[start_idx:end_idx]), + input[start_idx:end_idx]) + + +def swap_blocks(src, dst, block_mapping): + index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device) + index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device) + for src_idx, dst_idx in block_mapping.items(): + index_src[0] = src_idx + index_dst[0] = dst_idx + dst.index_put_([index_dst], src.index_select(0, index_src)) + if dst.device.type == 'hpu': + htorch.core.mark_step() + torch.hpu.synchronize() + + +def copy_blocks(key_caches, value_caches, block_mapping): + index_src = torch.zeros((1, ), + dtype=torch.int32, + device=key_caches[0].device) + index_dst = torch.zeros((1, ), + dtype=torch.int32, + device=key_caches[0].device) + for src, dsts in block_mapping.items(): + index_src[0] = src + for dst in dsts: + index_dst[0] = dst + for key_cache in key_caches: + key_cache.index_copy_(0, index_dst, + key_cache.index_select(0, index_src)) + for value_cache in value_caches: + value_cache.index_copy_(0, index_dst, + value_cache.index_select(0, index_src)) + if key_caches[0].device.type == 'hpu': + htorch.core.mark_step() + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + if scale is None: + raise "dynamic scaled_fp8_quant not implemented for HPU" + #TODO: calculate scale to match gaudi2 240 range instead of 448 + if use_per_token_if_dynamic: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + else: + output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0] + + return output, scale diff --git a/vllm/_ipex_ops.py b/vllm/_custom_ops/_ipex_ops.py similarity index 100% rename from vllm/_ipex_ops.py rename to vllm/_custom_ops/_ipex_ops.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 39d00bd5733ff..badb29af1f5f6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -233,7 +233,7 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, @@ -306,7 +306,8 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - self._check_scheme_supported(scheme.get_min_capability()) + if torch.cuda.is_available(): + self._check_scheme_supported(scheme.get_min_capability()) return scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index cc9d71db140c2..631774994b5c0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,7 +21,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c829cb836ee4c..8e2ed041adf0b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -112,13 +112,18 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + + if torch.cuda.is_available(): + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + else: + self.cutlass_fp8_supported = False + self.use_marlin = False def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 20100c76bd690..de5cd810b2a94 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +if current_platform.is_hpu(): + import habana_frameworks.torch.utils.experimental as htexp def cutlass_fp8_supported() -> bool: @@ -18,8 +20,17 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) + dtype = torch.float16 + device = tensor.device + if current_platform.is_hpu(): + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + #dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to('cpu') + + fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale + return dq_weight @@ -76,6 +87,9 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() + if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * (448/240) + # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize @@ -147,12 +161,25 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + if current_platform.is_hpu(): + #hpu does not support torch._scaled_mm (SW-197036) + output = torch.ops.hpu.fp8_gemm_v2(qinput, + False, + weight, + False, + None, + input.dtype, + x_scale, + weight_scale, + None, + False) + else: + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) return torch.narrow(output, 0, 0, input.shape[0]) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 676a51ce67f96..f02609aa9ff3b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,6 +54,9 @@ from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from vllm.platforms import current_platform +if current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore class LlamaMLP(nn.Module): @@ -518,6 +521,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) + #Avoid OOM due to large graph when loading weights + if current_platform.is_hpu(): + htcore.mark_step() + # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should # make sure to leave KV cache scale factors in a known good (dummy) state diff --git a/vllm/utils.py b/vllm/utils.py index fa6e132dd3522..661d5d62e069b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -176,69 +176,25 @@ def clear(self): def is_hip() -> bool: - return torch.version.hip is not None + return ops.is_hip() - -@lru_cache(maxsize=None) def is_cpu() -> bool: - from importlib.metadata import PackageNotFoundError, version - try: - return "cpu" in version("vllm") - except PackageNotFoundError: - return False - + return ops.is_cpu() -@lru_cache(maxsize=None) def is_openvino() -> bool: - from importlib.metadata import PackageNotFoundError, version - try: - return "openvino" in version("vllm") - except PackageNotFoundError: - return False - + return ops.is_openvino() -@lru_cache(maxsize=None) def is_neuron() -> bool: - try: - import transformers_neuronx - except ImportError: - transformers_neuronx = None - return transformers_neuronx is not None + return ops.is_neuron() - -@lru_cache(maxsize=None) def is_hpu() -> bool: - from importlib import util - return util.find_spec('habana_frameworks') is not None - + return ops.is_hpu() -@lru_cache(maxsize=None) def is_tpu() -> bool: - try: - import libtpu - except ImportError: - libtpu = None - return libtpu is not None + return ops.is_tpu() - -@lru_cache(maxsize=None) def is_xpu() -> bool: - from importlib.metadata import version - is_xpu_flag = "xpu" in version("vllm") - # vllm is not build with xpu - if not is_xpu_flag: - return False - try: - import intel_extension_for_pytorch as ipex # noqa: F401 - _import_ipex = True - except ImportError as e: - logger.warning("Import Error for IPEX: %s", e.msg) - _import_ipex = False - # ipex dependency is not ready - if not _import_ipex: - logger.warning("not found ipex lib") - return False - return hasattr(torch, "xpu") and torch.xpu.is_available() + return ops.is_xpu() @lru_cache(maxsize=None) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a975dba6f5136..a2c7a96757faa 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -453,8 +453,7 @@ def __init__( def load_model(self) -> None: import habana_frameworks.torch.core as htcore - if self.model_config.quantization == 'inc': - htcore.hpu_set_env() + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 9d083915041fe..bf285c93cdd47 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -109,8 +109,7 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - if self.model_config.quantization == 'inc': - self._set_env_vars() + self._set_env_vars() init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) From 221eb5600f7523c957ebad318e54d908af6c8332 Mon Sep 17 00:00:00 2001 From: Liran Bachar Date: Sun, 1 Sep 2024 13:57:58 +0300 Subject: [PATCH 4/6] Revert "support loading autofp8 checkpoint" This reverts commit a032ea2781583756f1fca8bdaa6284fa2693b841. --- .../_cuda_ops.py => _custom_ops.py} | 0 vllm/_custom_ops/__init__.py | 75 ----- vllm/_custom_ops/_hpu_ops.py | 317 ------------------ vllm/{_custom_ops => }/_ipex_ops.py | 0 .../compressed_tensors/compressed_tensors.py | 5 +- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 19 +- .../layers/quantization/utils/w8a8_utils.py | 41 +-- vllm/model_executor/models/llama.py | 7 - vllm/utils.py | 58 +++- vllm/worker/habana_model_runner.py | 3 +- vllm/worker/habana_worker.py | 3 +- 12 files changed, 72 insertions(+), 458 deletions(-) rename vllm/{_custom_ops/_cuda_ops.py => _custom_ops.py} (100%) delete mode 100644 vllm/_custom_ops/__init__.py delete mode 100644 vllm/_custom_ops/_hpu_ops.py rename vllm/{_custom_ops => }/_ipex_ops.py (100%) diff --git a/vllm/_custom_ops/_cuda_ops.py b/vllm/_custom_ops.py similarity index 100% rename from vllm/_custom_ops/_cuda_ops.py rename to vllm/_custom_ops.py diff --git a/vllm/_custom_ops/__init__.py b/vllm/_custom_ops/__init__.py deleted file mode 100644 index 2411a1465c187..0000000000000 --- a/vllm/_custom_ops/__init__.py +++ /dev/null @@ -1,75 +0,0 @@ - -from functools import lru_cache - -@lru_cache(maxsize=None) -def is_hip() -> bool: - return torch.version.hip is not None - - -@lru_cache(maxsize=None) -def is_cpu() -> bool: - from importlib.metadata import PackageNotFoundError, version - try: - return "cpu" in version("vllm") - except PackageNotFoundError: - return False - - -@lru_cache(maxsize=None) -def is_openvino() -> bool: - from importlib.metadata import PackageNotFoundError, version - try: - return "openvino" in version("vllm") - except PackageNotFoundError: - return False - - -@lru_cache(maxsize=None) -def is_neuron() -> bool: - try: - import transformers_neuronx - except ImportError: - transformers_neuronx = None - return transformers_neuronx is not None - - -@lru_cache(maxsize=None) -def is_hpu() -> bool: - from importlib import util - return util.find_spec('habana_frameworks') is not None - - -@lru_cache(maxsize=None) -def is_tpu() -> bool: - try: - import libtpu - except ImportError: - libtpu = None - return libtpu is not None - - -@lru_cache(maxsize=None) -def is_xpu() -> bool: - from importlib.metadata import version - is_xpu_flag = "xpu" in version("vllm") - # vllm is not build with xpu - if not is_xpu_flag: - return False - try: - import intel_extension_for_pytorch as ipex # noqa: F401 - _import_ipex = True - except ImportError as e: - logger.warning("Import Error for IPEX: %s", e.msg) - _import_ipex = False - # ipex dependency is not ready - if not _import_ipex: - logger.warning("not found ipex lib") - return False - return hasattr(torch, "xpu") and torch.xpu.is_available() - -if is_xpu(): - from ._ipex_ops import * -elif is_hpu(): - from ._hpu_ops import * -else: - from ._cuda_ops import * \ No newline at end of file diff --git a/vllm/_custom_ops/_hpu_ops.py b/vllm/_custom_ops/_hpu_ops.py deleted file mode 100644 index d553540f9e25a..0000000000000 --- a/vllm/_custom_ops/_hpu_ops.py +++ /dev/null @@ -1,317 +0,0 @@ -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -############################################################################### -import os -from typing import Optional, Tuple - -import habana_frameworks.torch as htorch -import torch -import torch.nn.functional as F - -import vllm.hpu.utils as hpu_utils - -PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') - - -def silu_and_mul(output, input): - d = input.shape[-1] // 2 - silu = torch.nn.SiLU().to(input.device) - x, y = torch.split(input, d, dim=-1) - output.copy_(silu(x) * y) - - -def fetch_from_cache(cache, blocks, permutations): - return [ - cache.index_select(0, blocks[:, i]).permute(permutations) - for i in range(blocks.size(1)) - ] - - -def paged_attention_v1(query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - alibi_slopes=None, - kv_cache_dtype=None, - qk_matmul_op=torch.matmul, - softmax_op=torch.softmax, - av_matmul_op=torch.matmul, - k_cache_cls=None, - v_cache_cls=None) -> None: - seq_len = block_tables.size(1) - batch_size, query_heads, _ = query.shape - _, _, kv_heads, _ = key_cache.shape - min_inf = torch.finfo(query.dtype).min - mask = (torch.arange(0, - seq_len * block_size, - dtype=torch.int32, - device=key_cache.device).view(1, -1).expand( - batch_size, -1).ge(context_lens.view(-1, 1)).view( - batch_size, 1, 1, -1)) - query.mul_(scale) - query = query.unsqueeze(-2) - fetch_keys = fetch_from_cache if k_cache_cls is None else k_cache_cls.fetch_from_cache - keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] - mask = mask.unsqueeze(2) - - attn_weights = [qk_matmul_op(query, k) for k in keys] - attn_weights = torch.cat(attn_weights, dim=-1) - if alibi_slopes is not None: - attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, - -attn_weights.size(3):]) - attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - - fetch_values = fetch_from_cache if v_cache_cls is None else k_cache_cls.fetch_from_cache - values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) - if PA_SPLIT_VALUE: - attn_weights = attn_weights.split(block_size, dim=-1) - else: - values = [torch.cat(values, dim=-2)] - attn_weights = [attn_weights] - if query_heads != kv_heads: - values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [av_matmul_op(a, v) for a, v in zip(attn_weights, values)] - if query_heads != kv_heads: - attn_weights = [a.flatten(1, 2) for a in attn_weights] - attn_weights = sum(attn_weights) - return attn_weights.squeeze(-2) - - -def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - silu_and_mul(out, x) - return out - - -def static_fused_moe(hidden_states, w1, w2, score, topk): - B, D = hidden_states.shape - num_experts = w1.shape[0] - routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros((1, B, D), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights = torch.zeros((B, num_experts), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights.scatter_(-1, selected_experts, routing_weights) - padded_weights = padded_weights.reshape(-1, B, w1.shape[0]) - padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - - htorch.core.mark_step() - - for expert_idx in range(num_experts): - padded_weight = padded_weights[expert_idx] - current_state_static = hidden_states.reshape(-1, D) - w_output = silu_and_mul_wrapper( - torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1))) - w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) - current_hidden_states_static = w_output * padded_weight - final_hidden_states += current_hidden_states_static - htorch.core.mark_step() - - return final_hidden_states.view(-1, D) - - -def prompt_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - p: float = 0.0, - scale: Optional[float] = None, - qk_matmul_op = torch.matmul, - softmax_op = torch.softmax, - av_matmul_op = torch.matmul, -) -> torch.Tensor: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - query_heads = query.size(1) - kv_heads = key.size(1) - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - key = key.unflatten(1, (kv_heads, 1)) - value = value.unflatten(1, (kv_heads, 1)) - attn_bias = attn_bias.unsqueeze(2) - attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2)) - if attn_bias is not None: - attn_weights.add_(attn_bias) - attn_weights = softmax_op(attn_weights, dim=-1) - attn_weights = av_matmul_op(attn_weights, value) - if query_heads != kv_heads: - attn_weights = attn_weights.flatten(1, 2) - attn_weights = attn_weights.transpose(1, 2) - return attn_weights - - - - -def reshape_and_cache(key, - value, - key_cache, - value_cache, - slot_mapping, - dtype, - is_prompt=False): - num_blocks = key_cache.size(0) - block_size = key_cache.size(1) - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - offsets = torch.fmod(slot_mapping, block_size) - num_slots_requested = slot_mapping.size(0) - num_slots_available = num_blocks * block_size - # NOTE(kzawora): HPU PT bridge crashes with - # RuntimeError: Invalid inputs for scatter_nd_onnx - # on index_put when num_slots_requested > num_slots_available. - # This case might occur when we have little kv cache blocks and - # lots of padding, or are doing warmup. - # This loop is a workaround for this issue. Please remove it - # once key_cache.index_put_(indices, offsets), key) works. - num_kv_cache_passes = torch.div(num_slots_requested, - num_slots_available).ceil().int().item() - for i in range(num_kv_cache_passes): - start_idx = i * num_slots_available - end_idx = (i + 1) * num_slots_available - key_cache.index_put_( - (indices[start_idx:end_idx], offsets[start_idx:end_idx]), - key[start_idx:end_idx]) - value_cache.index_put_( - (indices[start_idx:end_idx], offsets[start_idx:end_idx]), - value[start_idx:end_idx]) - - -def prepare_to_cache(cache, slot_mapping): - num_blocks = cache.size(0) - block_size = cache.size(1) - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - offsets = torch.fmod(slot_mapping, block_size) - num_slots_requested = slot_mapping.size(0) - num_slots_available = num_blocks * block_size - # NOTE(kzawora): HPU PT bridge crashes with - # RuntimeError: Invalid inputs for scatter_nd_onnx - # on index_put when num_slots_requested > num_slots_available. - # This case might occur when we have little kv cache blocks and - # lots of padding, or are doing warmup. - # This loop is a workaround for this issue. Please remove it - # once key_cache.index_put_(indices, offsets), key) works. - num_kv_cache_passes = torch.div(num_slots_requested, - num_slots_available).ceil().int().item() - - return num_kv_cache_passes, num_slots_available, indices, offsets - - -def insert_or_update_cache(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offsets): - for i in range(num_kv_cache_passes): - start_idx = i * num_slots_available - end_idx = (i + 1) * num_slots_available - cache.index_put_( - (block_indices[start_idx:end_idx], block_offsets[start_idx:end_idx]), - input[start_idx:end_idx]) - - -def swap_blocks(src, dst, block_mapping): - index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device) - index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device) - for src_idx, dst_idx in block_mapping.items(): - index_src[0] = src_idx - index_dst[0] = dst_idx - dst.index_put_([index_dst], src.index_select(0, index_src)) - if dst.device.type == 'hpu': - htorch.core.mark_step() - torch.hpu.synchronize() - - -def copy_blocks(key_caches, value_caches, block_mapping): - index_src = torch.zeros((1, ), - dtype=torch.int32, - device=key_caches[0].device) - index_dst = torch.zeros((1, ), - dtype=torch.int32, - device=key_caches[0].device) - for src, dsts in block_mapping.items(): - index_src[0] = src - for dst in dsts: - index_dst[0] = dst - for key_cache in key_caches: - key_cache.index_copy_(0, index_dst, - key_cache.index_select(0, index_src)) - for value_cache in value_caches: - value_cache.index_copy_(0, index_dst, - value_cache.index_select(0, index_src)) - if key_caches[0].device.type == 'hpu': - htorch.core.mark_step() - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - batch_dim_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensor for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - batch_dim_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - if batch_dim_padding: - shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) - output = torch.empty(shape, - device=input.device, - dtype=torch.float8_e4m3fn) - else: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) - if scale is None: - raise "dynamic scaled_fp8_quant not implemented for HPU" - #TODO: calculate scale to match gaudi2 240 range instead of 448 - if use_per_token_if_dynamic: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) - else: - output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0] - - return output, scale diff --git a/vllm/_custom_ops/_ipex_ops.py b/vllm/_ipex_ops.py similarity index 100% rename from vllm/_custom_ops/_ipex_ops.py rename to vllm/_ipex_ops.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index badb29af1f5f6..39d00bd5733ff 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -233,7 +233,7 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, @@ -306,8 +306,7 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - if torch.cuda.is_available(): - self._check_scheme_supported(scheme.get_min_capability()) + self._check_scheme_supported(scheme.get_min_capability()) return scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 631774994b5c0..cc9d71db140c2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,7 +21,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False + self.cutlass_fp8_supported = cutlass_fp8_supported() @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8e2ed041adf0b..c829cb836ee4c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -112,18 +112,13 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - - if torch.cuda.is_available(): - self.cutlass_fp8_supported = cutlass_fp8_supported() - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 - else: - self.cutlass_fp8_supported = False - self.use_marlin = False + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index de5cd810b2a94..20100c76bd690 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,8 +6,6 @@ from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -if current_platform.is_hpu(): - import habana_frameworks.torch.utils.experimental as htexp def cutlass_fp8_supported() -> bool: @@ -20,17 +18,8 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: - dtype = torch.float16 - device = tensor.device - if current_platform.is_hpu(): - dtype = torch.bfloat16 - if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - #dequant on cpu to avoid nan on gaudi2 - tensor = tensor.to('cpu') - - fake_qweight = tensor.to(dtype).to(device) + fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale - return dq_weight @@ -87,9 +76,6 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - max_w_scale = max_w_scale * (448/240) - # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize @@ -161,25 +147,12 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - if current_platform.is_hpu(): - #hpu does not support torch._scaled_mm (SW-197036) - output = torch.ops.hpu.fp8_gemm_v2(qinput, - False, - weight, - False, - None, - input.dtype, - x_scale, - weight_scale, - None, - False) - else: - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) return torch.narrow(output, 0, 0, input.shape[0]) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f02609aa9ff3b..676a51ce67f96 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,9 +54,6 @@ from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -from vllm.platforms import current_platform -if current_platform.is_hpu(): - import habana_frameworks.torch.core as htcore class LlamaMLP(nn.Module): @@ -521,10 +518,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) - #Avoid OOM due to large graph when loading weights - if current_platform.is_hpu(): - htcore.mark_step() - # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should # make sure to leave KV cache scale factors in a known good (dummy) state diff --git a/vllm/utils.py b/vllm/utils.py index 661d5d62e069b..fa6e132dd3522 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -176,25 +176,69 @@ def clear(self): def is_hip() -> bool: - return ops.is_hip() + return torch.version.hip is not None + +@lru_cache(maxsize=None) def is_cpu() -> bool: - return ops.is_cpu() + from importlib.metadata import PackageNotFoundError, version + try: + return "cpu" in version("vllm") + except PackageNotFoundError: + return False + +@lru_cache(maxsize=None) def is_openvino() -> bool: - return ops.is_openvino() + from importlib.metadata import PackageNotFoundError, version + try: + return "openvino" in version("vllm") + except PackageNotFoundError: + return False + +@lru_cache(maxsize=None) def is_neuron() -> bool: - return ops.is_neuron() + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + +@lru_cache(maxsize=None) def is_hpu() -> bool: - return ops.is_hpu() + from importlib import util + return util.find_spec('habana_frameworks') is not None + +@lru_cache(maxsize=None) def is_tpu() -> bool: - return ops.is_tpu() + try: + import libtpu + except ImportError: + libtpu = None + return libtpu is not None + +@lru_cache(maxsize=None) def is_xpu() -> bool: - return ops.is_xpu() + from importlib.metadata import version + is_xpu_flag = "xpu" in version("vllm") + # vllm is not build with xpu + if not is_xpu_flag: + return False + try: + import intel_extension_for_pytorch as ipex # noqa: F401 + _import_ipex = True + except ImportError as e: + logger.warning("Import Error for IPEX: %s", e.msg) + _import_ipex = False + # ipex dependency is not ready + if not _import_ipex: + logger.warning("not found ipex lib") + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() @lru_cache(maxsize=None) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a2c7a96757faa..a975dba6f5136 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -453,7 +453,8 @@ def __init__( def load_model(self) -> None: import habana_frameworks.torch.core as htcore - htcore.hpu_set_env() + if self.model_config.quantization == 'inc': + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index bf285c93cdd47..9d083915041fe 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -109,7 +109,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - self._set_env_vars() + if self.model_config.quantization == 'inc': + self._set_env_vars() init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) From c899aef31c064523daa5c38746d203dc148518cc Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Mon, 2 Sep 2024 12:54:54 +0300 Subject: [PATCH 5/6] warmup_mode kward restore --- vllm/worker/habana_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 0100076aec8e2..241980f32f097 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1615,7 +1615,10 @@ def execute_model( if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) if htorch.utils.internal.is_lazy(): - execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) + execute_model_kwargs.update({ + "bypass_hpu_graphs": not use_graphs, + "warmup_mode": warmup_mode + }) htorch.core.mark_step() if self.is_driver_worker: From 4eedfb91c8ef33a601b9e203a7ad8048d854222f Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Mon, 2 Sep 2024 14:24:59 +0300 Subject: [PATCH 6/6] change format --- vllm/worker/habana_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 241980f32f097..dec1b65858eb4 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1567,8 +1567,8 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): self.seen_configs.add(cfg) if not seen and not warmup_mode: phase = 'prompt' if is_prompt else 'decode' - logger.warning('Configuration: (', phase, ', ', batch_size, ', ', - seq_len, ') was not warmed-up!') + logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", + phase, batch_size, seq_len) @torch.inference_mode() def execute_model(