Skip to content

Commit

Permalink
Lora Mask based on lora index (#348)
Browse files Browse the repository at this point in the history
Changes the filling of lora mask from lora_id to lora_index. This is
needed to ensure that the mask does not fail in case lora id is greater
than max_loras
  • Loading branch information
hlahkar authored Oct 3, 2024
1 parent 25f4ed9 commit da03d8b
Showing 1 changed file with 113 additions and 98 deletions.
211 changes: 113 additions & 98 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ class PreparePromptMetadata(NamedTuple):
lora_requests: Set[LoRARequest]
multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]]
slot_mapping: List[List[int]]
lora_mask: Optional[torch.Tensor]
lora_logits_mask: Optional[torch.Tensor]
lora_ids: List[int]

@classmethod
def empty(cls):
Expand All @@ -365,8 +364,7 @@ def empty(cls):
lora_requests=set(),
multi_modal_kwargs=None,
slot_mapping=[],
lora_mask=None,
lora_logits_mask=None)
lora_ids=[])


class PrepareDecodeMetadata(NamedTuple):
Expand All @@ -377,8 +375,7 @@ class PrepareDecodeMetadata(NamedTuple):
lora_prompt_mapping: List[List[int]]
lora_requests: Set[LoRARequest]
slot_mapping: List[List[int]]
lora_mask: Optional[torch.Tensor]
lora_logits_mask: Optional[torch.Tensor]
lora_ids: List[int]

@classmethod
def empty(cls):
Expand All @@ -389,8 +386,7 @@ def empty(cls):
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
lora_mask=None,
lora_logits_mask=None)
lora_ids=[])


# How batches are constructed.
Expand Down Expand Up @@ -425,8 +421,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
real_batch_size: Optional[int] = None
batch_size_padded: Optional[int] = None
virtual_engine: int = 0
lora_mask: Optional[torch.Tensor] = None
lora_logits_mask: Optional[torch.Tensor] = None
lora_ids: Optional[List[int]] = None
async_callback: Optional[Callable] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
Expand All @@ -439,8 +434,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"real_batch_size": self.real_batch_size,
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"lora_mask": self.lora_mask,
"lora_logits_mask": self.lora_logits_mask,
"lora_ids": self.lora_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
Expand Down Expand Up @@ -474,8 +468,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"lora_mask": self.lora_mask,
"lora_logits_mask": self.lora_logits_mask,
"lora_ids": self.lora_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
Expand Down Expand Up @@ -836,49 +829,21 @@ def _prepare_prompt(
find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg),
self.block_size)

lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
counter = 0
if self.lora_config:
lora_mask = torch.zeros(
len(seq_group_metadata_list) * max_prompt_len,
(self.lora_config.max_loras) * self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_logits_mask = torch.zeros(len(seq_group_metadata_list),
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

ones = torch.ones(max_prompt_len,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
logit_ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_ids: List[int] = []
for seq_group_metadata, context_len in zip(seq_group_metadata_list,
context_lens):
lora_id = seq_group_metadata.lora_int_id
lora_ids.append(lora_id)

if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
start_row = counter * max_prompt_len
end_row = start_row + max_prompt_len
start_col = (lora_id - 1) * self.lora_config.max_lora_rank
end_col = start_col + self.lora_config.max_lora_rank
lora_mask[start_row:end_row, start_col:end_col] = ones
lora_logits_mask[counter, start_col:end_col] = logit_ones
counter = counter + 1

lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))

if lora_mask is not None:
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_logits_mask.to('hpu')

input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
pad=0,
Expand Down Expand Up @@ -919,20 +884,17 @@ def _prepare_prompt(
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

return PreparePromptMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
slot_mapping=slot_mapping,
lora_mask=lora_mask,
lora_logits_mask=lora_logits_mask,
)
return PreparePromptMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
slot_mapping=slot_mapping,
lora_ids=lora_ids)

