From 78436a602c9d6bb0239eff9285235211b9fe8045 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Tue, 20 Aug 2024 09:17:55 +0300 Subject: [PATCH] Move HPU specific LoRA ops to vllm.hpu.ops module --- vllm/hpu/ops.py | 75 +++++++++++++++++++++++++ vllm/lora/layers.py | 89 +++--------------------------- vllm/worker/habana_model_runner.py | 22 ++++---- 3 files changed, 96 insertions(+), 90 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 2af5634a8d1a6..662c53486b4ca 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -191,3 +191,78 @@ def prompt_attention( valid_seq_lengths, 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights + + +def dispatch_bgmv_linear( + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indices: torch.LongTensor, + layer_idx: int, + scale: float, +): + """ + `wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices + stacked into single tensors, assuming same rank. HPU handles no-LoRA + requests using zero valued A and B tensors. These zero valued tensors are + appended at the end of `wa_t_all` and `wb_t_all` during initialization. For + custom BGMV, the corresponding `wa` and `wb` for each batch is created + based on the lora_index of each sample. + + For example: + `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, + hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles + no-LoRA case. The `wa` tensor for a batch of size batch_Size will have + a shape of (batch_size, num_layers, hidden_dim, lora_rank) + + This method avoids for-loop as well as graph breaks. + """ + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + max_loras = wa_t_all.size(0) + # Wrap-around for negative indices + indices = indices % max_loras + wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out @ wb + out = out.squeeze(1) + y += out * scale + + +def dispatch_bgmv_embedding( + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + indices: torch.LongTensor, + layer_idx: int, + scale: float, +): + """ + `wa_t_all` contains all LoRA A weight matrices stacked into a single tensor + assuming same rank. HPU handles no-LoRA requests using zero valued A + tensor. This zero valued tensor is appended at the end of `wa_t_all` during + initialization. For custom BGMV, the corresponding wa for each batch is + created based on the lora_index of the sample. + + For example: + `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, + hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles + no-LoRA case. The wa tensor for a batch of size batch_Size will have a + shape of (batch_size, num_layers, lora_rank, hidden_dim) + + + This method avoids for-loop as well as graph breaks. + """ + assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' + max_loras = wa_t_all.size(0) + # Wrap-around for negative indices + indices = indices % max_loras + wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) + + x = x.unsqueeze(1) + out = x @ wa + out = out.squeeze(1) + y += out * scale \ No newline at end of file diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 912cd0b47202a..4a45f3fda88f1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -29,6 +29,9 @@ VocabParallelEmbedding) from vllm.utils import is_hpu +if is_hpu(): + from vllm.hpu.ops import dispatch_bgmv_embedding, dispatch_bgmv_linear + if TYPE_CHECKING: pass @@ -64,81 +67,6 @@ def dec(*args, **kwargs): return dec -def custom_bgmv( - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - indices: torch.LongTensor, - layer_idx: int, - scale: float, -): - """ - `wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices - stacked into single tensors, assuming same rank. HPU handles no-LoRA - requests using zero valued A and B tensors. These zero valued tensors are - appended at the end of `wa_t_all` and `wb_t_all` during initialization. For - custom BGMV, the corresponding `wa` and `wb` for each batch is created - based on the lora_index of each sample. - - For example: - `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, - hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles - no-LoRA case. The `wa` tensor for a batch of size batch_Size will have - a shape of (batch_size, num_layers, hidden_dim, lora_rank) - - This method avoids for-loop as well as graph breaks. - """ - assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' - max_loras = wa_t_all.size(0) - # Wrap-around for negative indices - indices = indices % max_loras - wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) - wb = torch.index_select(wb_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) - - x = x.unsqueeze(1) - out = x @ wa - out = out @ wb - out = out.squeeze(1) - y += out * scale - - -def custom_bgmv_embed( - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - indices: torch.LongTensor, - layer_idx: int, - scale: float, -): - """ - `wa_t_all` contains all LoRA A weight matrices stacked into a single tensor - assuming same rank. HPU handles no-LoRA requests using zero valued A - tensor. This zero valued tensor is appended at the end of `wa_t_all` during - initialization. For custom BGMV, the corresponding wa for each batch is - created based on the lora_index of the sample. - - For example: - `wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank, - hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles - no-LoRA case. The wa tensor for a batch of size batch_Size will have a - shape of (batch_size, num_layers, lora_rank, hidden_dim) - - - This method avoids for-loop as well as graph breaks. - """ - assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' - max_loras = wa_t_all.size(0) - # Wrap-around for negative indices - indices = indices % max_loras - wa = torch.index_select(wa_t_all, 0, indices)[:, 0, :, :].transpose(-1, -2) - - x = x.unsqueeze(1) - out = x @ wa - out = out.squeeze(1) - y += out * scale - - def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -166,7 +94,8 @@ def _apply_lora( output = output.view(-1, output.shape[-1]) indices = indices.view(-1) if is_hpu(): - custom_bgmv(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + dispatch_bgmv_linear(output, x, lora_a_stacked, lora_b_stacked, + indices, 0, 1.0) else: add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) return output.view_as(org_output) @@ -207,7 +136,7 @@ def _apply_lora_packed_nslice( offset_left = 0 for slice_idx in range(len(output_slices)): if is_hpu(): - custom_bgmv( + dispatch_bgmv_linear( output[:, offset_left:offset_left + output_slices[slice_idx]], x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], indices, 0, 1.0) @@ -416,9 +345,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], -1) if is_hpu(): - custom_bgmv_embed(full_output, full_lora_a_embeddings, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + dispatch_bgmv_embedding(full_output, full_lora_a_embeddings, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) else: bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, self.indices[:self.indices_len[0]], 0, 1.0) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index ce7a0ad8dd1fc..d129bb5cbc0ca 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1556,16 +1556,18 @@ def execute_model( selected_token_indices=sampling_metadata.selected_token_indices ) - from vllm.lora.layers import VocabParallelEmbeddingWithLoRA - property = vars(self.model.model) - model = list(property['_modules'].values())[0] - property = vars(model) - modules = list(property['_modules'].values()) - for module in modules: - if isinstance(module, VocabParallelEmbeddingWithLoRA): - for i in range(0, 4): - module.indices_len[ - i] = sampling_metadata.selected_token_indices.numel() + if self.lora_config: + from vllm.lora.layers import VocabParallelEmbeddingWithLoRA + property = vars(self.model.model) + model = list(property['_modules'].values())[0] + property = vars(model) + modules = list(property['_modules'].values()) + for module in modules: + if isinstance(module, VocabParallelEmbeddingWithLoRA): + for i in range(0, 4): + module.indices_len[ + i] = sampling_metadata.selected_token_indices.numel( + ) # Compute the logits. with self.profiler.record_event(