Skip to content

Commit

Permalink
Move HPU specific LoRA ops to vllm.hpu.ops module
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjuCSudhakaran committed Aug 20, 2024
1 parent 557a23e commit 78436a6
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 90 deletions.
75 changes: 75 additions & 0 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,78 @@ def prompt_attention(
valid_seq_lengths, 'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights


def dispatch_bgmv_linear(
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indices: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
`wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices
stacked into single tensors, assuming same rank. HPU handles no-LoRA
requests using zero valued A and B tensors. These zero valued tensors are
appended at the end of `wa_t_all` and `wb_t_all` during initialization. For
custom BGMV, the corresponding `wa` and `wb` for each batch is created
based on the lora_index of each sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The `wa` tensor for a batch of size batch_Size will have
a shape of (batch_size, num_layers, hidden_dim, lora_rank)
This method avoids for-loop as well as graph breaks.
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)
wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)

x = x.unsqueeze(1)
out = x @ wa
out = out @ wb
out = out.squeeze(1)
y += out * scale


def dispatch_bgmv_embedding(
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
indices: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
`wa_t_all` contains all LoRA A weight matrices stacked into a single tensor
assuming same rank. HPU handles no-LoRA requests using zero valued A
tensor. This zero valued tensor is appended at the end of `wa_t_all` during
initialization. For custom BGMV, the corresponding wa for each batch is
created based on the lora_index of the sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The wa tensor for a batch of size batch_Size will have a
shape of (batch_size, num_layers, lora_rank, hidden_dim)
This method avoids for-loop as well as graph breaks.
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)

x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
y += out * scale
89 changes: 9 additions & 80 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
VocabParallelEmbedding)
from vllm.utils import is_hpu

if is_hpu():
from vllm.hpu.ops import dispatch_bgmv_embedding, dispatch_bgmv_linear

if TYPE_CHECKING:
pass

Expand Down Expand Up @@ -64,81 +67,6 @@ def dec(*args, **kwargs):
return dec


def custom_bgmv(
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indices: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
`wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices
stacked into single tensors, assuming same rank. HPU handles no-LoRA
requests using zero valued A and B tensors. These zero valued tensors are
appended at the end of `wa_t_all` and `wb_t_all` during initialization. For
custom BGMV, the corresponding `wa` and `wb` for each batch is created
based on the lora_index of each sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The `wa` tensor for a batch of size batch_Size will have
a shape of (batch_size, num_layers, hidden_dim, lora_rank)
This method avoids for-loop as well as graph breaks.
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)
wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)

x = x.unsqueeze(1)
out = x @ wa
out = out @ wb
out = out.squeeze(1)
y += out * scale


def custom_bgmv_embed(
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
indices: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
`wa_t_all` contains all LoRA A weight matrices stacked into a single tensor
assuming same rank. HPU handles no-LoRA requests using zero valued A
tensor. This zero valued tensor is appended at the end of `wa_t_all` during
initialization. For custom BGMV, the corresponding wa for each batch is
created based on the lora_index of the sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The wa tensor for a batch of size batch_Size will have a
shape of (batch_size, num_layers, lora_rank, hidden_dim)
This method avoids for-loop as well as graph breaks.
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)

x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
y += out * scale


def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
Expand Down Expand Up @@ -166,7 +94,8 @@ def _apply_lora(
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
if is_hpu():
custom_bgmv(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
dispatch_bgmv_linear(output, x, lora_a_stacked, lora_b_stacked,
indices, 0, 1.0)
else:
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
return output.view_as(org_output)
Expand Down Expand Up @@ -207,7 +136,7 @@ def _apply_lora_packed_nslice(
offset_left = 0
for slice_idx in range(len(output_slices)):
if is_hpu():
custom_bgmv(
dispatch_bgmv_linear(
output[:, offset_left:offset_left + output_slices[slice_idx]],
x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx],
indices, 0, 1.0)
Expand Down Expand Up @@ -416,9 +345,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1)
if is_hpu():
custom_bgmv_embed(full_output, full_lora_a_embeddings,
self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
dispatch_bgmv_embedding(full_output, full_lora_a_embeddings,
self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
else:
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
Expand Down
22 changes: 12 additions & 10 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,16 +1556,18 @@ def execute_model(
selected_token_indices=sampling_metadata.selected_token_indices
)

from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
property = vars(self.model.model)
model = list(property['_modules'].values())[0]
property = vars(model)
modules = list(property['_modules'].values())
for module in modules:
if isinstance(module, VocabParallelEmbeddingWithLoRA):
for i in range(0, 4):
module.indices_len[
i] = sampling_metadata.selected_token_indices.numel()
if self.lora_config:
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
property = vars(self.model.model)
model = list(property['_modules'].values())[0]
property = vars(model)
modules = list(property['_modules'].values())
for module in modules:
if isinstance(module, VocabParallelEmbeddingWithLoRA):
for i in range(0, 4):
module.indices_len[
i] = sampling_metadata.selected_token_indices.numel(
)

# Compute the logits.
with self.profiler.record_event(
Expand Down

0 comments on commit 78436a6

Please sign in to comment.