Skip to content

Commit

Permalink
Fix runtime errors reported when using long input sequence lengths wi…
Browse files Browse the repository at this point in the history
…th LoRA (#339)

This PR has following fixes,
- Increase size of indices tensors used to maintain multi-lora state
information from max_num_batched_tokens to 3*max_num_batched_tokens.
This increase is done to provide buffer for padding done in batch &
sequence dimensions.
- Move logic to remove padding from lora_logits from execute_model()
back to Class LogitsProcessorWithLoRA, this is done to fix race
condition caused by updating multi-lora state information directly.

FIX #237
  • Loading branch information
vivekgoe committed Sep 27, 2024
1 parent 1c6bada commit c3577af
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 15 deletions.
2 changes: 2 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,8 @@ def _get_logits(
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
if current_platform.is_hpu():
lora_logits = lora_logits[:logits.shape[0], :]
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1], ] = lora_logits
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def __init__(
self.long_lora_context: Optional[LongContextLoRAContext] = None
if current_platform.is_hpu():
self.punica_wrapper = GaudiPunicaWrapper(
max_num_batched_tokens,
3 * max_num_batched_tokens,
max_batches=self.max_num_seqs,
device="hpu")
else:
Expand Down
20 changes: 6 additions & 14 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,9 +1203,9 @@ def prepare_input_tensors(

if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
**dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping,
is_prefill=(num_prefills > 0)))
else:
lora_mapping = None

Expand Down Expand Up @@ -1370,9 +1370,9 @@ def warmup_scenario(self,
times = 3 if use_graphs or is_pt_profiler_run else 1
if self.lora_config and not is_lora_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
[0] * batch_size * seq_len,
)
**dict(index_mapping=[0] * batch_size * seq_len,
prompt_mapping=[0] * batch_size * seq_len,
is_prefill=is_prompt))
self.set_active_loras(set(), lora_mapping)
if is_prompt:
seqs = [
Expand Down Expand Up @@ -1915,14 +1915,6 @@ def execute_model(
)

if self.lora_config:
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
modules = unwrap_model(self.model.model)
for module in modules:
if isinstance(module, VocabParallelEmbeddingWithLoRA):
for i in range(0, len(module.punica_wrapper.indices_len)):
module.punica_wrapper.indices_len[
i] = sampling_metadata.selected_token_indices.numel(
)
lora_logits_mask: torch.Tensor = model_input.lora_logits_mask
LoraMask.setLoraMask(
lora_logits_mask.index_select(
Expand Down

0 comments on commit c3577af

Please sign in to comment.