Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mixtral quantization using INC #188

Closed
wants to merge 9 commits into from
84 changes: 53 additions & 31 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,37 +88,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:]


def static_fused_moe(hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

htorch.core.mark_step()

for expert_idx in range(num_experts):
w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1))
w_output = silu_and_mul(w_output)
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
final_hidden_states += w_output * padded_weights[expert_idx]
htorch.core.mark_step()

return final_hidden_states.view(-1, D)


@hpu_utils.with_mark_steps
def prompt_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -148,3 +117,56 @@ def prompt_attention(
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
return attn_weights


class MoeMatmul(torch.nn.Module):
Tiefen-boop marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super().__init__()

def set_weight(self, w):
self.weight = w

def calc(self, state, expert_id, w):
self.weight = w[expert_id].transpose(0, 1)
return self.forward(state)

def forward(self, state):
return torch.matmul(state, self.weight)


class StaticFusedMOE(torch.nn.Module):
def __init__(self, num_total_experts):
super().__init__()
self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.num_total_experts = num_total_experts


def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, self.num_total_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, self.num_total_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
htorch.core.mark_step()

for expert_idx in range(self.num_total_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = self.w13_list[expert_idx].calc(current_state_static, expert_idx, w1)
w_output = silu_and_mul(w_output)
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
htorch.core.mark_step()
Tiefen-boop marked this conversation as resolved.
Show resolved Hide resolved

return final_hidden_states.view(-1, D)
20 changes: 15 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from vllm.model_executor.utils import set_weight_attrs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kzawora-intel could you check if these changes are upstreamable?

from vllm.utils import is_hpu

if is_hpu():
from vllm.hpu.ops import static_fused_moe

logger = init_logger(__name__)

Expand Down Expand Up @@ -78,7 +76,7 @@ def apply(
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)
use_grouped_topk, num_expert_group, topk_group, layer)

def forward_cuda(
self,
Expand All @@ -91,6 +89,7 @@ def forward_cuda(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
Expand All @@ -107,12 +106,12 @@ def forward_cuda(
def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool, num_expert_group: Optional[int],
topk_group: Optional[int]):
topk_group: Optional[int], layer: Optional[torch.nn.Module],):
assert not use_grouped_topk, 'use_grouped_topk must be False on HPU'
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, 'topk_group is not supported on HPU'
return static_fused_moe(x, w1, w2, router_logits, top_k)
return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
Expand All @@ -129,6 +128,7 @@ def forward_tpu(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
Expand Down Expand Up @@ -191,6 +191,9 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
if is_hpu():
from vllm.hpu.ops import StaticFusedMOE
self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand All @@ -207,6 +210,7 @@ def __init__(
params_dtype=params_dtype,
weight_loader=self.weight_loader)


def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int):
Expand Down Expand Up @@ -245,13 +249,19 @@ def weight_loader(self, param: torch.nn.Parameter,
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
if is_hpu():
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(param_data[expert_id])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
Expand Down
116 changes: 116 additions & 0 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)


class INCConfig(QuantizationConfig):
"""Config class for FP8."""

def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme

@classmethod
def get_name(cls) -> str:
return "inc"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16]

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "INCConfig":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["INCLinearMethod"]:
if isinstance(layer, LinearBase):
return INCLinearMethod(self)
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod()
return None

def get_scaled_act_names(self) -> List[str]:
return []

def get_min_capability(self) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75

@staticmethod
def get_config_filenames() -> List[str]:
return []

class INCLinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

Args:
quant_config: The quantization config.
"""

def __init__(self, quant_config: INCConfig, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
return F.linear(x, weight)
return F.linear(x, weight, bias)
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and model_config.quantization not in ["fp8", "inc"]
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

Expand Down
Loading