Skip to content

Commit

Permalink
Move Compute Logits to Mask Based Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hlahkar committed Sep 4, 2024
1 parent 49ffde6 commit 881ef3f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 43 deletions.
53 changes: 13 additions & 40 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,50 +219,23 @@ def dispatch_bgmv_linear(
stacked into single tensors, assuming same rank. HPU handles no-LoRA
requests using zero valued A and B tensors. These zero valued tensors are
appended at the end of `wa_t_all` and `wb_t_all` during initialization.
We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank]
and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also
have a loraMask of shape [batch_size, num_layers * lora_rank]
"""

assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wa_t_all.size(0)
# Wrap-around for negative indices
mask = LoraMask.getLoraMask()
if mask is not None:
"""
We reshape w_a_t_all to [hidden_dim, num_layers * lora_rank]
and w_b_t_all to [num_layers * lora_rank, hidden_dim]. We also
have a loraMask of shape [batch_size, num_layers * lora_rank]
"""
wa = wa_t_all[:, 0, :, :]
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wa_shape = wa.shape
wb_shape = wb.shape
wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1)
wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2])
out = x @ wa
assert (out.shape == mask.shape)
out = out * mask
out = out @ wb
else:
"""For custom BGMV, the corresponding `wa` and `wb` for each batch is
created based on the lora_index of each sample.
For example:
`wa_t_all` is tensor of shape (num_loras, num_layers, lora_rank,
hidden_dim), where `wa_t_all[-1]` is zero valued tensor which handles
no-LoRA case. The `wa` tensor for a batch of size batch_Size will have
a shape of (batch_size, num_layers, hidden_dim, lora_rank)
This method avoids for-loop as well as graph breaks.
"""
indices = indices % max_loras
wa = torch.index_select(wa_t_all, 0,
indices)[:, 0, :, :].transpose(-1, -2)
wb = torch.index_select(wb_t_all, 0,
indices)[:, 0, :, :].transpose(-1, -2)

x = x.unsqueeze(1)
out = x @ wa
out = out @ wb
out = out.squeeze(1)
wa = wa_t_all[:, 0, :, :]
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wa_shape = wa.shape
wb_shape = wb.shape
wa = wa.reshape(wa_shape[0] * wa_shape[1], wa_shape[2]).transpose(0, 1)
wb = wb.reshape(wb_shape[0] * wb_shape[1], wb_shape[2])
out = x @ wa
assert (out.shape == mask.shape)
out = out * mask
out = out @ wb
y += out * scale


Expand Down
36 changes: 33 additions & 3 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class PreparePromptMetadata(NamedTuple):
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[List[int]]
lora_mask: Optional[torch.Tensor]
lora_logits_mask: Optional[torch.Tensor]

@classmethod
def empty(cls):
Expand All @@ -268,7 +269,8 @@ def empty(cls):
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
lora_mask=None)
lora_mask=None,
lora_logits_mask=None)


class PrepareDecodeMetadata(NamedTuple):
Expand All @@ -280,6 +282,7 @@ class PrepareDecodeMetadata(NamedTuple):
lora_requests: Set[LoRARequest]
slot_mapping: List[List[int]]
lora_mask: Optional[torch.Tensor]
lora_logits_mask: Optional[torch.Tensor]

@classmethod
def empty(cls):
Expand All @@ -292,6 +295,7 @@ def empty(cls):
lora_requests=set(),
slot_mapping=[],
lora_mask=None,
lora_logits_mask=None
)


Expand Down Expand Up @@ -328,6 +332,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
batch_size_padded: Optional[int] = None
virtual_engine: int = 0
lora_mask: Optional[torch.Tensor] = None
lora_logits_mask: Optional[torch.Tensor] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand All @@ -340,6 +345,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"lora_mask": self.lora_mask,
"lora_logits_mask": self.lora_logits_mask,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
Expand Down Expand Up @@ -374,6 +380,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"lora_mask": self.lora_mask,
"lora_logits_mask": self.lora_logits_mask,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
Expand Down Expand Up @@ -747,16 +754,25 @@ def _prepare_prompt(
self.block_size)

lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
counter = 0
if self.lora_config:
lora_mask = torch.zeros(len(seq_group_metadata_list) *
max_prompt_len,
(self.lora_config.max_loras + 1) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_logits_mask = torch.zeros(len(seq_group_metadata_list),
(self.lora_config.max_loras + 1) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

ones = torch.ones(max_prompt_len,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
logit_ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
for seq_group_metadata, context_len in zip(seq_group_metadata_list,
context_lens):
lora_id = seq_group_metadata.lora_int_id
Expand All @@ -768,6 +784,7 @@ def _prepare_prompt(
start_col = (lora_id - 1) * self.lora_config.max_lora_rank
end_col = start_col + self.lora_config.max_lora_rank
lora_mask[start_row:end_row, start_col:end_col] = ones
lora_logits_mask[counter, start_col:end_col] = logit_ones
counter = counter + 1

lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
Expand All @@ -778,6 +795,7 @@ def _prepare_prompt(

if lora_mask is not None:
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_logits_mask.to('hpu')

input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
Expand Down Expand Up @@ -845,6 +863,7 @@ def _prepare_prompt(
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
lora_mask=lora_mask,
lora_logits_mask=lora_logits_mask,
)

def _prepare_decode(
Expand All @@ -863,6 +882,7 @@ def _prepare_decode(
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
counter = 0

if self.lora_config:
Expand Down Expand Up @@ -917,6 +937,7 @@ def _prepare_decode(

if lora_mask is not None:
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_mask
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
Expand Down Expand Up @@ -963,6 +984,7 @@ def _prepare_decode(
lora_requests=lora_requests,
slot_mapping=slot_mapping,
lora_mask=lora_mask,
lora_logits_mask=lora_logits_mask,
)

def prepare_input_tensors(
Expand Down Expand Up @@ -1018,6 +1040,7 @@ def prepare_input_tensors(
multi_modal_input,
slot_mapping,
lora_mask,
lora_logits_mask,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
Expand All @@ -1028,6 +1051,7 @@ def prepare_input_tensors(
decode_lora_requests,
decode_slot_mapping,
decode_lora_mask,
decode_lora_logits_mask,
) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
seq_lens, query_lens,
Expand Down Expand Up @@ -1055,6 +1079,7 @@ def prepare_input_tensors(
lora_prompt_mapping = decode_lora_prompt_mapping
lora_requests = decode_lora_requests
lora_mask = decode_lora_mask
lora_logits_mask = decode_lora_logits_mask

# FIXME: We need to adjust selected_token_indices to accommodate
# for padding
Expand Down Expand Up @@ -1124,7 +1149,9 @@ def prepare_input_tensors(
multi_modal_kwargs=multi_modal_input,
real_batch_size=real_batch_size,
batch_size_padded=batch_size_padded,
lora_mask=lora_mask), sampling_metadata
lora_mask=lora_mask,
lora_logits_mask=lora_logits_mask), \
sampling_metadata

def _seq_len(self, attn_metadata):
if attn_metadata.num_prefills != 0:
Expand Down Expand Up @@ -1198,6 +1225,7 @@ def profile_run(self) -> None:
True,
kv_caches,
is_profile_run=True)
return

def warmup_scenario(self,
batch_size,
Expand Down Expand Up @@ -1694,7 +1722,9 @@ def execute_model(
module.indices_len[
i] = sampling_metadata.selected_token_indices.numel(
)
LoraMask.setLoraMask(None)
lora_logits_mask: torch.Tensor = model_input.lora_logits_mask
LoraMask.setLoraMask(lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))

# Compute the logits.
with self.profiler.record_event(
Expand Down

0 comments on commit 881ef3f

Please sign in to comment.