From 29fb5edd1df36aa4fa0ff95c7b2cbb711b8cb035 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:19:40 +0300 Subject: [PATCH] Support loading checkpoints quantized using Autofp8 (#286) Support loading https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127 Skip cuda checks Use scaled_fp8_quant instead of _scaled_mm Fix weights and weight_scale for guudi2 flot8_e4m3fn range. --------- Co-authored-by: Nir David Co-authored-by: Konrad Zawora --- requirements-hpu.txt | 3 +- .../layers/fused_moe/fused_moe.py | 4 ++ .../compressed_tensors/compressed_tensors.py | 9 +++-- .../schemes/compressed_tensors_w8a8_fp8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 24 +++++++---- .../layers/quantization/utils/w8a8_utils.py | 40 +++++++++++++++---- vllm/worker/habana_model_runner.py | 3 +- 7 files changed, 64 insertions(+), 23 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index c7376a7c504f..1af5460128fb 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -6,4 +6,5 @@ ray == 2.32.0 triton pandas tabulate -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab + +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3e01112eaa14..cf17f1e240e4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -13,6 +13,10 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +if current_platform.is_hpu(): + from vllm_hpu_extension.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant + logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e536fae45c84..252ad864ced3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -243,8 +243,10 @@ def _get_scheme_from_parts( # TODO @dsikka: clean-up conditions if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + is_fp8_w8a8_supported = current_platform.is_hpu() or \ + self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), + error=False) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, @@ -314,7 +316,8 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - self._check_scheme_supported(scheme.get_min_capability()) + if not current_platform.is_hpu(): + self._check_scheme_supported(scheme.get_min_capability()) return scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 5931ec36c97d..29f3228c0dc5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -13,6 +13,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.platforms import current_platform from vllm.utils import is_hip __all__ = ["CompressedTensorsW8A8Fp8"] @@ -23,7 +24,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.cutlass_fp8_supported = not current_platform.is_hpu() and \ + cutlass_fp8_supported() @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b5feb55db0e7..88915942220c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -28,6 +28,10 @@ from vllm.platforms import current_platform from vllm.utils import is_hip, print_warning_once +if current_platform.is_hpu(): + from vllm_hpu_extension.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) @@ -116,14 +120,18 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) - # Disable marlin for rocm - if is_hip(): + if current_platform.is_cuda_alike(): + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if is_hip(): + self.use_marlin = False + else: + self.cutlass_fp8_supported = False self.use_marlin = False def create_weights( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index fb263d121fe5..048962721e26 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -10,6 +10,11 @@ # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None +if current_platform.is_hpu(): + import habana_frameworks.torch.utils.experimental as htexp + from vllm_hpu_extension.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant + def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm @@ -25,7 +30,15 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) + dtype = torch.float16 + device = tensor.device + if current_platform.is_hpu(): + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + #dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to('cpu') + + fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale return dq_weight @@ -58,7 +71,10 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - + if current_platform.is_hpu() and htexp._get_device_type( + ) == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max / + torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale @@ -129,12 +145,20 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + if current_platform.is_hpu(): + #hpu does not support torch._scaled_mm (SW-197036) + output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight, + False, None, input.dtype, + x_scale, weight_scale, None, + False) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index e80df4e7c8c1..c43acdf04923 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -587,8 +587,7 @@ def _set_gc_threshold(self) -> None: def load_model(self) -> None: import habana_frameworks.torch.core as htcore - if self.model_config.quantization == 'inc': - htcore.hpu_set_env() + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model(model_config=self.model_config,