Skip to content

Commit

Permalink
Fix runtime errors reported when using long input sequence lengths wi…
Browse files Browse the repository at this point in the history
…th LoRA (#343)

This PR has following fixes,

- Increase size of indices tensors used to maintain multi-lora state
information from max_num_batched_tokens to 3*max_num_batched_tokens.
This increase is done to provide buffer for padding done in batch &
sequence dimensions.

- Move logic to remove padding from lora_logits from execute_model()
back to Class LogitsProcessorWithLoRA, this is done to fix race
condition caused by updating multi-lora state information directly.

FIX #237
  • Loading branch information
vivekgoe committed Sep 27, 2024
1 parent 5953449 commit b70dcba
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
2 changes: 2 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,8 @@ def _get_logits(
nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
if is_hpu():
lora_logits = lora_logits[:logits.shape[0], :]
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
Expand Down
4 changes: 3 additions & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import get_device, is_pin_memory_available
from vllm.utils import get_device, is_hpu, is_pin_memory_available

logger = init_logger(__name__)

Expand Down Expand Up @@ -829,6 +829,8 @@ def create_lora_manager(
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "supported_lora_modules"):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
if is_hpu():
max_num_batched_tokens = 3 * max_num_batched_tokens
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
Expand Down
8 changes: 0 additions & 8 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,14 +1937,6 @@ def execute_model(
)

if self.lora_config:
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
modules = unwrap_model(self.model.model)
for module in modules:
if isinstance(module, VocabParallelEmbeddingWithLoRA):
for i in range(0, len(module.indices_len)):
module.indices_len[
i] = sampling_metadata.selected_token_indices.numel(
)
lora_logits_mask: torch.Tensor = model_input.lora_logits_mask
LoraMask.setLoraMask(
lora_logits_mask.index_select(
Expand Down

0 comments on commit b70dcba

Please sign in to comment.