Skip to content

Commit

Permalink
Remove Global variable
Browse files Browse the repository at this point in the history
  • Loading branch information
hlahkar committed Sep 2, 2024
1 parent bdca75e commit a8f1d7d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
3 changes: 0 additions & 3 deletions vllm/decode.py

This file was deleted.

53 changes: 36 additions & 17 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch.nn.functional as F

from vllm.logger import init_logger
import vllm.decode as decode

logger = init_logger(__name__)
HPUFusedRMSNorm = None
Expand Down Expand Up @@ -194,6 +193,18 @@ def prompt_attention(
return attn_weights


class LoraMask:
lora_mask = None

@staticmethod
def setLoraMask(mask):
LoraMask.lora_mask = mask

@staticmethod
def getLoraMask():
return LoraMask.lora_mask


def dispatch_bgmv_linear(
y: torch.Tensor,
x: torch.Tensor,
Expand All @@ -207,33 +218,41 @@ def dispatch_bgmv_linear(
`wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices
stacked into single tensors, assuming same rank. HPU handles no-LoRA
requests using zero valued A and B tensors. These zero valued tensors are
appended at the end of `wa_t_all` and `wb_t_all` during initialization. For
custom BGMV, the corresponding `wa` and `wb` for each batch is created
based on the lora_index of each sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The `wa` tensor for a batch of size batch_Size will have
a shape of (batch_size, num_layers, hidden_dim, lora_rank)
This method avoids for-loop as well as graph breaks.
appended at the end of `wa_t_all` and `wb_t_all` during initialization.
"""

assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
if decode.mask is not None:
mask = LoraMask.getLoraMask()
if mask is not None:
"""
We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank]
and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also
have a loraMask of shape [batch_size, num_layers * lora_rank]
"""
wa = wa_t_all[:, 0, :, :]
wb = wb_t_all[:, 0, :, :].transpose(0, 1)
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wa_shape = wa.shape
wb_shape = wb.shape
wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1)
wb = wb.reshape(wb_shape[0], wb_shape[1] * wb_shape[2]).transpose(0, 1)
wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2])
out = x @ wa
assert (out.shape == decode.mask.shape)
out = out * decode.mask
assert (out.shape == mask.shape)
out = out * mask
out = out @ wb
else:
"""For custom BGMV, the corresponding `wa` and `wb` for each batch is
created based on the lora_index of each sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The `wa` tensor for a batch of size batch_Size will have
a shape of (batch_size, num_layers, hidden_dim, lora_rank)
This method avoids for-loop as well as graph breaks.
"""
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0,
indices)[:, 0, :, :].transpose(-1, -2)
Expand Down
20 changes: 10 additions & 10 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.distributed.parallel_state import get_world_group
from vllm.hpu.ops import LoraMask as LoraMask
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
Expand All @@ -33,7 +34,6 @@
SequenceGroupMetadata)
from vllm.utils import (HabanaMemoryProfiler, format_bytes,
is_pin_memory_available, make_tensor_with_pad)
import vllm.decode as decode
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -230,11 +230,11 @@ def forward(self, *args, **kwargs):
input_ids.size(1),
input_ids.device,
torch.bfloat16)
decode.mask = kwargs.pop('mask')
LoraMask.setLoraMask(kwargs.pop('mask'))
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0, selected_token_indices)
return hidden_states, decode.mask
return hidden_states

def compute_logits(self, *args, **kwargs):
return self.model.compute_logits(*args, **kwargs)
Expand Down Expand Up @@ -339,7 +339,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"real_batch_size": self.real_batch_size,
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"mask": mask
"mask": self.mask
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
Expand Down Expand Up @@ -633,8 +633,6 @@ def _prepare_prompt(

if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
mask = None
counter = 0

for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
Expand Down Expand Up @@ -747,6 +745,8 @@ def _prepare_prompt(
find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg),
self.block_size)

mask: torch.Tensor = None
counter = 0
if self.lora_config:
mask = torch.zeros(len(seq_group_metadata_list) * max_prompt_len,
(self.lora_config.max_loras + 1) *
Expand Down Expand Up @@ -860,7 +860,7 @@ def _prepare_decode(

if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
mask = None
mask: torch.Tensor = None
counter = 0

if self.lora_config:
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def warmup_scenario(self,
if dummy_lora_requests_per_seq else None)
for i in range(batch_size)
]
#torch.hpu.synchronize()
torch.hpu.synchronize()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
Expand Down Expand Up @@ -1668,7 +1668,7 @@ def execute_model(
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
hidden_states, _ = self.model.forward(
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices
)
Expand All @@ -1682,7 +1682,7 @@ def execute_model(
module.indices_len[
i] = sampling_metadata.selected_token_indices.numel(
)
decode.mask = None
LoraMask.setLoraMask(None)

# Compute the logits.
with self.profiler.record_event(
Expand Down

0 comments on commit a8f1d7d

Please sign in to comment.