Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mask based BGMV implementation #223

Merged
merged 9 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def verify_with_model_config(self, model_config: ModelConfig):
model_config.quantization)

def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
if not is_hpu() and scheduler_config.max_num_batched_tokens > 65528:
raise ValueError(
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
Expand Down
47 changes: 27 additions & 20 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,18 @@ def prompt_attention(
return attn_weights


class LoraMask:
lora_mask = None

@staticmethod
def setLoraMask(mask):
LoraMask.lora_mask = mask

@staticmethod
def getLoraMask():
return LoraMask.lora_mask


def dispatch_bgmv_linear(
y: torch.Tensor,
x: torch.Tensor,
Expand All @@ -206,29 +218,24 @@ def dispatch_bgmv_linear(
`wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices
stacked into single tensors, assuming same rank. HPU handles no-LoRA
requests using zero valued A and B tensors. These zero valued tensors are
appended at the end of `wa_t_all` and `wb_t_all` during initialization. For
custom BGMV, the corresponding `wa` and `wb` for each batch is created
based on the lora_index of each sample.

For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The `wa` tensor for a batch of size batch_Size will have
a shape of (batch_size, num_layers, hidden_dim, lora_rank)

This method avoids for-loop as well as graph breaks.
appended at the end of `wa_t_all` and `wb_t_all` during initialization.
We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank]
and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also
have a loraMask of shape [batch_size, num_layers * lora_rank]
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)
wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2)

x = x.unsqueeze(1)
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
hlahkar marked this conversation as resolved.
Show resolved Hide resolved
mask = LoraMask.getLoraMask()
wa = wa_t_all[:, 0, :, :]
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wa_shape = wa.shape
wb_shape = wb.shape
wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1)
wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2])
out = x @ wa
assert (out.shape == mask.shape)
out = out * mask
out = out @ wb
out = out.squeeze(1)
y += out * scale


Expand Down Expand Up @@ -265,4 +272,4 @@ def dispatch_bgmv_embedding(
x = x.unsqueeze(1)
out = x @ wa
out = out.squeeze(1)
y += out * scale
y += out * scale
11 changes: 11 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,17 @@ def set_mapping(
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embedding_len = self.indices_len[3]
# NOTE(vgoel): These asserts can be skipped when upstreaming.
# Can be removed from vllm-fork also once lora functionality
# on Gaudi stabilizes.
if is_hpu():
emb_len = embedding_len
x_shape = x.shape
ind_shape = self.embeddings_indices[1].shape
assert embedding_len == x.shape[0] * x.shape[1], \
f"Extra Info: {emb_len}, {x_shape}, {ind_shape}"
assert embedding_len <= self.embeddings_indices[1].shape[0], \
f"Extra Info: {emb_len}, {x.shape}, {ind_shape}"
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
Expand Down
Loading
Loading