From 6154031a858fd12ecef87c065baa96aa687deafc Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 2 Nov 2024 05:49:32 -0700 Subject: [PATCH] feat: add TP support for bitsandbytes --- aphrodite/common/config.py | 6 ---- aphrodite/modeling/layers/linear.py | 23 +++++++++++--- aphrodite/modeling/model_loader/loader.py | 38 ++++++++++++++++++++++- 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index 88d3e2e8d..c7976fa8e 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -456,12 +456,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - if self.quantization == "bitsandbytes" and ( - parallel_config.tensor_parallel_size > 1 - or parallel_config.pipeline_parallel_size > 1): - raise ValueError( - "BitsAndBytes quantization with TP/PP is not supported yet.") - if self.quantization == "bitsandbytes" and self.enforce_eager is False: raise ValueError( "BitsAndBytes with enforce_eager=False is not supported yet.") diff --git a/aphrodite/modeling/layers/linear.py b/aphrodite/modeling/layers/linear.py index 9e1ac3732..bc0f09337 100644 --- a/aphrodite/modeling/layers/linear.py +++ b/aphrodite/modeling/layers/linear.py @@ -523,6 +523,8 @@ def weight_loader(self, param, shard_size, shard_offset) use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) if use_bitsandbytes: shard_size = loaded_weight.shape[output_dim] shard_offset = loaded_weight.shape[output_dim] * \ @@ -547,8 +549,11 @@ def weight_loader(self, loaded_weight.shape[output_dim], tp_rank, tp_size) else: start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -894,6 +899,8 @@ def weight_loader(self, param, shard_size, shard_offset) use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) if use_bitsandbytes: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), @@ -934,8 +941,11 @@ def weight_loader(self, else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -1044,6 +1054,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) # Special case for GGUF is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) @@ -1058,7 +1069,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - if input_dim is not None: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] if self.quant_config is None: start_idx = get_current_tp_rank_partition_offset( diff --git a/aphrodite/modeling/model_loader/loader.py b/aphrodite/modeling/model_loader/loader.py index abb45ee11..dc27271fe 100644 --- a/aphrodite/modeling/model_loader/loader.py +++ b/aphrodite/modeling/model_loader/loader.py @@ -25,6 +25,8 @@ LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar +from aphrodite.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from aphrodite.modeling.model_loader.tensorizer import ( TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer, serialize_aphrodite_model, tensorizer_weights_iterator) @@ -661,6 +663,8 @@ def save_model( class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" + # TODO: these module names are for Llama only, + # change so that it works with other models as well default_target_modules = [ "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj" @@ -846,13 +850,39 @@ def _parse_quant_state(param_name: str, yield weight_name, weight_tensor def generator() -> Generator: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if any(target_module in weight_name for target_module in self.target_modules): weight_name = weight_name.replace(".weight", ".qweight") + # weight partitions of different modules occur at + # different dimensions + # TODO: these module names are for Llama only, + # change so that it works with other models as well + if 'down_proj' in weight_name or 'o_proj' in weight_name: + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[ \ + start_index:end_index, ...] # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + with set_default_torch_dtype(torch.float32): processed_weight, quant_state = quantize_4bit( loaded_weight, @@ -867,6 +897,12 @@ def generator() -> Generator: if pre_quant: return quantized_checkpoint(), quant_state_dict + + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequanted Bitsandbytes models are not supported with " + "Tensor Parallel. Please try Pipeline Parallel instead.") + return generator(), quant_state_dict def _load_weights(self, model_config: ModelConfig,