Skip to content

Commit

Permalink
Change mask to lora_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
hlahkar committed Sep 4, 2024
1 parent ab369e3 commit 49ffde6
Showing 1 changed file with 32 additions and 31 deletions.
63 changes: 32 additions & 31 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def forward(self, *args, **kwargs):
input_ids.size(1),
input_ids.device,
torch.bfloat16)
LoraMask.setLoraMask(kwargs.pop('mask'))
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0, selected_token_indices)
Expand All @@ -254,7 +254,7 @@ class PreparePromptMetadata(NamedTuple):
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[List[int]]
mask: Optional[torch.Tensor]
lora_mask: Optional[torch.Tensor]

@classmethod
def empty(cls):
Expand All @@ -268,7 +268,7 @@ def empty(cls):
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
mask=None)
lora_mask=None)


class PrepareDecodeMetadata(NamedTuple):
Expand All @@ -279,7 +279,7 @@ class PrepareDecodeMetadata(NamedTuple):
lora_prompt_mapping: List[List[int]]
lora_requests: Set[LoRARequest]
slot_mapping: List[List[int]]
mask: Optional[torch.Tensor]
lora_mask: Optional[torch.Tensor]

@classmethod
def empty(cls):
Expand All @@ -291,7 +291,7 @@ def empty(cls):
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
mask=None,
lora_mask=None,
)


Expand Down Expand Up @@ -327,7 +327,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
real_batch_size: Optional[int] = None
batch_size_padded: Optional[int] = None
virtual_engine: int = 0
mask: Optional[torch.Tensor] = None
lora_mask: Optional[torch.Tensor] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand All @@ -339,7 +339,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"real_batch_size": self.real_batch_size,
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"mask": self.mask
"lora_mask": self.lora_mask,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
Expand Down Expand Up @@ -373,7 +373,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"mask": self.mask
"lora_mask": self.lora_mask,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
Expand Down Expand Up @@ -746,13 +746,14 @@ def _prepare_prompt(
find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg),
self.block_size)

mask: torch.Tensor = None
lora_mask: torch.Tensor = None
counter = 0
if self.lora_config:
mask = torch.zeros(len(seq_group_metadata_list) * max_prompt_len,
(self.lora_config.max_loras + 1) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_mask = torch.zeros(len(seq_group_metadata_list) *
max_prompt_len,
(self.lora_config.max_loras + 1) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(max_prompt_len,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
Expand All @@ -766,7 +767,7 @@ def _prepare_prompt(
end_row = start_row + max_prompt_len
start_col = (lora_id - 1) * self.lora_config.max_lora_rank
end_col = start_col + self.lora_config.max_lora_rank
mask[start_row:end_row, start_col:end_col] = ones
lora_mask[start_row:end_row, start_col:end_col] = ones
counter = counter + 1

lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
Expand All @@ -775,8 +776,8 @@ def _prepare_prompt(
(max_prompt_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))

if mask is not None:
mask = mask.to('hpu')
if lora_mask is not None:
lora_mask = lora_mask.to('hpu')

input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
Expand Down Expand Up @@ -843,7 +844,7 @@ def _prepare_prompt(
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
mask=mask,
lora_mask=lora_mask,
)

def _prepare_decode(
Expand All @@ -861,14 +862,14 @@ def _prepare_decode(

if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
mask: torch.Tensor = None
lora_mask: torch.Tensor = None
counter = 0

if self.lora_config:
mask = torch.zeros(len(seq_group_metadata_list),
(self.lora_config.max_loras + 1) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_mask = torch.zeros(len(seq_group_metadata_list),
(self.lora_config.max_loras + 1) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
Expand All @@ -884,7 +885,7 @@ def _prepare_decode(
lora_requests.add(seq_group_metadata.lora_request)
start_pos = (lora_id - 1) * self.lora_config.max_lora_rank
end_pos = start_pos + self.lora_config.max_lora_rank
mask[counter, start_pos:end_pos] = ones
lora_mask[counter, start_pos:end_pos] = ones
counter = counter + 1

for seq_id in seq_ids:
Expand Down Expand Up @@ -914,8 +915,8 @@ def _prepare_decode(
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)

if mask is not None:
mask = mask.to('hpu')
if lora_mask is not None:
lora_mask = lora_mask.to('hpu')
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
Expand Down Expand Up @@ -961,7 +962,7 @@ def _prepare_decode(
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
mask=mask,
lora_mask=lora_mask,
)

def prepare_input_tensors(
Expand Down Expand Up @@ -1016,7 +1017,7 @@ def prepare_input_tensors(
lora_requests,
multi_modal_input,
slot_mapping,
mask,
lora_mask,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
Expand All @@ -1026,7 +1027,7 @@ def prepare_input_tensors(
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
decode_mask,
decode_lora_mask,
) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
seq_lens, query_lens,
Expand All @@ -1053,7 +1054,7 @@ def prepare_input_tensors(
lora_index_mapping = decode_lora_index_mapping
lora_prompt_mapping = decode_lora_prompt_mapping
lora_requests = decode_lora_requests
mask = decode_mask
lora_mask = decode_lora_mask

# FIXME: We need to adjust selected_token_indices to accommodate
# for padding
Expand Down Expand Up @@ -1123,7 +1124,7 @@ def prepare_input_tensors(
multi_modal_kwargs=multi_modal_input,
real_batch_size=real_batch_size,
batch_size_padded=batch_size_padded,
mask=mask), sampling_metadata
lora_mask=lora_mask), sampling_metadata

def _seq_len(self, attn_metadata):
if attn_metadata.num_prefills != 0:
Expand Down Expand Up @@ -1659,7 +1660,7 @@ def execute_model(
"kv_caches": kv_caches,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"mask": model_input.mask
"lora_mask": model_input.lora_mask
}
if multi_modal_input is not None:
execute_model_kwargs.update(multi_modal_input)
Expand Down

0 comments on commit 49ffde6

Please sign in to comment.