diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index bbbb46c32a378..1ee56610d9ee5 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -219,50 +219,23 @@ def dispatch_bgmv_linear( stacked into single tensors, assuming same rank. HPU handles no-LoRA requests using zero valued A and B tensors. These zero valued tensors are appended at the end of `wa_t_all` and `wb_t_all` during initialization. + We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank] + and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also + have a loraMask of shape [batch_size, num_layers * lora_rank] """ assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' - max_loras = wa_t_all.size(0) - # Wrap-around for negative indices mask = LoraMask.getLoraMask() - if mask is not None: - """ - We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank] - and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also - have a loraMask of shape [batch_size, num_layers * lora_rank] - """ - wa = wa_t_all[:, 0, :, :] - wb = wb_t_all[:, 0, :, :].transpose(1, 2) - wa_shape = wa.shape - wb_shape = wb.shape - wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1) - wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2]) - out = x @ wa - assert (out.shape == mask.shape) - out = out * mask - out = out @ wb - else: - """For custom BGMV, the corresponding `wa` and `wb` for each batch is - created based on the lora_index of each sample. - - For example: - `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, - hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles - no-LoRA case. The `wa` tensor for a batch of size batch_Size will have - a shape of (batch_size, num_layers, hidden_dim, lora_rank) - - This method avoids for-loop as well as graph breaks. - """ - indices = indices % max_loras - wa = torch.index_select(wa_t_all, 0, - indices)[:, 0, :, :].transpose(-1, -2) - wb = torch.index_select(wb_t_all, 0, - indices)[:, 0, :, :].transpose(-1, -2) - - x = x.unsqueeze(1) - out = x @ wa - out = out @ wb - out = out.squeeze(1) + wa = wa_t_all[:, 0, :, :] + wb = wb_t_all[:, 0, :, :].transpose(1, 2) + wa_shape = wa.shape + wb_shape = wb.shape + wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1) + wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2]) + out = x @ wa + assert (out.shape == mask.shape) + out = out * mask + out = out @ wb y += out * scale diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 4b65a7ef46721..2f2c439dfbbe7 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -255,6 +255,7 @@ class PreparePromptMetadata(NamedTuple): multi_modal_input: Optional[torch.Tensor] slot_mapping: List[List[int]] lora_mask: Optional[torch.Tensor] + lora_logits_mask: Optional[torch.Tensor] @classmethod def empty(cls): @@ -268,7 +269,8 @@ def empty(cls): lora_requests=set(), multi_modal_input=None, slot_mapping=[], - lora_mask=None) + lora_mask=None, + lora_logits_mask=None) class PrepareDecodeMetadata(NamedTuple): @@ -280,6 +282,7 @@ class PrepareDecodeMetadata(NamedTuple): lora_requests: Set[LoRARequest] slot_mapping: List[List[int]] lora_mask: Optional[torch.Tensor] + lora_logits_mask: Optional[torch.Tensor] @classmethod def empty(cls): @@ -292,6 +295,7 @@ def empty(cls): lora_requests=set(), slot_mapping=[], lora_mask=None, + lora_logits_mask=None ) @@ -328,6 +332,7 @@ class ModelInputForHPU(ModelRunnerInputBase): batch_size_padded: Optional[int] = None virtual_engine: int = 0 lora_mask: Optional[torch.Tensor] = None + lora_logits_mask: Optional[torch.Tensor] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -340,6 +345,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "batch_size_padded": self.batch_size_padded, "virtual_engine": self.virtual_engine, "lora_mask": self.lora_mask, + "lora_logits_mask": self.lora_logits_mask, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -374,6 +380,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "lora_mask": self.lora_mask, + "lora_logits_mask": self.lora_logits_mask, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -747,6 +754,7 @@ def _prepare_prompt( self.block_size) lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None counter = 0 if self.lora_config: lora_mask = torch.zeros(len(seq_group_metadata_list) * @@ -754,9 +762,17 @@ def _prepare_prompt( (self.lora_config.max_loras + 1) * self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros(len(seq_group_metadata_list), + (self.lora_config.max_loras + 1) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(max_prompt_len, self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) for seq_group_metadata, context_len in zip(seq_group_metadata_list, context_lens): lora_id = seq_group_metadata.lora_int_id @@ -768,6 +784,7 @@ def _prepare_prompt( start_col = (lora_id - 1) * self.lora_config.max_lora_rank end_col = start_col + self.lora_config.max_lora_rank lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[counter, start_col:end_col] = logit_ones counter = counter + 1 lora_index_mapping += [lora_id] * (max_prompt_len - context_len) @@ -778,6 +795,7 @@ def _prepare_prompt( if lora_mask is not None: lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, @@ -845,6 +863,7 @@ def _prepare_prompt( multi_modal_input=multi_modal_input, slot_mapping=slot_mapping, lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask, ) def _prepare_decode( @@ -863,6 +882,7 @@ def _prepare_decode( if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None counter = 0 if self.lora_config: @@ -917,6 +937,7 @@ def _prepare_decode( if lora_mask is not None: lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) @@ -963,6 +984,7 @@ def _prepare_decode( lora_requests=lora_requests, slot_mapping=slot_mapping, lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask, ) def prepare_input_tensors( @@ -1018,6 +1040,7 @@ def prepare_input_tensors( multi_modal_input, slot_mapping, lora_mask, + lora_logits_mask, ) = self._prepare_prompt(prefill_reqs) ( decode_input_tokens, @@ -1028,6 +1051,7 @@ def prepare_input_tensors( decode_lora_requests, decode_slot_mapping, decode_lora_mask, + decode_lora_logits_mask, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, seq_lens, query_lens, @@ -1055,6 +1079,7 @@ def prepare_input_tensors( lora_prompt_mapping = decode_lora_prompt_mapping lora_requests = decode_lora_requests lora_mask = decode_lora_mask + lora_logits_mask = decode_lora_logits_mask # FIXME: We need to adjust selected_token_indices to accommodate # for padding @@ -1124,7 +1149,9 @@ def prepare_input_tensors( multi_modal_kwargs=multi_modal_input, real_batch_size=real_batch_size, batch_size_padded=batch_size_padded, - lora_mask=lora_mask), sampling_metadata + lora_mask=lora_mask, + lora_logits_mask=lora_logits_mask), \ + sampling_metadata def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: @@ -1198,6 +1225,7 @@ def profile_run(self) -> None: True, kv_caches, is_profile_run=True) + return def warmup_scenario(self, batch_size, @@ -1694,7 +1722,9 @@ def execute_model( module.indices_len[ i] = sampling_metadata.selected_token_indices.numel( ) - LoraMask.setLoraMask(None) + lora_logits_mask: torch.Tensor = model_input.lora_logits_mask + LoraMask.setLoraMask(lora_logits_mask.index_select( + 0, sampling_metadata.selected_token_indices)) # Compute the logits. with self.profiler.record_event(