Skip to content

Commit

Permalink
prune lora files
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Sep 23, 2024
1 parent 28df6fd commit c6d2d5a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 148 deletions.
25 changes: 8 additions & 17 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -228,7 +224,6 @@ def set_lora(

def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embeddings_indices = None
embeddings_indices = self.punica_wrapper.embeddings_indices
indices = embeddings_indices[1].view_as(x)
full_lora_a_embeddings = F.embedding(
Expand All @@ -246,19 +241,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1)
full_lora_a_embeddings.shape[1],
-1,
)

# Embedding layer only need expand op
if current_platform.is_hpu():
assert isinstance(self.punica_wrapper, GaudiPunicaWrapper)
self.punica_wrapper.add_lora_embedding(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
else:
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
return full_output.view_as(full_output_org)

@classmethod
Expand Down
128 changes: 4 additions & 124 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type

import safetensors.torch
import torch
Expand All @@ -26,12 +26,8 @@
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available

if current_platform.is_hpu():
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper

logger = init_logger(__name__)

_GLOBAL_LORA_ID = 0
Expand All @@ -49,116 +45,6 @@ class LongContextLoRAContext:
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)


def convert_mapping(
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
device = "hpu" if current_platform.is_hpu() else "cuda"
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset

indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, lora_indices, embedding_indices
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device=device,
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)

return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices, indices_len)


def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
Expand Down Expand Up @@ -430,15 +316,9 @@ def __init__(
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
if current_platform.is_hpu():
self.punica_wrapper = GaudiPunicaWrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device="hpu")
else:
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device="cuda")
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device="cuda")
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
Expand Down
12 changes: 5 additions & 7 deletions vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch

from vllm.triton_utils import HAS_TRITON
from vllm.utils import get_device

if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
Expand Down Expand Up @@ -105,7 +104,7 @@ def convert_mapping(
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=get_device(),
device="cuda",
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
Expand All @@ -132,9 +131,9 @@ def convert_mapping(
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=get_device())
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device=get_device(),
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
Expand All @@ -146,9 +145,8 @@ def convert_mapping(
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=get_device(),
dtype=torch.long) + (sampler_indices_padded *
len(sampler_indices_padded))
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
Expand Down

0 comments on commit c6d2d5a

Please sign in to comment.