Skip to content

Commit

Permalink
feat: add TP support for bitsandbytes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Nov 2, 2024
1 parent f98e7b2 commit 6154031
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 12 deletions.
6 changes: 0 additions & 6 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,6 @@ def verify_with_parallel_config(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")

if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitsAndBytes quantization with TP/PP is not supported yet.")

if self.quantization == "bitsandbytes" and self.enforce_eager is False:
raise ValueError(
"BitsAndBytes with enforce_eager=False is not supported yet.")
Expand Down
23 changes: 18 additions & 5 deletions aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ def weight_loader(self,
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
Expand All @@ -547,8 +549,11 @@ def weight_loader(self,
loaded_weight.shape[output_dim], tp_rank, tp_size)
else:
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -894,6 +899,8 @@ def weight_loader(self,
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
Expand Down Expand Up @@ -934,8 +941,11 @@ def weight_loader(self,
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -1044,6 +1054,7 @@ def __init__(self,
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
Expand All @@ -1058,7 +1069,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
if self.quant_config is None:
start_idx = get_current_tp_rank_partition_offset(
Expand Down
38 changes: 37 additions & 1 deletion aphrodite/modeling/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from aphrodite.modeling.model_loader.tensorizer import (
TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
serialize_aphrodite_model, tensorizer_weights_iterator)
Expand Down Expand Up @@ -661,6 +663,8 @@ def save_model(
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""

# TODO: these module names are for Llama only,
# change so that it works with other models as well
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
Expand Down Expand Up @@ -846,13 +850,39 @@ def _parse_quant_state(param_name: str,
yield weight_name, weight_tensor

def generator() -> Generator:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[ \
start_index:end_index, ...]
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()

with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
Expand All @@ -867,6 +897,12 @@ def generator() -> Generator:

if pre_quant:
return quantized_checkpoint(), quant_state_dict

if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequanted Bitsandbytes models are not supported with "
"Tensor Parallel. Please try Pipeline Parallel instead.")

return generator(), quant_state_dict

def _load_weights(self, model_config: ModelConfig,
Expand Down

0 comments on commit 6154031

Please sign in to comment.