Skip to content

Commit

Permalink
[Bugfix]Fix Phi-3 BNB online quantization (#10417)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
  • Loading branch information
jeejeelee authored Nov 19, 2024
1 parent 284203f commit 7eb719d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
12 changes: 9 additions & 3 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ def weight_loader(self,
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)

if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
Expand All @@ -480,6 +481,8 @@ def weight_loader(self,
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
Expand All @@ -495,7 +498,9 @@ def weight_loader(self,
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim] // 2
shard_offset = shard_size * shard_id
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
Expand Down Expand Up @@ -808,7 +813,8 @@ def weight_loader(self,
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)

if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,13 @@ class Phi3ForCausalLM(LlamaForCausalLM):
"gate_up_proj",
],
}

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_up_proj.",
".down_proj.",
".qkv_proj.",
".o_proj.",
]
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}

0 comments on commit 7eb719d

Please sign in to comment.