diff --git a/README_GAUDI.md b/README_GAUDI.md index 0ef30d5f96e64..04e2ff22f96e5 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -321,7 +321,7 @@ for graph capture (later referred to as \"usable graph memory\"), and the remaining 90% will be utilized for KV cache. Environment variable `VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default -(`VLLM_GRAPH_PROMPT_RATIO=0.5`), both stages have equal memory +(`VLLM_GRAPH_PROMPT_RATIO=0.3`), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable @@ -388,7 +388,7 @@ INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 G INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB ... INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB -INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5) +INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 4.755 GiB for prompt and 11.095 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB ... INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB @@ -448,7 +448,7 @@ Environment variables - `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for HPUGraph capture, `0.1` by default - `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory - dedicated for prompt graphs, `0.5` by default + dedicated for prompt graphs, `0.3` by default - `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`, `min_tokens` by default - `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode @@ -472,15 +472,15 @@ Environment variables `max_model_len` - Decode: - - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)` + - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): `min(max_num_seqs, 32)` - batch size max (`VLLM_DECODE_BS_BUCKET_MAX`): `max_num_seqs` - block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`): - `128` + `block_size` - block size step - (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128` + (`VLLM_DECODE_BLOCK_BUCKET_STEP`): `block_size` - block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`): `max(128, (max_num_seqs*max_model_len)/block_size)` diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 8c4905e2a488a..db1d8666e4800 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -245,7 +245,7 @@ Only after that, ``gpu_memory_utilization`` flag is utilized - at its default va Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. With its default value (``VLLM_GRAPH_RESERVED_MEM=0.1``), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache. -Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. +Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.3``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. .. note:: @@ -280,7 +280,7 @@ Each described step is logged by vLLM server, as follows (negative values corres INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB ... INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5) + INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB ... INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB @@ -324,7 +324,7 @@ Environment variables - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default - ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.1`` by default -- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default +- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.3`` by default - ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default - ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default - ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism @@ -343,11 +343,11 @@ Environment variables - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len`` - Decode: - - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)`` + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` - - sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``128`` - - sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``128`` + - sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``block_size`` - sequence length max (``VLLM_DECODE_BLOCK_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index 44226fc898218..e4bd54f8849b3 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -195,9 +195,6 @@ def check_health(self) -> None: def shutdown(self) -> None: self.driver_worker.shutdown_inc() - def __del__(self): - self.shutdown() - class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3e01112eaa14d..cf17f1e240e47 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -13,6 +13,10 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +if current_platform.is_hpu(): + from vllm_hpu_extension.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant + logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e536fae45c845..252ad864ced3e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -243,8 +243,10 @@ def _get_scheme_from_parts( # TODO @dsikka: clean-up conditions if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + is_fp8_w8a8_supported = current_platform.is_hpu() or \ + self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), + error=False) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, @@ -314,7 +316,8 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - self._check_scheme_supported(scheme.get_min_capability()) + if not current_platform.is_hpu(): + self._check_scheme_supported(scheme.get_min_capability()) return scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 5931ec36c97d5..29f3228c0dc5d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -13,6 +13,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.platforms import current_platform from vllm.utils import is_hip __all__ = ["CompressedTensorsW8A8Fp8"] @@ -23,7 +24,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.cutlass_fp8_supported = not current_platform.is_hpu() and \ + cutlass_fp8_supported() @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b5feb55db0e74..88915942220ca 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -28,6 +28,10 @@ from vllm.platforms import current_platform from vllm.utils import is_hip, print_warning_once +if current_platform.is_hpu(): + from vllm_hpu_extension.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) @@ -116,14 +120,18 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) - # Disable marlin for rocm - if is_hip(): + if current_platform.is_cuda_alike(): + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if is_hip(): + self.use_marlin = False + else: + self.cutlass_fp8_supported = False self.use_marlin = False def create_weights( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index fb263d121fe55..048962721e26b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -10,6 +10,11 @@ # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None +if current_platform.is_hpu(): + import habana_frameworks.torch.utils.experimental as htexp + from vllm_hpu_extension.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant + def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm @@ -25,7 +30,15 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) + dtype = torch.float16 + device = tensor.device + if current_platform.is_hpu(): + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + #dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to('cpu') + + fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale return dq_weight @@ -58,7 +71,10 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - + if current_platform.is_hpu() and htexp._get_device_type( + ) == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max / + torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale @@ -129,12 +145,20 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + if current_platform.is_hpu(): + #hpu does not support torch._scaled_mm (SW-197036) + output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight, + False, None, input.dtype, + x_scale, weight_scale, None, + False) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c038dfe42bf5f..d3d2973688843 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -561,6 +561,7 @@ def __init__( # Lazy initialization self.lora_manager: LRUCacheWorkerLoRAManager = None self.model: torch.nn.Module = None + self.inc_initialized_successfully = False # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() @@ -597,8 +598,7 @@ def _set_gc_threshold(self) -> None: def load_model(self) -> None: import habana_frameworks.torch.core as htcore - if self.model_config.quantization == 'inc': - htcore.hpu_set_env() + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model(model_config=self.model_config, @@ -643,6 +643,7 @@ def load_model(self) -> None: self.model = convert(self.model, config) htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) + self.inc_initialized_successfully = True logger.info("Preparing model with INC took %s", m_inc.get_summary_string()) elif not is_fake_hpu(): @@ -679,7 +680,6 @@ def _setup_buckets(self) -> None: if self.lora_config and \ max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size - blocks_step = 128 #FIXME: The default values should be max_model_len max_prompt_seq = 1024 max_decode_seq = 2048 @@ -691,7 +691,7 @@ def _setup_buckets(self) -> None: max=align_bs(max_bucket_cfg)) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', - min=align_bs(32), + min=1, step=align_bs(32), max=self.max_num_seqs) self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', @@ -702,9 +702,9 @@ def _setup_buckets(self) -> None: self.decode_block_bucket_cfg = read_bucket_settings( 'decode', 'block', - min=blocks_step, - step=blocks_step, - max=max(blocks_step, + min=self.block_size, + step=self.block_size, + max=max(self.block_size, self.max_num_seqs * max_decode_seq // self.block_size)) self.graphed_buckets: Set[Any] = set() @@ -1571,6 +1571,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: len(self.decode_buckets), list(sorted(self.decode_buckets))) + if not htorch.utils.internal.is_lazy() and not self.enforce_eager: + cache_size_limit = len(self.prompt_buckets) + len( + self.decode_buckets) + 1 + torch._dynamo.config.cache_size_limit = max( + cache_size_limit, torch._dynamo.config.cache_size_limit) + # Multiply by 8 to follow the original default ratio between + # the cache_size_limit and accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = max( + cache_size_limit * 8, + torch._dynamo.config.accumulated_cache_size_limit) + start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() @@ -1601,7 +1612,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: graph_free_mem = align_workers(graph_free_mem, torch.distributed.ReduceOp.MIN) prompt_graph_mem_ratio = float( - os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5')) + os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3')) prompt_available_memory = (prompt_graph_mem_ratio * graph_free_mem) decode_available_memory = (graph_free_mem - @@ -1799,6 +1810,7 @@ def make_model_input_from_broadcasted_tensor_dict( attn_backend=self.attn_backend, )) + @torch.inference_mode() def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1958,14 +1970,18 @@ def execute_model( return [output] def shutdown_inc(self): - print('inc shutdown') - if (model_config := getattr(self, "model_config", None)) and \ - getattr(model_config, "quantization", None) == 'inc': - print('inc shutdown start') + can_finalize_inc = False + from contextlib import suppress + with suppress(AttributeError): + can_finalize_inc = (self.model_config.quantization == 'inc') and \ + (self.model.model is not None) and \ + self.inc_initialized_successfully and \ + not getattr(self, "_is_inc_finalized", False) + if can_finalize_inc: from neural_compressor.torch.quantization import ( finalize_calibration) finalize_calibration(self.model.model) - print('inc shutdown') + self._is_inc_finalized = True def __del__(self): self.shutdown_inc() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 8cdbba02fbb33..2e4dfeac42c3e 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -320,9 +320,6 @@ def list_prompt_adapters(self) -> Set[int]: def shutdown_inc(self): self.model_runner.shutdown_inc() - def __del__(self): - self.shutdown_inc() - @property def max_model_len(self) -> int: return self.model_config.max_model_len