def _prepare_decode(
self,
Expand All @@ -949,18 +911,7 @@ def _prepare_decode(

if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
counter = 0

if self.lora_config:
lora_mask = torch.zeros(len(seq_group_metadata_list),
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_ids: List[int] = []

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))
Expand All @@ -971,13 +922,10 @@ def _prepare_decode(

seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
lora_ids.append(lora_id)

if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
start_pos = (lora_id - 1) * self.lora_config.max_lora_rank
end_pos = start_pos + self.lora_config.max_lora_rank
lora_mask[counter, start_pos:end_pos] = ones
counter = counter + 1

for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
Expand Down Expand Up @@ -1012,9 +960,6 @@ def _prepare_decode(
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)

if lora_mask is not None:
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_mask
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
Expand Down Expand Up @@ -1075,17 +1020,14 @@ def _prepare_decode(
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
)
return PrepareDecodeMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
lora_mask=lora_mask,
lora_logits_mask=lora_logits_mask,
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
lora_ids=lora_ids)

def prepare_input_tensors(
self,
Expand Down Expand Up @@ -1142,8 +1084,7 @@ def prepare_input_tensors(
lora_requests,
multi_modal_kwargs,
slot_mapping,
lora_mask,
lora_logits_mask,
lora_ids,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
Expand All @@ -1153,8 +1094,7 @@ def prepare_input_tensors(
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
decode_lora_mask,
decode_lora_logits_mask,
decode_lora_ids,
) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
seq_lens, query_lens,
Expand All @@ -1181,8 +1121,7 @@ def prepare_input_tensors(
lora_index_mapping = decode_lora_index_mapping
lora_prompt_mapping = decode_lora_prompt_mapping
lora_requests = decode_lora_requests
lora_mask = decode_lora_mask
lora_logits_mask = decode_lora_logits_mask
lora_ids = decode_lora_ids

# FIXME: We need to adjust selected_token_indices to accommodate
# for padding
Expand Down Expand Up @@ -1252,8 +1191,7 @@ def prepare_input_tensors(
multi_modal_kwargs=multi_modal_kwargs,
real_batch_size=real_batch_size,
batch_size_padded=batch_size_padded,
lora_mask=lora_mask,
lora_logits_mask=lora_logits_mask), \
lora_ids=lora_ids), \
sampling_metadata

def _seq_len(self, attn_metadata):
Expand Down Expand Up @@ -1853,6 +1791,76 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
phase, batch_size, seq_len)

def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
is_prompt: bool):
'''
This is a helper function to create the mask for lora computations.
Lora Mask is needed to ensure we match the correct lora weights for the
for the request.
For Prompt phase we have
lora_mask with shape (batch_size * seq_len, max_loras * max_rank)
lora_logits_mask with shape (batch_size, max_loras * max_rank)
For Decode phase we have both
lora_mask and lora_logits_mask with shape
(batch_size, max_loras * max_rank)
'''
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
lora_index = 0

if self.lora_config:
if is_prompt:
lora_mask = torch.zeros(
input_tokens.shape[0] * input_tokens.shape[1],
(self.lora_config.max_loras) *\
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_logits_mask = torch.zeros(
input_tokens.shape[0], (self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

ones = torch.ones(input_tokens.shape[1],
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
logit_ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_row = i * input_tokens.shape[1]
end_row = start_row + input_tokens.shape[1]
start_col = lora_index * self.lora_config.max_lora_rank
end_col = start_col + self.lora_config.max_lora_rank
lora_mask[start_row:end_row, start_col:end_col] = ones
lora_logits_mask[i, start_col:end_col] = logit_ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_logits_mask.to('hpu')
else:
lora_mask = torch.zeros(input_tokens.shape[0],
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_pos = lora_index * self.lora_config.max_lora_rank
end_pos = start_pos + self.lora_config.max_lora_rank
lora_mask[i, start_pos:end_pos] = ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_mask

return lora_mask, lora_logits_mask

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -1887,13 +1895,21 @@ def execute_model(
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)

lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids, attn_metadata.is_prompt)

execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": model_input.lora_mask,
"lora_mask": lora_mask,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
Expand All @@ -1915,7 +1931,6 @@ def execute_model(
)

if self.lora_config:
lora_logits_mask: torch.Tensor = model_input.lora_logits_mask
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
Expand Down

0 comments on commit da03d8b

Please sign in to comment.