Skip to content

Commit

Permalink
revert FP8 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Oct 4, 2024
1 parent c07cbc6 commit d90bbce
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 130 deletions.
10 changes: 2 additions & 8 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def __init__(self,
quant_config, prefix)

self.gather_output = gather_output
self.collective_func = tensor_model_parallel_all_gather

# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -368,7 +367,7 @@ def forward(self, input_):
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.collective_func(output_parallel)
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
Expand Down Expand Up @@ -974,7 +973,6 @@ def __init__(self,

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
self.collective_func = tensor_model_parallel_all_reduce

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -1053,18 +1051,14 @@ def weight_loader_v2(self, param: BasevLLMParameter,

param.load_row_parallel_weight(loaded_weight=loaded_weight)

def resolve_input(self, input_):
def forward(self, input_):
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
return input_parallel

def forward(self, input_):
input_parallel = self.resolve_input(input_)

# Matrix multiply.
assert self.quant_method is not None
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.inc import INCConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import (
Expand All @@ -47,7 +46,6 @@
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"inc": INCConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
Expand Down
119 changes: 0 additions & 119 deletions vllm/model_executor/layers/quantization/inc.py

This file was deleted.

2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "inc"]
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]

if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported
Expand Down

0 comments on commit d90bbce

Please sign in to comment.