forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Mixtral quantization using INC
- Loading branch information
1 parent
6f047d8
commit 3e503d0
Showing
4 changed files
with
185 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch.nn import Module | ||
from torch.nn.parameter import Parameter | ||
import torch.nn.functional as F | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase | ||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
ACTIVATION_SCHEMES = ["static", "dynamic"] | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class INCConfig(QuantizationConfig): | ||
"""Config class for FP8.""" | ||
|
||
def __init__( | ||
self, | ||
is_checkpoint_fp8_serialized: bool = False, | ||
activation_scheme: str = "dynamic", | ||
) -> None: | ||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized | ||
if is_checkpoint_fp8_serialized: | ||
logger.warning("Detected fp8 checkpoint. Please note that the " | ||
"format is experimental and subject to change.") | ||
if activation_scheme not in ACTIVATION_SCHEMES: | ||
raise ValueError( | ||
f"Unsupported activation scheme {activation_scheme}") | ||
self.activation_scheme = activation_scheme | ||
|
||
@classmethod | ||
def get_name(cls) -> str: | ||
return "inc" | ||
|
||
@classmethod | ||
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
return [torch.bfloat16] | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "INCConfig": | ||
quant_method = cls.get_from_keys(config, ["quant_method"]) | ||
is_checkpoint_fp8_serialized = ("fp8" in quant_method) | ||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) | ||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, | ||
activation_scheme=activation_scheme) | ||
|
||
def get_quant_method(self, layer: torch.nn.Module, | ||
prefix: str) -> Optional["INCLinearMethod"]: | ||
if isinstance(layer, LinearBase): | ||
return INCLinearMethod(self) | ||
elif isinstance(layer, FusedMoE): | ||
return UnquantizedFusedMoEMethod() | ||
return None | ||
|
||
def get_scaled_act_names(self) -> List[str]: | ||
return [] | ||
|
||
def get_min_capability(self) -> int: | ||
# The AWQ kernel only supports Turing or newer GPUs. | ||
return 75 | ||
|
||
@staticmethod | ||
def get_config_filenames() -> List[str]: | ||
return [] | ||
|
||
class INCLinearMethod(LinearMethodBase): | ||
"""Linear method for FP8. | ||
Supports loading FP8 checkpoints with static weight scale and | ||
dynamic/static activation scale. | ||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic | ||
activation scaling. The weight scaling factor will be initialized after | ||
the model weights are loaded. | ||
Limitations: | ||
1. Only support per-tensor quantization due to torch._scaled_mm support. | ||
2. Only support float8_e4m3fn data type due to the limitation of | ||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) | ||
Args: | ||
quant_config: The quantization config. | ||
""" | ||
|
||
def __init__(self, quant_config: INCConfig, separate_bias_add: bool = False): | ||
self.separate_bias_add = separate_bias_add | ||
self.quant_config = quant_config | ||
|
||
def create_weights(self, layer: torch.nn.Module, | ||
input_size_per_partition: int, | ||
output_partition_sizes: List[int], input_size: int, | ||
output_size: int, params_dtype: torch.dtype, | ||
**extra_weight_attrs): | ||
output_size_per_partition = sum(output_partition_sizes) | ||
weight = Parameter(torch.empty(output_size_per_partition, | ||
input_size_per_partition, | ||
dtype=params_dtype), | ||
requires_grad=False) | ||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) | ||
layer.register_parameter("weight", weight) | ||
set_weight_attrs(weight, extra_weight_attrs) | ||
|
||
def apply(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
weight = layer.weight | ||
if self.separate_bias_add: | ||
if bias is not None: | ||
return F.linear(x, weight) + bias | ||
return F.linear(x, weight) | ||
return F.linear(x, weight, bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters