From a6f8dee2b3da8b708b1e9ffff8346292442da8a6 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 6 Aug 2024 15:12:13 +0300 Subject: [PATCH 01/24] Inc on vLLM - Split qk and v calculations --- vllm/config.py | 3 ++ vllm/engine/arg_utils.py | 6 +++ vllm/engine/llm_engine.py | 25 +++++-------- vllm/model_executor/layers/linear.py | 55 ++++++++++++++++++++++------ vllm/model_executor/models/llama.py | 22 +++++++++-- 5 files changed, 80 insertions(+), 31 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7aa3977a497e..243018b5f01c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -431,6 +431,7 @@ class CacheConfig: cache_dtype: Data type for kv cache storage. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. + split_qk_v: Whether to split qk and v calculations. """ def __init__( @@ -443,6 +444,7 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + split_qk_v: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -452,6 +454,7 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb + self.split_qk_v = split_qk_v self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d6c544750afe..983d010b92ca 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -62,6 +62,7 @@ class EngineArgs: swap_space: int = 4 # GiB cpu_offload_gb: int = 0 # GiB gpu_memory_utilization: float = 0.90 + split_qk_v: bool = False max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API @@ -358,6 +359,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help='If specified, ignore GPU profiling result and use this number' 'of GPU blocks. Used for testing preemption.') + parser.add_argument('--split-qk-v', + type=bool, + default=EngineArgs.split_qk_v, + help='Whether to separate qk and v calculations.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -734,6 +739,7 @@ def create_engine_config(self, ) -> EngineConfig: swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, num_gpu_blocks_override=self.num_gpu_blocks_override, + split_qk_v=self.split_qk_v, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f8b9c48bc958..a68e86ede217 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -186,7 +186,7 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "enable_prefix_caching=%s)", + "enable_prefix_caching=%s, split_qk_v=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -217,6 +217,7 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, cache_config.enable_prefix_caching, + cache_config.split_qk_v, ) # TODO(woosuk): Print more configs in debug mode. @@ -274,32 +275,26 @@ def __init__( usage_context, extra_kvs={ # Common configuration - "dtype": - str(model_config.dtype), + "dtype": str(model_config.dtype), "tensor_parallel_size": parallel_config.tensor_parallel_size, - "block_size": - cache_config.block_size, + "block_size": cache_config.block_size, "gpu_memory_utilization": cache_config.gpu_memory_utilization, # Quantization - "quantization": - model_config.quantization, - "kv_cache_dtype": - str(cache_config.cache_dtype), + "quantization": model_config.quantization, + "kv_cache_dtype": str(cache_config.cache_dtype), # Feature flags - "enable_lora": - bool(lora_config), - "enable_prompt_adapter": - bool(prompt_adapter_config), + "enable_lora": bool(lora_config), + "enable_prompt_adapter": bool(prompt_adapter_config), "enable_prefix_caching": cache_config.enable_prefix_caching, - "enforce_eager": - model_config.enforce_eager, + "enforce_eager": model_config.enforce_eager, "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + "split_qk_v": cache_config.split_qk_v, }) if self.tokenizer: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 10c8a95f838d..f3fb477a6c88 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -525,7 +525,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + split_qk_v: bool = False): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -542,14 +543,21 @@ def __init__(self, else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 + self.split_qk_v = split_qk_v + self.q_size = self.num_heads * self.head_size * tp_size + self.kv_size = self.num_kv_heads * self.head_size * tp_size input_size = self.hidden_size - output_size = (self.num_heads + - 2 * self.num_kv_heads) * tp_size * self.head_size self.output_sizes = [ - self.num_heads * self.head_size * tp_size, # q_proj - self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj + self.q_size, # q_proj + self.kv_size, # k_proj ] + if split_qk_v: + output_size = (self.num_heads + + self.num_kv_heads) * tp_size * self.head_size + else: + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes.append(self.kv_size) # v_proj super().__init__(input_size=input_size, output_size=output_size, @@ -560,6 +568,16 @@ def __init__(self, quant_config=quant_config, prefix=prefix) + if split_qk_v: + self.v_proj = ColumnParallelLinear(input_size=input_size, + output_size=self.kv_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, @@ -641,13 +659,19 @@ def weight_loader(self, "q": (0, self.num_heads * self.head_size), "k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size), - "v": - ((self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size), - "total": - ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, - 0) } + if self.split_qk_v: + orig_qkv_offsets["total"] = ( + (self.num_heads + self.num_kv_heads) * self.head_size, + 0) + else: + orig_qkv_offsets["v"] = ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size) + orig_qkv_offsets["total"] = ( + (self.num_heads + 2 * self.num_kv_heads) * + self.head_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_shard( param, orig_qkv_offsets, loaded_shard_id) @@ -682,6 +706,13 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def forward(self, input_): + output, output_bias = super().forward(input_) + if not self.split_qk_v: + return output, output_bias + v, _ = self.v_proj(input_) + return output, v, output_bias + class RowParallelLinear(LinearBase): """Linear layer with row parallelism. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 51716b12513d..760844d12a98 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -133,6 +133,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.split_qk_v = cache_config.split_qk_v self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, @@ -142,6 +143,7 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + split_qk_v=self.split_qk_v, ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, @@ -172,8 +174,13 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.split_qk_v: + qk, v, _ = self.qkv_proj(hidden_states) + q, k = qk.split([self.q_size, self.kv_size], dim=-1) + else: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -388,6 +395,7 @@ def __init__( self.config = config self.lora_config = lora_config + self.split_qk_v = cache_config.split_qk_v self.model = LlamaModel(config, cache_config, @@ -466,10 +474,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + if self.split_qk_v: + stacked_params_mapping.append((".qkv_proj.v_proj", ".v_proj", "v")) + else: + stacked_params_mapping.append((".qkv_proj", ".v_proj", "v")) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -500,7 +511,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + if self.split_qk_v and shard_id == "v": + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: From 23e931b188a04fe0036864ca4783b924e6953bab Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 15:27:58 +0300 Subject: [PATCH 02/24] Support loading checkpoints quantized using Autofp8 --- vllm/hpu/ops.py | 55 ++++++++++++++++++- .../layers/fused_moe/fused_moe.py | 5 ++ .../compressed_tensors/compressed_tensors.py | 5 +- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 21 ++++--- .../layers/quantization/utils/w8a8_utils.py | 37 ++++++++++--- vllm/model_executor/models/llama.py | 6 +- vllm/worker/habana_model_runner.py | 3 +- 8 files changed, 110 insertions(+), 24 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 939d195a12b0..323a33e9fa2a 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. ############################################################################### -from typing import Optional +from typing import Optional, Tuple import habana_frameworks.torch as htorch import torch @@ -291,3 +291,56 @@ def forward(self, hidden_states, w1, w2, score, topk): final_hidden_states += current_hidden_states_static return final_hidden_states.view(-1, D) + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + if scale is None: + raise "dynamic scaled_fp8_quant not implemented for HPU" + #TODO: calculate scale to match gaudi2 240 range instead of 448 + if use_per_token_if_dynamic: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + else: + output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0] + + return output, scale \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 413c0b6d0924..3682362c5a86 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -11,6 +11,11 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.platforms import current_platform + +if current_platform.is_hpu(): + from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 39d00bd5733f..badb29af1f5f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -233,7 +233,7 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, @@ -306,7 +306,8 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - self._check_scheme_supported(scheme.get_min_capability()) + if torch.cuda.is_available(): + self._check_scheme_supported(scheme.get_min_capability()) return scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index cc9d71db140c..631774994b5c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,7 +21,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c829cb836ee4..f3e304ce141c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once +if current_platform.is_hpu(): + from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -112,13 +115,17 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + if torch.cuda.is_available(): + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + else: + self.cutlass_fp8_supported = False + self.use_marlin = False def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 20100c76bd69..8904c9fa1789 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,7 +6,10 @@ from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform - +if current_platform.is_hpu(): + import habana_frameworks.torch.utils.experimental as htexp + from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant def cutlass_fp8_supported() -> bool: capability = current_platform.get_device_capability() @@ -18,7 +21,15 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) + dtype = torch.float16 + device = tensor.device + if current_platform.is_hpu(): + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + #dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to('cpu') + + fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale return dq_weight @@ -76,7 +87,8 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - + if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale @@ -147,12 +159,19 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + if current_platform.is_hpu(): + #hpu does not support torch._scaled_mm (SW-197036) + output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight, + False, None, input.dtype, + x_scale, weight_scale, None, + False) + else: + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) return torch.narrow(output, 0, 0, input.shape[0]) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 51716b12513d..8ccefe7be33f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,8 +54,9 @@ from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers - -is_hpu = current_platform.is_hpu() +from vllm.platforms import current_platform +if current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore class LlamaMLP(nn.Module): @@ -521,6 +522,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) if current_platform.is_hpu(): torch.hpu.synchronize() + htcore.mark_step() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index ce3848ae0a6d..b0b9114ac2d0 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -562,8 +562,7 @@ def _set_gc_threshold(self) -> None: def load_model(self) -> None: import habana_frameworks.torch.core as htcore - if self.model_config.quantization == 'inc': - htcore.hpu_set_env() + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( From 363de3ce0b1b8ba0efd352c49d59b90303a4d66b Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 15:39:29 +0300 Subject: [PATCH 03/24] ruff fixes --- vllm/hpu/ops.py | 5 +++-- .../quantization/compressed_tensors/compressed_tensors.py | 3 ++- .../schemes/compressed_tensors_w8a8_fp8.py | 3 ++- vllm/model_executor/layers/quantization/fp8.py | 4 ++-- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 6 ++++-- vllm/model_executor/models/llama.py | 5 ++--- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 323a33e9fa2a..e46877f5abfa 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -329,7 +329,7 @@ def scaled_fp8_quant( else: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: - raise "dynamic scaled_fp8_quant not implemented for HPU" + raise RuntimeError("dynamic scaled_fp8_quant not implemented for HPU") #TODO: calculate scale to match gaudi2 240 range instead of 448 if use_per_token_if_dynamic: scale = torch.empty((input.numel() // input.shape[-1], 1), @@ -341,6 +341,7 @@ def scaled_fp8_quant( scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0] + output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, + dtype=torch.float8_e4m3fn)[0] return output, scale \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index badb29af1f5f..b1b594f09a90 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -233,7 +233,8 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) \ + if torch.cuda.is_available() else True if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 631774994b5c..9bcc155c5072 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,7 +21,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False + self.cutlass_fp8_supported = cutlass_fp8_supported() \ + if torch.cuda.is_available() else False @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f3e304ce141c..4704988fae00 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -118,8 +118,8 @@ def __init__(self, quant_config: Fp8Config): if torch.cuda.is_available(): self.cutlass_fp8_supported = cutlass_fp8_supported() - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8904c9fa1789..e909039cbc0b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -87,8 +87,10 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max) + if current_platform.is_hpu() and \ + htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * \ + (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8ccefe7be33f..2766a525f677 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,7 +48,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip @@ -321,7 +320,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if is_hpu: + if current_platform.is_hpu(): import habana_frameworks.torch as htorch htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): @@ -333,7 +332,7 @@ def forward( attn_metadata, residual, ) - if is_hpu: + if current_platform.is_hpu(): htorch.core.mark_step() if not get_pp_group().is_last_rank: From e4fc78b7458fe560cd1975b65482daaea369f8ff Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 15:41:09 +0300 Subject: [PATCH 04/24] ruff fixes --- .../quantization/compressed_tensors/compressed_tensors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b1b594f09a90..36b655935bcc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -233,7 +233,8 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) \ + CompressedTensorsW8A8Fp8.get_min_capability(), + error=False) \ if torch.cuda.is_available() else True if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( From d165c6e89479ee0966bd60d177162122472d7696 Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 15:44:06 +0300 Subject: [PATCH 05/24] isort fixes --- vllm/model_executor/layers/quantization/fp8.py | 1 + vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 ++ vllm/model_executor/models/llama.py | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4704988fae00..9682b14d968f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once + if current_platform.is_hpu(): from vllm.hpu.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index e909039cbc0b..5cd998d7326e 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,8 +6,10 @@ from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform + if current_platform.is_hpu(): import habana_frameworks.torch.utils.experimental as htexp + from vllm.hpu.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2766a525f677..eda04bc1f120 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,12 +48,13 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -from vllm.platforms import current_platform + if current_platform.is_hpu(): import habana_frameworks.torch.core as htcore From 6f0016bfb165a4252309b4ccc4b61905fa5c3178 Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 15:50:35 +0300 Subject: [PATCH 06/24] ruff format --- vllm/hpu/ops.py | 130 +- .../layers/fused_moe/fused_moe.py | 396 ++--- .../compressed_tensors/compressed_tensors.py | 282 ++-- .../schemes/compressed_tensors_w8a8_fp8.py | 88 +- .../model_executor/layers/quantization/fp8.py | 342 +++-- .../layers/quantization/utils/w8a8_utils.py | 153 +- vllm/model_executor/models/llama.py | 257 ++-- vllm/worker/habana_model_runner.py | 1331 ++++++++++------- 8 files changed, 1804 insertions(+), 1175 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index e46877f5abfa..4878b3c7ee05 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -16,17 +16,23 @@ HPUFusedRMSNorm = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm + HPUFusedRMSNorm = FusedRMSNorm except ImportError: - logger.warning("Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm.") + logger.warning( + "Could not import HPU FusedRMSNorm kernel. " + "vLLM will use forward_native implementation of RMSNorm." + ) HPUFusedSDPA = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA + HPUFusedSDPA = FusedSDPA except ImportError: - logger.warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") + logger.warning( + "Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation." + ) def batch2block(tensor, block_mapping): @@ -61,9 +67,19 @@ def block_softmax(batch_size, attn, block_mapping): return attn -def flat_pa(query, key_cache, value_cache, block_list, block_mapping, - block_bias, scale, matmul_qk_op, matmul_av_op, keys_fetch_func, - values_fetch_func): +def flat_pa( + query, + key_cache, + value_cache, + block_list, + block_mapping, + block_bias, + scale, + matmul_qk_op, + matmul_av_op, + keys_fetch_func, + values_fetch_func, +): batch_size = query.size(0) q_heads = query.size(1) kv_heads = key_cache.size(2) @@ -97,7 +113,7 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] -#TODO: remove after fusedsdpa fix for query_head != kv_head +# TODO: remove after fusedsdpa fix for query_head != kv_head def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). @@ -107,8 +123,9 @@ def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = kv.shape if n_rep == 1: return kv - kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) + kv = kv[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -144,15 +161,25 @@ def prompt_attention( if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) else: - #TODO: remove after fusedsdpa fix for query_heads != kv_heads + # TODO: remove after fusedsdpa fix for query_heads != kv_heads if query_heads != kv_heads: key = repeat_kv(key, int(query_heads // kv_heads)) value = repeat_kv(value, int(query_heads // kv_heads)) - softmax_mode = 'fast' + softmax_mode = "fast" recompute_mode = True - attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True, - scale, softmax_mode, recompute_mode, - valid_seq_lengths, 'right') + attn_weights = FusedSDPA.apply( + query, + key, + value, + None, + 0.0, + True, + scale, + softmax_mode, + recompute_mode, + valid_seq_lengths, + "right", + ) attn_weights = attn_weights.transpose(1, 2) return attn_weights @@ -190,7 +217,7 @@ def dispatch_bgmv_linear( the final output. """ - assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + assert layer_idx == 0, f"layer_idx should be 0, but got {layer_idx}" mask = LoraMask.getLoraMask() wa = wa_t_all[:, 0, :, :] @@ -199,7 +226,7 @@ def dispatch_bgmv_linear( wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2]) out = x @ wa - assert (out.shape == mask.shape) + assert out.shape == mask.shape out = out * mask out = out @ wb y += out * scale @@ -224,7 +251,7 @@ def dispatch_bgmv_embedding( output. """ - assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + assert layer_idx == 0, f"layer_idx should be 0, but got {layer_idx}" max_loras = wb_t_all.size(0) x = x.repeat(1, max_loras) @@ -236,7 +263,6 @@ def dispatch_bgmv_embedding( class MoeMatmul(torch.nn.Module): - def __init__(self): super().__init__() @@ -252,29 +278,32 @@ def forward(self, state): class StaticFusedMOE(torch.nn.Module): - def __init__(self, num_total_experts): super().__init__() self.w13_list = torch.nn.ModuleList( - [MoeMatmul() for _ in range(num_total_experts)]) + [MoeMatmul() for _ in range(num_total_experts)] + ) self.w2_list = torch.nn.ModuleList( - [MoeMatmul() for _ in range(num_total_experts)]) + [MoeMatmul() for _ in range(num_total_experts)] + ) self.num_total_experts = num_total_experts def forward(self, hidden_states, w1, w2, score, topk): B, D = hidden_states.shape routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) + routing_weights, selected_experts = torch.topk( + routing_weights, topk, dim=-1 + ) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros((1, B, D), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights = torch.zeros((B, self.num_total_experts), - dtype=hidden_states.dtype, - device=hidden_states.device) + final_hidden_states = torch.zeros( + (1, B, D), dtype=hidden_states.dtype, device=hidden_states.device + ) + padded_weights = torch.zeros( + (B, self.num_total_experts), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) padded_weights.scatter_(-1, selected_experts, routing_weights) padded_weights = padded_weights.reshape(-1, B, self.num_total_experts) padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) @@ -283,8 +312,9 @@ def forward(self, hidden_states, w1, w2, score, topk): for expert_idx in range(self.num_total_experts): padded_weight = padded_weights[expert_idx] current_state_static = hidden_states.reshape(-1, D) - w_output = self.w13_list[expert_idx].calc(current_state_static, - expert_idx, w1) + w_output = self.w13_list[expert_idx].calc( + current_state_static, expert_idx, w1 + ) w_output = silu_and_mul(w_output) w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2) current_hidden_states_static = w_output * padded_weight @@ -292,6 +322,7 @@ def forward(self, hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) + # fp8 def scaled_fp8_quant( input: torch.Tensor, @@ -300,7 +331,6 @@ def scaled_fp8_quant( scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Quantize input tensor to FP8 and return quantized tensor and scale. This function supports both static and dynamic quantization: If you @@ -311,11 +341,11 @@ def scaled_fp8_quant( Args: input: The input tensor to be quantized to FP8 scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic + scale_ub: Optional upper bound for scaling factor in dynamic per token case batch_dim_padding: If specified, pad the first dimension of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token + use_per_token_if_dynamic: Whether to do per_tensor or per_token in the dynamic quantization case. Returns: Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and @@ -323,25 +353,29 @@ def scaled_fp8_quant( """ if batch_dim_padding: shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) - output = torch.empty(shape, - device=input.device, - dtype=torch.float8_e4m3fn) + output = torch.empty( + shape, device=input.device, dtype=torch.float8_e4m3fn + ) else: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: - raise RuntimeError("dynamic scaled_fp8_quant not implemented for HPU") - #TODO: calculate scale to match gaudi2 240 range instead of 448 + raise RuntimeError("dynamic scaled_fp8_quant not implemented for HPU") + # TODO: calculate scale to match gaudi2 240 range instead of 448 if use_per_token_if_dynamic: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty( + (input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32, + ) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) + output, input, scale, scale_ub + ) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, - dtype=torch.float8_e4m3fn)[0] + output = torch.ops.hpu.cast_to_fp8_v2( + input, 1 / scale, False, False, dtype=torch.float8_e4m3fn + )[0] - return output, scale \ No newline at end of file + return output, scale diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3682362c5a86..585b5e0c64c7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,4 +1,5 @@ """Fused MoE kernel.""" + import functools import json import os @@ -15,6 +16,7 @@ if current_platform.is_hpu(): from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant logger = init_logger(__name__) @@ -112,12 +114,16 @@ def fused_moe_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) if use_fp8: a_scale = tl.load(a_scale_ptr) @@ -133,13 +139,14 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -150,9 +157,9 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load( + topk_weights_ptr + offs_token, mask=token_mask, other=0 + ) accumulator = accumulator * moe_weight[:, None] if use_fp8: @@ -162,15 +169,16 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = ( + c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + ) c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, - num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, block_size: int, num_experts: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -209,32 +217,45 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, - use_fp8: bool) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8: bool, +) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -245,8 +266,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + grid = lambda META: ( + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + ) fused_moe_kernel[grid]( A, @@ -284,8 +307,9 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @functools.lru_cache -def get_moe_configs(E: int, N: int, - dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, N: int, dtype: Optional[str] +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -300,11 +324,13 @@ def get_moe_configs(E: int, N: int, json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", - config_file_path) + logger.info( + "Using configuration from %s for MoE layer.", config_file_path + ) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} @@ -322,17 +348,17 @@ def get_default_config( dtype: Optional[str], ) -> Dict[str, int]: config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, } if M <= E: config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, } return config @@ -368,23 +394,21 @@ def fused_topk( topk: int, renormalize: bool, ): - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" M, _ = hidden_states.shape - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) ops.topk_softmax( topk_weights, topk_ids, @@ -399,59 +423,66 @@ def fused_topk( # This is used by the Deepseek-V2 model -def grouped_topk(hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0): - - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, - sorted=False)[1] # [n, top_k_group] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + num_token, num_expert_group, scores.shape[-1] // num_expert_group + ) + .reshape(num_token, -1) + ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=False + ) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None): +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] num_tokens, _ = hidden_states.shape E, N, _ = w1.shape @@ -471,18 +502,25 @@ def fused_experts(hidden_states: torch.Tensor, config = get_config_func(M) - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) - compute_type = (tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16) + compute_type = ( + tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + ) if inplace: out_hidden_states = hidden_states @@ -490,9 +528,10 @@ def fused_experts(hidden_states: torch.Tensor, out_hidden_states = torch.empty_like(hidden_states) for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape @@ -513,45 +552,52 @@ def fused_experts(hidden_states: torch.Tensor, curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) - - invoke_fused_moe_kernel(curr_hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8) + moe_align_block_size(curr_topk_ids, config["BLOCK_SIZE_M"], E) + ) + + invoke_fused_moe_kernel( + curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8) - - torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) return out_hidden_states @@ -608,22 +654,30 @@ def fused_moe( if use_grouped_topk: assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, - topk, renormalize, - num_expert_group, topk_group) + topk_weights, topk_ids = grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + ) else: - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - return fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - override_config=override_config, - use_fp8=use_fp8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + topk_weights, topk_ids = fused_topk( + hidden_states, gating_output, topk, renormalize + ) + + return fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 36b655935bcc..41e09231f86a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -5,29 +5,41 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, - CompressedTensorsScheme, CompressedTensorsUnquantized, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + W4A16SPARSE24_SUPPORTED_BITS, + WNA16_SUPPORTED_BITS, + CompressedTensorsScheme, + CompressedTensorsUnquantized, + CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16, +) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - CompressionFormat, QuantizationArgs, QuantizationStrategy, - QuantizationType, find_matched_target, is_activation_quantization_format, - should_ignore_layer) + CompressionFormat, + QuantizationArgs, + QuantizationStrategy, + QuantizationType, + find_matched_target, + is_activation_quantization_format, + should_ignore_layer, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform class CompressedTensorsConfig(QuantizationConfig): - - def __init__(self, - target_scheme_map: Dict[str, Any], - ignore: List[str], - quant_format: str, - kv_cache_scheme: Optional[Dict[str, Any]] = None): - + def __init__( + self, + target_scheme_map: Dict[str, Any], + ignore: List[str], + quant_format: str, + kv_cache_scheme: Optional[Dict[str, Any]] = None, + ): self.ignore = ignore self.quant_format = quant_format # Map from [target -> scheme] @@ -58,6 +70,7 @@ def get_quant_method( prefix: str, ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): @@ -82,28 +95,32 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": targets = quant_config.get("targets") for target in targets: target_scheme_map[target] = {} - target_scheme_map[target][ - "weights"] = QuantizationArgs.parse_obj( - quant_config.get("weights")) + target_scheme_map[target]["weights"] = ( + QuantizationArgs.parse_obj(quant_config.get("weights")) + ) try: - target_scheme_map[target][ - "input_activations"] = QuantizationArgs.parse_obj( - quant_config.get("input_activations")) + target_scheme_map[target]["input_activations"] = ( + QuantizationArgs.parse_obj( + quant_config.get("input_activations") + ) + ) except Exception: target_scheme_map[target]["input_activations"] = None - return cls(target_scheme_map=target_scheme_map, - ignore=ignore, - quant_format=quant_format, - kv_cache_scheme=config.get("kv_cache_scheme")) + return cls( + target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + kv_cache_scheme=config.get("kv_cache_scheme"), + ) @classmethod def get_config_filenames(cls) -> List[str]: return [] - def _check_scheme_supported(self, - min_capability: int, - error: bool = True) -> bool: + def _check_scheme_supported( + self, min_capability: int, error: bool = True + ) -> bool: capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] supported = capability >= min_capability @@ -111,54 +128,70 @@ def _check_scheme_supported(self, raise RuntimeError( "Quantization scheme is not supported for ", f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") + f"Current capability: {capability}.", + ) return supported - def _is_static_tensor_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_static_tensor_w8a8( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_tensor = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TENSOR.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_tensor = ( + weight_strategy + and input_quant.strategy == QuantizationStrategy.TENSOR.value + ) is_symmetric = weight_quant.symmetric and input_quant.symmetric is_static = not weight_quant.dynamic and not input_quant.dynamic return is_8_bits and is_tensor and is_symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w8a8( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_token = ( + weight_strategy + and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) is_symmetric = weight_quant.symmetric and input_quant.symmetric is_dynamic = not weight_quant.dynamic and input_quant.dynamic return is_8_bits and is_token and is_symmetric and is_dynamic - def _is_fp8_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False # Confirm we have floating points. - if not (weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT): + if not ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT + ): return False # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL - ]) - if not (is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight): + is_per_tensor_or_channel_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + ] + if not ( + is_symmetric_weight + and is_static_weight + and is_per_tensor_or_channel_weight + ): return False # Dynamic quantization is always supported if weights supported. @@ -168,15 +201,17 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, # Confirm activation scheme is supported. is_symmetric_activation = input_quant.symmetric is_per_tensor_activation = ( - input_quant.strategy == QuantizationStrategy.TENSOR) + input_quant.strategy == QuantizationStrategy.TENSOR + ) if not (is_symmetric_activation and is_per_tensor_activation): return False # All conditions satisfied. return True - def _is_fp8_w8a16(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a16( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -188,87 +223,108 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL - ]) - if not (is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight): + is_per_tensor_or_channel_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + ] + if not ( + is_symmetric_weight + and is_static_weight + and is_per_tensor_or_channel_weight + ): return False # All conditions satisfied. return True - def _is_wNa16_group_channel(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_wNa16_group_channel( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: input_quant_none = input_quant is None is_symmetric = weight_quant.symmetric is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value - or weight_quant.strategy == QuantizationStrategy.GROUP.value) + or weight_quant.strategy == QuantizationStrategy.GROUP.value + ) is_static = not weight_quant.dynamic - return (is_channel_group and input_quant_none and is_symmetric - and is_static) + return ( + is_channel_group and input_quant_none and is_symmetric and is_static + ) def _get_scheme_from_parts( - self, weight_quant: BaseModel, - input_quant: BaseModel) -> "CompressedTensorsScheme": - + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> "CompressedTensorsScheme": # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): - if (self.quant_format == CompressionFormat.marlin_24.value - and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + if ( + self.quant_format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS + ): return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, - group_size=weight_quant.group_size) - if (self.quant_format == CompressionFormat.pack_quantized.value - and weight_quant.num_bits in WNA16_SUPPORTED_BITS): + group_size=weight_quant.group_size, + ) + if ( + self.quant_format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS + ): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, - group_size=weight_quant.group_size) + group_size=weight_quant.group_size, + ) # Detect If Activation Quantization. if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), - error=False) \ - if torch.cuda.is_available() else True + is_fp8_w8a8_supported = ( + self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), + error=False, + ) + if torch.cuda.is_available() + else True + ) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=(not input_quant.dynamic)) + is_static_input_scheme=(not input_quant.dynamic), + ) else: return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=(input_quant - and not input_quant.dynamic)) + is_static_input_scheme=( + input_quant and not input_quant.dynamic + ), + ) if self._is_fp8_w8a16(weight_quant, input_quant): return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=(input_quant - and not input_quant.dynamic)) + is_static_input_scheme=( + input_quant and not input_quant.dynamic + ), + ) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( - strategy=weight_quant.strategy, - is_static_input_scheme=True) + strategy=weight_quant.strategy, is_static_input_scheme=True + ) if self._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( - strategy=weight_quant.strategy, - is_static_input_scheme=False) + strategy=weight_quant.strategy, is_static_input_scheme=False + ) raise NotImplementedError( - "No compressed-tensors compatible scheme was found.") + "No compressed-tensors compatible scheme was found." + ) def get_scheme( - self, - layer: torch.nn.Module, - layer_name: Optional[str] = None) -> "CompressedTensorsScheme": + self, layer: torch.nn.Module, layer_name: Optional[str] = None + ) -> "CompressedTensorsScheme": """ compressed-tensors supports non uniform in the following way: @@ -298,13 +354,15 @@ def get_scheme( matched_target = find_matched_target( layer_name=layer_name, module=layer, - targets=self.target_scheme_map.keys()) + targets=self.target_scheme_map.keys(), + ) # Find the quant_scheme scheme_dict = self.target_scheme_map[matched_target] scheme = self._get_scheme_from_parts( weight_quant=scheme_dict["weights"], - input_quant=scheme_dict["input_activations"]) + input_quant=scheme_dict["input_activations"], + ) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) @@ -315,20 +373,24 @@ def get_scheme( class CompressedTensorsLinearMethod(LinearMethodBase): - def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): """ - Use the CompressedTensorsScheme associated with each layer to create + Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param details """ @@ -343,17 +405,20 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes=output_partition_sizes, output_size=output_size, params_dtype=params_dtype, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.scheme = scheme - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): """ - Use the output of create_weights and the CompressedTensorsScheme - associated with the layer to apply the forward pass with the + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the layer input. See LinearMethodBase for param details """ @@ -391,18 +456,21 @@ def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]): raise NotImplementedError( "Currently supported kv cache quantization is " "num_bits=8, type=float, however " - f"received num_bits={num_bits}, type={type_}") + f"received num_bits={num_bits}, type={type_}" + ) strategy = kv_cache_scheme.get("strategy") if strategy != "tensor": raise NotImplementedError( "Only support per-tensor scaling factor " "for compressed-tensors KV cache. " - f"Expected strategy: tensor, found strategy: {strategy}") + f"Expected strategy: tensor, found strategy: {strategy}" + ) is_symmetric = kv_cache_scheme.get("symmetric") if not is_symmetric: raise NotImplementedError( "Only support symmetric scaling factor " "for compressed-tensors KV cache. " - f"However found symmetric: {is_symmetric}") + f"However found symmetric: {is_symmetric}" + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 9bcc155c5072..6dfb2a59f851 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,25 +4,30 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - QuantizationStrategy) + QuantizationStrategy, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_channel_scale_param, - create_per_tensor_scale_param, cutlass_fp8_supported, - requantize_with_max_scale) + apply_fp8_linear, + create_per_channel_scale_param, + create_per_tensor_scale_param, + cutlass_fp8_supported, + requantize_with_max_scale, +) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsW8A8Fp8"] class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() \ - if torch.cuda.is_available() else False + self.cutlass_fp8_supported = ( + cutlass_fp8_supported() if torch.cuda.is_available() else False + ) @classmethod def get_min_capability(cls) -> int: @@ -53,53 +58,69 @@ def process_weights_after_loading(self, layer) -> None: # INPUT SCALE if self.is_static_input_scheme: - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) else: layer.input_scale = None - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes # WEIGHT - weight = torch.nn.Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - requires_grad=False) + weight = torch.nn.Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - }) + set_weight_attrs( + weight, + { + "input_dim": 1, + "output_dim": 0, + "weight_loader": weight_loader, + }, + ) # WEIGHT SCALE layer_kwargs = {"weight_loader": weight_loader} if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = create_per_channel_scale_param( - output_partition_sizes, **layer_kwargs) + output_partition_sizes, **layer_kwargs + ) else: assert self.strategy == QuantizationStrategy.TENSOR weight_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + output_partition_sizes, **layer_kwargs + ) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: input_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + output_partition_sizes, **layer_kwargs + ) layer.register_parameter("input_scale", input_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return apply_fp8_linear( input=x, weight=layer.weight, @@ -107,4 +128,5 @@ def apply_weights(self, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, - use_per_token_if_dynamic=True) + use_per_token_if_dynamic=True, + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9682b14d968f..d43653bf827f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,25 +7,39 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + is_layer_skipped, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, apply_fp8_linear, convert_to_channelwise, - create_per_tensor_scale_param, cutlass_fp8_supported, - per_tensor_dequantize, requantize_with_max_scale) + all_close_1d, + apply_fp8_linear, + convert_to_channelwise, + create_per_tensor_scale_param, + cutlass_fp8_supported, + per_tensor_dequantize, + requantize_with_max_scale, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once if current_platform.is_hpu(): from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -44,11 +58,14 @@ def __init__( ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: - logger.warning("Detected fp8 checkpoint. Please note that the " - "format is experimental and subject to change.") + logger.warning( + "Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError( - f"Unsupported activation scheme {activation_scheme}") + f"Unsupported activation scheme {activation_scheme}" + ) self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] @@ -71,15 +88,18 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) - is_checkpoint_fp8_serialized = ("fp8" in quant_method) + is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) - return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, - activation_scheme=activation_scheme, - ignored_layers=ignored_layers) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): @@ -148,39 +168,51 @@ def create_weights( layer.orig_dtype = params_dtype # WEIGHT - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) - weight = Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - requires_grad=False) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + requires_grad=False, + ) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) + set_weight_attrs( + weight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = create_per_tensor_scale_param(output_partition_sizes, - **extra_weight_attrs) + scale = create_per_tensor_scale_param( + output_partition_sizes, **extra_weight_attrs + ) layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": - scale = create_per_tensor_scale_param(output_partition_sizes, - **extra_weight_attrs) + scale = create_per_tensor_scale_param( + output_partition_sizes, **extra_weight_attrs + ) layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, - scale=None) + qweight, weight_scale = ops.scaled_fp8_quant( + layer.weight, scale=None + ) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) @@ -194,8 +226,9 @@ def process_weights_after_loading(self, layer: Module) -> None: # so extend the weight scales to be channelwise. if self.use_marlin: weight = layer.weight - weight_scale = convert_to_channelwise(layer.weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. @@ -211,8 +244,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) else: layer.input_scale = None @@ -221,11 +255,12 @@ def process_weights_after_loading(self, layer: Module) -> None: # Activations not quantized for marlin. del layer.input_scale - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -234,7 +269,8 @@ def apply(self, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) return apply_fp8_linear( input=x, @@ -243,7 +279,8 @@ def apply(self, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, - use_per_token_if_dynamic=False) + use_per_token_if_dynamic=False, + ) class Fp8MoEMethod(FusedMoEMethodBase): @@ -262,42 +299,51 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) + w13_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_scale", w13_scale) - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_scale", w2_scale) # If loading fp8 checkpoint, pass the weight loaders. @@ -312,17 +358,20 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " - "was not serialized fp8.") + "was not serialized fp8." + ) - a13_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) + a13_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("a13_scale", a13_scale) set_weight_attrs(a13_scale, extra_weight_attrs) - a2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) + a2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("a2_scale", a2_scale) set_weight_attrs(a2_scale, extra_weight_attrs) else: @@ -330,32 +379,36 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, layer.a2_scale = None def process_weights_after_loading(self, layer: Module) -> None: - # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(layer.w2_weight.data, - dtype=torch.float8_e4m3fn) + w13_weight = torch.empty_like( + layer.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like( + layer.w2_weight.data, dtype=torch.float8_e4m3fn + ) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter(torch.ones( - layer.num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) + layer.w13_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) + w13_weight[expert, :, :], layer.w13_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter( + w13_weight, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) return # If checkpoint is fp8, we need to handle that the @@ -368,17 +421,22 @@ def process_weights_after_loading(self, layer: Module) -> None: if layer.a13_scale is None or layer.a2_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.a13_scale) - or not all_close_1d(layer.a2_scale)): + "activation scales are None." + ) + if not all_close_1d(layer.a13_scale) or not all_close_1d( + layer.a2_scale + ): print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") - layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), - requires_grad=False) - layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), - requires_grad=False) + "for each layer. " + ) + layer.a13_scale = torch.nn.Parameter( + layer.a13_scale.max(), requires_grad=False + ) + layer.a2_scale = torch.nn.Parameter( + layer.a2_scale.max(), requires_grad=False + ) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. @@ -389,44 +447,56 @@ def process_weights_after_loading(self, layer: Module) -> None: start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) + layer.w13_weight[expert_id][ + start : start + shard_size, : + ], + layer.w13_scale[expert_id][shard_id], + ) + ( + layer.w13_weight[expert_id][ + start : start + shard_size, : + ], + _, + ) = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id] + ) start += shard_size - layer.w13_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) return - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_moe - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 5cd998d7326e..c5c1da179166 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -11,8 +11,10 @@ import habana_frameworks.torch.utils.experimental as htexp from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant + def cutlass_fp8_supported() -> bool: capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] @@ -21,15 +23,15 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, - torch.Tensor]) -> torch.Tensor: + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: dtype = torch.float16 device = tensor.device if current_platform.is_hpu(): dtype = torch.bfloat16 if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - #dequant on cpu to avoid nan on gaudi2 - tensor = tensor.to('cpu') + # dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to("cpu") fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale @@ -45,34 +47,38 @@ def create_per_tensor_scale_param( output_partition_sizes: List[int], **extra_weight_attrs, ) -> Parameter: - scale = Parameter(torch.empty(len(output_partition_sizes), - dtype=torch.float32), - requires_grad=False) + scale = Parameter( + torch.empty(len(output_partition_sizes), dtype=torch.float32), + requires_grad=False, + ) scale[:] = torch.finfo(torch.float32).min - set_weight_attrs(scale, { - "needs_scalar_to_array": True, - **extra_weight_attrs - }) + set_weight_attrs( + scale, {"needs_scalar_to_array": True, **extra_weight_attrs} + ) return scale -def create_per_channel_scale_param(output_partition_sizes: List[int], - **extra_weight_attrs) -> Parameter: - scale = Parameter(torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - requires_grad=False) +def create_per_channel_scale_param( + output_partition_sizes: List[int], **extra_weight_attrs +) -> Parameter: + scale = Parameter( + torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + requires_grad=False, + ) scale[:] = torch.finfo(torch.float32).min set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs}) return scale def convert_to_channelwise( - weight_scale: torch.Tensor, - logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + weight_scale: torch.Tensor, logical_widths: List[int] +) -> Tuple[torch.Tensor, torch.Tensor]: # Create channelwise buffer - weight_scale_channel = torch.empty((sum(logical_widths), 1), - dtype=torch.float32, - device=weight_scale.device) + weight_scale_channel = torch.empty( + (sum(logical_widths), 1), + dtype=torch.float32, + device=weight_scale.device, + ) # Expand each scale to match the size of each logical matrix. start = 0 @@ -85,32 +91,39 @@ def convert_to_channelwise( def requantize_with_max_scale( - weight: torch.Tensor, weight_scale: torch.Tensor, - logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int] +) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - if current_platform.is_hpu() and \ - htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - max_w_scale = max_w_scale * \ - (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max) + if ( + current_platform.is_hpu() + and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2 + ): + max_w_scale = max_w_scale * ( + torch.finfo(torch.float8_e4m3fn).max + / torch.finfo(torch.float8_e4m3fnuz).max + ) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale # from disk in this case. Skip requantization in this case (since) # we already are quantized with the single scale. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo( - torch.float8_e4m3fn).min) + unfused_module_in_checkpoint = ( + weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) # If unfused checkpoint, need requanize with the single scale. if unfused_module_in_checkpoint: start = 0 for idx, logical_width in enumerate(logical_widths): end = start + logical_width - weight_dq = per_tensor_dequantize(weight[start:end, :], - weight_scale[idx]) + weight_dq = per_tensor_dequantize( + weight[start:end, :], weight_scale[idx] + ) weight[start:end, :], _ = ops.scaled_fp8_quant( - weight_dq, max_w_scale) + weight_dq, max_w_scale + ) start = end return max_w_scale, weight @@ -136,15 +149,18 @@ def apply_fp8_linear( input, input_scale, scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic) + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) # Fused GEMM_DQ - return ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + return ops.cutlass_scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token @@ -156,26 +172,37 @@ def apply_fp8_linear( input, input_scale, batch_dim_padding=17, - use_per_token_if_dynamic=use_per_token_if_dynamic) + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ if current_platform.is_hpu(): - #hpu does not support torch._scaled_mm (SW-197036) - output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight, - False, None, input.dtype, - x_scale, weight_scale, None, - False) + # hpu does not support torch._scaled_mm (SW-197036) + output = torch.ops.hpu.fp8_gemm_v2( + qinput, + False, + weight, + False, + None, + input.dtype, + x_scale, + weight_scale, + None, + False, + ) else: - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + output, _ = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) return torch.narrow(output, 0, 0, input.shape[0]) else: @@ -197,9 +224,9 @@ def apply_fp8_linear( # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=torch.float32) + output, _ = torch._scaled_mm( + qinput, weight, out_dtype=torch.float32 + ) # Unpad (undo batch_dim_padding) output = torch.narrow(output, 0, 0, input.shape[0]) @@ -223,9 +250,11 @@ def apply_int8_linear( # * static, layer.input_scale is scalar and x_scale is input_scale. x_q, x_scale = ops.scaled_int8_quant(input, input_scale) - return ops.cutlass_scaled_mm(x_q, - weight, - scale_a=x_scale, - scale_b=weight_scale, - out_dtype=input.dtype, - bias=bias) + return ops.cutlass_scaled_mm( + x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + bias=bias, + ) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index eda04bc1f120..d6f99e1746f2 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" + from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch @@ -29,24 +30,37 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, +) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - get_compressed_tensors_cache_scale) + get_compressed_tensors_cache_scale, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, + kv_cache_scales_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput @@ -60,7 +74,6 @@ class LlamaMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -76,15 +89,20 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -95,7 +113,6 @@ def forward(self, x): class LlamaAttention(nn.Module): - def __init__( self, config: LlamaConfig, @@ -127,8 +144,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -159,12 +177,14 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) def forward( self, @@ -182,7 +202,6 @@ def forward( class LlamaDecoderLayer(nn.Module): - def __init__( self, config: LlamaConfig, @@ -195,21 +214,26 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr( + config, "max_position_embeddings", 8192 + ) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -226,10 +250,12 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -245,7 +271,8 @@ def forward( hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -255,13 +282,13 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) return hidden_states, residual class LlamaModel(nn.Module): - def __init__( self, config: LlamaConfig, @@ -273,12 +300,16 @@ def __init__( super().__init__() self.config = config self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -288,11 +319,14 @@ def __init__( self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: LlamaDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -323,6 +357,7 @@ def forward( if current_platform.is_hpu(): import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] @@ -337,10 +372,9 @@ def forward( htorch.core.mark_step() if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -361,8 +395,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", ] embedding_modules = { "embed_tokens": "input_embeddings", @@ -390,11 +428,13 @@ def __init__( self.config = config self.lora_config = lora_config - self.model = LlamaModel(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + self.model = LlamaModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -406,16 +446,17 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) self.sampler = Sampler() else: self.lm_head = PPMissingLayer() @@ -427,17 +468,24 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None + input_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - input_embeds) + model_output = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + input_embeds, + ) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor( + self.lm_head, hidden_states, sampling_metadata + ) return logits def sample( @@ -449,18 +497,22 @@ def sample( return next_tokens def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + } + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -475,20 +527,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if ( + "rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name + ): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if scale_name := get_compressed_tensors_cache_scale(name): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -517,8 +572,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) if current_platform.is_hpu(): torch.hpu.synchronize() @@ -531,9 +587,12 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): if not isinstance(self.model.layers[layer_idx], nn.Identity): layer_self_attn = self.model.layers[layer_idx].self_attn @@ -546,5 +605,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if hasattr(layer_self_attn, "kv_scale"): layer_self_attn.attn._kv_scale = scaling_factor else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") + raise RuntimeError( + "Self attention has no KV cache scaling " + "factor attribute!" + ) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index b0b9114ac2d0..8ca8c50d025f 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -13,17 +13,36 @@ import os import time from enum import IntEnum -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, - Optional, Set, Tuple, Type, TypeVar, Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + MultiModalConfig, + ParallelConfig, + SchedulerConfig, +) from vllm.distributed.parallel_state import get_world_group from vllm.hpu.ops import LoraMask as LoraMask from vllm.logger import init_logger @@ -33,16 +52,26 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, - SequenceGroupMetadata) -from vllm.utils import (HabanaMemoryProfiler, format_bytes, - is_pin_memory_available, make_tensor_with_pad) +from vllm.sequence import ( + IntermediateTensors, + SamplerOutput, + SequenceData, + SequenceGroupMetadata, +) +from vllm.utils import ( + HabanaMemoryProfiler, + format_bytes, + is_pin_memory_available, + make_tensor_with_pad, +) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, + ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) + _init_sampling_metadata_from_tensor_dict, +) from .profiler import Profiler @@ -60,10 +89,12 @@ LORA_WARMUP_RANK = 8 -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): +def subtuple( + obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None, +): if obj is None: return None if to_override is None: @@ -71,8 +102,9 @@ def subtuple(obj: object, fields = set(to_copy) | set(to_override.keys()) values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) + _TYPE_CACHE[typename] = collections.namedtuple( + typename, " ".join(fields) + ) return _TYPE_CACHE[typename](**values) @@ -84,14 +116,14 @@ def read_bucket_settings(phase: str, dim: str, **defaults): param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ - params = ['min', 'step', 'max'] - env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] + params = ["min", "step", "max"] + env_vars = [f"VLLM_{phase}_{dim}_BUCKET_{p}".upper() for p in params] default_values = [defaults[p] for p in params] values = [ int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) ] for e, v, d in zip(env_vars, values, defaults): - logger.info('%s=%s (default:%s)', e, v, d) + logger.info("%s=%s (default:%s)", e, v, d) return values @@ -99,7 +131,7 @@ def warmup_range(config: Tuple[int, int, int]): """Generate a warmup range. Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you + Then, increase the values in the range by the value of bstep until you reach bmax. Example: @@ -109,29 +141,36 @@ def warmup_range(config: Tuple[int, int, int]): => return ramp_up + stable => (2, 4, 8, 16, 32, 64) """ bmin, bstep, bmax = config - assert bmin <= bmax, ("Min. batch size cannot be greater than max. " - "batch size. If you want to skip warmup, " - "set VLLM_SKIP_WARMUP=true") + assert bmin <= bmax, ( + "Min. batch size cannot be greater than max. " + "batch size. If you want to skip warmup, " + "set VLLM_SKIP_WARMUP=true" + ) base = itertools.repeat(2) ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ - ramp_up_acc) + ramp_up_tw = itertools.takewhile( + lambda x: x < bstep and x <= bmax, ramp_up_acc + ) stable = range(bstep, bmax + 1, bstep) buckets = list(ramp_up_tw) + list(stable) return list(filter(lambda bucket: bucket >= bmin, buckets)) -def generate_prompt_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): +def generate_prompt_buckets( + bs_bucket_config, seq_bucket_config, max_num_batched_tokens=None +): buckets = list( - itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config))) + itertools.product( + warmup_range(bs_bucket_config), warmup_range(seq_bucket_config) + ) + ) if len(buckets) == 0: - msg = ("No buckets could be captured with following config " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config}") + msg = ( + "No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}" + ) raise ValueError(msg) filtered_buckets = buckets @@ -140,12 +179,15 @@ def generate_prompt_buckets(bs_bucket_config, filtered_buckets = list( filter( lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets)) + buckets, + ) + ) if len(filtered_buckets) == 0: # we can handle this if we ignore max_num_batched_tokens - min_bucket_bs, min_bucket_seq = min(buckets, - key=lambda b: (b[0] * b[1])) + min_bucket_bs, min_bucket_seq = min( + buckets, key=lambda b: (b[0] * b[1]) + ) min_reqd_budget = min_bucket_bs * min_bucket_seq msg = ( "The current bucketing configuration " @@ -156,20 +198,23 @@ def generate_prompt_buckets(bs_bucket_config, f"smallest bucket ({min_reqd_budget}) would exceed token " "budget. Please increase max_num_batched_tokens or decrease " "bucket minimum Ignoring max_num_batched_tokens at risk of " - "out-of-memory errors.") + "out-of-memory errors." + ) logger.error(msg) return list( - sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] + sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])) + ), [] captured_buckets = list( - sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])) + ) omitted_buckets = list( - sorted([x for x in buckets if x not in filtered_buckets])) + sorted([x for x in buckets if x not in filtered_buckets]) + ) return captured_buckets, omitted_buckets -def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, - max_blocks): +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, max_blocks): buckets = [] for bs in warmup_range(bs_bucket_config): for blocks in warmup_range(blocks_bucket_config): @@ -205,28 +250,31 @@ def align_workers(value, op): world_size = torch.distributed.get_world_size() if world_size <= 1: return value - value_t = torch.tensor(value, device='cpu') + value_t = torch.tensor(value, device="cpu") torch.distributed.all_reduce(value_t, op=op, group=group) return value_t.item() def setup_profiler(): schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) - DEVICE = 'hpu' + DEVICE = "hpu" activities = [torch.profiler.ProfilerActivity.CPU] - activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == - 'hpu' else []) - #from habana_frameworks.torch.activity_profiler import DebugActivity - #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] + activities.extend( + [torch.profiler.ProfilerActivity.HPU] if DEVICE == "hpu" else [] + ) + # from habana_frameworks.torch.activity_profiler import DebugActivity + # debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] profiler = torch.profiler.profile( schedule=schedule, activities=activities, - #debug_activities=debug_activities, - on_trace_ready=torch.profiler.tensorboard_trace_handler('.', - use_gzip=True), + # debug_activities=debug_activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + ".", use_gzip=True + ), record_shapes=False, - with_stack=True) + with_stack=True, + ) return profiler @@ -236,77 +284,91 @@ def pad_list(list, k, v): return list + [v] * padding -class HpuModelAdapter(): - +class HpuModelAdapter: def __init__(self, model, block_size, dtype, enforce_eager): self.model = model - self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', - '0').lower() in ['1', 'true'] + self.prefill_use_fusedsdpa = os.getenv( + "VLLM_PROMPT_USE_FUSEDSDPA", "0" + ).lower() in ["1", "true"] self.block_size = block_size self.dtype = dtype if not htorch.utils.internal.is_lazy() and not enforce_eager: - self.model = torch.compile(self.model, - backend='hpu_backend', - dynamic=False) + self.model = torch.compile( + self.model, backend="hpu_backend", dynamic=False + ) - def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, - dtype): + def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): prefill_metadata = attn_metadata if prefill_metadata is None or self.prefill_use_fusedsdpa: return attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor - len_mask = (torch.arange(0, seq_len, device=device, - dtype=torch.int32).view(1, seq_len).ge( - seq_lens_t.unsqueeze(-1)).view( - batch_size, 1, 1, seq_len)) - causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), - device=device, - dtype=torch.bool), - diagonal=1) + len_mask = ( + torch.arange(0, seq_len, device=device, dtype=torch.int32) + .view(1, seq_len) + .ge(seq_lens_t.unsqueeze(-1)) + .view(batch_size, 1, 1, seq_len) + ) + causal_mask = torch.triu( + torch.ones( + (batch_size, 1, seq_len, seq_len), + device=device, + dtype=torch.bool, + ), + diagonal=1, + ) mask = causal_mask.logical_or(len_mask) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf)) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf + ) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata def _set_block_mapping(self, metadata, batch_size, device, dtype): - mask = torch.arange(0, - self.block_size, - device=device, - dtype=torch.int32).unsqueeze(0) + mask = torch.arange( + 0, self.block_size, device=device, dtype=torch.int32 + ).unsqueeze(0) mask = mask >= metadata.block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf)) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf + ) block_mapping = torch.nn.functional.one_hot( - metadata.block_mapping.to(torch.long), - num_classes=batch_size).to(dtype) - metadata = metadata._replace(block_mapping=block_mapping, - attn_bias=attn_bias) + metadata.block_mapping.to(torch.long), num_classes=batch_size + ).to(dtype) + metadata = metadata._replace( + block_mapping=block_mapping, attn_bias=attn_bias + ) return metadata - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, - dtype): + def _update_metadata( + self, attn_metadata, batch_size, seq_len, device, dtype + ): if attn_metadata.is_prompt: meta = attn_metadata - attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, - device, dtype) + attn_metadata = self._set_attn_bias( + meta, batch_size, seq_len, device, dtype + ) else: meta = attn_metadata - attn_metadata = self._set_block_mapping(meta, batch_size, device, - dtype) + attn_metadata = self._set_block_mapping( + meta, batch_size, device, dtype + ) return attn_metadata def forward(self, *args, **kwargs): kwargs = kwargs.copy() - selected_token_indices = kwargs.pop('selected_token_indices') - if 'warmup_mode' in kwargs: - kwargs.pop('warmup_mode') - input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._update_metadata( - kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) - LoraMask.setLoraMask(kwargs.pop('lora_mask')) + selected_token_indices = kwargs.pop("selected_token_indices") + if "warmup_mode" in kwargs: + kwargs.pop("warmup_mode") + input_ids = kwargs["input_ids"] + kwargs["attn_metadata"] = self._update_metadata( + kwargs["attn_metadata"], + input_ids.size(0), + input_ids.size(1), + input_ids.device, + self.dtype, + ) + LoraMask.setLoraMask(kwargs.pop("lora_mask")) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -335,18 +397,20 @@ class PreparePromptMetadata(NamedTuple): @classmethod def empty(cls): - return PreparePromptMetadata(input_tokens=[], - input_positions=[], - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - multi_modal_input=None, - slot_mapping=[], - lora_mask=None, - lora_logits_mask=None) + return PreparePromptMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + lora_mask=None, + lora_logits_mask=None, + ) class PrepareDecodeMetadata(NamedTuple): @@ -362,15 +426,17 @@ class PrepareDecodeMetadata(NamedTuple): @classmethod def empty(cls): - return PrepareDecodeMetadata(input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], - lora_mask=None, - lora_logits_mask=None) + return PrepareDecodeMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + lora_mask=None, + lora_logits_mask=None, + ) # How batches are constructed. @@ -383,7 +449,7 @@ class BatchType(IntEnum): MIXED = 2 -TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU") +TModelInputForHPU = TypeVar("TModelInputForHPU", bound="ModelInputForHPU") @dataclasses.dataclass(frozen=True) @@ -394,6 +460,7 @@ class ModelInputForHPU(ModelRunnerInputBase): runners that run additional steps should subclass this method to add additional fields. """ + input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None @@ -432,7 +499,8 @@ def from_broadcasted_tensor_dict( ) -> TModelInputForHPU: if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) + attn_backend, tensor_dict + ) return cls(**tensor_dict) @@ -441,6 +509,7 @@ class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU): """ Used by the ModelRunner. """ + sampling_metadata: Optional["SamplingMetadata"] = None # Used for speculative decoding. We do not broadcast it because it is only # used by the driver worker. @@ -457,8 +526,9 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_logits_mask": self.lora_logits_mask, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) + _add_sampling_metadata_broadcastable_dict( + tensor_dict, self.sampling_metadata + ) return tensor_dict @classmethod @@ -471,7 +541,8 @@ def from_broadcasted_tensor_dict( # FIXME(kzawora): this fails for whatever reason - why? if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) + attn_backend, tensor_dict + ) return cls(**tensor_dict) @@ -479,6 +550,7 @@ class HabanaModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): """ Helper class for shared methods between GPU model runners. """ + _model_input_cls: Type[TModelInputForHPU] def __init__( @@ -503,17 +575,22 @@ def __init__( self.is_driver_worker = is_driver_worker self.profiler = Profiler() - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) + self.sliding_window = ( + model_config.get_sliding_window() + if model_config is not None + else None + ) + self.device_config = ( + device_config if device_config is not None else DeviceConfig() + ) self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs self.max_model_len = self.scheduler_config.max_model_len - self.max_num_batched_tokens = \ + self.max_num_batched_tokens = ( self.scheduler_config.max_num_batched_tokens + ) self.block_size = cache_config.block_size self.pin_memory = is_pin_memory_available() @@ -551,17 +628,16 @@ def _set_gc_threshold(self) -> None: requested_gc_thrs = [0] * len(default_gc_thrs) for i in range(len(default_gc_thrs)): requested_gc_thrs[i] = int( - os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i])) + os.environ.get(f"VLLM_GC_THR_GEN{i}", default_gc_thrs[i]) + ) if requested_gc_thrs == default_gc_thrs: - gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER', - 2)) - requested_gc_thrs = [ - t * gc_thr_multiplier for t in default_gc_thrs - ] + gc_thr_multiplier = int(os.environ.get("VLLM_GC_THR_MULTIPLIER", 2)) + requested_gc_thrs = [t * gc_thr_multiplier for t in default_gc_thrs] gc.set_threshold(*requested_gc_thrs) def load_model(self) -> None: import habana_frameworks.torch.core as htcore + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: @@ -573,44 +649,60 @@ def load_model(self) -> None: multimodal_config=self.multimodal_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - cache_config=self.cache_config) - msg = ("Pre-loading model weights on " - f"{next(self.model.parameters()).device} " - f"took {m_getmodel.get_summary_string()}") + cache_config=self.cache_config, + ) + msg = ( + "Pre-loading model weights on " + f"{next(self.model.parameters()).device} " + f"took {m_getmodel.get_summary_string()}" + ) logger.info(msg) if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") - assert hasattr(self.model, "embedding_modules" - ), "Model does not have embedding_modules" + assert ( + hasattr(self.model, "supported_lora_modules") + and self.model.supported_lora_modules + ), "Model does not support LoRA" + assert hasattr( + self.model, "embedding_modules" + ), "Model does not have embedding_modules" assert hasattr( self.model, "embedding_padding_modules" ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, self.lora_config, self.device, + self.vocab_size, + self.lora_config, + self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.model.embedding_padding_modules, + ) self.model = self.lora_manager.create_lora_manager(self.model) - if self.model_config.quantization == 'inc': + if self.model_config.quantization == "inc": logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: from neural_compressor.torch.quantization import ( - FP8Config, convert, prepare) + FP8Config, + convert, + prepare, + ) + config = FP8Config.from_json_file( - os.getenv("QUANT_CONFIG", "")) + os.getenv("QUANT_CONFIG", "") + ) if config.measure: self.model = prepare(self.model, config) elif config.quantize: self.model = convert(self.model, config) - htcore.hpu_initialize(self.model, - mark_only_scales_as_const=True) - logger.info("Preparing model with INC took %s", - m_inc.get_summary_string()) + htcore.hpu_initialize( + self.model, mark_only_scales_as_const=True + ) + logger.info( + "Preparing model with INC took %s", + m_inc.get_summary_string(), + ) else: self.model = self.model.to("hpu") htcore.mark_step() @@ -621,7 +713,8 @@ def load_model(self) -> None: self.model, self.block_size, dtype=self.model_config.dtype, - enforce_eager=self.enforce_eager) + enforce_eager=self.enforce_eager, + ) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) @@ -640,46 +733,60 @@ def _is_valid_bucket(self, bucket): def _setup_buckets(self) -> None: align_bs = lambda x: min(self.max_num_seqs, x) max_bucket_cfg = 64 - if self.lora_config and \ - max_bucket_cfg > self.max_num_batched_tokens // self.block_size: + if ( + self.lora_config + and max_bucket_cfg > self.max_num_batched_tokens // self.block_size + ): max_bucket_cfg = self.max_num_batched_tokens // self.block_size blocks_step = 128 - #FIXME: The default values should be max_model_len + # FIXME: The default values should be max_model_len max_prompt_seq = 1024 max_decode_seq = 2048 self.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', + "prompt", + "bs", min=1, step=align_bs(32), - max=align_bs(max_bucket_cfg)) - self.decode_bs_bucket_cfg = read_bucket_settings('decode', - 'bs', - min=align_bs(32), - step=align_bs(32), - max=self.max_num_seqs) - self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) + max=align_bs(max_bucket_cfg), + ) + self.decode_bs_bucket_cfg = read_bucket_settings( + "decode", + "bs", + min=align_bs(32), + step=align_bs(32), + max=self.max_num_seqs, + ) + self.prompt_seq_bucket_cfg = read_bucket_settings( + "prompt", + "seq", + min=self.block_size, + step=self.block_size, + max=max_prompt_seq, + ) self.decode_block_bucket_cfg = read_bucket_settings( - 'decode', - 'block', + "decode", + "block", min=blocks_step, step=blocks_step, - max=max(blocks_step, - self.max_num_seqs * max_decode_seq // self.block_size)) + max=max( + blocks_step, + self.max_num_seqs * max_decode_seq // self.block_size, + ), + ) self.graphed_buckets: Set[Any] = set() - msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.prompt_bs_bucket_cfg}, " - f"seq:{self.prompt_seq_bucket_cfg}") + msg = ( + "Prompt bucket config (min, step, max_warmup) " + f"bs:{self.prompt_bs_bucket_cfg}, " + f"seq:{self.prompt_seq_bucket_cfg}" + ) logger.info(msg) - msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.decode_bs_bucket_cfg}, " - f"block:{self.decode_block_bucket_cfg}") + msg = ( + "Decode bucket config (min, step, max_warmup) " + f"bs:{self.decode_bs_bucket_cfg}, " + f"block:{self.decode_block_bucket_cfg}" + ) logger.info(msg) def _prepare_prompt( @@ -709,13 +816,16 @@ def _prepare_prompt( seq_id = seq_ids[0] computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): + if ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not ( + computed_block_nums is None or computed_block_nums == [] + ) + ): raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") + "chunked prefill cannot be used with prefix caching " "now." + ) token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] @@ -727,8 +837,11 @@ def _prepare_prompt( seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: + if ( + computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + ): # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] @@ -757,7 +870,8 @@ def _prepare_prompt( if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) + seq_group_metadata.multi_modal_data.data + ) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -778,7 +892,8 @@ def _prepare_prompt( if self.sliding_window is not None: assert context_len == 0, ( "Prefix caching is currently not supported with " - "sliding window attention") + "sliding window attention" + ) start_idx = max(0, seq_len - self.sliding_window) for i in range(context_len, seq_len): if i < start_idx: @@ -798,15 +913,18 @@ def _prepare_prompt( if multi_modal_input_list: assert self.multimodal_config, ( "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) + "vision language models." + ) + multi_modal_input = torch.cat(multi_modal_input_list, dim=0).to( + self.device + ) else: multi_modal_input = None max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), - self.block_size) + self.block_size, + ) lora_mask: torch.Tensor = None lora_logits_mask: torch.Tensor = None @@ -815,20 +933,27 @@ def _prepare_prompt( lora_mask = torch.zeros( len(seq_group_metadata_list) * max_prompt_len, (self.lora_config.max_loras) * self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - lora_logits_mask = torch.zeros(len(seq_group_metadata_list), - (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - ones = torch.ones(max_prompt_len, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - logit_ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - for seq_group_metadata, context_len in zip(seq_group_metadata_list, - context_lens): + dtype=self.lora_config.lora_dtype, + ) + lora_logits_mask = torch.zeros( + len(seq_group_metadata_list), + (self.lora_config.max_loras) * self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype, + ) + + ones = torch.ones( + max_prompt_len, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype, + ) + logit_ones = torch.ones( + 1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype, + ) + for seq_group_metadata, context_len in zip( + seq_group_metadata_list, context_lens + ): lora_id = seq_group_metadata.lora_int_id if lora_id > 0: @@ -843,35 +968,45 @@ def _prepare_prompt( lora_index_mapping += [lora_id] * (max_prompt_len - context_len) lora_prompt_mapping.extend( - [lora_id] * - (max_prompt_len - context_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + [lora_id] + * ( + max_prompt_len - context_len + if seq_group_metadata.sampling_params.prompt_logprobs + else 1 + ) + ) if lora_mask is not None: - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_logits_mask.to('hpu') - - input_tokens = make_tensor_with_pad(input_tokens, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - - input_positions = make_tensor_with_pad(input_positions, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - - slot_mapping = make_tensor_with_pad(slot_mapping, - max_len=max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.long, - device=self.device) + lora_mask = lora_mask.to("hpu") + lora_logits_mask = lora_logits_mask.to("hpu") + + input_tokens = make_tensor_with_pad( + input_tokens, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device, + ) + + input_positions = make_tensor_with_pad( + input_positions, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device, + ) + + slot_mapping = make_tensor_with_pad( + slot_mapping, + max_len=max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device, + ) + + seq_lens_tensor = torch.tensor( + seq_lens, dtype=torch.long, device=self.device + ) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -920,16 +1055,20 @@ def _prepare_decode( counter = 0 if self.lora_config: - lora_mask = torch.zeros(len(seq_group_metadata_list), - (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) + lora_mask = torch.zeros( + len(seq_group_metadata_list), + (self.lora_config.max_loras) * self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype, + ) + ones = torch.ones( + 1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype, + ) dummy_slots = itertools.cycle( - range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) + range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size) + ) for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -954,8 +1093,11 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) + seq_len = ( + seq_len + if self.sliding_window is None + else min(seq_len, self.sliding_window) + ) seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] @@ -970,20 +1112,21 @@ def _prepare_decode( lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) + sliding_window_blocks = ( + self.sliding_window // self.block_size + ) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) if lora_mask is not None: - lora_mask = lora_mask.to('hpu') + lora_mask = lora_mask.to("hpu") lora_logits_mask = lora_mask - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + input_tokens = torch.tensor( + input_tokens, dtype=torch.long, device=self.device + ) + input_positions = torch.tensor( + input_positions, dtype=torch.long, device=self.device + ) num_decode_tokens = sum(seq_lens) @@ -993,34 +1136,38 @@ def _prepare_decode( [i] * b_u for i, b_u in enumerate(blocks_used) ] block_mapping: List[int] = list( - itertools.chain.from_iterable(block_mapping_nested)) + itertools.chain.from_iterable(block_mapping_nested) + ) last_block = [ sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) ] - block_usage = [[self.block_size] * (b_u - 1) + [lb] - for b_u, lb in zip(blocks_used, last_block)] + block_usage = [ + [self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block) + ] block_usage = list(itertools.chain(*block_usage)) - block_bucket_size = find_bucket(len(block_list), - self.decode_block_bucket_cfg) + block_bucket_size = find_bucket( + len(block_list), self.decode_block_bucket_cfg + ) block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) block_mapping = pad_list(block_mapping, block_bucket_size, 0) block_usage = pad_list(block_usage, block_bucket_size, 0) - block_list = torch.tensor(block_list, - dtype=torch.int, - device=self.device) - block_mapping = torch.tensor(block_mapping, - dtype=torch.int, - device=self.device) - block_usage = torch.tensor(block_usage, - dtype=torch.bfloat16, - device=self.device) + block_list = torch.tensor( + block_list, dtype=torch.int, device=self.device + ) + block_mapping = torch.tensor( + block_mapping, dtype=torch.int, device=self.device + ) + block_usage = torch.tensor( + block_usage, dtype=torch.bfloat16, device=self.device + ) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) + slot_mapping = torch.tensor( + slot_mapping, dtype=torch.long, device=self.device + ) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, @@ -1066,17 +1213,21 @@ def prepare_input_tensors( self.event_start = self.profiler.get_timestamp_us() is_prompt = seq_group_metadata_list[0].is_prompt - base_event_name = 'prompt' if is_prompt else 'decode' - self.profiler.start('internal', base_event_name) + base_event_name = "prompt" if is_prompt else "decode" + self.profiler.start("internal", base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else \ - self.decode_bs_bucket_cfg + bucket_cfg = ( + self.prompt_bs_bucket_cfg + if is_prompt + else self.decode_bs_bucket_cfg + ) batch_size_padded = find_bucket(real_batch_size, bucket_cfg) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() - seq_group_metadata_list.extend(seq_group_metadata_list[0] - for _ in range(batch_size_padding)) + seq_group_metadata_list.extend( + seq_group_metadata_list[0] for _ in range(batch_size_padding) + ) prefill_reqs = [] decode_reqs = [] @@ -1112,10 +1263,13 @@ def prepare_input_tensors( decode_lora_mask, decode_lora_logits_mask, ) = self._prepare_decode(decode_reqs) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - seq_lens, query_lens, - self.device, - self.pin_memory) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens, + self.device, + self.pin_memory, + ) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 @@ -1128,8 +1282,8 @@ def prepare_input_tensors( # support mixed batches, so we either use decode or prefill # inputs, without coalescing. assert (num_prefills == 0 and num_decode_tokens > 0) or ( - num_prefills > 0 - and num_decode_tokens == 0), "HPU does not support mixed batches!" + num_prefills > 0 and num_decode_tokens == 0 + ), "HPU does not support mixed batches!" if num_decode_tokens > 0: input_tokens = decode_input_tokens input_positions = decode_input_positions @@ -1148,13 +1302,16 @@ def prepare_input_tensors( paddings = list(itertools.accumulate(paddings)) paddings_prompt_logprobs = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.sampling_params.prompt_logprobs is not None \ - and seq_group_metadata.is_prompt: - paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) + if ( + seq_group_metadata.sampling_params.prompt_logprobs is not None + and seq_group_metadata.is_prompt + ): + paddings_prompt_logprobs += [paddings[i]] * seq_lens[i] paddings = torch.tensor( paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device) + device=sampling_metadata.selected_token_indices.device, + ) sampling_metadata.selected_token_indices.add_(paddings) if self.lora_config: @@ -1165,8 +1322,10 @@ def prepare_input_tensors( else: lora_mapping = None - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): + if ( + prefill_attn_metadata is not None + and decode_attn_metadata is not None + ): batch_type = BatchType.MIXED raise NotImplementedError("Mixed batch is not supported on HPU") elif prefill_attn_metadata is not None: @@ -1187,7 +1346,7 @@ def prepare_input_tensors( "num_prefills": num_prefills, "batch_type": batch_type, "seq_lens": seq_lens, - "query_lens": query_lens + "query_lens": query_lens, } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) @@ -1195,22 +1354,26 @@ def prepare_input_tensors( assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) - attn_metadata = prefill_attn_metadata if \ - prefill_attn_metadata is not None else decode_attn_metadata - - return self._model_input_cls(input_tokens=input_tokens, - seq_lens=seq_lens, - query_lens=query_lens, - input_positions=input_positions, - attn_metadata=attn_metadata, - lora_requests=lora_requests, - lora_mapping=lora_mapping, - multi_modal_kwargs=multi_modal_input, - real_batch_size=real_batch_size, - batch_size_padded=batch_size_padded, - lora_mask=lora_mask, - lora_logits_mask=lora_logits_mask), \ - sampling_metadata + attn_metadata = ( + prefill_attn_metadata + if prefill_attn_metadata is not None + else decode_attn_metadata + ) + + return self._model_input_cls( + input_tokens=input_tokens, + seq_lens=seq_lens, + query_lens=query_lens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_requests=lora_requests, + lora_mapping=lora_mapping, + multi_modal_kwargs=multi_modal_input, + real_batch_size=real_batch_size, + batch_size_padded=batch_size_padded, + lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask, + ), sampling_metadata def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: @@ -1239,17 +1402,24 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") - attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', - 'block_usage', 'slot_mapping', 'is_prompt' - ]) + attention_metadata = subtuple( + metadata, + "TrimmedAttentionMetadata", + [ + "attn_bias", + "seq_lens_tensor", + "block_list", + "block_mapping", + "block_usage", + "slot_mapping", + "is_prompt", + ], + ) return attention_metadata - def create_dummy_seq_group_metadata(self, - group_id, - seq_len, - is_prompt, - lora_request=None): + def create_dummy_seq_group_metadata( + self, group_id, seq_len, is_prompt, lora_request=None + ): sampling_params = SamplingParams(temperature=0) num_blocks = math.ceil(seq_len / self.block_size) if is_prompt: @@ -1264,35 +1434,38 @@ def create_dummy_seq_group_metadata(self, output_token_ids = [1] * output_len seq_data = SequenceData(prompt_token_ids) seq_data.output_token_ids = output_token_ids - return SequenceGroupMetadata(request_id=str(group_id), - is_prompt=(output_len == 0), - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=block_tables, - lora_request=lora_request) + return SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=(output_len == 0), + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=lora_request, + ) def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = min(self.prompt_seq_bucket_cfg[-1], - self.max_num_batched_tokens // max_batch_size) + max_seq_len = min( + self.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size, + ) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches) return - def warmup_scenario(self, - batch_size, - seq_len, - is_prompt, - kv_caches, - is_profile_run=False) -> None: + def warmup_scenario( + self, batch_size, seq_len, is_prompt, kv_caches, is_profile_run=False + ) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - scenario_name = ("warmup_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") + scenario_name = ( + "warmup_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}" + ) max_num_seqs = self.scheduler_config.max_num_seqs # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory @@ -1310,14 +1483,15 @@ def warmup_scenario(self, lora_int_id=lora_id, lora_local_path="/not/a/real/path", ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) + self.lora_manager.add_dummy_lora( + dummy_lora_request, rank=LORA_WARMUP_RANK + ) dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs) ] - self.profiler.start('internal', scenario_name) + self.profiler.start("internal", scenario_name) times = 3 if use_graphs or is_profile_run else 1 if self.lora_config and not is_profile_run: lora_mapping = LoRAMapping( @@ -1332,7 +1506,9 @@ def warmup_scenario(self, seq_len, is_prompt, lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) + if dummy_lora_requests_per_seq + else None, + ) for i in range(batch_size) ] else: @@ -1345,7 +1521,9 @@ def warmup_scenario(self, b * self.block_size - 1, is_prompt, lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) + if dummy_lora_requests_per_seq + else None, + ) for i, b in enumerate(blocks) ] torch.hpu.synchronize() @@ -1369,8 +1547,9 @@ def remove_all_loras(self): raise RuntimeError("LoRA is not enabled.") self.lora_manager.remove_all_adapters() - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: + def set_active_loras( + self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping + ) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) @@ -1397,43 +1576,56 @@ def list_loras(self) -> Set[int]: def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( - HabanaMemoryProfiler.current_free_device_memory()) + HabanaMemoryProfiler.current_free_device_memory() + ) dim = "num_blocks" if phase == "Prompt": dim = "seq_len" - msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " - f"batch_size:{batch_size} " - f"{dim}:{seq_len} " - f"free_mem:{free_mem}") + msg = ( + f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"free_mem:{free_mem}" + ) logger.info(msg) def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): - self.log_warmup('Prompt' if is_prompt else 'Decode', i, - len(buckets), batch_size, seq_len) + self.log_warmup( + "Prompt" if is_prompt else "Decode", + i, + len(buckets), + batch_size, + seq_len, + ) self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - def warmup_graphs(self, - strategy, - buckets, - is_prompt, - kv_caches, - available_mem, - starting_mem=0, - total_batch_seq=0.001): + def warmup_graphs( + self, + strategy, + buckets, + is_prompt, + kv_caches, + available_mem, + starting_mem=0, + total_batch_seq=0.001, + ): total_mem = starting_mem idx = 0 phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' num_candidates = len(buckets) - ordering : Union[Callable[[Any], Tuple[Any, Any]], \ - Callable[[Any], Tuple[Any, Any, Any]]] - if strategy == 'min_tokens': + ordering: Union[ + Callable[[Any], Tuple[Any, Any]], + Callable[[Any], Tuple[Any, Any, Any]], + ] + if strategy == "min_tokens": ordering = lambda b: (b[0] * b[1], b[1], b[0]) - elif strategy == 'max_bs': + elif strategy == "max_bs": ordering = lambda b: (-b[0], b[1]) else: raise NotImplementedError( - f'Unsupported graph allocation strategy: {strategy}') + f"Unsupported graph allocation strategy: {strategy}" + ) buckets = list(sorted(buckets, key=ordering)) captured_all = True for idx, (batch_size, seq_len) in enumerate(buckets): @@ -1450,8 +1642,9 @@ def warmup_graphs(self, self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - used_mem = align_workers(mem_prof.consumed_device_memory, - torch.distributed.ReduceOp.MAX) + used_mem = align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX + ) available_mem -= used_mem total_mem += used_mem total_batch_seq += batch_seq @@ -1461,73 +1654,84 @@ def warmup_graphs(self, def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): num_candidates = len(buckets) phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' - graphed = list(c[:2] for c in self.graphed_buckets - if c[2] == is_prompt) + graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) if num_candidates == 0: num_candidates = 1 - msg = (f'{phase} captured:{len(graphed)} ' - f'({100 * len(graphed) / num_candidates:.1f}%) ' - f'used_mem:{format_bytes(total_mem)} ' - f'buckets:{sorted(list(graphed))}') + msg = ( + f"{phase} captured:{len(graphed)} " + f"({100 * len(graphed) / num_candidates:.1f}%) " + f"used_mem:{format_bytes(total_mem)} " + f"buckets:{sorted(list(graphed))}" + ) logger.info(msg) @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: - if profile := os.environ.get('VLLM_PT_PROFILE', None): - phase, bs, seq_len, graph = profile.split('_') - is_prompt = phase == 'prompt' - graphs = graph == 't' + if profile := os.environ.get("VLLM_PT_PROFILE", None): + phase, bs, seq_len, graph = profile.split("_") + is_prompt = phase == "prompt" + graphs = graph == "t" if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, - True) + self.warmup_scenario( + int(bs), int(seq_len), is_prompt, kv_caches, True + ) raise AssertionError("Finished profiling") - if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true': + if os.environ.get("VLLM_SKIP_WARMUP", "false").lower() == "true": logger.info("Skipping warmup...") return - self.profiler.start('internal', 'warmup') + self.profiler.start("internal", "warmup") max_blocks = kv_caches[0][0].size(0) self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) + self.prompt_bs_bucket_cfg, + self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens, + ) if self.lora_config: self.prompt_buckets[:] = [ - bucket for bucket in self.prompt_buckets + bucket + for bucket in self.prompt_buckets if self._is_valid_bucket(bucket) ] msg = ( f"Generated {len(self.prompt_buckets)} " - f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") + f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}" + ) logger.info(msg) - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") + msg = ( + f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})" + ) logger.info(msg) msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" logger.debug(msg) self.decode_buckets = generate_decode_buckets( - self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, - max_blocks) + self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, max_blocks + ) if self.lora_config: self.decode_buckets[:] = [ - bucket for bucket in self.decode_buckets + bucket + for bucket in self.decode_buckets if self._is_valid_bucket(bucket) ] - logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.decode_buckets), - list(sorted(self.decode_buckets))) + logger.info( + "Generated %d decode buckets [bs, total_blocks]: %s", + len(self.decode_buckets), + list(sorted(self.decode_buckets)), + ) start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() - compile_only_mode_context = functools.partial(bc.env_setting, - "PT_COMPILE_ONLY_MODE", - True) + compile_only_mode_context = functools.partial( + bc.env_setting, "PT_COMPILE_ONLY_MODE", True + ) can_use_compile_only_mode = True try: with compile_only_mode_context(): @@ -1535,83 +1739,120 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.debug("Using PT_COMPILE_ONLY_MODE.") except KeyError: can_use_compile_only_mode = False - logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' - 'Warmup time will be negatively impacted. ' - 'Please update Gaudi Software Suite.') - with compile_only_mode_context( - ) if can_use_compile_only_mode else contextlib.nullcontext(): + logger.warning( + "Cannot use PT_COMPILE_ONLY_MODE. " + "Warmup time will be negatively impacted. " + "Please update Gaudi Software Suite." + ) + with compile_only_mode_context() if can_use_compile_only_mode else contextlib.nullcontext(): self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) self.warmup_all_buckets(self.decode_buckets, False, kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): - assert self.mem_margin is not None, \ - ("HabanaWorker.determine_num_available_blocks needs " - "to be called before warming up the model.") + assert self.mem_margin is not None, ( + "HabanaWorker.determine_num_available_blocks needs " + "to be called before warming up the model." + ) free_mem = HabanaMemoryProfiler.current_free_device_memory() graph_free_mem = free_mem - self.mem_margin - graph_free_mem = align_workers(graph_free_mem, - torch.distributed.ReduceOp.MIN) + graph_free_mem = align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) prompt_graph_mem_ratio = float( - os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5')) - prompt_available_memory = (prompt_graph_mem_ratio * - graph_free_mem) - decode_available_memory = (graph_free_mem - - prompt_available_memory) + os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.5") + ) + prompt_available_memory = ( + prompt_graph_mem_ratio * graph_free_mem + ) + decode_available_memory = ( + graph_free_mem - prompt_available_memory + ) msg = ( f"Using {format_bytes(graph_free_mem)}" f"/{format_bytes(free_mem)} " "of free device memory for HPUGraphs, " f"{format_bytes(prompt_available_memory)} for prompt and " f"{format_bytes(decode_available_memory)} for decode " - f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" + ) logger.info(msg) - prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', - 'min_tokens') - decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', - 'max_bs') - mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ + prompt_strategy = os.environ.get( + "VLLM_GRAPH_PROMPT_STRATEGY", "min_tokens" + ) + decode_strategy = os.environ.get( + "VLLM_GRAPH_DECODE_STRATEGY", "max_bs" + ) + mem_post_prompt, prompt_batch_seq, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, kv_caches, - prompt_available_memory) - mem_post_decode, decode_batch_seq, decode_captured_all = \ + prompt_strategy, + self.prompt_buckets, + True, + kv_caches, + prompt_available_memory, + ) + ) + mem_post_decode, decode_batch_seq, decode_captured_all = ( self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, - decode_available_memory) + decode_strategy, + self.decode_buckets, + False, + kv_caches, + decode_available_memory, + ) + ) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space # left. Let's try to use it for capturing more prompt buckets. - if (mem_post_decode + mem_post_prompt < graph_free_mem - and not prompt_captured_all and decode_captured_all): + if ( + mem_post_decode + mem_post_prompt < graph_free_mem + and not prompt_captured_all + and decode_captured_all + ): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, + prompt_strategy, + self.prompt_buckets, + True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_prompt, prompt_batch_seq)) + mem_post_prompt, + prompt_batch_seq, + ) + ) # Not all decode buckets were captured, but all prompt buckets # were captured and we have some free graph-allocated space # left. Let's try to use it for capturing more decode buckets. - if mem_post_decode + mem_post_prompt < graph_free_mem \ - and not decode_captured_all \ - and prompt_captured_all: + if ( + mem_post_decode + mem_post_prompt < graph_free_mem + and not decode_captured_all + and prompt_captured_all + ): mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, + decode_strategy, + self.decode_buckets, + False, + kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_decode, decode_batch_seq) + mem_post_decode, + decode_batch_seq, + ) - self.log_graph_warmup_summary(self.prompt_buckets, True, - mem_post_prompt) - self.log_graph_warmup_summary(self.decode_buckets, False, - mem_post_decode) + self.log_graph_warmup_summary( + self.prompt_buckets, True, mem_post_prompt + ) + self.log_graph_warmup_summary( + self.decode_buckets, False, mem_post_decode + ) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() elapsed_time = end_time - start_time msg = ( f"Warmup finished in {elapsed_time:.0f} secs, " - f"allocated {format_bytes(end_mem - start_mem)} of device memory") + f"allocated {format_bytes(end_mem - start_mem)} of device memory" + ) logger.info(msg) self.profiler.end() @@ -1629,13 +1870,16 @@ def mem_margin(self, value): def _maybe_wrap_in_hpu_graph(*args, **kwargs): - return htorch.hpu.wrap_in_hpu_graph( - HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True - ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs) - + return ( + htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) + if htorch.utils.internal.is_lazy() + else HpuModelAdapter(*args, **kwargs) + ) -class HabanaProfilerCounterHelper(): +class HabanaProfilerCounterHelper: def __init__(self): self.niter = 0 self.average_real_throughput = None @@ -1655,8 +1899,15 @@ def capture_seq_group_metadata_stats(self, seq_group_metadata_list): for seq_data in seq_group_metadata.seq_data.values() ] - def get_counter_dict(self, cache_config, duration, seq_len, - batch_size_padded, real_batch_size, is_prompt): + def get_counter_dict( + self, + cache_config, + duration, + seq_len, + batch_size_padded, + real_batch_size, + is_prompt, + ): throughput = batch_size_padded / (duration / 1e6) throughput_effective = real_batch_size / (duration / 1e6) @@ -1668,57 +1919,66 @@ def get_counter_dict(self, cache_config, duration, seq_len, self.average_real_throughput = throughput_effective else: # https://www.heikohoffmann.de/htmlthesis/node134.html self.average_real_throughput = self.average_real_throughput + 1 / ( - self.niter + 1) * (throughput_effective - - self.average_real_throughput) + self.niter + 1 + ) * (throughput_effective - self.average_real_throughput) phase = "prompt" if is_prompt else "decode" counters = { - f'{phase}_bucket_batch_size': batch_size_padded, - f'{phase}_batch_size': real_batch_size, - f'{phase}_bucket_seq_len': seq_len, - f'{phase}_seq_len': real_max_seq_len, - f'{phase}_bucket_gen_throughput': throughput, - f'{phase}_real_gen_throughput': throughput_effective, - f'{phase}_batch_token_utilization': batch_token_utilization, - 'average_real_throughput': self.average_real_throughput, - 'engine_iteration': self.niter, + f"{phase}_bucket_batch_size": batch_size_padded, + f"{phase}_batch_size": real_batch_size, + f"{phase}_bucket_seq_len": seq_len, + f"{phase}_seq_len": real_max_seq_len, + f"{phase}_bucket_gen_throughput": throughput, + f"{phase}_real_gen_throughput": throughput_effective, + f"{phase}_batch_token_utilization": batch_token_utilization, + "average_real_throughput": self.average_real_throughput, + "engine_iteration": self.niter, } self.niter += 1 if is_prompt: prompt_bucket_in_throughput = (seq_len * batch_size_padded) / ( - duration / 1e6) - prompt_real_in_throughput = sum( - self.prompt_seq_lens) / (duration / 1e6) - counters[ - f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput - counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput + duration / 1e6 + ) + prompt_real_in_throughput = sum(self.prompt_seq_lens) / ( + duration / 1e6 + ) + counters[f"{phase}_bucket_in_throughput"] = ( + prompt_bucket_in_throughput + ) + counters[f"{phase}_real_in_throughput"] = prompt_real_in_throughput # KV cache might not be created yet (e.g. for profiling run) - if cache_config.num_gpu_blocks is not None and \ - cache_config.num_gpu_blocks != 0: + if ( + cache_config.num_gpu_blocks is not None + and cache_config.num_gpu_blocks != 0 + ): cache_num_blocks_used = [ math.ceil(sl / cache_config.block_size) for sl in self.real_seq_lens ] cache_total_num_blocks_used = sum(cache_num_blocks_used) num_cache_blocks = cache_config.num_gpu_blocks - cache_total_num_free_blocks = \ + cache_total_num_free_blocks = ( num_cache_blocks - cache_total_num_blocks_used - cache_computed_utilization = \ + ) + cache_computed_utilization = ( cache_total_num_blocks_used / num_cache_blocks + ) max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size) batch_block_utilization = cache_total_num_blocks_used / ( - batch_size_padded * max_blocks_per_seq) - counters['cache_num_blocks_used'] = cache_total_num_blocks_used - counters['cache_num_free_blocks'] = cache_total_num_free_blocks - counters['cache_computed_utilization'] = cache_computed_utilization - counters[ - f'{phase}_batch_block_utilization'] = batch_block_utilization + batch_size_padded * max_blocks_per_seq + ) + counters["cache_num_blocks_used"] = cache_total_num_blocks_used + counters["cache_num_free_blocks"] = cache_total_num_free_blocks + counters["cache_computed_utilization"] = cache_computed_utilization + counters[f"{phase}_batch_block_utilization"] = ( + batch_block_utilization + ) if not self.logged_once: - counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks - counters[ - 'const_gpu_memory_utilization'] = \ - cache_config.gpu_memory_utilization - counters['const_block_size'] = cache_config.block_size + counters["const_cache_num_blocks"] = cache_config.num_gpu_blocks + counters["const_gpu_memory_utilization"] = ( + cache_config.gpu_memory_utilization + ) + counters["const_block_size"] = cache_config.block_size self.logged_once = True return counters @@ -1727,18 +1987,21 @@ def unwrap_model(model): if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): return unwrap_model(model._orig_mod) else: - model = list(vars(model)['_modules'].values())[0] - modules = list(vars(model)['_modules'].values()) + model = list(vars(model)["_modules"].values())[0] + modules = list(vars(model)["_modules"].values()) return modules class HabanaModelRunner( - HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): + HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata] +): """ GPU model runner with sampling step. """ + _model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = ( - ModelInputForHPUWithSamplingMetadata) + ModelInputForHPUWithSamplingMetadata + ) def make_model_input_from_broadcasted_tensor_dict( self, @@ -1748,13 +2011,14 @@ def make_model_input_from_broadcasted_tensor_dict( ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, - )) + ) + ) def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, ) -> ModelInputForHPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1765,22 +2029,27 @@ def prepare_model_input( - input_tokens[num_prefill_tokens:] contains decode tokens. If cuda graph is required, this API automatically pads inputs. """ - with self.profiler.record_event('internal', 'prepare_input_tensors'): + with self.profiler.record_event("internal", "prepare_input_tensors"): assert seq_group_metadata_list is not None self.profiler_counter_helper.capture_seq_group_metadata_stats( - seq_group_metadata_list=seq_group_metadata_list) + seq_group_metadata_list=seq_group_metadata_list + ) model_input, sampling_metadata = self.prepare_input_tensors( - seq_group_metadata_list) + seq_group_metadata_list + ) assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) + return dataclasses.replace( + model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine, + ) def finish_measurements(self): from neural_compressor.torch.quantization import finalize_calibration + finalize_calibration(self.model.model) def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): @@ -1788,9 +2057,13 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): seen = cfg in self.seen_configs self.seen_configs.add(cfg) if not seen and not warmup_mode: - phase = 'prompt' if is_prompt else 'decode' - logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", - phase, batch_size, seq_len) + phase = "prompt" if is_prompt else "decode" + logger.warning( + "Configuration: (%s, %s, %s) was not warmed-up!", + phase, + batch_size, + seq_len, + ) @torch.inference_mode() def execute_model( @@ -1803,13 +2076,15 @@ def execute_model( ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( - "num_steps > 1 is not supported in HabanaModelRunner") + "num_steps > 1 is not supported in HabanaModelRunner" + ) if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) + self.set_active_loras( + model_input.lora_requests, model_input.lora_mapping + ) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata @@ -1833,7 +2108,7 @@ def execute_model( "kv_caches": kv_caches, "attn_metadata": self.trim_attn_metadata(attn_metadata), "intermediate_tensors": intermediate_tensors, - "lora_mask": model_input.lora_mask + "lora_mask": model_input.lora_mask, } if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) @@ -1842,42 +2117,50 @@ def execute_model( htorch.core.mark_step() if self.is_driver_worker: - model_event_name = ("model_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") + model_event_name = ( + "model_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}" + ) else: - model_event_name = 'model_executable' - with self.profiler.record_event('internal', model_event_name): + model_event_name = "model_executable" + with self.profiler.record_event("internal", model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata.selected_token_indices + selected_token_indices=sampling_metadata.selected_token_indices, ) if self.lora_config: from vllm.lora.layers import VocabParallelEmbeddingWithLoRA + modules = unwrap_model(self.model.model) for module in modules: if isinstance(module, VocabParallelEmbeddingWithLoRA): for i in range(0, len(module.indices_len)): - module.indices_len[ - i] = sampling_metadata.selected_token_indices.numel( - ) + module.indices_len[i] = ( + sampling_metadata.selected_token_indices.numel() + ) lora_logits_mask: torch.Tensor = model_input.lora_logits_mask LoraMask.setLoraMask( lora_logits_mask.index_select( - 0, sampling_metadata.selected_token_indices)) + 0, sampling_metadata.selected_token_indices + ) + ) # Compute the logits. with self.profiler.record_event( - 'internal', ('compute_logits_' - f'{"prompt" if is_prompt else "decode"}_bs' - f'{batch_size}_' - f'seq{seq_len}')): + "internal", + ( + 'compute_logits_' + f'{"prompt" if is_prompt else "decode"}_bs' + f'{batch_size}_' + f'seq{seq_len}' + ), + ): sampling_metadata.selected_token_indices = None - logits = self.model.compute_logits(hidden_states, - sampling_metadata) + logits = self.model.compute_logits(hidden_states, sampling_metadata) htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -1885,10 +2168,14 @@ def execute_model( # Sample the next token. with self.profiler.record_event( - 'internal', ('sample_' - f'{"prompt" if is_prompt else "decode"}_' - f'bs{batch_size}_' - f'seq{seq_len}')): + "internal", + ( + 'sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}' + ), + ): output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, @@ -1906,19 +2193,23 @@ def execute_model( seq_len=seq_len, batch_size_padded=batch_size_padded, real_batch_size=real_batch_size, - is_prompt=is_prompt) + is_prompt=is_prompt, + ) self.profiler.record_counter(self.event_start, counters) return [output] def shutdown_inc(self): - print('inc shutdown') - if (model_config := getattr(self, "model_config", None)) and \ - getattr(model_config, "quantization", None) == 'inc': - print('inc shutdown start') + print("inc shutdown") + if (model_config := getattr(self, "model_config", None)) and getattr( + model_config, "quantization", None + ) == "inc": + print("inc shutdown start") from neural_compressor.torch.quantization import ( - finalize_calibration) + finalize_calibration, + ) + finalize_calibration(self.model.model) - print('inc shutdown') + print("inc shutdown") def __del__(self): self.shutdown_inc() From 7f587ebf94a46e07e3fc8685dd8dc8cdb2d7e1c6 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:55:11 +0300 Subject: [PATCH 07/24] Update habana_model_runner.py --- vllm/worker/habana_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 8ca8c50d025f..d1ae61fc7ffb 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1744,7 +1744,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: "Warmup time will be negatively impacted. " "Please update Gaudi Software Suite." ) - with compile_only_mode_context() if can_use_compile_only_mode else contextlib.nullcontext(): + with compile_only_mode_context() if can_use_compile_only_mode \ + else contextlib.nullcontext(): self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) self.warmup_all_buckets(self.decode_buckets, False, kv_caches) From c204f3fb6066db0c3daa0fbebc27a3e789f9327e Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 15:57:44 +0300 Subject: [PATCH 08/24] isort fixes --- .../compressed_tensors/compressed_tensors.py | 30 +++------- .../schemes/compressed_tensors_w8a8_fp8.py | 15 ++--- .../model_executor/layers/quantization/fp8.py | 29 +++------ vllm/model_executor/models/llama.py | 31 +++------- vllm/worker/habana_model_runner.py | 59 ++++--------------- 5 files changed, 44 insertions(+), 120 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 41e09231f86a..586d1c5291da 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -5,29 +5,17 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig, - QuantizeMethodBase, -) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - W4A16SPARSE24_SUPPORTED_BITS, - WNA16_SUPPORTED_BITS, - CompressedTensorsScheme, - CompressedTensorsUnquantized, - CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16, -) + W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, + CompressedTensorsScheme, CompressedTensorsUnquantized, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - CompressionFormat, - QuantizationArgs, - QuantizationStrategy, - QuantizationType, - find_matched_target, - is_activation_quantization_format, - should_ignore_layer, -) + CompressionFormat, QuantizationArgs, QuantizationStrategy, + QuantizationType, find_matched_target, is_activation_quantization_format, + should_ignore_layer) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 6dfb2a59f851..184a1ce0a679 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,18 +4,13 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, -) + CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - QuantizationStrategy, -) + QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, - create_per_channel_scale_param, - create_per_tensor_scale_param, - cutlass_fp8_supported, - requantize_with_max_scale, -) + apply_fp8_linear, create_per_channel_scale_param, + create_per_tensor_scale_param, cutlass_fp8_supported, + requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsW8A8Fp8"] diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d43653bf827f..d76f06698e2c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,32 +7,19 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase -from vllm.model_executor.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped, -) + is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, - apply_fp8_linear, - convert_to_channelwise, - create_per_tensor_scale_param, - cutlass_fp8_supported, - per_tensor_dequantize, - requantize_with_max_scale, -) + all_close_1d, apply_fp8_linear, convert_to_channelwise, + create_per_tensor_scale_param, cutlass_fp8_supported, + per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d6f99e1746f2..30c679fb6b67 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,37 +30,24 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import ( - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, -) + QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - get_compressed_tensors_cache_scale, -) + get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, - ParallelLMHead, - VocabParallelEmbedding, -) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - kv_cache_scales_loader, - maybe_remap_kv_scale_name, -) + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index d1ae61fc7ffb..c145e0f23d0a 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -13,36 +13,17 @@ import os import time from enum import IntEnum -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, + Optional, Set, Tuple, Type, TypeVar, Union) import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import ( - CacheConfig, - DeviceConfig, - LoadConfig, - LoRAConfig, - ModelConfig, - MultiModalConfig, - ParallelConfig, - SchedulerConfig, -) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + SchedulerConfig) from vllm.distributed.parallel_state import get_world_group from vllm.hpu.ops import LoraMask as LoraMask from vllm.logger import init_logger @@ -52,26 +33,16 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams -from vllm.sequence import ( - IntermediateTensors, - SamplerOutput, - SequenceData, - SequenceGroupMetadata, -) -from vllm.utils import ( - HabanaMemoryProfiler, - format_bytes, - is_pin_memory_available, - make_tensor_with_pad, -) +from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, + SequenceGroupMetadata) +from vllm.utils import (HabanaMemoryProfiler, format_bytes, + is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( - ModelRunnerBase, - ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict, -) + _init_sampling_metadata_from_tensor_dict) from .profiler import Profiler @@ -684,10 +655,7 @@ def load_model(self) -> None: logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: from neural_compressor.torch.quantization import ( - FP8Config, - convert, - prepare, - ) + FP8Config, convert, prepare) config = FP8Config.from_json_file( os.getenv("QUANT_CONFIG", "") @@ -2206,8 +2174,7 @@ def shutdown_inc(self): ) == "inc": print("inc shutdown start") from neural_compressor.torch.quantization import ( - finalize_calibration, - ) + finalize_calibration) finalize_calibration(self.model.model) print("inc shutdown") From 2e0048631c3a16944abe067209f22089f5d80c0d Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 16:06:55 +0300 Subject: [PATCH 09/24] yapf fixes --- vllm/hpu/ops.py | 59 +- .../layers/fused_moe/fused_moe.py | 148 ++-- .../compressed_tensors/compressed_tensors.py | 152 ++-- .../schemes/compressed_tensors_w8a8_fp8.py | 20 +- .../model_executor/layers/quantization/fp8.py | 133 ++-- .../layers/quantization/utils/w8a8_utils.py | 55 +- vllm/model_executor/models/llama.py | 152 ++-- vllm/worker/habana_model_runner.py | 701 +++++++----------- 8 files changed, 587 insertions(+), 833 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 4878b3c7ee05..ddb27de19d75 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -19,20 +19,16 @@ HPUFusedRMSNorm = FusedRMSNorm except ImportError: - logger.warning( - "Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm." - ) + logger.warning("Could not import HPU FusedRMSNorm kernel. " + "vLLM will use forward_native implementation of RMSNorm.") HPUFusedSDPA = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA HPUFusedSDPA = FusedSDPA except ImportError: - logger.warning( - "Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation." - ) + logger.warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") def batch2block(tensor, block_mapping): @@ -123,9 +119,8 @@ def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = kv.shape if n_rep == 1: return kv - kv = kv[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) + kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, + head_dim) return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -263,6 +258,7 @@ def dispatch_bgmv_embedding( class MoeMatmul(torch.nn.Module): + def __init__(self): super().__init__() @@ -278,27 +274,26 @@ def forward(self, state): class StaticFusedMOE(torch.nn.Module): + def __init__(self, num_total_experts): super().__init__() self.w13_list = torch.nn.ModuleList( - [MoeMatmul() for _ in range(num_total_experts)] - ) + [MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList( - [MoeMatmul() for _ in range(num_total_experts)] - ) + [MoeMatmul() for _ in range(num_total_experts)]) self.num_total_experts = num_total_experts def forward(self, hidden_states, w1, w2, score, topk): B, D = hidden_states.shape routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk( - routing_weights, topk, dim=-1 - ) + routing_weights, selected_experts = torch.topk(routing_weights, + topk, + dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros( - (1, B, D), dtype=hidden_states.dtype, device=hidden_states.device - ) + final_hidden_states = torch.zeros((1, B, D), + dtype=hidden_states.dtype, + device=hidden_states.device) padded_weights = torch.zeros( (B, self.num_total_experts), dtype=hidden_states.dtype, @@ -312,9 +307,8 @@ def forward(self, hidden_states, w1, w2, score, topk): for expert_idx in range(self.num_total_experts): padded_weight = padded_weights[expert_idx] current_state_static = hidden_states.reshape(-1, D) - w_output = self.w13_list[expert_idx].calc( - current_state_static, expert_idx, w1 - ) + w_output = self.w13_list[expert_idx].calc(current_state_static, + expert_idx, w1) w_output = silu_and_mul(w_output) w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2) current_hidden_states_static = w_output * padded_weight @@ -353,9 +347,9 @@ def scaled_fp8_quant( """ if batch_dim_padding: shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) - output = torch.empty( - shape, device=input.device, dtype=torch.float8_e4m3fn - ) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) else: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: @@ -368,14 +362,15 @@ def scaled_fp8_quant( dtype=torch.float32, ) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) + output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - output = torch.ops.hpu.cast_to_fp8_v2( - input, 1 / scale, False, False, dtype=torch.float8_e4m3fn - )[0] + output = torch.ops.hpu.cast_to_fp8_v2(input, + 1 / scale, + False, + False, + dtype=torch.float8_e4m3fn)[0] return output, scale diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 585b5e0c64c7..39fa611e5157 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -114,16 +114,12 @@ def fused_moe_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - ) + b_ptrs = (b_ptr + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) if use_fp8: a_scale = tl.load(a_scale_ptr) @@ -141,12 +137,13 @@ def fused_moe_kernel( # K dimension. a = tl.load( a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0, ) - b = tl.load( - b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 - ) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -157,9 +154,9 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load( - topk_weights_ptr + offs_token, mask=token_mask, other=0 - ) + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) accumulator = accumulator * moe_weight[:, None] if use_fp8: @@ -169,16 +166,15 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = ( - c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - ) + c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + + stride_cn * offs_cn[None, :]) c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, num_experts: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -217,17 +213,17 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device - ) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty( - (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device - ) - num_tokens_post_pad = torch.empty( - (1), dtype=torch.int32, device=topk_ids.device - ) + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) ops.moe_align_block_size( topk_ids, num_experts, @@ -266,10 +262,8 @@ def invoke_fused_moe_kernel( A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None - grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), - ) + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) fused_moe_kernel[grid]( A, @@ -307,9 +301,8 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @functools.lru_cache -def get_moe_configs( - E: int, N: int, dtype: Optional[str] -) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -324,13 +317,11 @@ def get_moe_configs( json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name - ) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info( - "Using configuration from %s for MoE layer.", config_file_path - ) + logger.info("Using configuration from %s for MoE layer.", + config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} @@ -394,21 +385,23 @@ def fused_topk( topk: int, renormalize: bool, ): - assert ( - hidden_states.shape[0] == gating_output.shape[0] - ), "Number of tokens mismatch" + assert (hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" M, _ = hidden_states.shape - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) ops.topk_softmax( topk_weights, topk_ids, @@ -431,31 +424,25 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, ): - assert ( - hidden_states.shape[0] == gating_output.shape[0] - ), "Number of tokens mismatch" + assert (hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] + group_scores = (scores.view(num_token, num_expert_group, + -1).max(dim=-1).values) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand( - num_token, num_expert_group, scores.shape[-1] // num_expert_group - ) - .reshape(num_token, -1) - ) # [n, e] + score_mask = (group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1)) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk( - tmp_scores, k=topk, dim=-1, sorted=False - ) + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -482,7 +469,9 @@ def fused_experts( assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] num_tokens, _ = hidden_states.shape E, N, _ = w1.shape @@ -518,9 +507,8 @@ def fused_experts( dtype=hidden_states.dtype, ) - compute_type = ( - tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - ) + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) if inplace: out_hidden_states = hidden_states @@ -552,8 +540,7 @@ def fused_experts( curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config["BLOCK_SIZE_M"], E) - ) + moe_align_block_size(curr_topk_ids, config["BLOCK_SIZE_M"], E)) invoke_fused_moe_kernel( curr_hidden_states, @@ -663,9 +650,8 @@ def fused_moe( topk_group, ) else: - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize - ) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) return fused_experts( hidden_states, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 586d1c5291da..f600f92efae6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -21,6 +21,7 @@ class CompressedTensorsConfig(QuantizationConfig): + def __init__( self, target_scheme_map: Dict[str, Any], @@ -84,14 +85,11 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": for target in targets: target_scheme_map[target] = {} target_scheme_map[target]["weights"] = ( - QuantizationArgs.parse_obj(quant_config.get("weights")) - ) + QuantizationArgs.parse_obj(quant_config.get("weights"))) try: target_scheme_map[target]["input_activations"] = ( QuantizationArgs.parse_obj( - quant_config.get("input_activations") - ) - ) + quant_config.get("input_activations"))) except Exception: target_scheme_map[target]["input_activations"] = None @@ -106,9 +104,9 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def get_config_filenames(cls) -> List[str]: return [] - def _check_scheme_supported( - self, min_capability: int, error: bool = True - ) -> bool: + def _check_scheme_supported(self, + min_capability: int, + error: bool = True) -> bool: capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] supported = capability >= min_capability @@ -120,52 +118,41 @@ def _check_scheme_supported( ) return supported - def _is_static_tensor_w8a8( - self, weight_quant: BaseModel, input_quant: BaseModel - ) -> bool: + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value - ) - is_tensor = ( - weight_strategy - and input_quant.strategy == QuantizationStrategy.TENSOR.value - ) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_tensor = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TENSOR.value) is_symmetric = weight_quant.symmetric and input_quant.symmetric is_static = not weight_quant.dynamic and not input_quant.dynamic return is_8_bits and is_tensor and is_symmetric and is_static - def _is_dynamic_token_w8a8( - self, weight_quant: BaseModel, input_quant: BaseModel - ) -> bool: + def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value - ) - is_token = ( - weight_strategy - and input_quant.strategy == QuantizationStrategy.TOKEN.value - ) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) is_symmetric = weight_quant.symmetric and input_quant.symmetric is_dynamic = not weight_quant.dynamic and input_quant.dynamic return is_8_bits and is_token and is_symmetric and is_dynamic - def _is_fp8_w8a8( - self, weight_quant: BaseModel, input_quant: BaseModel - ) -> bool: + def _is_fp8_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False # Confirm we have floating points. - if not ( - weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT - ): + if not (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT): return False # Confirm weight scheme is supported. @@ -175,11 +162,8 @@ def _is_fp8_w8a8( QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, ] - if not ( - is_symmetric_weight - and is_static_weight - and is_per_tensor_or_channel_weight - ): + if not (is_symmetric_weight and is_static_weight + and is_per_tensor_or_channel_weight): return False # Dynamic quantization is always supported if weights supported. @@ -189,17 +173,15 @@ def _is_fp8_w8a8( # Confirm activation scheme is supported. is_symmetric_activation = input_quant.symmetric is_per_tensor_activation = ( - input_quant.strategy == QuantizationStrategy.TENSOR - ) + input_quant.strategy == QuantizationStrategy.TENSOR) if not (is_symmetric_activation and is_per_tensor_activation): return False # All conditions satisfied. return True - def _is_fp8_w8a16( - self, weight_quant: BaseModel, input_quant: BaseModel - ) -> bool: + def _is_fp8_w8a16(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -215,49 +197,39 @@ def _is_fp8_w8a16( QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, ] - if not ( - is_symmetric_weight - and is_static_weight - and is_per_tensor_or_channel_weight - ): + if not (is_symmetric_weight and is_static_weight + and is_per_tensor_or_channel_weight): return False # All conditions satisfied. return True - def _is_wNa16_group_channel( - self, weight_quant: BaseModel, input_quant: BaseModel - ) -> bool: + def _is_wNa16_group_channel(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: input_quant_none = input_quant is None is_symmetric = weight_quant.symmetric is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value - or weight_quant.strategy == QuantizationStrategy.GROUP.value - ) + or weight_quant.strategy == QuantizationStrategy.GROUP.value) is_static = not weight_quant.dynamic - return ( - is_channel_group and input_quant_none and is_symmetric and is_static - ) + return (is_channel_group and input_quant_none and is_symmetric + and is_static) def _get_scheme_from_parts( - self, weight_quant: BaseModel, input_quant: BaseModel - ) -> "CompressedTensorsScheme": + self, weight_quant: BaseModel, + input_quant: BaseModel) -> "CompressedTensorsScheme": # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): - if ( - self.quant_format == CompressionFormat.marlin_24.value - and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS - ): + if (self.quant_format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, group_size=weight_quant.group_size, ) - if ( - self.quant_format == CompressionFormat.pack_quantized.value - and weight_quant.num_bits in WNA16_SUPPORTED_BITS - ): + if (self.quant_format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, @@ -267,14 +239,10 @@ def _get_scheme_from_parts( # Detect If Activation Quantization. if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - is_fp8_w8a8_supported = ( - self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), - error=False, - ) - if torch.cuda.is_available() - else True - ) + is_fp8_w8a8_supported = (self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), + error=False, + ) if torch.cuda.is_available() else True) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, @@ -283,36 +251,34 @@ def _get_scheme_from_parts( else: return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=( - input_quant and not input_quant.dynamic - ), + is_static_input_scheme=(input_quant + and not input_quant.dynamic), ) if self._is_fp8_w8a16(weight_quant, input_quant): return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=( - input_quant and not input_quant.dynamic - ), + is_static_input_scheme=(input_quant + and not input_quant.dynamic), ) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( - strategy=weight_quant.strategy, is_static_input_scheme=True - ) + strategy=weight_quant.strategy, + is_static_input_scheme=True) if self._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( - strategy=weight_quant.strategy, is_static_input_scheme=False - ) + strategy=weight_quant.strategy, + is_static_input_scheme=False) raise NotImplementedError( - "No compressed-tensors compatible scheme was found." - ) + "No compressed-tensors compatible scheme was found.") def get_scheme( - self, layer: torch.nn.Module, layer_name: Optional[str] = None - ) -> "CompressedTensorsScheme": + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None) -> "CompressedTensorsScheme": """ compressed-tensors supports non uniform in the following way: @@ -361,6 +327,7 @@ def get_scheme( class CompressedTensorsLinearMethod(LinearMethodBase): + def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config @@ -444,21 +411,18 @@ def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]): raise NotImplementedError( "Currently supported kv cache quantization is " "num_bits=8, type=float, however " - f"received num_bits={num_bits}, type={type_}" - ) + f"received num_bits={num_bits}, type={type_}") strategy = kv_cache_scheme.get("strategy") if strategy != "tensor": raise NotImplementedError( "Only support per-tensor scaling factor " "for compressed-tensors KV cache. " - f"Expected strategy: tensor, found strategy: {strategy}" - ) + f"Expected strategy: tensor, found strategy: {strategy}") is_symmetric = kv_cache_scheme.get("symmetric") if not is_symmetric: raise NotImplementedError( "Only support symmetric scaling factor " "for compressed-tensors KV cache. " - f"However found symmetric: {is_symmetric}" - ) + f"However found symmetric: {is_symmetric}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 184a1ce0a679..ee772c2951e8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -17,12 +17,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = ( - cutlass_fp8_supported() if torch.cuda.is_available() else False - ) + self.cutlass_fp8_supported = (cutlass_fp8_supported() + if torch.cuda.is_available() else False) @classmethod def get_min_capability(cls) -> int: @@ -53,9 +53,8 @@ def process_weights_after_loading(self, layer) -> None: # INPUT SCALE if self.is_static_input_scheme: - layer.input_scale = Parameter( - layer.input_scale.max(), requires_grad=False - ) + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) else: layer.input_scale = None @@ -94,20 +93,17 @@ def create_weights( layer_kwargs = {"weight_loader": weight_loader} if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = create_per_channel_scale_param( - output_partition_sizes, **layer_kwargs - ) + output_partition_sizes, **layer_kwargs) else: assert self.strategy == QuantizationStrategy.TENSOR weight_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs - ) + output_partition_sizes, **layer_kwargs) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: input_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs - ) + output_partition_sizes, **layer_kwargs) layer.register_parameter("input_scale", input_scale) def apply_weights( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d76f06698e2c..28fcceb449af 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -45,14 +45,11 @@ def __init__( ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: - logger.warning( - "Detected fp8 checkpoint. Please note that the " - "format is experimental and subject to change." - ) + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError( - f"Unsupported activation scheme {activation_scheme}" - ) + f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] @@ -84,9 +81,8 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": ignored_layers=ignored_layers, ) - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): @@ -155,11 +151,9 @@ def create_weights( layer.orig_dtype = params_dtype # WEIGHT - weight_dtype = ( - torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized - else params_dtype - ) + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) weight = Parameter( torch.empty( output_size_per_partition, @@ -182,24 +176,21 @@ def create_weights( # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = create_per_tensor_scale_param( - output_partition_sizes, **extra_weight_attrs - ) + scale = create_per_tensor_scale_param(output_partition_sizes, + **extra_weight_attrs) layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": - scale = create_per_tensor_scale_param( - output_partition_sizes, **extra_weight_attrs - ) + scale = create_per_tensor_scale_param(output_partition_sizes, + **extra_weight_attrs) layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant( - layer.weight, scale=None - ) + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) @@ -213,9 +204,8 @@ def process_weights_after_loading(self, layer: Module) -> None: # so extend the weight scales to be channelwise. if self.use_marlin: weight = layer.weight - weight_scale = convert_to_channelwise( - layer.weight_scale, layer.logical_widths - ) + weight_scale = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. @@ -231,9 +221,8 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": - layer.input_scale = Parameter( - layer.input_scale.max(), requires_grad=False - ) + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) else: layer.input_scale = None @@ -312,9 +301,10 @@ def create_weights( set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype - ), + torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) @@ -323,14 +313,15 @@ def create_weights( # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) layer.register_parameter("w13_scale", w13_scale) - w2_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) layer.register_parameter("w2_scale", w2_scale) # If loading fp8 checkpoint, pass the weight loaders. @@ -345,8 +336,7 @@ def create_weights( if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) + "was not serialized fp8.") a13_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), @@ -368,12 +358,10 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - layer.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like( - layer.w2_weight.data, dtype=torch.float8_e4m3fn - ) + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(layer.w2_weight.data, + dtype=torch.float8_e4m3fn) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. @@ -387,15 +375,13 @@ def process_weights_after_loading(self, layer: Module) -> None: ) for expert in range(layer.num_experts): w13_weight[expert, :, :], layer.w13_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])) w2_weight[expert, :, :], layer.w2_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - layer.w13_weight = torch.nn.Parameter( - w13_weight, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) return # If checkpoint is fp8, we need to handle that the @@ -408,22 +394,17 @@ def process_weights_after_loading(self, layer: Module) -> None: if layer.a13_scale is None or layer.a2_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None." - ) + "activation scales are None.") if not all_close_1d(layer.a13_scale) or not all_close_1d( - layer.a2_scale - ): + layer.a2_scale): print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer. " - ) - layer.a13_scale = torch.nn.Parameter( - layer.a13_scale.max(), requires_grad=False - ) - layer.a2_scale = torch.nn.Parameter( - layer.a2_scale.max(), requires_grad=False - ) + "for each layer. ") + layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + requires_grad=False) + layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), + requires_grad=False) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. @@ -434,24 +415,20 @@ def process_weights_after_loading(self, layer: Module) -> None: start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][ - start : start + shard_size, : - ], + layer.w13_weight[expert_id][start:start + + shard_size, :], layer.w13_scale[expert_id][shard_id], ) ( - layer.w13_weight[expert_id][ - start : start + shard_size, : - ], + layer.w13_weight[expert_id][start:start + + shard_size, :], _, - ) = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id] - ) + ) = ops.scaled_fp8_quant(dq_weight, + max_w13_scales[expert_id]) start += shard_size - layer.w13_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) + layer.w13_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) return def apply( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index c5c1da179166..2ac620742680 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -23,8 +23,8 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] -) -> torch.Tensor: + tensor: torch.Tensor, inv_scale: Union[float, + torch.Tensor]) -> torch.Tensor: dtype = torch.float16 device = tensor.device if current_platform.is_hpu(): @@ -52,15 +52,15 @@ def create_per_tensor_scale_param( requires_grad=False, ) scale[:] = torch.finfo(torch.float32).min - set_weight_attrs( - scale, {"needs_scalar_to_array": True, **extra_weight_attrs} - ) + set_weight_attrs(scale, { + "needs_scalar_to_array": True, + **extra_weight_attrs + }) return scale -def create_per_channel_scale_param( - output_partition_sizes: List[int], **extra_weight_attrs -) -> Parameter: +def create_per_channel_scale_param(output_partition_sizes: List[int], + **extra_weight_attrs) -> Parameter: scale = Parameter( torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), requires_grad=False, @@ -71,8 +71,8 @@ def create_per_channel_scale_param( def convert_to_channelwise( - weight_scale: torch.Tensor, logical_widths: List[int] -) -> Tuple[torch.Tensor, torch.Tensor]: + weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Create channelwise buffer weight_scale_channel = torch.empty( (sum(logical_widths), 1), @@ -91,39 +91,32 @@ def convert_to_channelwise( def requantize_with_max_scale( - weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int] -) -> Tuple[torch.Tensor, torch.Tensor]: + weight: torch.Tensor, weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - if ( - current_platform.is_hpu() - and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2 - ): - max_w_scale = max_w_scale * ( - torch.finfo(torch.float8_e4m3fn).max - / torch.finfo(torch.float8_e4m3fnuz).max - ) + if (current_platform.is_hpu() and htexp._get_device_type() + == htexp.synDeviceType.synDeviceGaudi2): + max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max / + torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale # from disk in this case. Skip requantization in this case (since) # we already are quantized with the single scale. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = ( - weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min - ) + unfused_module_in_checkpoint = (weight_scale[-1] + > torch.finfo(torch.float8_e4m3fn).min) # If unfused checkpoint, need requanize with the single scale. if unfused_module_in_checkpoint: start = 0 for idx, logical_width in enumerate(logical_widths): end = start + logical_width - weight_dq = per_tensor_dequantize( - weight[start:end, :], weight_scale[idx] - ) + weight_dq = per_tensor_dequantize(weight[start:end, :], + weight_scale[idx]) weight[start:end, :], _ = ops.scaled_fp8_quant( - weight_dq, max_w_scale - ) + weight_dq, max_w_scale) start = end return max_w_scale, weight @@ -224,9 +217,9 @@ def apply_fp8_linear( # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm( - qinput, weight, out_dtype=torch.float32 - ) + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=torch.float32) # Unpad (undo batch_dim_padding) output = torch.narrow(output, 0, 0, input.shape[0]) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 30c679fb6b67..124386b61e4c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -61,6 +61,7 @@ class LlamaMLP(nn.Module): + def __init__( self, hidden_size: int, @@ -86,10 +87,8 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError( - f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now." - ) + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): @@ -100,6 +99,7 @@ def forward(self, x): class LlamaAttention(nn.Module): + def __init__( self, config: LlamaConfig, @@ -131,9 +131,8 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr( - config, "head_dim", self.hidden_size // self.total_num_heads - ) + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -189,6 +188,7 @@ def forward( class LlamaDecoderLayer(nn.Module): + def __init__( self, config: LlamaConfig, @@ -201,26 +201,21 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None - ): + config, "original_max_position_embeddings", None): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings - ) - max_position_embeddings = getattr( - config, "max_position_embeddings", 8192 - ) + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False - ) + config, "bias", False) self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr( - config, "num_key_value_heads", config.num_attention_heads - ), + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -237,12 +232,10 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -258,8 +251,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( - hidden_states, residual - ) + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -269,13 +261,13 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class LlamaModel(nn.Module): + def __init__( self, config: LlamaConfig, @@ -287,16 +279,12 @@ def __init__( super().__init__() self.config = config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or ( - config.tie_word_embeddings and get_pp_group().is_last_rank - ): + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -359,9 +347,10 @@ def forward( htorch.core.mark_step() if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -433,17 +422,16 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, + if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) self.sampler = Sampler() else: self.lm_head = PPMissingLayer() @@ -467,12 +455,10 @@ def forward( ) return model_output - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor( - self.lm_head, hidden_states, sampling_metadata - ) + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) return logits def sample( @@ -484,22 +470,22 @@ def sample( return next_tokens def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, device: torch.device - ) -> IntermediateTensors: - return IntermediateTensors( - { - "hidden_states": torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - "residual": torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - } - ) + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + }) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -514,19 +500,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ( - "rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name - ): + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if scale_name := get_compressed_tensors_cache_scale(name): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) continue @@ -559,9 +542,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) if current_platform.is_hpu(): torch.hpu.synchronize() @@ -574,11 +556,11 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, ): if not isinstance(self.model.layers[layer_idx], nn.Identity): layer_self_attn = self.model.layers[layer_idx].self_attn @@ -592,7 +574,5 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if hasattr(layer_self_attn, "kv_scale"): layer_self_attn.attn._kv_scale = scaling_factor else: - raise RuntimeError( - "Self attention has no KV cache scaling " - "factor attribute!" - ) + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c145e0f23d0a..9cb4a4915c01 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -73,9 +73,8 @@ def subtuple( fields = set(to_copy) | set(to_override.keys()) values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple( - typename, " ".join(fields) - ) + _TYPE_CACHE[typename] = collections.namedtuple(typename, + " ".join(fields)) return _TYPE_CACHE[typename](**values) @@ -112,36 +111,29 @@ def warmup_range(config: Tuple[int, int, int]): => return ramp_up + stable => (2, 4, 8, 16, 32, 64) """ bmin, bstep, bmax = config - assert bmin <= bmax, ( - "Min. batch size cannot be greater than max. " - "batch size. If you want to skip warmup, " - "set VLLM_SKIP_WARMUP=true" - ) + assert bmin <= bmax, ("Min. batch size cannot be greater than max. " + "batch size. If you want to skip warmup, " + "set VLLM_SKIP_WARMUP=true") base = itertools.repeat(2) ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile( - lambda x: x < bstep and x <= bmax, ramp_up_acc - ) + ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, + ramp_up_acc) stable = range(bstep, bmax + 1, bstep) buckets = list(ramp_up_tw) + list(stable) return list(filter(lambda bucket: bucket >= bmin, buckets)) -def generate_prompt_buckets( - bs_bucket_config, seq_bucket_config, max_num_batched_tokens=None -): +def generate_prompt_buckets(bs_bucket_config, + seq_bucket_config, + max_num_batched_tokens=None): buckets = list( - itertools.product( - warmup_range(bs_bucket_config), warmup_range(seq_bucket_config) - ) - ) + itertools.product(warmup_range(bs_bucket_config), + warmup_range(seq_bucket_config))) if len(buckets) == 0: - msg = ( - "No buckets could be captured with following config " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config}" - ) + msg = ("No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}") raise ValueError(msg) filtered_buckets = buckets @@ -151,14 +143,12 @@ def generate_prompt_buckets( filter( lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, buckets, - ) - ) + )) if len(filtered_buckets) == 0: # we can handle this if we ignore max_num_batched_tokens - min_bucket_bs, min_bucket_seq = min( - buckets, key=lambda b: (b[0] * b[1]) - ) + min_bucket_bs, min_bucket_seq = min(buckets, + key=lambda b: (b[0] * b[1])) min_reqd_budget = min_bucket_bs * min_bucket_seq msg = ( "The current bucketing configuration " @@ -169,23 +159,20 @@ def generate_prompt_buckets( f"smallest bucket ({min_reqd_budget}) would exceed token " "budget. Please increase max_num_batched_tokens or decrease " "bucket minimum Ignoring max_num_batched_tokens at risk of " - "out-of-memory errors." - ) + "out-of-memory errors.") logger.error(msg) return list( - sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])) - ), [] + sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] captured_buckets = list( - sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])) - ) + sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) omitted_buckets = list( - sorted([x for x in buckets if x not in filtered_buckets]) - ) + sorted([x for x in buckets if x not in filtered_buckets])) return captured_buckets, omitted_buckets -def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, max_blocks): +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, + max_blocks): buckets = [] for bs in warmup_range(bs_bucket_config): for blocks in warmup_range(blocks_bucket_config): @@ -230,9 +217,8 @@ def setup_profiler(): schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) DEVICE = "hpu" activities = [torch.profiler.ProfilerActivity.CPU] - activities.extend( - [torch.profiler.ProfilerActivity.HPU] if DEVICE == "hpu" else [] - ) + activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == + "hpu" else []) # from habana_frameworks.torch.activity_profiler import DebugActivity # debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] @@ -240,9 +226,8 @@ def setup_profiler(): schedule=schedule, activities=activities, # debug_activities=debug_activities, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - ".", use_gzip=True - ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(".", + use_gzip=True), record_shapes=False, with_stack=True, ) @@ -256,30 +241,29 @@ def pad_list(list, k, v): class HpuModelAdapter: + def __init__(self, model, block_size, dtype, enforce_eager): self.model = model - self.prefill_use_fusedsdpa = os.getenv( - "VLLM_PROMPT_USE_FUSEDSDPA", "0" - ).lower() in ["1", "true"] + self.prefill_use_fusedsdpa = os.getenv("VLLM_PROMPT_USE_FUSEDSDPA", + "0").lower() in ["1", "true"] self.block_size = block_size self.dtype = dtype if not htorch.utils.internal.is_lazy() and not enforce_eager: - self.model = torch.compile( - self.model, backend="hpu_backend", dynamic=False - ) + self.model = torch.compile(self.model, + backend="hpu_backend", + dynamic=False) - def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): + def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, + dtype): prefill_metadata = attn_metadata if prefill_metadata is None or self.prefill_use_fusedsdpa: return attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor - len_mask = ( - torch.arange(0, seq_len, device=device, dtype=torch.int32) - .view(1, seq_len) - .ge(seq_lens_t.unsqueeze(-1)) - .view(batch_size, 1, 1, seq_len) - ) + len_mask = (torch.arange(0, seq_len, device=device, + dtype=torch.int32).view(1, seq_len).ge( + seq_lens_t.unsqueeze(-1)).view( + batch_size, 1, 1, seq_len)) causal_mask = torch.triu( torch.ones( (batch_size, 1, seq_len, seq_len), @@ -290,40 +274,35 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): ) mask = causal_mask.logical_or(len_mask) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf - ) + mask, -math.inf) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata def _set_block_mapping(self, metadata, batch_size, device, dtype): - mask = torch.arange( - 0, self.block_size, device=device, dtype=torch.int32 - ).unsqueeze(0) + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf - ) + mask, -math.inf) block_mapping = torch.nn.functional.one_hot( - metadata.block_mapping.to(torch.long), num_classes=batch_size - ).to(dtype) - metadata = metadata._replace( - block_mapping=block_mapping, attn_bias=attn_bias - ) + metadata.block_mapping.to(torch.long), + num_classes=batch_size).to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) return metadata - def _update_metadata( - self, attn_metadata, batch_size, seq_len, device, dtype - ): + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): if attn_metadata.is_prompt: meta = attn_metadata - attn_metadata = self._set_attn_bias( - meta, batch_size, seq_len, device, dtype - ) + attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, + device, dtype) else: meta = attn_metadata - attn_metadata = self._set_block_mapping( - meta, batch_size, device, dtype - ) + attn_metadata = self._set_block_mapping(meta, batch_size, device, + dtype) return attn_metadata def forward(self, *args, **kwargs): @@ -470,8 +449,7 @@ def from_broadcasted_tensor_dict( ) -> TModelInputForHPU: if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict - ) + attn_backend, tensor_dict) return cls(**tensor_dict) @@ -497,9 +475,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_logits_mask": self.lora_logits_mask, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict( - tensor_dict, self.sampling_metadata - ) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) return tensor_dict @classmethod @@ -512,8 +489,7 @@ def from_broadcasted_tensor_dict( # FIXME(kzawora): this fails for whatever reason - why? if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict - ) + attn_backend, tensor_dict) return cls(**tensor_dict) @@ -546,22 +522,17 @@ def __init__( self.is_driver_worker = is_driver_worker self.profiler = Profiler() - self.sliding_window = ( - model_config.get_sliding_window() - if model_config is not None - else None - ) - self.device_config = ( - device_config if device_config is not None else DeviceConfig() - ) + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs self.max_model_len = self.scheduler_config.max_model_len self.max_num_batched_tokens = ( - self.scheduler_config.max_num_batched_tokens - ) + self.scheduler_config.max_num_batched_tokens) self.block_size = cache_config.block_size self.pin_memory = is_pin_memory_available() @@ -599,11 +570,13 @@ def _set_gc_threshold(self) -> None: requested_gc_thrs = [0] * len(default_gc_thrs) for i in range(len(default_gc_thrs)): requested_gc_thrs[i] = int( - os.environ.get(f"VLLM_GC_THR_GEN{i}", default_gc_thrs[i]) - ) + os.environ.get(f"VLLM_GC_THR_GEN{i}", default_gc_thrs[i])) if requested_gc_thrs == default_gc_thrs: - gc_thr_multiplier = int(os.environ.get("VLLM_GC_THR_MULTIPLIER", 2)) - requested_gc_thrs = [t * gc_thr_multiplier for t in default_gc_thrs] + gc_thr_multiplier = int(os.environ.get("VLLM_GC_THR_MULTIPLIER", + 2)) + requested_gc_thrs = [ + t * gc_thr_multiplier for t in default_gc_thrs + ] gc.set_threshold(*requested_gc_thrs) def load_model(self) -> None: @@ -622,21 +595,17 @@ def load_model(self) -> None: scheduler_config=self.scheduler_config, cache_config=self.cache_config, ) - msg = ( - "Pre-loading model weights on " - f"{next(self.model.parameters()).device} " - f"took {m_getmodel.get_summary_string()}" - ) + msg = ("Pre-loading model weights on " + f"{next(self.model.parameters()).device} " + f"took {m_getmodel.get_summary_string()}") logger.info(msg) if self.lora_config: - assert ( - hasattr(self.model, "supported_lora_modules") - and self.model.supported_lora_modules - ), "Model does not support LoRA" - assert hasattr( - self.model, "embedding_modules" - ), "Model does not have embedding_modules" + assert (hasattr(self.model, "supported_lora_modules") + and self.model.supported_lora_modules + ), "Model does not support LoRA" + assert hasattr(self.model, "embedding_modules" + ), "Model does not have embedding_modules" assert hasattr( self.model, "embedding_padding_modules" ), "Model does not have embedding_padding_modules" @@ -658,15 +627,13 @@ def load_model(self) -> None: FP8Config, convert, prepare) config = FP8Config.from_json_file( - os.getenv("QUANT_CONFIG", "") - ) + os.getenv("QUANT_CONFIG", "")) if config.measure: self.model = prepare(self.model, config) elif config.quantize: self.model = convert(self.model, config) - htcore.hpu_initialize( - self.model, mark_only_scales_as_const=True - ) + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) logger.info( "Preparing model with INC took %s", m_inc.get_summary_string(), @@ -701,10 +668,8 @@ def _is_valid_bucket(self, bucket): def _setup_buckets(self) -> None: align_bs = lambda x: min(self.max_num_seqs, x) max_bucket_cfg = 64 - if ( - self.lora_config - and max_bucket_cfg > self.max_num_batched_tokens // self.block_size - ): + if (self.lora_config and max_bucket_cfg + > self.max_num_batched_tokens // self.block_size): max_bucket_cfg = self.max_num_batched_tokens // self.block_size blocks_step = 128 # FIXME: The default values should be max_model_len @@ -743,18 +708,14 @@ def _setup_buckets(self) -> None: ) self.graphed_buckets: Set[Any] = set() - msg = ( - "Prompt bucket config (min, step, max_warmup) " - f"bs:{self.prompt_bs_bucket_cfg}, " - f"seq:{self.prompt_seq_bucket_cfg}" - ) + msg = ("Prompt bucket config (min, step, max_warmup) " + f"bs:{self.prompt_bs_bucket_cfg}, " + f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - msg = ( - "Decode bucket config (min, step, max_warmup) " - f"bs:{self.decode_bs_bucket_cfg}, " - f"block:{self.decode_block_bucket_cfg}" - ) + msg = ("Decode bucket config (min, step, max_warmup) " + f"bs:{self.decode_bs_bucket_cfg}, " + f"block:{self.decode_block_bucket_cfg}") logger.info(msg) def _prepare_prompt( @@ -784,16 +745,13 @@ def _prepare_prompt( seq_id = seq_ids[0] computed_block_nums = seq_group_metadata.computed_block_nums - if ( - self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not ( - computed_block_nums is None or computed_block_nums == [] - ) - ): + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): raise RuntimeError( - "chunked prefill cannot be used with prefix caching " "now." - ) + "chunked prefill cannot be used with prefix caching " + "now.") token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] @@ -805,11 +763,9 @@ def _prepare_prompt( seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. - if ( - computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - ): + if (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None): # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] @@ -838,8 +794,7 @@ def _prepare_prompt( if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data - ) + seq_group_metadata.multi_modal_data.data) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -860,8 +815,7 @@ def _prepare_prompt( if self.sliding_window is not None: assert context_len == 0, ( "Prefix caching is currently not supported with " - "sliding window attention" - ) + "sliding window attention") start_idx = max(0, seq_len - self.sliding_window) for i in range(context_len, seq_len): if i < start_idx: @@ -881,11 +835,9 @@ def _prepare_prompt( if multi_modal_input_list: assert self.multimodal_config, ( "Multi-modal inputs are only supported by " - "vision language models." - ) - multi_modal_input = torch.cat(multi_modal_input_list, dim=0).to( - self.device - ) + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) else: multi_modal_input = None @@ -919,9 +871,8 @@ def _prepare_prompt( self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype, ) - for seq_group_metadata, context_len in zip( - seq_group_metadata_list, context_lens - ): + for seq_group_metadata, context_len in zip(seq_group_metadata_list, + context_lens): lora_id = seq_group_metadata.lora_int_id if lora_id > 0: @@ -936,13 +887,9 @@ def _prepare_prompt( lora_index_mapping += [lora_id] * (max_prompt_len - context_len) lora_prompt_mapping.extend( - [lora_id] - * ( - max_prompt_len - context_len - if seq_group_metadata.sampling_params.prompt_logprobs - else 1 - ) - ) + [lora_id] * + (max_prompt_len - context_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if lora_mask is not None: lora_mask = lora_mask.to("hpu") @@ -972,9 +919,9 @@ def _prepare_prompt( device=self.device, ) - seq_lens_tensor = torch.tensor( - seq_lens, dtype=torch.long, device=self.device - ) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -1035,8 +982,7 @@ def _prepare_decode( ) dummy_slots = itertools.cycle( - range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size) - ) + range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -1061,11 +1007,8 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - seq_len = ( - seq_len - if self.sliding_window is None - else min(seq_len, self.sliding_window) - ) + seq_len = (seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window)) seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] @@ -1080,21 +1023,20 @@ def _prepare_decode( lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: - sliding_window_blocks = ( - self.sliding_window // self.block_size - ) + sliding_window_blocks = (self.sliding_window // + self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) if lora_mask is not None: lora_mask = lora_mask.to("hpu") lora_logits_mask = lora_mask - input_tokens = torch.tensor( - input_tokens, dtype=torch.long, device=self.device - ) - input_positions = torch.tensor( - input_positions, dtype=torch.long, device=self.device - ) + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) num_decode_tokens = sum(seq_lens) @@ -1104,38 +1046,34 @@ def _prepare_decode( [i] * b_u for i, b_u in enumerate(blocks_used) ] block_mapping: List[int] = list( - itertools.chain.from_iterable(block_mapping_nested) - ) + itertools.chain.from_iterable(block_mapping_nested)) last_block = [ sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) ] - block_usage = [ - [self.block_size] * (b_u - 1) + [lb] - for b_u, lb in zip(blocks_used, last_block) - ] + block_usage = [[self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block)] block_usage = list(itertools.chain(*block_usage)) - block_bucket_size = find_bucket( - len(block_list), self.decode_block_bucket_cfg - ) + block_bucket_size = find_bucket(len(block_list), + self.decode_block_bucket_cfg) block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) block_mapping = pad_list(block_mapping, block_bucket_size, 0) block_usage = pad_list(block_usage, block_bucket_size, 0) - block_list = torch.tensor( - block_list, dtype=torch.int, device=self.device - ) - block_mapping = torch.tensor( - block_mapping, dtype=torch.int, device=self.device - ) - block_usage = torch.tensor( - block_usage, dtype=torch.bfloat16, device=self.device - ) + block_list = torch.tensor(block_list, + dtype=torch.int, + device=self.device) + block_mapping = torch.tensor(block_mapping, + dtype=torch.int, + device=self.device) + block_usage = torch.tensor(block_usage, + dtype=torch.bfloat16, + device=self.device) - slot_mapping = torch.tensor( - slot_mapping, dtype=torch.long, device=self.device - ) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, @@ -1185,17 +1123,13 @@ def prepare_input_tensors( self.profiler.start("internal", base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = ( - self.prompt_bs_bucket_cfg - if is_prompt - else self.decode_bs_bucket_cfg - ) + bucket_cfg = (self.prompt_bs_bucket_cfg + if is_prompt else self.decode_bs_bucket_cfg) batch_size_padded = find_bucket(real_batch_size, bucket_cfg) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() - seq_group_metadata_list.extend( - seq_group_metadata_list[0] for _ in range(batch_size_padding) - ) + seq_group_metadata_list.extend(seq_group_metadata_list[0] + for _ in range(batch_size_padding)) prefill_reqs = [] decode_reqs = [] @@ -1250,8 +1184,8 @@ def prepare_input_tensors( # support mixed batches, so we either use decode or prefill # inputs, without coalescing. assert (num_prefills == 0 and num_decode_tokens > 0) or ( - num_prefills > 0 and num_decode_tokens == 0 - ), "HPU does not support mixed batches!" + num_prefills > 0 + and num_decode_tokens == 0), "HPU does not support mixed batches!" if num_decode_tokens > 0: input_tokens = decode_input_tokens input_positions = decode_input_positions @@ -1270,10 +1204,8 @@ def prepare_input_tensors( paddings = list(itertools.accumulate(paddings)) paddings_prompt_logprobs = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if ( - seq_group_metadata.sampling_params.prompt_logprobs is not None - and seq_group_metadata.is_prompt - ): + if (seq_group_metadata.sampling_params.prompt_logprobs is not None + and seq_group_metadata.is_prompt): paddings_prompt_logprobs += [paddings[i]] * seq_lens[i] paddings = torch.tensor( paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, @@ -1290,10 +1222,8 @@ def prepare_input_tensors( else: lora_mapping = None - if ( - prefill_attn_metadata is not None - and decode_attn_metadata is not None - ): + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): batch_type = BatchType.MIXED raise NotImplementedError("Mixed batch is not supported on HPU") elif prefill_attn_metadata is not None: @@ -1322,11 +1252,8 @@ def prepare_input_tensors( assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) - attn_metadata = ( - prefill_attn_metadata - if prefill_attn_metadata is not None - else decode_attn_metadata - ) + attn_metadata = (prefill_attn_metadata if prefill_attn_metadata + is not None else decode_attn_metadata) return self._model_input_cls( input_tokens=input_tokens, @@ -1385,9 +1312,11 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: ) return attention_metadata - def create_dummy_seq_group_metadata( - self, group_id, seq_len, is_prompt, lora_request=None - ): + def create_dummy_seq_group_metadata(self, + group_id, + seq_len, + is_prompt, + lora_request=None): sampling_params = SamplingParams(temperature=0) num_blocks = math.ceil(seq_len / self.block_size) if is_prompt: @@ -1423,17 +1352,18 @@ def profile_run(self) -> None: self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches) return - def warmup_scenario( - self, batch_size, seq_len, is_prompt, kv_caches, is_profile_run=False - ) -> None: + def warmup_scenario(self, + batch_size, + seq_len, + is_prompt, + kv_caches, + is_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - scenario_name = ( - "warmup_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}" - ) + scenario_name = ("warmup_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") max_num_seqs = self.scheduler_config.max_num_seqs # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory @@ -1451,9 +1381,8 @@ def warmup_scenario( lora_int_id=lora_id, lora_local_path="/not/a/real/path", ) - self.lora_manager.add_dummy_lora( - dummy_lora_request, rank=LORA_WARMUP_RANK - ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] @@ -1474,10 +1403,8 @@ def warmup_scenario( seq_len, is_prompt, lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq - else None, - ) - for i in range(batch_size) + if dummy_lora_requests_per_seq else None, + ) for i in range(batch_size) ] else: # FIXME: seq_len is actually number of blocks @@ -1489,10 +1416,8 @@ def warmup_scenario( b * self.block_size - 1, is_prompt, lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq - else None, - ) - for i, b in enumerate(blocks) + if dummy_lora_requests_per_seq else None, + ) for i, b in enumerate(blocks) ] torch.hpu.synchronize() profiler = None @@ -1515,9 +1440,8 @@ def remove_all_loras(self): raise RuntimeError("LoRA is not enabled.") self.lora_manager.remove_all_adapters() - def set_active_loras( - self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping - ) -> None: + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) @@ -1544,17 +1468,14 @@ def list_loras(self) -> Set[int]: def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( - HabanaMemoryProfiler.current_free_device_memory() - ) + HabanaMemoryProfiler.current_free_device_memory()) dim = "num_blocks" if phase == "Prompt": dim = "seq_len" - msg = ( - f"[Warmup][{phase}][{i+1}/{max_i}] " - f"batch_size:{batch_size} " - f"{dim}:{seq_len} " - f"free_mem:{free_mem}" - ) + msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"free_mem:{free_mem}") logger.info(msg) def warmup_all_buckets(self, buckets, is_prompt, kv_caches): @@ -1592,8 +1513,7 @@ def warmup_graphs( ordering = lambda b: (-b[0], b[1]) else: raise NotImplementedError( - f"Unsupported graph allocation strategy: {strategy}" - ) + f"Unsupported graph allocation strategy: {strategy}") buckets = list(sorted(buckets, key=ordering)) captured_all = True for idx, (batch_size, seq_len) in enumerate(buckets): @@ -1610,9 +1530,8 @@ def warmup_graphs( self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - used_mem = align_workers( - mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX - ) + used_mem = align_workers(mem_prof.consumed_device_memory, + torch.distributed.ReduceOp.MAX) available_mem -= used_mem total_mem += used_mem total_batch_seq += batch_seq @@ -1622,15 +1541,14 @@ def warmup_graphs( def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): num_candidates = len(buckets) phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' - graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) + graphed = list(c[:2] for c in self.graphed_buckets + if c[2] == is_prompt) if num_candidates == 0: num_candidates = 1 - msg = ( - f"{phase} captured:{len(graphed)} " - f"({100 * len(graphed) / num_candidates:.1f}%) " - f"used_mem:{format_bytes(total_mem)} " - f"buckets:{sorted(list(graphed))}" - ) + msg = (f"{phase} captured:{len(graphed)} " + f"({100 * len(graphed) / num_candidates:.1f}%) " + f"used_mem:{format_bytes(total_mem)} " + f"buckets:{sorted(list(graphed))}") logger.info(msg) @torch.inference_mode() @@ -1641,9 +1559,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: graphs = graph == "t" if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario( - int(bs), int(seq_len), is_prompt, kv_caches, True - ) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, + True) raise AssertionError("Finished profiling") if os.environ.get("VLLM_SKIP_WARMUP", "false").lower() == "true": logger.info("Skipping warmup...") @@ -1658,34 +1575,29 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: ) if self.lora_config: self.prompt_buckets[:] = [ - bucket - for bucket in self.prompt_buckets + bucket for bucket in self.prompt_buckets if self._is_valid_bucket(bucket) ] msg = ( f"Generated {len(self.prompt_buckets)} " - f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}" - ) + f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") logger.info(msg) - msg = ( - f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})" - ) + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") logger.info(msg) msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" logger.debug(msg) self.decode_buckets = generate_decode_buckets( - self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, max_blocks - ) + self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, + max_blocks) if self.lora_config: self.decode_buckets[:] = [ - bucket - for bucket in self.decode_buckets + bucket for bucket in self.decode_buckets if self._is_valid_bucket(bucket) ] logger.info( @@ -1697,9 +1609,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() - compile_only_mode_context = functools.partial( - bc.env_setting, "PT_COMPILE_ONLY_MODE", True - ) + compile_only_mode_context = functools.partial(bc.env_setting, + "PT_COMPILE_ONLY_MODE", + True) can_use_compile_only_mode = True try: with compile_only_mode_context(): @@ -1707,11 +1619,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.debug("Using PT_COMPILE_ONLY_MODE.") except KeyError: can_use_compile_only_mode = False - logger.warning( - "Cannot use PT_COMPILE_ONLY_MODE. " - "Warmup time will be negatively impacted. " - "Please update Gaudi Software Suite." - ) + logger.warning("Cannot use PT_COMPILE_ONLY_MODE. " + "Warmup time will be negatively impacted. " + "Please update Gaudi Software Suite.") with compile_only_mode_context() if can_use_compile_only_mode \ else contextlib.nullcontext(): self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) @@ -1720,37 +1630,29 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, ( "HabanaWorker.determine_num_available_blocks needs " - "to be called before warming up the model." - ) + "to be called before warming up the model.") free_mem = HabanaMemoryProfiler.current_free_device_memory() graph_free_mem = free_mem - self.mem_margin - graph_free_mem = align_workers( - graph_free_mem, torch.distributed.ReduceOp.MIN - ) + graph_free_mem = align_workers(graph_free_mem, + torch.distributed.ReduceOp.MIN) prompt_graph_mem_ratio = float( - os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.5") - ) - prompt_available_memory = ( - prompt_graph_mem_ratio * graph_free_mem - ) - decode_available_memory = ( - graph_free_mem - prompt_available_memory - ) + os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.5")) + prompt_available_memory = (prompt_graph_mem_ratio * + graph_free_mem) + decode_available_memory = (graph_free_mem - + prompt_available_memory) msg = ( f"Using {format_bytes(graph_free_mem)}" f"/{format_bytes(free_mem)} " "of free device memory for HPUGraphs, " f"{format_bytes(prompt_available_memory)} for prompt and " f"{format_bytes(decode_available_memory)} for decode " - f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" - ) + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") logger.info(msg) - prompt_strategy = os.environ.get( - "VLLM_GRAPH_PROMPT_STRATEGY", "min_tokens" - ) - decode_strategy = os.environ.get( - "VLLM_GRAPH_DECODE_STRATEGY", "max_bs" - ) + prompt_strategy = os.environ.get("VLLM_GRAPH_PROMPT_STRATEGY", + "min_tokens") + decode_strategy = os.environ.get("VLLM_GRAPH_DECODE_STRATEGY", + "max_bs") mem_post_prompt, prompt_batch_seq, prompt_captured_all = ( self.warmup_graphs( prompt_strategy, @@ -1758,8 +1660,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: True, kv_caches, prompt_available_memory, - ) - ) + )) mem_post_decode, decode_batch_seq, decode_captured_all = ( self.warmup_graphs( decode_strategy, @@ -1767,17 +1668,13 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: False, kv_caches, decode_available_memory, - ) - ) + )) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space # left. Let's try to use it for capturing more prompt buckets. - if ( - mem_post_decode + mem_post_prompt < graph_free_mem - and not prompt_captured_all - and decode_captured_all - ): + if (mem_post_decode + mem_post_prompt < graph_free_mem + and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( prompt_strategy, @@ -1787,17 +1684,13 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq, - ) - ) + )) # Not all decode buckets were captured, but all prompt buckets # were captured and we have some free graph-allocated space # left. Let's try to use it for capturing more decode buckets. - if ( - mem_post_decode + mem_post_prompt < graph_free_mem - and not decode_captured_all - and prompt_captured_all - ): + if (mem_post_decode + mem_post_prompt < graph_free_mem + and not decode_captured_all and prompt_captured_all): mem_post_decode, _, _ = self.warmup_graphs( decode_strategy, self.decode_buckets, @@ -1808,20 +1701,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: decode_batch_seq, ) - self.log_graph_warmup_summary( - self.prompt_buckets, True, mem_post_prompt - ) - self.log_graph_warmup_summary( - self.decode_buckets, False, mem_post_decode - ) + self.log_graph_warmup_summary(self.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary(self.decode_buckets, False, + mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() elapsed_time = end_time - start_time msg = ( f"Warmup finished in {elapsed_time:.0f} secs, " - f"allocated {format_bytes(end_mem - start_mem)} of device memory" - ) + f"allocated {format_bytes(end_mem - start_mem)} of device memory") logger.info(msg) self.profiler.end() @@ -1839,16 +1729,14 @@ def mem_margin(self, value): def _maybe_wrap_in_hpu_graph(*args, **kwargs): - return ( - htorch.hpu.wrap_in_hpu_graph( - HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True - ) - if htorch.utils.internal.is_lazy() - else HpuModelAdapter(*args, **kwargs) - ) + return (htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter(*args, **kwargs), + disable_tensor_cache=True) + if htorch.utils.internal.is_lazy() else HpuModelAdapter( + *args, **kwargs)) class HabanaProfilerCounterHelper: + def __init__(self): self.niter = 0 self.average_real_throughput = None @@ -1888,8 +1776,8 @@ def get_counter_dict( self.average_real_throughput = throughput_effective else: # https://www.heikohoffmann.de/htmlthesis/node134.html self.average_real_throughput = self.average_real_throughput + 1 / ( - self.niter + 1 - ) * (throughput_effective - self.average_real_throughput) + self.niter + 1) * (throughput_effective - + self.average_real_throughput) phase = "prompt" if is_prompt else "decode" counters = { f"{phase}_bucket_batch_size": batch_size_padded, @@ -1905,48 +1793,38 @@ def get_counter_dict( self.niter += 1 if is_prompt: prompt_bucket_in_throughput = (seq_len * batch_size_padded) / ( - duration / 1e6 - ) - prompt_real_in_throughput = sum(self.prompt_seq_lens) / ( - duration / 1e6 - ) + duration / 1e6) + prompt_real_in_throughput = sum( + self.prompt_seq_lens) / (duration / 1e6) counters[f"{phase}_bucket_in_throughput"] = ( - prompt_bucket_in_throughput - ) + prompt_bucket_in_throughput) counters[f"{phase}_real_in_throughput"] = prompt_real_in_throughput # KV cache might not be created yet (e.g. for profiling run) - if ( - cache_config.num_gpu_blocks is not None - and cache_config.num_gpu_blocks != 0 - ): + if (cache_config.num_gpu_blocks is not None + and cache_config.num_gpu_blocks != 0): cache_num_blocks_used = [ math.ceil(sl / cache_config.block_size) for sl in self.real_seq_lens ] cache_total_num_blocks_used = sum(cache_num_blocks_used) num_cache_blocks = cache_config.num_gpu_blocks - cache_total_num_free_blocks = ( - num_cache_blocks - cache_total_num_blocks_used - ) - cache_computed_utilization = ( - cache_total_num_blocks_used / num_cache_blocks - ) + cache_total_num_free_blocks = (num_cache_blocks - + cache_total_num_blocks_used) + cache_computed_utilization = (cache_total_num_blocks_used / + num_cache_blocks) max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size) batch_block_utilization = cache_total_num_blocks_used / ( - batch_size_padded * max_blocks_per_seq - ) + batch_size_padded * max_blocks_per_seq) counters["cache_num_blocks_used"] = cache_total_num_blocks_used counters["cache_num_free_blocks"] = cache_total_num_free_blocks counters["cache_computed_utilization"] = cache_computed_utilization counters[f"{phase}_batch_block_utilization"] = ( - batch_block_utilization - ) + batch_block_utilization) if not self.logged_once: counters["const_cache_num_blocks"] = cache_config.num_gpu_blocks counters["const_gpu_memory_utilization"] = ( - cache_config.gpu_memory_utilization - ) + cache_config.gpu_memory_utilization) counters["const_block_size"] = cache_config.block_size self.logged_once = True return counters @@ -1962,15 +1840,13 @@ def unwrap_model(model): class HabanaModelRunner( - HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata] -): + HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ _model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = ( - ModelInputForHPUWithSamplingMetadata - ) + ModelInputForHPUWithSamplingMetadata) def make_model_input_from_broadcasted_tensor_dict( self, @@ -1980,8 +1856,7 @@ def make_model_input_from_broadcasted_tensor_dict( ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, - ) - ) + )) def prepare_model_input( self, @@ -2001,11 +1876,9 @@ def prepare_model_input( with self.profiler.record_event("internal", "prepare_input_tensors"): assert seq_group_metadata_list is not None self.profiler_counter_helper.capture_seq_group_metadata_stats( - seq_group_metadata_list=seq_group_metadata_list - ) + seq_group_metadata_list=seq_group_metadata_list) model_input, sampling_metadata = self.prepare_input_tensors( - seq_group_metadata_list - ) + seq_group_metadata_list) assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt @@ -2045,15 +1918,13 @@ def execute_model( ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( - "num_steps > 1 is not supported in HabanaModelRunner" - ) + "num_steps > 1 is not supported in HabanaModelRunner") if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None - self.set_active_loras( - model_input.lora_requests, model_input.lora_mapping - ) + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata @@ -2086,19 +1957,18 @@ def execute_model( htorch.core.mark_step() if self.is_driver_worker: - model_event_name = ( - "model_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}" - ) + model_event_name = ("model_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = "model_executable" with self.profiler.record_event("internal", model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata.selected_token_indices, + selected_token_indices=sampling_metadata. + selected_token_indices, ) if self.lora_config: @@ -2109,27 +1979,23 @@ def execute_model( if isinstance(module, VocabParallelEmbeddingWithLoRA): for i in range(0, len(module.indices_len)): module.indices_len[i] = ( - sampling_metadata.selected_token_indices.numel() - ) + sampling_metadata.selected_token_indices.numel()) lora_logits_mask: torch.Tensor = model_input.lora_logits_mask LoraMask.setLoraMask( lora_logits_mask.index_select( - 0, sampling_metadata.selected_token_indices - ) - ) + 0, sampling_metadata.selected_token_indices)) # Compute the logits. with self.profiler.record_event( - "internal", - ( - 'compute_logits_' - f'{"prompt" if is_prompt else "decode"}_bs' - f'{batch_size}_' - f'seq{seq_len}' - ), + "internal", + ('compute_logits_' + f'{"prompt" if is_prompt else "decode"}_bs' + f'{batch_size}_' + f'seq{seq_len}'), ): sampling_metadata.selected_token_indices = None - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, + sampling_metadata) htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -2137,13 +2003,11 @@ def execute_model( # Sample the next token. with self.profiler.record_event( - "internal", - ( - 'sample_' - f'{"prompt" if is_prompt else "decode"}_' - f'bs{batch_size}_' - f'seq{seq_len}' - ), + "internal", + ('sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}'), ): output = self.model.sample( logits=logits, @@ -2170,8 +2034,7 @@ def execute_model( def shutdown_inc(self): print("inc shutdown") if (model_config := getattr(self, "model_config", None)) and getattr( - model_config, "quantization", None - ) == "inc": + model_config, "quantization", None) == "inc": print("inc shutdown start") from neural_compressor.torch.quantization import ( finalize_calibration) From 0f40204b99b89b24995bd17b9ffe328510a156ab Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Tue, 17 Sep 2024 17:42:41 +0300 Subject: [PATCH 10/24] revert commit --- vllm/hpu/ops.py | 80 +- .../layers/fused_moe/fused_moe.py | 278 +++---- .../compressed_tensors/compressed_tensors.py | 114 ++- .../schemes/compressed_tensors_w8a8_fp8.py | 56 +- .../model_executor/layers/quantization/fp8.py | 205 ++--- .../layers/quantization/utils/w8a8_utils.py | 110 +-- vllm/model_executor/models/llama.py | 112 +-- vllm/worker/habana_model_runner.py | 738 ++++++++---------- 8 files changed, 688 insertions(+), 1005 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index ddb27de19d75..323a33e9fa2a 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -16,7 +16,6 @@ HPUFusedRMSNorm = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm - HPUFusedRMSNorm = FusedRMSNorm except ImportError: logger.warning("Could not import HPU FusedRMSNorm kernel. " @@ -24,7 +23,6 @@ HPUFusedSDPA = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA - HPUFusedSDPA = FusedSDPA except ImportError: logger.warning("Could not import HPU FusedSDPA kernel. " @@ -63,19 +61,9 @@ def block_softmax(batch_size, attn, block_mapping): return attn -def flat_pa( - query, - key_cache, - value_cache, - block_list, - block_mapping, - block_bias, - scale, - matmul_qk_op, - matmul_av_op, - keys_fetch_func, - values_fetch_func, -): +def flat_pa(query, key_cache, value_cache, block_list, block_mapping, + block_bias, scale, matmul_qk_op, matmul_av_op, keys_fetch_func, + values_fetch_func): batch_size = query.size(0) q_heads = query.size(1) kv_heads = key_cache.size(2) @@ -109,7 +97,7 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] -# TODO: remove after fusedsdpa fix for query_head != kv_head +#TODO: remove after fusedsdpa fix for query_head != kv_head def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). @@ -156,25 +144,15 @@ def prompt_attention( if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) else: - # TODO: remove after fusedsdpa fix for query_heads != kv_heads + #TODO: remove after fusedsdpa fix for query_heads != kv_heads if query_heads != kv_heads: key = repeat_kv(key, int(query_heads // kv_heads)) value = repeat_kv(value, int(query_heads // kv_heads)) - softmax_mode = "fast" + softmax_mode = 'fast' recompute_mode = True - attn_weights = FusedSDPA.apply( - query, - key, - value, - None, - 0.0, - True, - scale, - softmax_mode, - recompute_mode, - valid_seq_lengths, - "right", - ) + attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True, + scale, softmax_mode, recompute_mode, + valid_seq_lengths, 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights @@ -212,7 +190,7 @@ def dispatch_bgmv_linear( the final output. """ - assert layer_idx == 0, f"layer_idx should be 0, but got {layer_idx}" + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' mask = LoraMask.getLoraMask() wa = wa_t_all[:, 0, :, :] @@ -221,7 +199,7 @@ def dispatch_bgmv_linear( wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2]) out = x @ wa - assert out.shape == mask.shape + assert (out.shape == mask.shape) out = out * mask out = out @ wb y += out * scale @@ -246,7 +224,7 @@ def dispatch_bgmv_embedding( output. """ - assert layer_idx == 0, f"layer_idx should be 0, but got {layer_idx}" + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' max_loras = wb_t_all.size(0) x = x.repeat(1, max_loras) @@ -294,11 +272,9 @@ def forward(self, hidden_states, w1, w2, score, topk): final_hidden_states = torch.zeros((1, B, D), dtype=hidden_states.dtype, device=hidden_states.device) - padded_weights = torch.zeros( - (B, self.num_total_experts), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + padded_weights = torch.zeros((B, self.num_total_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) padded_weights.scatter_(-1, selected_experts, routing_weights) padded_weights = padded_weights.reshape(-1, B, self.num_total_experts) padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) @@ -316,7 +292,6 @@ def forward(self, hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) - # fp8 def scaled_fp8_quant( input: torch.Tensor, @@ -325,6 +300,7 @@ def scaled_fp8_quant( scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Quantize input tensor to FP8 and return quantized tensor and scale. This function supports both static and dynamic quantization: If you @@ -335,11 +311,11 @@ def scaled_fp8_quant( Args: input: The input tensor to be quantized to FP8 scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic + scale_ub: Optional upper bound for scaling factor in dynamic per token case batch_dim_padding: If specified, pad the first dimension of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token + use_per_token_if_dynamic: Whether to do per_tensor or per_token in the dynamic quantization case. Returns: Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and @@ -353,24 +329,18 @@ def scaled_fp8_quant( else: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: - raise RuntimeError("dynamic scaled_fp8_quant not implemented for HPU") - # TODO: calculate scale to match gaudi2 240 range instead of 448 + raise "dynamic scaled_fp8_quant not implemented for HPU" + #TODO: calculate scale to match gaudi2 240 range instead of 448 if use_per_token_if_dynamic: - scale = torch.empty( - (input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32, - ) + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - output = torch.ops.hpu.cast_to_fp8_v2(input, - 1 / scale, - False, - False, - dtype=torch.float8_e4m3fn)[0] + output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0] - return output, scale + return output, scale \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 39fa611e5157..3682362c5a86 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,5 +1,4 @@ """Fused MoE kernel.""" - import functools import json import os @@ -16,7 +15,6 @@ if current_platform.is_hpu(): from vllm.hpu.ops import scaled_fp8_quant - ops.scaled_fp8_quant = scaled_fp8_quant logger = init_logger(__name__) @@ -118,8 +116,8 @@ def fused_moe_kernel( offs_k[None, :] * stride_ak) off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = (b_ptr + off_experts * stride_be + - (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) if use_fp8: a_scale = tl.load(a_scale_ptr) @@ -135,12 +133,10 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) @@ -166,8 +162,8 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + - stride_cn * offs_cn[None, :]) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @@ -224,34 +220,21 @@ def moe_align_block_size( num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: Dict[str, Any], - compute_type: tl.dtype, - use_fp8: bool, -) -> None: +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -263,7 +246,7 @@ def invoke_fused_moe_kernel( assert B_scale is not None grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) fused_moe_kernel[grid]( A, @@ -339,17 +322,17 @@ def get_default_config( dtype: Optional[str], ) -> Dict[str, int]: config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 } if M <= E: config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 } return config @@ -385,8 +368,8 @@ def fused_topk( topk: int, renormalize: bool, ): - assert (hidden_states.shape[0] == gating_output.shape[0] - ), "Number of tokens mismatch" + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") M, _ = hidden_states.shape @@ -416,28 +399,27 @@ def fused_topk( # This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -): - assert (hidden_states.shape[0] == gating_output.shape[0] - ), "Number of tokens mismatch" +def grouped_topk(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0): + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] - group_scores = (scores.view(num_token, num_expert_group, - -1).max(dim=-1).values) # [n, n_group] + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = (group_mask.unsqueeze(-1).expand( + score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1)) # [n, e] + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, @@ -449,20 +431,18 @@ def grouped_topk( return topk_weights, topk_ids -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -): +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" @@ -491,21 +471,15 @@ def fused_experts( config = get_config_func(M) - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) @@ -516,10 +490,9 @@ def fused_experts( out_hidden_states = torch.empty_like(hidden_states) for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = ( - chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, num_tokens), - ) + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape @@ -540,51 +513,45 @@ def fused_experts( curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config["BLOCK_SIZE_M"], E)) - - invoke_fused_moe_kernel( - curr_hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) + + invoke_fused_moe_kernel(curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8) + + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states @@ -641,29 +608,22 @@ def fused_moe( if use_grouped_topk: assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states, - gating_output, - topk, - renormalize, - num_expert_group, - topk_group, - ) + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) else: topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) - return fused_experts( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - override_config=override_config, - use_fp8=use_fp8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index f600f92efae6..badb29af1f5f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -22,13 +22,12 @@ class CompressedTensorsConfig(QuantizationConfig): - def __init__( - self, - target_scheme_map: Dict[str, Any], - ignore: List[str], - quant_format: str, - kv_cache_scheme: Optional[Dict[str, Any]] = None, - ): + def __init__(self, + target_scheme_map: Dict[str, Any], + ignore: List[str], + quant_format: str, + kv_cache_scheme: Optional[Dict[str, Any]] = None): + self.ignore = ignore self.quant_format = quant_format # Map from [target -> scheme] @@ -59,7 +58,6 @@ def get_quant_method( prefix: str, ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import - if isinstance(layer, LinearBase): return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): @@ -84,21 +82,20 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": targets = quant_config.get("targets") for target in targets: target_scheme_map[target] = {} - target_scheme_map[target]["weights"] = ( - QuantizationArgs.parse_obj(quant_config.get("weights"))) + target_scheme_map[target][ + "weights"] = QuantizationArgs.parse_obj( + quant_config.get("weights")) try: - target_scheme_map[target]["input_activations"] = ( - QuantizationArgs.parse_obj( - quant_config.get("input_activations"))) + target_scheme_map[target][ + "input_activations"] = QuantizationArgs.parse_obj( + quant_config.get("input_activations")) except Exception: target_scheme_map[target]["input_activations"] = None - return cls( - target_scheme_map=target_scheme_map, - ignore=ignore, - quant_format=quant_format, - kv_cache_scheme=config.get("kv_cache_scheme"), - ) + return cls(target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + kv_cache_scheme=config.get("kv_cache_scheme")) @classmethod def get_config_filenames(cls) -> List[str]: @@ -114,8 +111,7 @@ def _check_scheme_supported(self, raise RuntimeError( "Quantization scheme is not supported for ", f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.", - ) + f"Current capability: {capability}.") return supported def _is_static_tensor_w8a8(self, weight_quant: BaseModel, @@ -158,10 +154,9 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = weight_quant.strategy in [ - QuantizationStrategy.TENSOR, - QuantizationStrategy.CHANNEL, - ] + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) if not (is_symmetric_weight and is_static_weight and is_per_tensor_or_channel_weight): return False @@ -193,10 +188,9 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = weight_quant.strategy in [ - QuantizationStrategy.TENSOR, - QuantizationStrategy.CHANNEL, - ] + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) if not (is_symmetric_weight and is_static_weight and is_per_tensor_or_channel_weight): return False @@ -219,6 +213,7 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": + # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value @@ -226,41 +221,34 @@ def _get_scheme_from_parts( return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, - group_size=weight_quant.group_size, - ) + group_size=weight_quant.group_size) if (self.quant_format == CompressionFormat.pack_quantized.value and weight_quant.num_bits in WNA16_SUPPORTED_BITS): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, - group_size=weight_quant.group_size, - ) + group_size=weight_quant.group_size) # Detect If Activation Quantization. if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - is_fp8_w8a8_supported = (self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), - error=False, - ) if torch.cuda.is_available() else True) + is_fp8_w8a8_supported = self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=(not input_quant.dynamic), - ) + is_static_input_scheme=(not input_quant.dynamic)) else: return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, is_static_input_scheme=(input_quant - and not input_quant.dynamic), - ) + and not input_quant.dynamic)) if self._is_fp8_w8a16(weight_quant, input_quant): return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, is_static_input_scheme=(input_quant - and not input_quant.dynamic), - ) + and not input_quant.dynamic)) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( @@ -308,15 +296,13 @@ def get_scheme( matched_target = find_matched_target( layer_name=layer_name, module=layer, - targets=self.target_scheme_map.keys(), - ) + targets=self.target_scheme_map.keys()) # Find the quant_scheme scheme_dict = self.target_scheme_map[matched_target] scheme = self._get_scheme_from_parts( weight_quant=scheme_dict["weights"], - input_quant=scheme_dict["input_activations"], - ) + input_quant=scheme_dict["input_activations"]) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) @@ -334,18 +320,13 @@ def __init__(self, quantization_config: CompressedTensorsConfig): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): """ - Use the CompressedTensorsScheme associated with each layer to create + Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param details """ @@ -360,20 +341,17 @@ def create_weights( output_partition_sizes=output_partition_sizes, output_size=output_size, params_dtype=params_dtype, - weight_loader=weight_loader, - ) + weight_loader=weight_loader) layer.scheme = scheme - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ): + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): """ - Use the output of create_weights and the CompressedTensorsScheme - associated with the layer to apply the forward pass with the + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the layer input. See LinearMethodBase for param details """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee772c2951e8..631774994b5c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,8 +21,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = (cutlass_fp8_supported() - if torch.cuda.is_available() else False) + self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False @classmethod def get_min_capability(cls) -> int: @@ -58,36 +57,25 @@ def process_weights_after_loading(self, layer) -> None: else: layer.input_scale = None - def create_weights( - self, - layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, - **kwargs, - ): + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes # WEIGHT - weight = torch.nn.Parameter( - torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn, - ), - requires_grad=False, - ) + weight = torch.nn.Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + requires_grad=False) layer.register_parameter("weight", weight) - set_weight_attrs( - weight, - { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - }, - ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": weight_loader, + }) # WEIGHT SCALE layer_kwargs = {"weight_loader": weight_loader} @@ -106,12 +94,11 @@ def create_weights( output_partition_sizes, **layer_kwargs) layer.register_parameter("input_scale", input_scale) - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return apply_fp8_linear( input=x, weight=layer.weight, @@ -119,5 +106,4 @@ def apply_weights( input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, - use_per_token_if_dynamic=True, - ) + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 28fcceb449af..f3e304ce141c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,10 +23,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once - if current_platform.is_hpu(): from vllm.hpu.ops import scaled_fp8_quant - ops.scaled_fp8_quant = scaled_fp8_quant ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -72,14 +70,12 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) - is_checkpoint_fp8_serialized = "fp8" in quant_method + is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) - return cls( - is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, - activation_scheme=activation_scheme, - ignored_layers=ignored_layers, - ) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: @@ -122,8 +118,8 @@ def __init__(self, quant_config: Fp8Config): if torch.cuda.is_available(): self.cutlass_fp8_supported = cutlass_fp8_supported() - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 @@ -154,23 +150,16 @@ def create_weights( weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype) - weight = Parameter( - torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype, - ), - requires_grad=False, - ) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + requires_grad=False) layer.register_parameter("weight", weight) - set_weight_attrs( - weight, - { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }, - ) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. @@ -231,12 +220,11 @@ def process_weights_after_loading(self, layer: Module) -> None: # Activations not quantized for marlin. del layer.input_scale - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -245,8 +233,7 @@ def apply( workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias, - ) + bias=bias) return apply_fp8_linear( input=x, @@ -255,8 +242,7 @@ def apply( input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, - use_per_token_if_dynamic=False, - ) + use_per_token_if_dynamic=False) class Fp8MoEMethod(FusedMoEMethodBase): @@ -275,38 +261,27 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - def create_weights( - self, - layer: Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False, - ) + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -338,17 +313,15 @@ def create_weights( "Found static activation scheme for checkpoint that " "was not serialized fp8.") - a13_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), - requires_grad=False, - ) + a13_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) layer.register_parameter("a13_scale", a13_scale) set_weight_attrs(a13_scale, extra_weight_attrs) - a2_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), - requires_grad=False, - ) + a2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) layer.register_parameter("a2_scale", a2_scale) set_weight_attrs(a2_scale, extra_weight_attrs) else: @@ -356,6 +329,7 @@ def create_weights( layer.a2_scale = None def process_weights_after_loading(self, layer: Module) -> None: + # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: w13_weight = torch.empty_like(layer.w13_weight.data, @@ -365,19 +339,18 @@ def process_weights_after_loading(self, layer: Module) -> None: # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter( - torch.ones( - layer.num_experts, - dtype=torch.float32, - device=w13_weight.device, - ), - requires_grad=False, - ) + layer.w13_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])) - w2_weight[expert, :, :], layer.w2_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])) + w13_weight[expert, :, :], layer.w13_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, @@ -395,8 +368,8 @@ def process_weights_after_loading(self, layer: Module) -> None: raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") - if not all_close_1d(layer.a13_scale) or not all_close_1d( - layer.a2_scale): + if (not all_close_1d(layer.a13_scale) + or not all_close_1d(layer.a2_scale)): print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " @@ -417,50 +390,42 @@ def process_weights_after_loading(self, layer: Module) -> None: dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_scale[expert_id][shard_id], - ) - ( - layer.w13_weight[expert_id][start:start + - shard_size, :], - _, - ) = ops.scaled_fp8_quant(dq_weight, - max_w13_scales[expert_id]) + layer.w13_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) start += shard_size layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) return - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_moe + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) + from vllm.model_executor.layers.fused_moe import fused_moe + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 2ac620742680..8904c9fa1789 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,15 +6,11 @@ from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform - if current_platform.is_hpu(): import habana_frameworks.torch.utils.experimental as htexp - from vllm.hpu.ops import scaled_fp8_quant - ops.scaled_fp8_quant = scaled_fp8_quant - def cutlass_fp8_supported() -> bool: capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] @@ -30,8 +26,8 @@ def per_tensor_dequantize( if current_platform.is_hpu(): dtype = torch.bfloat16 if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - # dequant on cpu to avoid nan on gaudi2 - tensor = tensor.to("cpu") + #dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to('cpu') fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale @@ -47,10 +43,9 @@ def create_per_tensor_scale_param( output_partition_sizes: List[int], **extra_weight_attrs, ) -> Parameter: - scale = Parameter( - torch.empty(len(output_partition_sizes), dtype=torch.float32), - requires_grad=False, - ) + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) scale[:] = torch.finfo(torch.float32).min set_weight_attrs(scale, { "needs_scalar_to_array": True, @@ -61,10 +56,9 @@ def create_per_tensor_scale_param( def create_per_channel_scale_param(output_partition_sizes: List[int], **extra_weight_attrs) -> Parameter: - scale = Parameter( - torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), - requires_grad=False, - ) + scale = Parameter(torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + requires_grad=False) scale[:] = torch.finfo(torch.float32).min set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs}) return scale @@ -74,11 +68,9 @@ def convert_to_channelwise( weight_scale: torch.Tensor, logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Create channelwise buffer - weight_scale_channel = torch.empty( - (sum(logical_widths), 1), - dtype=torch.float32, - device=weight_scale.device, - ) + weight_scale_channel = torch.empty((sum(logical_widths), 1), + dtype=torch.float32, + device=weight_scale.device) # Expand each scale to match the size of each logical matrix. start = 0 @@ -95,18 +87,16 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - if (current_platform.is_hpu() and htexp._get_device_type() - == htexp.synDeviceType.synDeviceGaudi2): - max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max / - torch.finfo(torch.float8_e4m3fnuz).max) + if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale # from disk in this case. Skip requantization in this case (since) # we already are quantized with the single scale. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = (weight_scale[-1] - > torch.finfo(torch.float8_e4m3fn).min) + unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo( + torch.float8_e4m3fn).min) # If unfused checkpoint, need requanize with the single scale. if unfused_module_in_checkpoint: @@ -142,18 +132,15 @@ def apply_fp8_linear( input, input_scale, scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic, - ) + use_per_token_if_dynamic=use_per_token_if_dynamic) # Fused GEMM_DQ - return ops.cutlass_scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) + return ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token @@ -165,37 +152,26 @@ def apply_fp8_linear( input, input_scale, batch_dim_padding=17, - use_per_token_if_dynamic=use_per_token_if_dynamic, - ) + use_per_token_if_dynamic=use_per_token_if_dynamic) - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = x_scale.numel() == 1 + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ if current_platform.is_hpu(): - # hpu does not support torch._scaled_mm (SW-197036) - output = torch.ops.hpu.fp8_gemm_v2( - qinput, - False, - weight, - False, - None, - input.dtype, - x_scale, - weight_scale, - None, - False, - ) + #hpu does not support torch._scaled_mm (SW-197036) + output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight, + False, None, input.dtype, + x_scale, weight_scale, None, + False) else: - output, _ = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) return torch.narrow(output, 0, 0, input.shape[0]) else: @@ -243,11 +219,9 @@ def apply_int8_linear( # * static, layer.input_scale is scalar and x_scale is input_scale. x_q, x_scale = ops.scaled_int8_quant(input, input_scale) - return ops.cutlass_scaled_mm( - x_q, - weight, - scale_a=x_scale, - scale_b=weight_scale, - out_dtype=input.dtype, - bias=bias, - ) + return ops.cutlass_scaled_mm(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + bias=bias) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 124386b61e4c..8ccefe7be33f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,7 +21,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" - from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch @@ -55,7 +54,7 @@ from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers - +from vllm.platforms import current_platform if current_platform.is_hpu(): import habana_frameworks.torch.core as htcore @@ -77,15 +76,12 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - ) + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -163,14 +159,12 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -279,8 +273,8 @@ def __init__( super().__init__() self.config = config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings @@ -294,14 +288,11 @@ def __init__( self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: LlamaDecoderLayer( - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - ), - prefix=f"{prefix}.layers", - ) + lambda prefix: LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -330,9 +321,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if current_platform.is_hpu(): + if is_hpu: import habana_frameworks.torch as htorch - htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] @@ -343,7 +333,7 @@ def forward( attn_metadata, residual, ) - if current_platform.is_hpu(): + if is_hpu: htorch.core.mark_step() if not get_pp_group().is_last_rank: @@ -371,12 +361,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" ] embedding_modules = { "embed_tokens": "input_embeddings", @@ -404,13 +390,11 @@ def __init__( self.config = config self.lora_config = lora_config - self.model = LlamaModel( - config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model", - ) + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -443,16 +427,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None, + input_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - input_embeds, - ) + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + input_embeds) return model_output def compute_logits(self, hidden_states: torch.Tensor, @@ -474,17 +453,13 @@ def make_empty_intermediate_tensors( device: torch.device) -> IntermediateTensors: return IntermediateTensors({ "hidden_states": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), "residual": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), }) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -513,7 +488,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) continue - for param_name, weight_name, shard_id in stacked_params_mapping: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -556,12 +531,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, + quantization_param_path, tp_rank, tp_size, self.config.num_hidden_layers, - self.config.__class__.model_type, - ): + self.config.__class__.model_type): if not isinstance(self.model.layers[layer_idx], nn.Identity): layer_self_attn = self.model.layers[layer_idx].self_attn diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 9cb4a4915c01..b0b9114ac2d0 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -60,12 +60,10 @@ LORA_WARMUP_RANK = 8 -def subtuple( - obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None, -): +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): if obj is None: return None if to_override is None: @@ -74,7 +72,7 @@ def subtuple( values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: _TYPE_CACHE[typename] = collections.namedtuple(typename, - " ".join(fields)) + ' '.join(fields)) return _TYPE_CACHE[typename](**values) @@ -86,14 +84,14 @@ def read_bucket_settings(phase: str, dim: str, **defaults): param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ - params = ["min", "step", "max"] - env_vars = [f"VLLM_{phase}_{dim}_BUCKET_{p}".upper() for p in params] + params = ['min', 'step', 'max'] + env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] default_values = [defaults[p] for p in params] values = [ int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) ] for e, v, d in zip(env_vars, values, defaults): - logger.info("%s=%s (default:%s)", e, v, d) + logger.info('%s=%s (default:%s)', e, v, d) return values @@ -101,7 +99,7 @@ def warmup_range(config: Tuple[int, int, int]): """Generate a warmup range. Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you + Then, increase the values in the range by the value of bstep until you reach bmax. Example: @@ -116,8 +114,8 @@ def warmup_range(config: Tuple[int, int, int]): "set VLLM_SKIP_WARMUP=true") base = itertools.repeat(2) ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, - ramp_up_acc) + ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ + ramp_up_acc) stable = range(bstep, bmax + 1, bstep) buckets = list(ramp_up_tw) + list(stable) return list(filter(lambda bucket: bucket >= bmin, buckets)) @@ -142,8 +140,7 @@ def generate_prompt_buckets(bs_bucket_config, filtered_buckets = list( filter( lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets, - )) + buckets)) if len(filtered_buckets) == 0: # we can handle this if we ignore max_num_batched_tokens @@ -208,29 +205,28 @@ def align_workers(value, op): world_size = torch.distributed.get_world_size() if world_size <= 1: return value - value_t = torch.tensor(value, device="cpu") + value_t = torch.tensor(value, device='cpu') torch.distributed.all_reduce(value_t, op=op, group=group) return value_t.item() def setup_profiler(): schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) - DEVICE = "hpu" + DEVICE = 'hpu' activities = [torch.profiler.ProfilerActivity.CPU] activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == - "hpu" else []) - # from habana_frameworks.torch.activity_profiler import DebugActivity - # debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] + 'hpu' else []) + #from habana_frameworks.torch.activity_profiler import DebugActivity + #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] profiler = torch.profiler.profile( schedule=schedule, activities=activities, - # debug_activities=debug_activities, - on_trace_ready=torch.profiler.tensorboard_trace_handler(".", + #debug_activities=debug_activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler('.', use_gzip=True), record_shapes=False, - with_stack=True, - ) + with_stack=True) return profiler @@ -240,17 +236,17 @@ def pad_list(list, k, v): return list + [v] * padding -class HpuModelAdapter: +class HpuModelAdapter(): def __init__(self, model, block_size, dtype, enforce_eager): self.model = model - self.prefill_use_fusedsdpa = os.getenv("VLLM_PROMPT_USE_FUSEDSDPA", - "0").lower() in ["1", "true"] + self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, - backend="hpu_backend", + backend='hpu_backend', dynamic=False) def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, @@ -264,17 +260,13 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype=torch.int32).view(1, seq_len).ge( seq_lens_t.unsqueeze(-1)).view( batch_size, 1, 1, seq_len)) - causal_mask = torch.triu( - torch.ones( - (batch_size, 1, seq_len, seq_len), - device=device, - dtype=torch.bool, - ), - diagonal=1, - ) + causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), + device=device, + dtype=torch.bool), + diagonal=1) mask = causal_mask.logical_or(len_mask) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata @@ -284,8 +276,8 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): device=device, dtype=torch.int32).unsqueeze(0) mask = mask >= metadata.block_usage.unsqueeze(-1) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) block_mapping = torch.nn.functional.one_hot( metadata.block_mapping.to(torch.long), num_classes=batch_size).to(dtype) @@ -307,18 +299,14 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, def forward(self, *args, **kwargs): kwargs = kwargs.copy() - selected_token_indices = kwargs.pop("selected_token_indices") - if "warmup_mode" in kwargs: - kwargs.pop("warmup_mode") - input_ids = kwargs["input_ids"] - kwargs["attn_metadata"] = self._update_metadata( - kwargs["attn_metadata"], - input_ids.size(0), - input_ids.size(1), - input_ids.device, - self.dtype, - ) - LoraMask.setLoraMask(kwargs.pop("lora_mask")) + selected_token_indices = kwargs.pop('selected_token_indices') + if 'warmup_mode' in kwargs: + kwargs.pop('warmup_mode') + input_ids = kwargs['input_ids'] + kwargs['attn_metadata'] = self._update_metadata( + kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), + input_ids.device, self.dtype) + LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -347,20 +335,18 @@ class PreparePromptMetadata(NamedTuple): @classmethod def empty(cls): - return PreparePromptMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - multi_modal_input=None, - slot_mapping=[], - lora_mask=None, - lora_logits_mask=None, - ) + return PreparePromptMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + lora_mask=None, + lora_logits_mask=None) class PrepareDecodeMetadata(NamedTuple): @@ -376,17 +362,15 @@ class PrepareDecodeMetadata(NamedTuple): @classmethod def empty(cls): - return PrepareDecodeMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], - lora_mask=None, - lora_logits_mask=None, - ) + return PrepareDecodeMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + lora_mask=None, + lora_logits_mask=None) # How batches are constructed. @@ -399,7 +383,7 @@ class BatchType(IntEnum): MIXED = 2 -TModelInputForHPU = TypeVar("TModelInputForHPU", bound="ModelInputForHPU") +TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU") @dataclasses.dataclass(frozen=True) @@ -410,7 +394,6 @@ class ModelInputForHPU(ModelRunnerInputBase): runners that run additional steps should subclass this method to add additional fields. """ - input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None @@ -458,7 +441,6 @@ class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU): """ Used by the ModelRunner. """ - sampling_metadata: Optional["SamplingMetadata"] = None # Used for speculative decoding. We do not broadcast it because it is only # used by the driver worker. @@ -497,7 +479,6 @@ class HabanaModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): """ Helper class for shared methods between GPU model runners. """ - _model_input_cls: Type[TModelInputForHPU] def __init__( @@ -531,8 +512,8 @@ def __init__( self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs self.max_model_len = self.scheduler_config.max_model_len - self.max_num_batched_tokens = ( - self.scheduler_config.max_num_batched_tokens) + self.max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens self.block_size = cache_config.block_size self.pin_memory = is_pin_memory_available() @@ -570,9 +551,9 @@ def _set_gc_threshold(self) -> None: requested_gc_thrs = [0] * len(default_gc_thrs) for i in range(len(default_gc_thrs)): requested_gc_thrs[i] = int( - os.environ.get(f"VLLM_GC_THR_GEN{i}", default_gc_thrs[i])) + os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i])) if requested_gc_thrs == default_gc_thrs: - gc_thr_multiplier = int(os.environ.get("VLLM_GC_THR_MULTIPLIER", + gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER', 2)) requested_gc_thrs = [ t * gc_thr_multiplier for t in default_gc_thrs @@ -581,7 +562,6 @@ def _set_gc_threshold(self) -> None: def load_model(self) -> None: import habana_frameworks.torch.core as htcore - htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: @@ -593,17 +573,16 @@ def load_model(self) -> None: multimodal_config=self.multimodal_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) + cache_config=self.cache_config) msg = ("Pre-loading model weights on " f"{next(self.model.parameters()).device} " f"took {m_getmodel.get_summary_string()}") logger.info(msg) if self.lora_config: - assert (hasattr(self.model, "supported_lora_modules") - and self.model.supported_lora_modules - ), "Model does not support LoRA" + assert hasattr(self.model, "supported_lora_modules" + ) and self.model.supported_lora_modules, ( + "Model does not support LoRA") assert hasattr(self.model, "embedding_modules" ), "Model does not have embedding_modules" assert hasattr( @@ -612,20 +591,16 @@ def load_model(self) -> None: self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, + self.vocab_size, self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules, - ) + self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) - if self.model_config.quantization == "inc": + if self.model_config.quantization == 'inc': logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: from neural_compressor.torch.quantization import ( FP8Config, convert, prepare) - config = FP8Config.from_json_file( os.getenv("QUANT_CONFIG", "")) if config.measure: @@ -634,10 +609,8 @@ def load_model(self) -> None: self.model = convert(self.model, config) htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) - logger.info( - "Preparing model with INC took %s", - m_inc.get_summary_string(), - ) + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) else: self.model = self.model.to("hpu") htcore.mark_step() @@ -648,8 +621,7 @@ def load_model(self) -> None: self.model, self.block_size, dtype=self.model_config.dtype, - enforce_eager=self.enforce_eager, - ) + enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) @@ -668,44 +640,36 @@ def _is_valid_bucket(self, bucket): def _setup_buckets(self) -> None: align_bs = lambda x: min(self.max_num_seqs, x) max_bucket_cfg = 64 - if (self.lora_config and max_bucket_cfg - > self.max_num_batched_tokens // self.block_size): + if self.lora_config and \ + max_bucket_cfg > self.max_num_batched_tokens // self.block_size: max_bucket_cfg = self.max_num_batched_tokens // self.block_size blocks_step = 128 - # FIXME: The default values should be max_model_len + #FIXME: The default values should be max_model_len max_prompt_seq = 1024 max_decode_seq = 2048 self.prompt_bs_bucket_cfg = read_bucket_settings( - "prompt", - "bs", + 'prompt', + 'bs', min=1, step=align_bs(32), - max=align_bs(max_bucket_cfg), - ) - self.decode_bs_bucket_cfg = read_bucket_settings( - "decode", - "bs", - min=align_bs(32), - step=align_bs(32), - max=self.max_num_seqs, - ) - self.prompt_seq_bucket_cfg = read_bucket_settings( - "prompt", - "seq", - min=self.block_size, - step=self.block_size, - max=max_prompt_seq, - ) + max=align_bs(max_bucket_cfg)) + self.decode_bs_bucket_cfg = read_bucket_settings('decode', + 'bs', + min=align_bs(32), + step=align_bs(32), + max=self.max_num_seqs) + self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=max_prompt_seq) self.decode_block_bucket_cfg = read_bucket_settings( - "decode", - "block", + 'decode', + 'block', min=blocks_step, step=blocks_step, - max=max( - blocks_step, - self.max_num_seqs * max_decode_seq // self.block_size, - ), - ) + max=max(blocks_step, + self.max_num_seqs * max_decode_seq // self.block_size)) self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " @@ -763,9 +727,8 @@ def _prepare_prompt( seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. - if (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None): + if computed_block_nums is not None and len( + computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] @@ -843,8 +806,7 @@ def _prepare_prompt( max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), - self.block_size, - ) + self.block_size) lora_mask: torch.Tensor = None lora_logits_mask: torch.Tensor = None @@ -853,24 +815,18 @@ def _prepare_prompt( lora_mask = torch.zeros( len(seq_group_metadata_list) * max_prompt_len, (self.lora_config.max_loras) * self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype, - ) - lora_logits_mask = torch.zeros( - len(seq_group_metadata_list), - (self.lora_config.max_loras) * self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype, - ) - - ones = torch.ones( - max_prompt_len, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype, - ) - logit_ones = torch.ones( - 1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype, - ) + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros(len(seq_group_metadata_list), + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(max_prompt_len, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) for seq_group_metadata, context_len in zip(seq_group_metadata_list, context_lens): lora_id = seq_group_metadata.lora_int_id @@ -892,32 +848,26 @@ def _prepare_prompt( if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if lora_mask is not None: - lora_mask = lora_mask.to("hpu") - lora_logits_mask = lora_logits_mask.to("hpu") - - input_tokens = make_tensor_with_pad( - input_tokens, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device, - ) - - input_positions = make_tensor_with_pad( - input_positions, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device, - ) - - slot_mapping = make_tensor_with_pad( - slot_mapping, - max_len=max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device, - ) + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + + input_tokens = make_tensor_with_pad(input_tokens, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + + input_positions = make_tensor_with_pad(input_positions, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + + slot_mapping = make_tensor_with_pad(slot_mapping, + max_len=max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, @@ -970,16 +920,13 @@ def _prepare_decode( counter = 0 if self.lora_config: - lora_mask = torch.zeros( - len(seq_group_metadata_list), - (self.lora_config.max_loras) * self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype, - ) - ones = torch.ones( - 1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype, - ) + lora_mask = torch.zeros(len(seq_group_metadata_list), + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) dummy_slots = itertools.cycle( range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) @@ -1007,8 +954,8 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - seq_len = (seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window)) + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] @@ -1029,7 +976,7 @@ def _prepare_decode( block_tables.append(block_table) if lora_mask is not None: - lora_mask = lora_mask.to("hpu") + lora_mask = lora_mask.to('hpu') lora_logits_mask = lora_mask input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -1119,12 +1066,12 @@ def prepare_input_tensors( self.event_start = self.profiler.get_timestamp_us() is_prompt = seq_group_metadata_list[0].is_prompt - base_event_name = "prompt" if is_prompt else "decode" - self.profiler.start("internal", base_event_name) + base_event_name = 'prompt' if is_prompt else 'decode' + self.profiler.start('internal', base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = (self.prompt_bs_bucket_cfg - if is_prompt else self.decode_bs_bucket_cfg) + bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else \ + self.decode_bs_bucket_cfg batch_size_padded = find_bucket(real_batch_size, bucket_cfg) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() @@ -1165,13 +1112,10 @@ def prepare_input_tensors( decode_lora_mask, decode_lora_logits_mask, ) = self._prepare_decode(decode_reqs) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens, - self.device, - self.pin_memory, - ) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + seq_lens, query_lens, + self.device, + self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 @@ -1204,14 +1148,13 @@ def prepare_input_tensors( paddings = list(itertools.accumulate(paddings)) paddings_prompt_logprobs = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if (seq_group_metadata.sampling_params.prompt_logprobs is not None - and seq_group_metadata.is_prompt): - paddings_prompt_logprobs += [paddings[i]] * seq_lens[i] + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device, - ) + device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings) if self.lora_config: @@ -1244,7 +1187,7 @@ def prepare_input_tensors( "num_prefills": num_prefills, "batch_type": batch_type, "seq_lens": seq_lens, - "query_lens": query_lens, + "query_lens": query_lens } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) @@ -1252,23 +1195,22 @@ def prepare_input_tensors( assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) - attn_metadata = (prefill_attn_metadata if prefill_attn_metadata - is not None else decode_attn_metadata) - - return self._model_input_cls( - input_tokens=input_tokens, - seq_lens=seq_lens, - query_lens=query_lens, - input_positions=input_positions, - attn_metadata=attn_metadata, - lora_requests=lora_requests, - lora_mapping=lora_mapping, - multi_modal_kwargs=multi_modal_input, - real_batch_size=real_batch_size, - batch_size_padded=batch_size_padded, - lora_mask=lora_mask, - lora_logits_mask=lora_logits_mask, - ), sampling_metadata + attn_metadata = prefill_attn_metadata if \ + prefill_attn_metadata is not None else decode_attn_metadata + + return self._model_input_cls(input_tokens=input_tokens, + seq_lens=seq_lens, + query_lens=query_lens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_requests=lora_requests, + lora_mapping=lora_mapping, + multi_modal_kwargs=multi_modal_input, + real_batch_size=real_batch_size, + batch_size_padded=batch_size_padded, + lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask), \ + sampling_metadata def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: @@ -1297,19 +1239,10 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") - attention_metadata = subtuple( - metadata, - "TrimmedAttentionMetadata", - [ - "attn_bias", - "seq_lens_tensor", - "block_list", - "block_mapping", - "block_usage", - "slot_mapping", - "is_prompt", - ], - ) + attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', + 'block_usage', 'slot_mapping', 'is_prompt' + ]) return attention_metadata def create_dummy_seq_group_metadata(self, @@ -1331,23 +1264,19 @@ def create_dummy_seq_group_metadata(self, output_token_ids = [1] * output_len seq_data = SequenceData(prompt_token_ids) seq_data.output_token_ids = output_token_ids - return SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=(output_len == 0), - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=block_tables, - lora_request=lora_request, - ) + return SequenceGroupMetadata(request_id=str(group_id), + is_prompt=(output_len == 0), + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=lora_request) def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = min( - self.prompt_seq_bucket_cfg[-1], - self.max_num_batched_tokens // max_batch_size, - ) + max_seq_len = min(self.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches) return @@ -1388,7 +1317,7 @@ def warmup_scenario(self, dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs) ] - self.profiler.start("internal", scenario_name) + self.profiler.start('internal', scenario_name) times = 3 if use_graphs or is_profile_run else 1 if self.lora_config and not is_profile_run: lora_mapping = LoRAMapping( @@ -1403,8 +1332,8 @@ def warmup_scenario(self, seq_len, is_prompt, lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None, - ) for i in range(batch_size) + if dummy_lora_requests_per_seq else None) + for i in range(batch_size) ] else: # FIXME: seq_len is actually number of blocks @@ -1416,8 +1345,8 @@ def warmup_scenario(self, b * self.block_size - 1, is_prompt, lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None, - ) for i, b in enumerate(blocks) + if dummy_lora_requests_per_seq else None) + for i, b in enumerate(blocks) ] torch.hpu.synchronize() profiler = None @@ -1480,40 +1409,31 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): - self.log_warmup( - "Prompt" if is_prompt else "Decode", - i, - len(buckets), - batch_size, - seq_len, - ) + self.log_warmup('Prompt' if is_prompt else 'Decode', i, + len(buckets), batch_size, seq_len) self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - def warmup_graphs( - self, - strategy, - buckets, - is_prompt, - kv_caches, - available_mem, - starting_mem=0, - total_batch_seq=0.001, - ): + def warmup_graphs(self, + strategy, + buckets, + is_prompt, + kv_caches, + available_mem, + starting_mem=0, + total_batch_seq=0.001): total_mem = starting_mem idx = 0 phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' num_candidates = len(buckets) - ordering: Union[ - Callable[[Any], Tuple[Any, Any]], - Callable[[Any], Tuple[Any, Any, Any]], - ] - if strategy == "min_tokens": + ordering : Union[Callable[[Any], Tuple[Any, Any]], \ + Callable[[Any], Tuple[Any, Any, Any]]] + if strategy == 'min_tokens': ordering = lambda b: (b[0] * b[1], b[1], b[0]) - elif strategy == "max_bs": + elif strategy == 'max_bs': ordering = lambda b: (-b[0], b[1]) else: raise NotImplementedError( - f"Unsupported graph allocation strategy: {strategy}") + f'Unsupported graph allocation strategy: {strategy}') buckets = list(sorted(buckets, key=ordering)) captured_all = True for idx, (batch_size, seq_len) in enumerate(buckets): @@ -1545,34 +1465,32 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): if c[2] == is_prompt) if num_candidates == 0: num_candidates = 1 - msg = (f"{phase} captured:{len(graphed)} " - f"({100 * len(graphed) / num_candidates:.1f}%) " - f"used_mem:{format_bytes(total_mem)} " - f"buckets:{sorted(list(graphed))}") + msg = (f'{phase} captured:{len(graphed)} ' + f'({100 * len(graphed) / num_candidates:.1f}%) ' + f'used_mem:{format_bytes(total_mem)} ' + f'buckets:{sorted(list(graphed))}') logger.info(msg) @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: - if profile := os.environ.get("VLLM_PT_PROFILE", None): - phase, bs, seq_len, graph = profile.split("_") - is_prompt = phase == "prompt" - graphs = graph == "t" + if profile := os.environ.get('VLLM_PT_PROFILE', None): + phase, bs, seq_len, graph = profile.split('_') + is_prompt = phase == 'prompt' + graphs = graph == 't' if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, True) raise AssertionError("Finished profiling") - if os.environ.get("VLLM_SKIP_WARMUP", "false").lower() == "true": + if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true': logger.info("Skipping warmup...") return - self.profiler.start("internal", "warmup") + self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( - self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens, - ) + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ bucket for bucket in self.prompt_buckets @@ -1600,11 +1518,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: bucket for bucket in self.decode_buckets if self._is_valid_bucket(bucket) ] - logger.info( - "Generated %d decode buckets [bs, total_blocks]: %s", - len(self.decode_buckets), - list(sorted(self.decode_buckets)), - ) + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.decode_buckets), + list(sorted(self.decode_buckets))) start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() @@ -1619,24 +1535,24 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.debug("Using PT_COMPILE_ONLY_MODE.") except KeyError: can_use_compile_only_mode = False - logger.warning("Cannot use PT_COMPILE_ONLY_MODE. " - "Warmup time will be negatively impacted. " - "Please update Gaudi Software Suite.") - with compile_only_mode_context() if can_use_compile_only_mode \ - else contextlib.nullcontext(): + logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' + 'Warmup time will be negatively impacted. ' + 'Please update Gaudi Software Suite.') + with compile_only_mode_context( + ) if can_use_compile_only_mode else contextlib.nullcontext(): self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) self.warmup_all_buckets(self.decode_buckets, False, kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): - assert self.mem_margin is not None, ( - "HabanaWorker.determine_num_available_blocks needs " + assert self.mem_margin is not None, \ + ("HabanaWorker.determine_num_available_blocks needs " "to be called before warming up the model.") free_mem = HabanaMemoryProfiler.current_free_device_memory() graph_free_mem = free_mem - self.mem_margin graph_free_mem = align_workers(graph_free_mem, torch.distributed.ReduceOp.MIN) prompt_graph_mem_ratio = float( - os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.5")) + os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5')) prompt_available_memory = (prompt_graph_mem_ratio * graph_free_mem) decode_available_memory = (graph_free_mem - @@ -1649,26 +1565,18 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: f"{format_bytes(decode_available_memory)} for decode " f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") logger.info(msg) - prompt_strategy = os.environ.get("VLLM_GRAPH_PROMPT_STRATEGY", - "min_tokens") - decode_strategy = os.environ.get("VLLM_GRAPH_DECODE_STRATEGY", - "max_bs") - mem_post_prompt, prompt_batch_seq, prompt_captured_all = ( + prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', + 'min_tokens') + decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', + 'max_bs') + mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, - self.prompt_buckets, - True, - kv_caches, - prompt_available_memory, - )) - mem_post_decode, decode_batch_seq, decode_captured_all = ( + prompt_strategy, self.prompt_buckets, True, kv_caches, + prompt_available_memory) + mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, - self.decode_buckets, - False, - kv_caches, - decode_available_memory, - )) + decode_strategy, self.decode_buckets, False, kv_caches, + decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1677,29 +1585,21 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, - self.prompt_buckets, - True, + prompt_strategy, self.prompt_buckets, True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_prompt, - prompt_batch_seq, - )) + mem_post_prompt, prompt_batch_seq)) # Not all decode buckets were captured, but all prompt buckets # were captured and we have some free graph-allocated space # left. Let's try to use it for capturing more decode buckets. - if (mem_post_decode + mem_post_prompt < graph_free_mem - and not decode_captured_all and prompt_captured_all): + if mem_post_decode + mem_post_prompt < graph_free_mem \ + and not decode_captured_all \ + and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, - self.decode_buckets, - False, - kv_caches, + decode_strategy, self.decode_buckets, False, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_decode, - decode_batch_seq, - ) + mem_post_decode, decode_batch_seq) self.log_graph_warmup_summary(self.prompt_buckets, True, mem_post_prompt) @@ -1729,13 +1629,12 @@ def mem_margin(self, value): def _maybe_wrap_in_hpu_graph(*args, **kwargs): - return (htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter(*args, **kwargs), - disable_tensor_cache=True) - if htorch.utils.internal.is_lazy() else HpuModelAdapter( - *args, **kwargs)) + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs) -class HabanaProfilerCounterHelper: +class HabanaProfilerCounterHelper(): def __init__(self): self.niter = 0 @@ -1756,15 +1655,8 @@ def capture_seq_group_metadata_stats(self, seq_group_metadata_list): for seq_data in seq_group_metadata.seq_data.values() ] - def get_counter_dict( - self, - cache_config, - duration, - seq_len, - batch_size_padded, - real_batch_size, - is_prompt, - ): + def get_counter_dict(self, cache_config, duration, seq_len, + batch_size_padded, real_batch_size, is_prompt): throughput = batch_size_padded / (duration / 1e6) throughput_effective = real_batch_size / (duration / 1e6) @@ -1780,15 +1672,15 @@ def get_counter_dict( self.average_real_throughput) phase = "prompt" if is_prompt else "decode" counters = { - f"{phase}_bucket_batch_size": batch_size_padded, - f"{phase}_batch_size": real_batch_size, - f"{phase}_bucket_seq_len": seq_len, - f"{phase}_seq_len": real_max_seq_len, - f"{phase}_bucket_gen_throughput": throughput, - f"{phase}_real_gen_throughput": throughput_effective, - f"{phase}_batch_token_utilization": batch_token_utilization, - "average_real_throughput": self.average_real_throughput, - "engine_iteration": self.niter, + f'{phase}_bucket_batch_size': batch_size_padded, + f'{phase}_batch_size': real_batch_size, + f'{phase}_bucket_seq_len': seq_len, + f'{phase}_seq_len': real_max_seq_len, + f'{phase}_bucket_gen_throughput': throughput, + f'{phase}_real_gen_throughput': throughput_effective, + f'{phase}_batch_token_utilization': batch_token_utilization, + 'average_real_throughput': self.average_real_throughput, + 'engine_iteration': self.niter, } self.niter += 1 if is_prompt: @@ -1796,36 +1688,37 @@ def get_counter_dict( duration / 1e6) prompt_real_in_throughput = sum( self.prompt_seq_lens) / (duration / 1e6) - counters[f"{phase}_bucket_in_throughput"] = ( - prompt_bucket_in_throughput) - counters[f"{phase}_real_in_throughput"] = prompt_real_in_throughput + counters[ + f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput + counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput # KV cache might not be created yet (e.g. for profiling run) - if (cache_config.num_gpu_blocks is not None - and cache_config.num_gpu_blocks != 0): + if cache_config.num_gpu_blocks is not None and \ + cache_config.num_gpu_blocks != 0: cache_num_blocks_used = [ math.ceil(sl / cache_config.block_size) for sl in self.real_seq_lens ] cache_total_num_blocks_used = sum(cache_num_blocks_used) num_cache_blocks = cache_config.num_gpu_blocks - cache_total_num_free_blocks = (num_cache_blocks - - cache_total_num_blocks_used) - cache_computed_utilization = (cache_total_num_blocks_used / - num_cache_blocks) + cache_total_num_free_blocks = \ + num_cache_blocks - cache_total_num_blocks_used + cache_computed_utilization = \ + cache_total_num_blocks_used / num_cache_blocks max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size) batch_block_utilization = cache_total_num_blocks_used / ( batch_size_padded * max_blocks_per_seq) - counters["cache_num_blocks_used"] = cache_total_num_blocks_used - counters["cache_num_free_blocks"] = cache_total_num_free_blocks - counters["cache_computed_utilization"] = cache_computed_utilization - counters[f"{phase}_batch_block_utilization"] = ( - batch_block_utilization) + counters['cache_num_blocks_used'] = cache_total_num_blocks_used + counters['cache_num_free_blocks'] = cache_total_num_free_blocks + counters['cache_computed_utilization'] = cache_computed_utilization + counters[ + f'{phase}_batch_block_utilization'] = batch_block_utilization if not self.logged_once: - counters["const_cache_num_blocks"] = cache_config.num_gpu_blocks - counters["const_gpu_memory_utilization"] = ( - cache_config.gpu_memory_utilization) - counters["const_block_size"] = cache_config.block_size + counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks + counters[ + 'const_gpu_memory_utilization'] = \ + cache_config.gpu_memory_utilization + counters['const_block_size'] = cache_config.block_size self.logged_once = True return counters @@ -1834,8 +1727,8 @@ def unwrap_model(model): if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): return unwrap_model(model._orig_mod) else: - model = list(vars(model)["_modules"].values())[0] - modules = list(vars(model)["_modules"].values()) + model = list(vars(model)['_modules'].values())[0] + modules = list(vars(model)['_modules'].values()) return modules @@ -1844,7 +1737,6 @@ class HabanaModelRunner( """ GPU model runner with sampling step. """ - _model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = ( ModelInputForHPUWithSamplingMetadata) @@ -1862,7 +1754,7 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForHPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1873,7 +1765,7 @@ def prepare_model_input( - input_tokens[num_prefill_tokens:] contains decode tokens. If cuda graph is required, this API automatically pads inputs. """ - with self.profiler.record_event("internal", "prepare_input_tensors"): + with self.profiler.record_event('internal', 'prepare_input_tensors'): assert seq_group_metadata_list is not None self.profiler_counter_helper.capture_seq_group_metadata_stats( seq_group_metadata_list=seq_group_metadata_list) @@ -1882,16 +1774,13 @@ def prepare_model_input( assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt - return dataclasses.replace( - model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine, - ) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) def finish_measurements(self): from neural_compressor.torch.quantization import finalize_calibration - finalize_calibration(self.model.model) def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): @@ -1899,13 +1788,9 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): seen = cfg in self.seen_configs self.seen_configs.add(cfg) if not seen and not warmup_mode: - phase = "prompt" if is_prompt else "decode" - logger.warning( - "Configuration: (%s, %s, %s) was not warmed-up!", - phase, - batch_size, - seq_len, - ) + phase = 'prompt' if is_prompt else 'decode' + logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", + phase, batch_size, seq_len) @torch.inference_mode() def execute_model( @@ -1948,7 +1833,7 @@ def execute_model( "kv_caches": kv_caches, "attn_metadata": self.trim_attn_metadata(attn_metadata), "intermediate_tensors": intermediate_tensors, - "lora_mask": model_input.lora_mask, + "lora_mask": model_input.lora_mask } if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) @@ -1963,23 +1848,22 @@ def execute_model( f"seq{seq_len}_" f"graphs{'T' if use_graphs else 'F'}") else: - model_event_name = "model_executable" - with self.profiler.record_event("internal", model_event_name): + model_event_name = 'model_executable' + with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata. - selected_token_indices, + selected_token_indices=sampling_metadata.selected_token_indices ) if self.lora_config: from vllm.lora.layers import VocabParallelEmbeddingWithLoRA - modules = unwrap_model(self.model.model) for module in modules: if isinstance(module, VocabParallelEmbeddingWithLoRA): for i in range(0, len(module.indices_len)): - module.indices_len[i] = ( - sampling_metadata.selected_token_indices.numel()) + module.indices_len[ + i] = sampling_metadata.selected_token_indices.numel( + ) lora_logits_mask: torch.Tensor = model_input.lora_logits_mask LoraMask.setLoraMask( lora_logits_mask.index_select( @@ -1987,12 +1871,10 @@ def execute_model( # Compute the logits. with self.profiler.record_event( - "internal", - ('compute_logits_' - f'{"prompt" if is_prompt else "decode"}_bs' - f'{batch_size}_' - f'seq{seq_len}'), - ): + 'internal', ('compute_logits_' + f'{"prompt" if is_prompt else "decode"}_bs' + f'{batch_size}_' + f'seq{seq_len}')): sampling_metadata.selected_token_indices = None logits = self.model.compute_logits(hidden_states, sampling_metadata) @@ -2003,12 +1885,10 @@ def execute_model( # Sample the next token. with self.profiler.record_event( - "internal", - ('sample_' - f'{"prompt" if is_prompt else "decode"}_' - f'bs{batch_size}_' - f'seq{seq_len}'), - ): + 'internal', ('sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}')): output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, @@ -2026,21 +1906,19 @@ def execute_model( seq_len=seq_len, batch_size_padded=batch_size_padded, real_batch_size=real_batch_size, - is_prompt=is_prompt, - ) + is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) return [output] def shutdown_inc(self): - print("inc shutdown") - if (model_config := getattr(self, "model_config", None)) and getattr( - model_config, "quantization", None) == "inc": - print("inc shutdown start") + print('inc shutdown') + if (model_config := getattr(self, "model_config", None)) and \ + getattr(model_config, "quantization", None) == 'inc': + print('inc shutdown start') from neural_compressor.torch.quantization import ( finalize_calibration) - finalize_calibration(self.model.model) - print("inc shutdown") + print('inc shutdown') def __del__(self): self.shutdown_inc() From 343b533c05a3c1fcf7337063f3e613d161180f26 Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Wed, 18 Sep 2024 15:24:43 +0300 Subject: [PATCH 11/24] Revert "Inc on vLLM - Split qk and v calculations" This reverts commit a6f8dee2b3da8b708b1e9ffff8346292442da8a6. --- vllm/config.py | 3 -- vllm/engine/arg_utils.py | 6 --- vllm/engine/llm_engine.py | 25 ++++++++----- vllm/model_executor/layers/linear.py | 55 ++++++---------------------- vllm/model_executor/models/llama.py | 22 ++--------- 5 files changed, 31 insertions(+), 80 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 243018b5f01c..7aa3977a497e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -431,7 +431,6 @@ class CacheConfig: cache_dtype: Data type for kv cache storage. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. - split_qk_v: Whether to split qk and v calculations. """ def __init__( @@ -444,7 +443,6 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, - split_qk_v: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -454,7 +452,6 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb - self.split_qk_v = split_qk_v self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 983d010b92ca..d6c544750afe 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -62,7 +62,6 @@ class EngineArgs: swap_space: int = 4 # GiB cpu_offload_gb: int = 0 # GiB gpu_memory_utilization: float = 0.90 - split_qk_v: bool = False max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API @@ -359,10 +358,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help='If specified, ignore GPU profiling result and use this number' 'of GPU blocks. Used for testing preemption.') - parser.add_argument('--split-qk-v', - type=bool, - default=EngineArgs.split_qk_v, - help='Whether to separate qk and v calculations.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -739,7 +734,6 @@ def create_engine_config(self, ) -> EngineConfig: swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, num_gpu_blocks_override=self.num_gpu_blocks_override, - split_qk_v=self.split_qk_v, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a68e86ede217..f8b9c48bc958 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -186,7 +186,7 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "enable_prefix_caching=%s, split_qk_v=%s)", + "enable_prefix_caching=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -217,7 +217,6 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, cache_config.enable_prefix_caching, - cache_config.split_qk_v, ) # TODO(woosuk): Print more configs in debug mode. @@ -275,26 +274,32 @@ def __init__( usage_context, extra_kvs={ # Common configuration - "dtype": str(model_config.dtype), + "dtype": + str(model_config.dtype), "tensor_parallel_size": parallel_config.tensor_parallel_size, - "block_size": cache_config.block_size, + "block_size": + cache_config.block_size, "gpu_memory_utilization": cache_config.gpu_memory_utilization, # Quantization - "quantization": model_config.quantization, - "kv_cache_dtype": str(cache_config.cache_dtype), + "quantization": + model_config.quantization, + "kv_cache_dtype": + str(cache_config.cache_dtype), # Feature flags - "enable_lora": bool(lora_config), - "enable_prompt_adapter": bool(prompt_adapter_config), + "enable_lora": + bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), "enable_prefix_caching": cache_config.enable_prefix_caching, - "enforce_eager": model_config.enforce_eager, + "enforce_eager": + model_config.enforce_eager, "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, - "split_qk_v": cache_config.split_qk_v, }) if self.tokenizer: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3fb477a6c88..10c8a95f838d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -525,8 +525,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - split_qk_v: bool = False): + prefix: str = ""): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -543,21 +542,14 @@ def __init__(self, else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 - self.split_qk_v = split_qk_v - self.q_size = self.num_heads * self.head_size * tp_size - self.kv_size = self.num_kv_heads * self.head_size * tp_size input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size self.output_sizes = [ - self.q_size, # q_proj - self.kv_size, # k_proj + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - if split_qk_v: - output_size = (self.num_heads + - self.num_kv_heads) * tp_size * self.head_size - else: - output_size = (self.num_heads + - 2 * self.num_kv_heads) * tp_size * self.head_size - self.output_sizes.append(self.kv_size) # v_proj super().__init__(input_size=input_size, output_size=output_size, @@ -568,16 +560,6 @@ def __init__(self, quant_config=quant_config, prefix=prefix) - if split_qk_v: - self.v_proj = ColumnParallelLinear(input_size=input_size, - output_size=self.kv_size, - bias=bias, - gather_output=False, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, @@ -659,19 +641,13 @@ def weight_loader(self, "q": (0, self.num_heads * self.head_size), "k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) } - if self.split_qk_v: - orig_qkv_offsets["total"] = ( - (self.num_heads + self.num_kv_heads) * self.head_size, - 0) - else: - orig_qkv_offsets["v"] = ( - (self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size) - orig_qkv_offsets["total"] = ( - (self.num_heads + 2 * self.num_kv_heads) * - self.head_size, 0) - shard_size, shard_offset = adjust_bitsandbytes_shard( param, orig_qkv_offsets, loaded_shard_id) @@ -706,13 +682,6 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def forward(self, input_): - output, output_bias = super().forward(input_) - if not self.split_qk_v: - return output, output_bias - v, _ = self.v_proj(input_) - return output, v, output_bias - class RowParallelLinear(LinearBase): """Linear layer with row parallelism. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9850dc45068f..8ccefe7be33f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -134,7 +134,6 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.split_qk_v = cache_config.split_qk_v self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, @@ -144,7 +143,6 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", - split_qk_v=self.split_qk_v, ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, @@ -175,13 +173,8 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - if self.split_qk_v: - qk, v, _ = self.qkv_proj(hidden_states) - q, k = qk.split([self.q_size, self.kv_size], dim=-1) - else: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -396,7 +389,6 @@ def __init__( self.config = config self.lora_config = lora_config - self.split_qk_v = cache_config.split_qk_v self.model = LlamaModel(config, cache_config, @@ -475,13 +467,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - if self.split_qk_v: - stacked_params_mapping.append((".qkv_proj.v_proj", ".v_proj", "v")) - else: - stacked_params_mapping.append((".qkv_proj", ".v_proj", "v")) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -512,10 +501,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader - if self.split_qk_v and shard_id == "v": - weight_loader(param, loaded_weight) - else: - weight_loader(param, loaded_weight, shard_id) + weight_loader(param, loaded_weight, shard_id) break else: From 8657c4c856f92ce430c2a67332b38f06d7773f09 Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Wed, 18 Sep 2024 15:27:59 +0300 Subject: [PATCH 12/24] formnat.sh --- vllm/hpu/ops.py | 10 +++++++--- .../compressed_tensors/compressed_tensors.py | 3 ++- .../schemes/compressed_tensors_w8a8_fp8.py | 3 ++- .../layers/quantization/utils/w8a8_utils.py | 9 ++++++--- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 323a33e9fa2a..5f059e0dd7a9 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -292,6 +292,7 @@ def forward(self, hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) + # fp8 def scaled_fp8_quant( input: torch.Tensor, @@ -300,7 +301,6 @@ def scaled_fp8_quant( scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Quantize input tensor to FP8 and return quantized tensor and scale. This function supports both static and dynamic quantization: If you @@ -341,6 +341,10 @@ def scaled_fp8_quant( scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0] + output = torch.ops.hpu.cast_to_fp8_v2(input, + 1 / scale, + False, + False, + dtype=torch.float8_e4m3fn)[0] - return output, scale \ No newline at end of file + return output, scale diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index badb29af1f5f..7fd5d9ef37ee 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -233,7 +233,8 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True + CompressedTensorsW8A8Fp8.get_min_capability(), + error=False) if torch.cuda.is_available() else True if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 631774994b5c..336bee27e957 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,7 +21,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False + self.cutlass_fp8_supported = cutlass_fp8_supported( + ) if torch.cuda.is_available() else False @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8904c9fa1789..eef201d6ccfa 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -11,6 +11,7 @@ from vllm.hpu.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant + def cutlass_fp8_supported() -> bool: capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] @@ -26,7 +27,7 @@ def per_tensor_dequantize( if current_platform.is_hpu(): dtype = torch.bfloat16 if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - #dequant on cpu to avoid nan on gaudi2 + #dequant on cpu to avoid nan on gaudi2 tensor = tensor.to('cpu') fake_qweight = tensor.to(dtype).to(device) @@ -87,8 +88,10 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: - max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max) + if current_platform.is_hpu() and htexp._get_device_type( + ) == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max / + torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale From 2e603ea87cad7e27dde29bb90b01b2f2ee90310d Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Wed, 18 Sep 2024 16:07:37 +0300 Subject: [PATCH 13/24] fix imports --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 6 +++--- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- vllm/model_executor/models/llama.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3682362c5a86..9964c8a9b590 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -14,7 +14,7 @@ from vllm.platforms import current_platform if current_platform.is_hpu(): - from vllm.hpu.ops import scaled_fp8_quant + from vllm_hpu_extension.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f3e304ce141c..fa50ff5125b6 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -24,7 +24,7 @@ from vllm.platforms import current_platform from vllm.utils import print_warning_once if current_platform.is_hpu(): - from vllm.hpu.ops import scaled_fp8_quant + from vllm_hpu_extension.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -118,8 +118,8 @@ def __init__(self, quant_config: Fp8Config): if torch.cuda.is_available(): self.cutlass_fp8_supported = cutlass_fp8_supported() - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index eef201d6ccfa..ee95c3782dac 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -8,7 +8,7 @@ from vllm.platforms import current_platform if current_platform.is_hpu(): import habana_frameworks.torch.utils.experimental as htexp - from vllm.hpu.ops import scaled_fp8_quant + from vllm_hpu_extension.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8ccefe7be33f..eda04bc1f120 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,7 +54,7 @@ from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -from vllm.platforms import current_platform + if current_platform.is_hpu(): import habana_frameworks.torch.core as htcore @@ -321,7 +321,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if is_hpu: + if current_platform.is_hpu(): import habana_frameworks.torch as htorch htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): @@ -333,7 +333,7 @@ def forward( attn_metadata, residual, ) - if is_hpu: + if current_platform.is_hpu(): htorch.core.mark_step() if not get_pp_group().is_last_rank: From a7a036a05e8c6c91c50e543d8343b749e918600b Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Wed, 18 Sep 2024 16:24:31 +0300 Subject: [PATCH 14/24] isort fix --- vllm/model_executor/layers/quantization/fp8.py | 1 + vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fa50ff5125b6..e363d11c85b1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once + if current_platform.is_hpu(): from vllm_hpu_extension.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index ee95c3782dac..8503a83c8bd5 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,6 +6,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform + if current_platform.is_hpu(): import habana_frameworks.torch.utils.experimental as htexp from vllm_hpu_extension.ops import scaled_fp8_quant From 2b4a196724b9c444ddc3c05477c761328d8cc0e7 Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Thu, 19 Sep 2024 12:52:29 +0300 Subject: [PATCH 15/24] update vllm-hpu-extension commit hash --- requirements-hpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index d451200aa114..485a9ed65555 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -6,4 +6,4 @@ ray == 2.32.0 triton pandas tabulate -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@30ee2d1 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@9f120e9 From 454acc9a340b532bd922b5bc9b0a094beee8471e Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 23 Sep 2024 12:58:33 +0300 Subject: [PATCH 16/24] pr fix --- .../quantization/compressed_tensors/compressed_tensors.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 7fd5d9ef37ee..8d6273a3f986 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -232,9 +232,10 @@ def _get_scheme_from_parts( # Detect If Activation Quantization. if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - is_fp8_w8a8_supported = self._check_scheme_supported( + is_fp8_w8a8_supported = True if current_platform.is_hpu() \ + else self._check_scheme_supported( CompressedTensorsW8A8Fp8.get_min_capability(), - error=False) if torch.cuda.is_available() else True + error=False) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, From e92abd6c4edeca7337142ac4bfc83135f65ac9ac Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:47:02 +0300 Subject: [PATCH 17/24] Update fp8.py --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3c0e228432e8..9fe8acc5b6f9 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -126,7 +126,7 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if is_hip(): self.use_marlin = False From f1508514def37b95e643f890dfe0d2d19851e6ea Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:47:44 +0300 Subject: [PATCH 18/24] Update fused_moe.py --- vllm/model_executor/layers/fused_moe/fused_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 089a70095611..cf17f1e240e4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -17,7 +17,6 @@ from vllm_hpu_extension.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant - logger = init_logger(__name__) From 3e8762e0fc92bc5e3c29b3181e4eb08ca5664fb6 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:08:50 +0300 Subject: [PATCH 19/24] Update compressed_tensors.py --- .../quantization/compressed_tensors/compressed_tensors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 6d5207f7feaf..252ad864ced3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -243,8 +243,8 @@ def _get_scheme_from_parts( # TODO @dsikka: clean-up conditions if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - is_fp8_w8a8_supported = True if current_platform.is_hpu() \ - else self._check_scheme_supported( + is_fp8_w8a8_supported = current_platform.is_hpu() or \ + self._check_scheme_supported( CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: @@ -316,7 +316,7 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - if torch.cuda.is_available(): + if not current_platform.is_hpu(): self._check_scheme_supported(scheme.get_min_capability()) return scheme From 57268012e72f7b95b35b39dca92febcfc9f1ccbd Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:10:43 +0300 Subject: [PATCH 20/24] Update compressed_tensors_w8a8_fp8.py --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index c8778f995b30..c4722554e9fc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -23,8 +23,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported( - ) if torch.cuda.is_available() else False + self.cutlass_fp8_supported = not current_platform.is_hpu() and \ + cutlass_fp8_supported() @classmethod def get_min_capability(cls) -> int: From 426e8e15d3fad205b77e4db50ecd62eda3ab8057 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:11:43 +0300 Subject: [PATCH 21/24] Update llama.py --- vllm/model_executor/models/llama.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 65bb84ed0905..eba607b93d63 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -55,8 +55,7 @@ from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -if current_platform.is_hpu(): - import habana_frameworks.torch.core as htcore +is_hpu = current_platform.is_hpu() class LlamaMLP(nn.Module): @@ -328,7 +327,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if current_platform.is_hpu(): + if is_hpu: import habana_frameworks.torch as htorch htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): @@ -340,7 +339,7 @@ def forward( attn_metadata, residual, ) - if current_platform.is_hpu(): + if is_hpu: htorch.core.mark_step() if not get_pp_group().is_last_rank: @@ -556,7 +555,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) if current_platform.is_hpu(): torch.hpu.synchronize() - htcore.mark_step() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should From 4cf34f400e9b7038c7f904bd850793c2423b90a8 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:17:28 +0300 Subject: [PATCH 22/24] Update compressed_tensors_w8a8_fp8.py --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index c4722554e9fc..29f3228c0dc5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -13,6 +13,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.platforms import current_platform from vllm.utils import is_hip __all__ = ["CompressedTensorsW8A8Fp8"] From f58d4c16c86c46d907b97de843da37f0f60f1547 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:17:39 +0300 Subject: [PATCH 23/24] Update vllm/model_executor/layers/quantization/fp8.py Co-authored-by: Konrad Zawora --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9fe8acc5b6f9..71aedeb01f3b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -120,7 +120,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - if torch.cuda.is_available(): + if current_platform.is_cuda_alike() self.cutlass_fp8_supported = cutlass_fp8_supported() # For GPUs that lack FP8 hardware support, we can leverage the From db9affeae2bb6cad86bd77cd532e29319fbfa882 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:20:07 +0300 Subject: [PATCH 24/24] Update fp8.py --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 71aedeb01f3b..88915942220c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -120,7 +120,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - if current_platform.is_cuda_alike() + if current_platform.is_cuda_alike(): self.cutlass_fp8_supported = cutlass_fp8_supported() # For GPUs that lack FP8 hardware support, we can leverage the