Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mixtral quantization using INC #188

Closed
wants to merge 9 commits into from
3 changes: 2 additions & 1 deletion vllm/executor/habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def check_health(self) -> None:
return

def shutdown(self) -> None:
self.driver_worker.shutdown_inc()
if hasattr(self, "driver_worker") and self.driver_worker is not None:
self.driver_worker.shutdown_inc()

def __del__(self):
self.shutdown()
Expand Down
88 changes: 57 additions & 31 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:]


def static_fused_moe(hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

htorch.core.mark_step()

for expert_idx in range(num_experts):
w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1))
w_output = silu_and_mul(w_output)
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
final_hidden_states += w_output * padded_weights[expert_idx]

return final_hidden_states.view(-1, D)


#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Expand Down Expand Up @@ -264,4 +234,60 @@ def dispatch_bgmv_embedding(
x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
y += out * scale
y += out * scale


Tiefen-boop marked this conversation as resolved.
Show resolved Hide resolved
class MoeMatmul(torch.nn.Module):

def __init__(self):
super().__init__()

def set_weight(self, w):
self.weight = w

def calc(self, state, expert_id, w):
self.weight = w[expert_id].transpose(0, 1)
return self.forward(state)

def forward(self, state):
return torch.matmul(state, self.weight)


class StaticFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
super().__init__()
self.w13_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.num_total_experts = num_total_experts

def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, self.num_total_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, self.num_total_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
htorch.core.mark_step()

for expert_idx in range(self.num_total_experts):
padded_weight = padded_weights[expert_idx]
w_output = self.w13_list[expert_idx].calc(hidden_states,
expert_idx, w1)
w_output = silu_and_mul(w_output)
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
final_hidden_states += w_output * padded_weight

return final_hidden_states.view(-1, D)
26 changes: 20 additions & 6 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from vllm.model_executor.utils import set_weight_attrs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kzawora-intel could you check if these changes are upstreamable?

from vllm.utils import is_hpu

if is_hpu():
from vllm.hpu.ops import static_fused_moe

logger = init_logger(__name__)


Expand Down Expand Up @@ -78,7 +75,8 @@ def apply(
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)
use_grouped_topk, num_expert_group, topk_group,
layer)

def forward_cuda(
self,
Expand All @@ -91,6 +89,7 @@ def forward_cuda(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
Expand All @@ -107,12 +106,14 @@ def forward_cuda(
def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool, num_expert_group: Optional[int],
topk_group: Optional[int]):
topk_group: Optional[int],
layer: Optional[torch.nn.Module]):
assert not use_grouped_topk, 'use_grouped_topk must be False on HPU'
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, 'topk_group is not supported on HPU'
return static_fused_moe(x, w1, w2, router_logits, top_k)
assert layer is not None, 'layer has to be provided on HP'
return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
Expand All @@ -129,6 +130,7 @@ def forward_tpu(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
Expand Down Expand Up @@ -191,6 +193,9 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
if is_hpu():
from vllm.hpu.ops import StaticFusedMOE
self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down Expand Up @@ -245,13 +250,22 @@ def weight_loader(self, param: torch.nn.Parameter,
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
if is_hpu():
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(
param_data[expert_id])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -52,6 +54,8 @@ def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["INCLinearMethod"]:
if isinstance(layer, LinearBase):
return INCLinearMethod(self)
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod()
return None

def get_scaled_act_names(self) -> List[str]:
Expand All @@ -78,7 +82,7 @@ class INCLinearMethod(LinearMethodBase):
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

Args:
quant_config: The quantization config.
"""
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and model_config.quantization not in ["fp8", "inc"]
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

Expand Down
4 changes: 1 addition & 3 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import habana_frameworks.torch as htorch
import torch
from neural_compressor.torch.quantization import finalize_calibration

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This brakes taking fp8 measurements with TP>1 - please revert changes made to this file


from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
Expand Down Expand Up @@ -1559,7 +1560,6 @@ def prepare_model_input(
virtual_engine=virtual_engine)

def finish_measurements(self):
from neural_compressor.torch.quantization import finalize_calibration
finalize_calibration(self.model.model)

def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
Expand Down Expand Up @@ -1692,8 +1692,6 @@ def shutdown_inc(self):
if (model_config := getattr(self, "model_config", None)) and \
getattr(model_config, "quantization", None) == 'inc':
print('inc shutdown start')
from neural_compressor.torch.quantization import (
finalize_calibration)
finalize_calibration(self.model.model)
print('inc shutdown')

Expand Down
Loading