From da03d8b8fa14fbc1cb276d19849a6c40b86a8b0e Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:09:10 +0530 Subject: [PATCH] Lora Mask based on lora index (#348) Changes the filling of lora mask from lora_id to lora_index. This is needed to ensure that the mask does not fail in case lora id is greater than max_loras --- vllm/worker/habana_model_runner.py | 211 +++++++++++++++-------------- 1 file changed, 113 insertions(+), 98 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 79133aaf8f0f2..2d72be5690664 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -350,8 +350,7 @@ class PreparePromptMetadata(NamedTuple): lora_requests: Set[LoRARequest] multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]] slot_mapping: List[List[int]] - lora_mask: Optional[torch.Tensor] - lora_logits_mask: Optional[torch.Tensor] + lora_ids: List[int] @classmethod def empty(cls): @@ -365,8 +364,7 @@ def empty(cls): lora_requests=set(), multi_modal_kwargs=None, slot_mapping=[], - lora_mask=None, - lora_logits_mask=None) + lora_ids=[]) class PrepareDecodeMetadata(NamedTuple): @@ -377,8 +375,7 @@ class PrepareDecodeMetadata(NamedTuple): lora_prompt_mapping: List[List[int]] lora_requests: Set[LoRARequest] slot_mapping: List[List[int]] - lora_mask: Optional[torch.Tensor] - lora_logits_mask: Optional[torch.Tensor] + lora_ids: List[int] @classmethod def empty(cls): @@ -389,8 +386,7 @@ def empty(cls): lora_prompt_mapping=[], lora_requests=set(), slot_mapping=[], - lora_mask=None, - lora_logits_mask=None) + lora_ids=[]) # How batches are constructed. @@ -425,8 +421,7 @@ class ModelInputForHPU(ModelRunnerInputBase): real_batch_size: Optional[int] = None batch_size_padded: Optional[int] = None virtual_engine: int = 0 - lora_mask: Optional[torch.Tensor] = None - lora_logits_mask: Optional[torch.Tensor] = None + lora_ids: Optional[List[int]] = None async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: @@ -439,8 +434,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "real_batch_size": self.real_batch_size, "batch_size_padded": self.batch_size_padded, "virtual_engine": self.virtual_engine, - "lora_mask": self.lora_mask, - "lora_logits_mask": self.lora_logits_mask, + "lora_ids": self.lora_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -474,8 +468,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, - "lora_mask": self.lora_mask, - "lora_logits_mask": self.lora_logits_mask, + "lora_ids": self.lora_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -836,38 +829,14 @@ def _prepare_prompt( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) - lora_mask: torch.Tensor = None - lora_logits_mask: torch.Tensor = None - counter = 0 - if self.lora_config: - lora_mask = torch.zeros( - len(seq_group_metadata_list) * max_prompt_len, - (self.lora_config.max_loras) * self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - lora_logits_mask = torch.zeros(len(seq_group_metadata_list), - (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - ones = torch.ones(max_prompt_len, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - logit_ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) + lora_ids: List[int] = [] for seq_group_metadata, context_len in zip(seq_group_metadata_list, context_lens): lora_id = seq_group_metadata.lora_int_id + lora_ids.append(lora_id) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - start_row = counter * max_prompt_len - end_row = start_row + max_prompt_len - start_col = (lora_id - 1) * self.lora_config.max_lora_rank - end_col = start_col + self.lora_config.max_lora_rank - lora_mask[start_row:end_row, start_col:end_col] = ones - lora_logits_mask[counter, start_col:end_col] = logit_ones - counter = counter + 1 lora_index_mapping += [lora_id] * (max_prompt_len - context_len) lora_prompt_mapping.extend( @@ -875,10 +844,6 @@ def _prepare_prompt( (max_prompt_len - context_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) - if lora_mask is not None: - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_logits_mask.to('hpu') - input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0, @@ -919,20 +884,17 @@ def _prepare_prompt( ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) - return PreparePromptMetadata( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - slot_mapping=slot_mapping, - lora_mask=lora_mask, - lora_logits_mask=lora_logits_mask, - ) + return PreparePromptMetadata(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + slot_mapping=slot_mapping, + lora_ids=lora_ids) def _prepare_decode( self, @@ -949,18 +911,7 @@ def _prepare_decode( if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() - lora_mask: torch.Tensor = None - lora_logits_mask: torch.Tensor = None - counter = 0 - - if self.lora_config: - lora_mask = torch.zeros(len(seq_group_metadata_list), - (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) + lora_ids: List[int] = [] dummy_slots = itertools.cycle( range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) @@ -971,13 +922,10 @@ def _prepare_decode( seq_ids = list(seq_group_metadata.seq_data.keys()) lora_id = seq_group_metadata.lora_int_id + lora_ids.append(lora_id) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - start_pos = (lora_id - 1) * self.lora_config.max_lora_rank - end_pos = start_pos + self.lora_config.max_lora_rank - lora_mask[counter, start_pos:end_pos] = ones - counter = counter + 1 for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] @@ -1012,9 +960,6 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - if lora_mask is not None: - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_mask input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) @@ -1075,17 +1020,14 @@ def _prepare_decode( num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, ) - return PrepareDecodeMetadata( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - slot_mapping=slot_mapping, - lora_mask=lora_mask, - lora_logits_mask=lora_logits_mask, - ) + return PrepareDecodeMetadata(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + lora_ids=lora_ids) def prepare_input_tensors( self, @@ -1142,8 +1084,7 @@ def prepare_input_tensors( lora_requests, multi_modal_kwargs, slot_mapping, - lora_mask, - lora_logits_mask, + lora_ids, ) = self._prepare_prompt(prefill_reqs) ( decode_input_tokens, @@ -1153,8 +1094,7 @@ def prepare_input_tensors( decode_lora_prompt_mapping, decode_lora_requests, decode_slot_mapping, - decode_lora_mask, - decode_lora_logits_mask, + decode_lora_ids, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, seq_lens, query_lens, @@ -1181,8 +1121,7 @@ def prepare_input_tensors( lora_index_mapping = decode_lora_index_mapping lora_prompt_mapping = decode_lora_prompt_mapping lora_requests = decode_lora_requests - lora_mask = decode_lora_mask - lora_logits_mask = decode_lora_logits_mask + lora_ids = decode_lora_ids # FIXME: We need to adjust selected_token_indices to accommodate # for padding @@ -1252,8 +1191,7 @@ def prepare_input_tensors( multi_modal_kwargs=multi_modal_kwargs, real_batch_size=real_batch_size, batch_size_padded=batch_size_padded, - lora_mask=lora_mask, - lora_logits_mask=lora_logits_mask), \ + lora_ids=lora_ids), \ sampling_metadata def _seq_len(self, attn_metadata): @@ -1853,6 +1791,76 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", phase, batch_size, seq_len) + def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], + is_prompt: bool): + ''' + This is a helper function to create the mask for lora computations. + Lora Mask is needed to ensure we match the correct lora weights for the + for the request. + For Prompt phase we have + lora_mask with shape (batch_size * seq_len, max_loras * max_rank) + lora_logits_mask with shape (batch_size, max_loras * max_rank) + For Decode phase we have both + lora_mask and lora_logits_mask with shape + (batch_size, max_loras * max_rank) + ''' + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + lora_index = 0 + + if self.lora_config: + if is_prompt: + lora_mask = torch.zeros( + input_tokens.shape[0] * input_tokens.shape[1], + (self.lora_config.max_loras) *\ + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros( + input_tokens.shape[0], (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(input_tokens.shape[1], + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_row = i * input_tokens.shape[1] + end_row = start_row + input_tokens.shape[1] + start_col = lora_index * self.lora_config.max_lora_rank + end_col = start_col + self.lora_config.max_lora_rank + lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[i, start_col:end_col] = logit_ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + else: + lora_mask = torch.zeros(input_tokens.shape[0], + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_pos = lora_index * self.lora_config.max_lora_rank + end_pos = start_pos + self.lora_config.max_lora_rank + lora_mask[i, start_pos:end_pos] = ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask + + return lora_mask, lora_logits_mask + @torch.inference_mode() def execute_model( self, @@ -1887,13 +1895,21 @@ def execute_model( seq_len = self._seq_len(attn_metadata) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) self._check_config(batch_size, seq_len, is_prompt, warmup_mode) + + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + if self.lora_config: + assert model_input.lora_ids is not None + lora_mask, lora_logits_mask = self.create_lora_mask( + input_tokens, model_input.lora_ids, attn_metadata.is_prompt) + execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": self.trim_attn_metadata(attn_metadata), "intermediate_tensors": intermediate_tensors, - "lora_mask": model_input.lora_mask, + "lora_mask": lora_mask, **(model_input.multi_modal_kwargs or {}), } if htorch.utils.internal.is_lazy(): @@ -1915,7 +1931,6 @@ def execute_model( ) if self.lora_config: - lora_logits_mask: torch.Tensor = model_input.lora_logits_mask LoraMask.setLoraMask( lora_logits_mask.index_select( 0, sampling_metadata.selected_token_indices))