From 19c96b8a6489810f1e9877103270b96869c5c7f4 Mon Sep 17 00:00:00 2001 From: Nir David Date: Sun, 11 Aug 2024 12:00:09 +0300 Subject: [PATCH] Inc on vLLM - Fix CR comments --- setup.py | 4 +- vllm/attention/backends/habana_attn.py | 28 ++++++++------ vllm/attention/ops/habana_paged_attn.py | 8 ++-- vllm/config.py | 4 +- vllm/engine/arg_utils.py | 7 ++-- vllm/hpu/cache_ops.py | 9 +++-- vllm/hpu/ops.py | 25 ++++++------ vllm/hpu/utils.py | 25 +++++++++--- vllm/model_executor/layers/layernorm.py | 3 +- .../layers/quantization/__init__.py | 2 +- .../model_executor/layers/quantization/inc.py | 16 ++++---- vllm/model_executor/model_loader/loader.py | 6 +-- vllm/model_executor/models/llama.py | 7 ++-- vllm/utils.py | 2 +- vllm/worker/cache_engine.py | 3 +- vllm/worker/habana_model_runner.py | 38 ++++++++++++------- vllm/worker/habana_worker.py | 3 -- 17 files changed, 110 insertions(+), 80 deletions(-) diff --git a/setup.py b/setup.py index 0470e172eb3f6..f7bec65f4cf4e 100644 --- a/setup.py +++ b/setup.py @@ -238,7 +238,7 @@ def _is_hpu() -> bool: is_hpu_available = True try: subprocess.run(["hl-smi"], capture_output=True, check=True) - except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError): + except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): if not os.path.exists('/dev/accel/accel0') and not os.path.exists( '/dev/accel/accel_controlD0'): # last resort... @@ -267,7 +267,7 @@ def _is_neuron() -> bool: torch_neuronx_installed = True try: subprocess.run(["neuron-ls"], capture_output=True, check=True) - except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError): + except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron" diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 6b1082325e026..7a867e79b203d 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -12,8 +12,8 @@ AttentionMetadata, AttentionType) from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) -from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.hpu import cache_ops +from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.logger import init_logger logger = init_logger(__name__) @@ -144,11 +144,11 @@ def __init__( self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.qk_matmul = Matmul() + self.matmul_qk = Matmul() self.softmax = Softmax() - self.av_matmul = Matmul() - self.key_cache = VLLMKVCache() - self.value_cache = VLLMKVCache() + self.matmul_av = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.position_bias = None @@ -212,9 +212,13 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - num_kv_cache_passes, num_slots_available, indices, offsets = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping) - key_cache = self.key_cache(key, key_cache, num_kv_cache_passes, num_slots_available, indices, offsets) - value_cache = self.value_cache(value, value_cache, num_kv_cache_passes, num_slots_available, indices, offsets) + num_kv_cache_passes, num_slots_available, indices, offsets = \ + cache_ops.prepare_to_cache(key_cache, + attn_metadata.slot_mapping) + key_cache = self.k_cache(key, key_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) + value_cache = self.v_cache(value, value_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) if attn_metadata.is_prompt: # Prompt run. @@ -240,9 +244,9 @@ def forward( attn_bias=attn_bias, p=0.0, scale=self.scale, - qk_matmul_op=self.qk_matmul, + matmul_qk_op=self.matmul_qk, softmax_op=self.softmax, - av_matmul_op=self.av_matmul, + matmul_av_op=self.matmul_av, ) output = out.reshape(batch_size, seq_len, hidden_size) else: @@ -266,8 +270,8 @@ def forward( query, key_cache, value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, self.kv_cache_dtype, self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale, self.qk_matmul, self.softmax, self.av_matmul, - self.key_cache, self.value_cache) + v_scale, self.matmul_qk, self.softmax, self.matmul_av, + self.k_cache, self.v_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 286fda8cd500c..9602886299c47 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -75,9 +75,9 @@ def forward_decode( alibi_slopes: Optional[torch.Tensor], k_scale: float, v_scale: float, - qk_matmul_op, + matmul_qk_op, softmax_op, - av_matmul_op, + matmul_av_op, k_cache_cls, v_cache_cls, ) -> torch.Tensor: @@ -93,9 +93,9 @@ def forward_decode( block_size, alibi_slopes, kv_cache_dtype, - qk_matmul_op, + matmul_qk_op, softmax_op, - av_matmul_op, + matmul_av_op, k_cache_cls, v_cache_cls, ) diff --git a/vllm/config.py b/vllm/config.py index ec7e8fed30fdb..6acb70ad047b2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -474,13 +474,13 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "hf8"): + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " "scaling factor. " - "FP8_E4M3 is also supported on hpu (hf8).") + "Intel Gaudi (HPU) supports fp8 (using fp8_inc).") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 29160143ef469..d6c544750afea 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -229,12 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'hf8'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). ' - 'FP8_E4M3 is also supported on hpu (hf8).') + 'Intel Gaudi (HPU) supports fp8 (using fp8_inc).') parser.add_argument( '--quantization-param-path', type=nullable_str, @@ -842,7 +842,8 @@ def create_engine_config(self, ) -> EngineConfig: self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path - device = device_config.device if self.weights_load_device is None else self.weights_load_device + device = device_config.device if self.weights_load_device is None else \ + self.weights_load_device load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 9527354719aba..98f109accea06 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -64,13 +64,14 @@ def prepare_to_cache(cache, slot_mapping): return num_kv_cache_passes, num_slots_available, indices, offsets -def insert_or_update_cache(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offsets): +def insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, block_offsets): for i in range(num_kv_cache_passes): start_idx = i * num_slots_available end_idx = (i + 1) * num_slots_available - cache.index_put_( - (block_indices[start_idx:end_idx], block_offsets[start_idx:end_idx]), - input[start_idx:end_idx]) + cache.index_put_((block_indices[start_idx:end_idx], + block_offsets[start_idx:end_idx]), + input[start_idx:end_idx]) def swap_blocks(src, dst, block_mapping): diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 3384729a1e479..f9e560de5b6b9 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -11,7 +11,6 @@ import torch import torch.nn.functional as F -import vllm.hpu.utils as hpu_utils from vllm.logger import init_logger logger = init_logger() @@ -43,9 +42,9 @@ def paged_attention_v1(query, block_size, alibi_slopes=None, kv_cache_dtype=None, - qk_matmul_op=torch.matmul, + matmul_qk_op=torch.matmul, softmax_op=torch.softmax, - av_matmul_op=torch.matmul, + matmul_av_op=torch.matmul, k_cache_cls=None, v_cache_cls=None) -> None: seq_len = block_tables.size(1) @@ -60,20 +59,22 @@ def paged_attention_v1(query, batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - fetch_keys = fetch_from_cache if k_cache_cls is None else k_cache_cls.fetch_from_cache + fetch_keys = fetch_from_cache if k_cache_cls is None else \ + k_cache_cls.fetch_from_cache keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = torch.cat([qk_matmul_op(query, k) for k in keys], dim=-1) + attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - fetch_values = fetch_from_cache if v_cache_cls is None else k_cache_cls.fetch_from_cache + fetch_values = fetch_from_cache if v_cache_cls is None else \ + v_cache_cls.fetch_from_cache values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) if PA_SPLIT_VALUE: attn_weights = attn_weights.split(block_size, dim=-1) @@ -82,7 +83,7 @@ def paged_attention_v1(query, attn_weights = [attn_weights] if query_heads != kv_heads: values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [av_matmul_op(a, v) for a, v in zip(attn_weights, values)] + attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] if query_heads != kv_heads: attn_weights = [a.flatten(1, 2) for a in attn_weights] attn_weights = sum(attn_weights) @@ -132,9 +133,9 @@ def prompt_attention( attn_bias: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, - qk_matmul_op = torch.matmul, - softmax_op = torch.softmax, - av_matmul_op = torch.matmul, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -147,11 +148,11 @@ def prompt_attention( value = value.unflatten(1, (kv_heads, 1)) if attn_bias is not None: attn_bias = attn_bias.unsqueeze(2) - attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2)) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) attn_weights = softmax_op(attn_weights, dim=-1) - attn_weights = av_matmul_op(attn_weights, value) + attn_weights = matmul_av_op(attn_weights, value) if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) attn_weights = attn_weights.transpose(1, 2) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index 8013f014ebd94..3d9c7cb1c4c22 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -5,12 +5,14 @@ # LICENSE file in the root directory of this source tree. ############################################################################### -import torch from functools import wraps import habana_frameworks.torch as htorch +import torch + from vllm.hpu.cache_ops import insert_or_update_cache + def with_mark_steps(fn): @wraps(fn) @@ -24,7 +26,9 @@ def wrapped(*args, **kwargs): return wrapped + class Matmul(torch.nn.Module): + def __init__(self): super(Matmul, self).__init__() @@ -33,19 +37,28 @@ def forward(self, x, y): class Softmax(torch.nn.Module): - def __init__(self): + + def __init__(self): super().__init__() - def forward(self, x, dim = None, inv_head = None): + def forward(self, x, dim=None, inv_head=None): return torch.softmax(x, dim) + class VLLMKVCache(torch.nn.Module): + def __init__(self): super(VLLMKVCache, self).__init__() - def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset): - insert_or_update_cache(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset) + def forward(self, input, cache, num_kv_cache_passes, num_slots_available, + block_indices, block_offset): + insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, + block_offset) return cache def fetch_from_cache(self, cache, blocks, permutations): - return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] + return [ + cache.index_select(0, blocks[:, i]).permute(permutations) + for i in range(blocks.size(1)) + ] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7434d02b60ada..c12668c14887d 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -86,8 +86,7 @@ def forward_hpu( self.variance_epsilon) return x.view(orig_shape), residual - x = HPUFusedRMSNorm.apply(x, self.weight, - self.variance_epsilon) + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) return x def forward_xpu( diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 746fa726354ba..7590d3e980275 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,9 +18,9 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.inc import INCConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig -from vllm.model_executor.layers.quantization.inc import INCConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index 931c9eefe741f..f6718ec2ac9e7 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -1,11 +1,9 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional import torch -from torch.nn import Module -from torch.nn.parameter import Parameter import torch.nn.functional as F +from torch.nn.parameter import Parameter -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( @@ -59,7 +57,8 @@ def get_quant_method(self, layer: torch.nn.Module, def get_scaled_act_names(self) -> List[str]: return [] - def get_min_capability(self) -> int: + @classmethod + def get_min_capability(cls) -> int: # The AWQ kernel only supports Turing or newer GPUs. return 75 @@ -67,6 +66,7 @@ def get_min_capability(self) -> int: def get_config_filenames() -> List[str]: return [] + class INCLinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -83,7 +83,9 @@ class INCLinearMethod(LinearMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: INCConfig, separate_bias_add: bool = False): + def __init__(self, + quant_config: INCConfig, + separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add self.quant_config = quant_config @@ -110,4 +112,4 @@ def apply(self, if bias is not None: return F.linear(x, weight) + bias return F.linear(x, weight) - return F.linear(x, weight, bias) \ No newline at end of file + return F.linear(x, weight, bias) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bb40a5835c3c8..06048d97088e1 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -37,7 +37,7 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_tpu, is_hpu +from vllm.utils import is_hpu, is_tpu logger = init_logger(__name__) @@ -53,8 +53,8 @@ def _get_quantization_config( capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " f"Minimum capability: {quant_config.get_min_capability()}. " f"Current capability: {capability}.") supported_dtypes = quant_config.get_supported_act_dtypes() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b71a4ee7e3b9d..676a51ce67f96 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,8 +48,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import is_hip, is_hpu +from vllm.utils import is_hip from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -317,7 +318,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if is_hpu(): + if current_platform.is_hpu(): import habana_frameworks.torch as htorch htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): @@ -329,7 +330,7 @@ def forward( attn_metadata, residual, ) - if is_hpu(): + if current_platform.is_hpu(): htorch.core.mark_step() if not get_pp_group().is_last_rank: diff --git a/vllm/utils.py b/vllm/utils.py index af8ddc294aa95..fe84253feb172 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -39,7 +39,7 @@ "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, - "hf8": torch.float8_e4m3fn, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 8e41cbfd511ff..ec0b8c2369210 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -91,7 +91,8 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - dtype = torch.int8 if self.dtype == torch.float8_e4m3fn else self.dtype + dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else \ + self.dtype kv_cache.append( torch.zeros(kv_cache_shape, dtype=dtype, diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 2588da84a3d6c..72aba42ae8553 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -413,6 +413,9 @@ def __init__( self._setup_buckets() def load_model(self) -> None: + import habana_frameworks.torch.core as htcore + if self.model_config.quantization == 'inc': + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( @@ -429,18 +432,21 @@ def load_model(self) -> None: f"took {m_getmodel.get_summary_string()}") logger.info(msg) - import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'inc': logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: - from neural_compressor.torch.quantization import FP8Config, convert, prepare - config = FP8Config.from_json_file(os.getenv("QUANT_CONFIG", "")) + from neural_compressor.torch.quantization import ( + FP8Config, convert, prepare) + config = FP8Config.from_json_file( + os.getenv("QUANT_CONFIG", "")) if config.measure: self.model = prepare(self.model, config) elif config.quantize: self.model = convert(self.model, config) - htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) - logger.info(f"Preparing model with INC took {m_inc.get_summary_string()}") + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) else: self.model = self.model.to("hpu") htcore.mark_step() @@ -1425,7 +1431,10 @@ def execute_model( if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) if htorch.utils.internal.is_lazy(): - execute_model_kwargs.update({"bypass_hpu_graphs":not use_graphs, "warmup_mode":warmup_mode}) + execute_model_kwargs.update({ + "bypass_hpu_graphs": not use_graphs, + "warmup_mode": warmup_mode + }) htorch.core.mark_step() if self.is_driver_worker: @@ -1439,8 +1448,8 @@ def execute_model( with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata. - selected_token_indices) + selected_token_indices=sampling_metadata.selected_token_indices + ) # Compute the logits. with self.profiler.record_event( @@ -1485,12 +1494,13 @@ def execute_model( def shutdown_inc(self): print('inc shutdown') - if model_config := getattr(self, "model_config", None): - if getattr(model_config, "quantization", None) == 'inc': - print('inc shutdown start') - from neural_compressor.torch.quantization import finalize_calibration - finalize_calibration(self.model.model) - print('inc shutdown') + if (model_config := getattr(self, "model_config", None)) and \ + getattr(model_config, "quantization", None) == 'inc': + print('inc shutdown start') + from neural_compressor.torch.quantization import ( + finalize_calibration) + finalize_calibration(self.model.model) + print('inc shutdown') def __del__(self): self.shutdown_inc() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 73278162dce03..87122c03d3c8f 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -118,9 +118,6 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - if self.model_config.quantization == 'inc': - import habana_frameworks.torch.core as htcore - htcore.hpu_set_env() self.model_runner.load_model() @torch.inference_mode()