diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index c8f00c1cbd59..a3dc2922e5b5 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -88,37 +88,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] -def static_fused_moe(hidden_states, w1, w2, score, topk): - B, D = hidden_states.shape - num_experts = w1.shape[0] - routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros((1, B, D), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights = torch.zeros((B, num_experts), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights.scatter_(-1, selected_experts, routing_weights) - padded_weights = padded_weights.reshape(-1, B, w1.shape[0]) - padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - - htorch.core.mark_step() - - for expert_idx in range(num_experts): - w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1)) - w_output = silu_and_mul(w_output) - w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) - final_hidden_states += w_output * padded_weights[expert_idx] - htorch.core.mark_step() - - return final_hidden_states.view(-1, D) - - @hpu_utils.with_mark_steps def prompt_attention( query: torch.Tensor, @@ -148,3 +117,56 @@ def prompt_attention( attn_weights = attn_weights.flatten(1, 2) attn_weights = attn_weights.transpose(1, 2) return attn_weights + + +class MoeMatmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def set_weight(self, w): + self.weight = w + + def calc(self, state, expert_id, w): + self.weight = w[expert_id].transpose(0, 1) + return self.forward(state) + + def forward(self, state): + return torch.matmul(state, self.weight) + + +class StaticFusedMOE(torch.nn.Module): + def __init__(self, num_total_experts): + super().__init__() + self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) + self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) + self.num_total_experts = num_total_experts + + + def forward(self, hidden_states, w1, w2, score, topk): + B, D = hidden_states.shape + routing_weights = F.softmax(score, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros((1, B, D), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights = torch.zeros((B, self.num_total_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, routing_weights) + padded_weights = padded_weights.reshape(-1, B, self.num_total_experts) + padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) + htorch.core.mark_step() + + for expert_idx in range(self.num_total_experts): + padded_weight = padded_weights[expert_idx] + current_state_static = hidden_states.reshape(-1, D) + w_output = self.w13_list[expert_idx].calc(current_state_static, expert_idx, w1) + w_output = silu_and_mul(w_output) + w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2) + current_hidden_states_static = w_output * padded_weight + final_hidden_states += current_hidden_states_static + htorch.core.mark_step() + + return final_hidden_states.view(-1, D) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b49bf40d4746..5ded0c2b5ea6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,8 +13,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_hpu -if is_hpu(): - from vllm.hpu.ops import static_fused_moe logger = init_logger(__name__) @@ -78,7 +76,7 @@ def apply( ) -> torch.Tensor: return self.forward(x, layer.w13_weight, layer.w2_weight, router_logits, top_k, renormalize, - use_grouped_topk, num_expert_group, topk_group) + use_grouped_topk, num_expert_group, topk_group, layer) def forward_cuda( self, @@ -91,6 +89,7 @@ def forward_cuda( use_grouped_topk: bool, num_expert_group: Optional[int], topk_group: Optional[int], + layer: Optional[torch.nn.Module], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe return fused_moe(x, @@ -107,12 +106,12 @@ def forward_cuda( def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool, num_expert_group: Optional[int], - topk_group: Optional[int]): + topk_group: Optional[int], layer: Optional[torch.nn.Module],): assert not use_grouped_topk, 'use_grouped_topk must be False on HPU' assert num_expert_group is None, ('num_expert_group is ' 'not supported on HPU') assert topk_group is None, 'topk_group is not supported on HPU' - return static_fused_moe(x, w1, w2, router_logits, top_k) + return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( @@ -129,6 +128,7 @@ def forward_tpu( use_grouped_topk: bool, num_expert_group: Optional[int], topk_group: Optional[int], + layer: Optional[torch.nn.Module], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk @@ -191,6 +191,9 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + if is_hpu(): + from vllm.hpu.ops import StaticFusedMOE + self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -207,6 +210,7 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: int, expert_id: int): @@ -245,13 +249,19 @@ def weight_loader(self, param: torch.nn.Parameter, if shard_id == 0: param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id]) # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id]) # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[:, shard] + if is_hpu(): + self.hpu_static_fused_moe.w2_list[expert_id].set_weight(param_data[expert_id]) else: raise ValueError( f"Shard id must be in [0,1,2] but got {shard_id}") diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py new file mode 100644 index 000000000000..d2cca285670d --- /dev/null +++ b/vllm/model_executor/layers/quantization/inc.py @@ -0,0 +1,116 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class INCConfig(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "INCConfig": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["INCLinearMethod"]: + if isinstance(layer, LinearBase): + return INCLinearMethod(self) + elif isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod() + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_min_capability(self) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + +class INCLinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: INCConfig, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) \ No newline at end of file diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f7e0f56c1a46..a8b0a7b07ed8 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,7 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. if (model_config.quantization is not None - and model_config.quantization != "fp8" + and model_config.quantization not in ["fp8", "inc"] and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"]