Skip to content

Commit

Permalink
Support Mixtral quantization using INC
Browse files Browse the repository at this point in the history
  • Loading branch information
dudilester committed Aug 14, 2024
1 parent 6f047d8 commit 3e503d0
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 37 deletions.
84 changes: 53 additions & 31 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,37 +88,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:]


def static_fused_moe(hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

htorch.core.mark_step()

for expert_idx in range(num_experts):
w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1))
w_output = silu_and_mul(w_output)
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
final_hidden_states += w_output * padded_weights[expert_idx]
htorch.core.mark_step()

return final_hidden_states.view(-1, D)


@hpu_utils.with_mark_steps
def prompt_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -148,3 +117,56 @@ def prompt_attention(
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
return attn_weights


class MoeMatmul(torch.nn.Module):
def __init__(self):
super().__init__()

def set_weight(self, w):
self.weight = w

def calc(self, state, expert_id, w):
self.weight = w[expert_id].transpose(0, 1)
return self.forward(state)

def forward(self, state):
return torch.matmul(state, self.weight)


class StaticFusedMOE(torch.nn.Module):
def __init__(self, num_total_experts):
super().__init__()
self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.num_total_experts = num_total_experts


def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, self.num_total_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, self.num_total_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
htorch.core.mark_step()

for expert_idx in range(self.num_total_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = self.w13_list[expert_idx].calc(current_state_static, expert_idx, w1)
w_output = silu_and_mul(w_output)
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
htorch.core.mark_step()

return final_hidden_states.view(-1, D)
20 changes: 15 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hpu

if is_hpu():
from vllm.hpu.ops import static_fused_moe

logger = init_logger(__name__)

Expand Down Expand Up @@ -78,7 +76,7 @@ def apply(
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)
use_grouped_topk, num_expert_group, topk_group, layer)

def forward_cuda(
self,
Expand All @@ -91,6 +89,7 @@ def forward_cuda(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
Expand All @@ -107,12 +106,12 @@ def forward_cuda(
def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool, num_expert_group: Optional[int],
topk_group: Optional[int]):
topk_group: Optional[int], layer: Optional[torch.nn.Module],):
assert not use_grouped_topk, 'use_grouped_topk must be False on HPU'
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, 'topk_group is not supported on HPU'
return static_fused_moe(x, w1, w2, router_logits, top_k)
return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
Expand All @@ -129,6 +128,7 @@ def forward_tpu(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
Expand Down Expand Up @@ -191,6 +191,9 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
if is_hpu():
from vllm.hpu.ops import StaticFusedMOE
self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand All @@ -207,6 +210,7 @@ def __init__(
params_dtype=params_dtype,
weight_loader=self.weight_loader)


def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int):
Expand Down Expand Up @@ -245,13 +249,19 @@ def weight_loader(self, param: torch.nn.Parameter,
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
if is_hpu():
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(param_data[expert_id])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
Expand Down
116 changes: 116 additions & 0 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)


class INCConfig(QuantizationConfig):
"""Config class for FP8."""

def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme

@classmethod
def get_name(cls) -> str:
return "inc"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16]

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "INCConfig":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["INCLinearMethod"]:
if isinstance(layer, LinearBase):
return INCLinearMethod(self)
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod()
return None

def get_scaled_act_names(self) -> List[str]:
return []

def get_min_capability(self) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75

@staticmethod
def get_config_filenames() -> List[str]:
return []

class INCLinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""

def __init__(self, quant_config: INCConfig, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
return F.linear(x, weight)
return F.linear(x, weight, bias)
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and model_config.quantization not in ["fp8", "inc"]
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

Expand Down

0 comments on commit 3e503d0

Please sign in to comment.