Skip to content

Commit

Permalink
Add quantized mixtral support (vllm-project#2673)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jan 31, 2024
1 parent 105a40f commit 3dad944
Show file tree
Hide file tree
Showing 3 changed files with 422 additions and 4 deletions.
13 changes: 9 additions & 4 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm.config import ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry
Expand All @@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)


def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
Expand All @@ -34,7 +39,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:

def get_model(model_config: ModelConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
model_class = _get_model_architecture(model_config)

# Get the (maybe quantized) linear method.
linear_method = None
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
Expand Down
Loading

0 comments on commit 3dad944

Please sign in to comment.