From c3577af3b52bd93b69dcc224f77179133bcdfc49 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Fri, 27 Sep 2024 12:28:36 +0530 Subject: [PATCH] Fix runtime errors reported when using long input sequence lengths with LoRA (#339) This PR has following fixes, - Increase size of indices tensors used to maintain multi-lora state information from max_num_batched_tokens to 3*max_num_batched_tokens. This increase is done to provide buffer for padding done in batch & sequence dimensions. - Move logic to remove padding from lora_logits from execute_model() back to Class LogitsProcessorWithLoRA, this is done to fix race condition caused by updating multi-lora state information directly. FIX https://github.com/HabanaAI/vllm-fork/issues/237 --- vllm/lora/layers.py | 2 ++ vllm/lora/models.py | 2 +- vllm/worker/habana_model_runner.py | 20 ++++++-------------- 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index b3758ad883d5..06160367054e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1203,6 +1203,8 @@ def _get_logits( ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), posinf=float("inf"), neginf=float("-inf"))) + if current_platform.is_hpu(): + lora_logits = lora_logits[:logits.shape[0], :] logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1], ] = lora_logits diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 546a4c402aed..582170a2df62 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -432,7 +432,7 @@ def __init__( self.long_lora_context: Optional[LongContextLoRAContext] = None if current_platform.is_hpu(): self.punica_wrapper = GaudiPunicaWrapper( - max_num_batched_tokens, + 3 * max_num_batched_tokens, max_batches=self.max_num_seqs, device="hpu") else: diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index d3d297368884..bfbe4085ddd3 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1203,9 +1203,9 @@ def prepare_input_tensors( if self.lora_config: lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=(num_prefills > 0))) else: lora_mapping = None @@ -1370,9 +1370,9 @@ def warmup_scenario(self, times = 3 if use_graphs or is_pt_profiler_run else 1 if self.lora_config and not is_lora_profile_run: lora_mapping = LoRAMapping( - [0] * batch_size * seq_len, - [0] * batch_size * seq_len, - ) + **dict(index_mapping=[0] * batch_size * seq_len, + prompt_mapping=[0] * batch_size * seq_len, + is_prefill=is_prompt)) self.set_active_loras(set(), lora_mapping) if is_prompt: seqs = [ @@ -1915,14 +1915,6 @@ def execute_model( ) if self.lora_config: - from vllm.lora.layers import VocabParallelEmbeddingWithLoRA - modules = unwrap_model(self.model.model) - for module in modules: - if isinstance(module, VocabParallelEmbeddingWithLoRA): - for i in range(0, len(module.punica_wrapper.indices_len)): - module.punica_wrapper.indices_len[ - i] = sampling_metadata.selected_token_indices.numel( - ) lora_logits_mask: torch.Tensor = model_input.lora_logits_mask LoraMask.setLoraMask( lora_logits_mask.index_select(