diff --git a/vllm/config.py b/vllm/config.py index 6acb70ad047b2..7aa3977a497ea 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1326,7 +1326,7 @@ def verify_with_model_config(self, model_config: ModelConfig): model_config.quantization) def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): - if scheduler_config.max_num_batched_tokens > 65528: + if not is_hpu() and scheduler_config.max_num_batched_tokens > 65528: raise ValueError( "Due to limitations of the custom LoRA CUDA kernel, " "max_num_batched_tokens must be <= 65528 when " diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 662c53486b4ca..1ee56610d9ee5 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -193,6 +193,18 @@ def prompt_attention( return attn_weights +class LoraMask: + lora_mask = None + + @staticmethod + def setLoraMask(mask): + LoraMask.lora_mask = mask + + @staticmethod + def getLoraMask(): + return LoraMask.lora_mask + + def dispatch_bgmv_linear( y: torch.Tensor, x: torch.Tensor, @@ -206,29 +218,24 @@ def dispatch_bgmv_linear( `wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices stacked into single tensors, assuming same rank. HPU handles no-LoRA requests using zero valued A and B tensors. These zero valued tensors are - appended at the end of `wa_t_all` and `wb_t_all` during initialization. For - custom BGMV, the corresponding `wa` and `wb` for each batch is created - based on the lora_index of each sample. - - For example: - `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, - hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles - no-LoRA case. The `wa` tensor for a batch of size batch_Size will have - a shape of (batch_size, num_layers, hidden_dim, lora_rank) - - This method avoids for-loop as well as graph breaks. + appended at the end of `wa_t_all` and `wb_t_all` during initialization. + We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank] + and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also + have a loraMask of shape [batch_size, num_layers * lora_rank] """ - assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' - max_loras = wa_t_all.size(0) - # Wrap-around for negative indices - indices = indices % max_loras - wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) - wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) - x = x.unsqueeze(1) + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + mask = LoraMask.getLoraMask() + wa = wa_t_all[:, 0, :, :] + wb = wb_t_all[:, 0, :, :].transpose(1, 2) + wa_shape = wa.shape + wb_shape = wb.shape + wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1) + wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2]) out = x @ wa + assert (out.shape == mask.shape) + out = out * mask out = out @ wb - out = out.squeeze(1) y += out * scale @@ -265,4 +272,4 @@ def dispatch_bgmv_embedding( x = x.unsqueeze(1) out = x @ wa out = out.squeeze(1) - y += out * scale \ No newline at end of file + y += out * scale diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 4a45f3fda88f1..aa01e9fb77af2 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -327,6 +327,17 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 embedding_len = self.indices_len[3] + # NOTE(vgoel): These asserts can be skipped when upstreaming. + # Can be removed from vllm-fork also once lora functionality + # on Gaudi stabilizes. + if is_hpu(): + emb_len = embedding_len + x_shape = x.shape + ind_shape = self.embeddings_indices[1].shape + assert embedding_len == x.shape[0] * x.shape[1], \ + f"Extra Info: {emb_len}, {x_shape}, {ind_shape}" + assert embedding_len <= self.embeddings_indices[1].shape[0], \ + f"Extra Info: {emb_len}, {x.shape}, {ind_shape}" indices = self.embeddings_indices[1][:embedding_len].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index dec1b65858eb4..e03c9167ad308 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -22,6 +22,7 @@ ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) from vllm.distributed.parallel_state import get_world_group +from vllm.hpu.ops import LoraMask as LoraMask from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -229,6 +230,7 @@ def forward(self, *args, **kwargs): input_ids.size(1), input_ids.device, torch.bfloat16) + LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -252,21 +254,23 @@ class PreparePromptMetadata(NamedTuple): lora_requests: Set[LoRARequest] multi_modal_input: Optional[torch.Tensor] slot_mapping: List[List[int]] + lora_mask: Optional[torch.Tensor] + lora_logits_mask: Optional[torch.Tensor] @classmethod def empty(cls): - return PreparePromptMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - multi_modal_input=None, - slot_mapping=[], - ) + return PreparePromptMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + lora_mask=None, + lora_logits_mask=None) class PrepareDecodeMetadata(NamedTuple): @@ -277,18 +281,20 @@ class PrepareDecodeMetadata(NamedTuple): lora_prompt_mapping: List[List[int]] lora_requests: Set[LoRARequest] slot_mapping: List[List[int]] + lora_mask: Optional[torch.Tensor] + lora_logits_mask: Optional[torch.Tensor] @classmethod def empty(cls): - return PrepareDecodeMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], - ) + return PrepareDecodeMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + lora_mask=None, + lora_logits_mask=None) # How batches are constructed. @@ -323,6 +329,8 @@ class ModelInputForHPU(ModelRunnerInputBase): real_batch_size: Optional[int] = None batch_size_padded: Optional[int] = None virtual_engine: int = 0 + lora_mask: Optional[torch.Tensor] = None + lora_logits_mask: Optional[torch.Tensor] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -333,7 +341,9 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "multi_modal_kwargs": self.multi_modal_kwargs, "real_batch_size": self.real_batch_size, "batch_size_padded": self.batch_size_padded, - "virtual_engine": self.virtual_engine + "virtual_engine": self.virtual_engine, + "lora_mask": self.lora_mask, + "lora_logits_mask": self.lora_logits_mask, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -367,6 +377,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_mask": self.lora_mask, + "lora_logits_mask": self.lora_logits_mask, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -739,12 +751,39 @@ def _prepare_prompt( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + counter = 0 + if self.lora_config: + lora_mask = torch.zeros(len(seq_group_metadata_list) * + max_prompt_len, + (self.lora_config.max_loras + 1) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros(len(seq_group_metadata_list), + (self.lora_config.max_loras + 1) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(max_prompt_len, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) for seq_group_metadata, context_len in zip(seq_group_metadata_list, context_lens): lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) + start_row = counter * max_prompt_len + end_row = start_row + max_prompt_len + start_col = (lora_id - 1) * self.lora_config.max_lora_rank + end_col = start_col + self.lora_config.max_lora_rank + lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[counter, start_col:end_col] = logit_ones + counter = counter + 1 lora_index_mapping += [lora_id] * (max_prompt_len - context_len) lora_prompt_mapping.extend( @@ -752,6 +791,10 @@ def _prepare_prompt( (max_prompt_len - context_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if lora_mask is not None: + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0, @@ -817,6 +860,8 @@ def _prepare_prompt( lora_requests=lora_requests, multi_modal_input=multi_modal_input, slot_mapping=slot_mapping, + lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask, ) def _prepare_decode( @@ -834,6 +879,18 @@ def _prepare_decode( if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + counter = 0 + + if self.lora_config: + lora_mask = torch.zeros(len(seq_group_metadata_list), + (self.lora_config.max_loras + 1) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -844,6 +901,10 @@ def _prepare_decode( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) + start_pos = (lora_id - 1) * self.lora_config.max_lora_rank + end_pos = start_pos + self.lora_config.max_lora_rank + lora_mask[counter, start_pos:end_pos] = ones + counter = counter + 1 for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] @@ -872,6 +933,9 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if lora_mask is not None: + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) @@ -917,6 +981,8 @@ def _prepare_decode( lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, slot_mapping=slot_mapping, + lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask, ) def prepare_input_tensors( @@ -971,6 +1037,8 @@ def prepare_input_tensors( lora_requests, multi_modal_input, slot_mapping, + lora_mask, + lora_logits_mask, ) = self._prepare_prompt(prefill_reqs) ( decode_input_tokens, @@ -980,6 +1048,8 @@ def prepare_input_tensors( decode_lora_prompt_mapping, decode_lora_requests, decode_slot_mapping, + decode_lora_mask, + decode_lora_logits_mask, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, seq_lens, query_lens, @@ -1006,6 +1076,8 @@ def prepare_input_tensors( lora_index_mapping = decode_lora_index_mapping lora_prompt_mapping = decode_lora_prompt_mapping lora_requests = decode_lora_requests + lora_mask = decode_lora_mask + lora_logits_mask = decode_lora_logits_mask # FIXME: We need to adjust selected_token_indices to accommodate # for padding @@ -1065,17 +1137,19 @@ def prepare_input_tensors( attn_metadata = prefill_attn_metadata if \ prefill_attn_metadata is not None else decode_attn_metadata - return self._model_input_cls( - input_tokens=input_tokens, - seq_lens=seq_lens, - query_lens=query_lens, - input_positions=input_positions, - attn_metadata=attn_metadata, - lora_requests=lora_requests, - lora_mapping=lora_mapping, - multi_modal_kwargs=multi_modal_input, - real_batch_size=real_batch_size, - batch_size_padded=batch_size_padded), sampling_metadata + return self._model_input_cls(input_tokens=input_tokens, + seq_lens=seq_lens, + query_lens=query_lens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_requests=lora_requests, + lora_mapping=lora_mapping, + multi_modal_kwargs=multi_modal_input, + real_batch_size=real_batch_size, + batch_size_padded=batch_size_padded, + lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask), \ + sampling_metadata def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: @@ -1149,6 +1223,7 @@ def profile_run(self) -> None: True, kv_caches, is_profile_run=True) + return def warmup_scenario(self, batch_size, @@ -1610,7 +1685,8 @@ def execute_model( "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": self.trim_attn_metadata(attn_metadata), - "intermediate_tensors": intermediate_tensors + "intermediate_tensors": intermediate_tensors, + "lora_mask": model_input.lora_mask } if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) @@ -1644,6 +1720,10 @@ def execute_model( module.indices_len[ i] = sampling_metadata.selected_token_indices.numel( ) + lora_logits_mask: torch.Tensor = model_input.lora_logits_mask + LoraMask.setLoraMask( + lora_logits_mask.index_select( + 0, sampling_metadata.selected_token_indices)) # Compute the logits. with self.profiler.record_event(