Skip to content

Commit

Permalink
Merge branch 'refs/heads/habana_main' into dev/dlester/mixtral_main_1
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiefen-boop committed Aug 29, 2024
2 parents be7f696 + 17cd625 commit f2710c9
Showing 1 changed file with 79 additions and 14 deletions.
93 changes: 79 additions & 14 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,46 @@ def warmup_range(config: Tuple[int, int, int]):
return list(ramp_up_tw) + list(stable)


def warmup_buckets(bs_bucket_config, seq_bucket_config):
buckets = itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config))
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
def warmup_buckets(bs_bucket_config, seq_bucket_config,
max_num_batched_tokens):
buckets = list(
itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config)))
if len(buckets) == 0:
msg = ("No buckets could be captured with following config "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config}")
raise ValueError(msg)

# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
buckets))

if len(filtered_buckets) == 0:
# legacy case - we can handle this if we ignore max_num_batched_tokens
min_bucket_bs, min_bucket_seq = min(buckets,
key=lambda b: (b[0] * b[1]))
min_reqd_budget = min_bucket_bs * min_bucket_seq
msg = (
"The current bucketing configuration "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config} cannot be used with specified "
f"max_num_batched_tokens ({max_num_batched_tokens}), as the "
f"smallest bucket ({min_reqd_budget}) would exceed token budget. "
"Please increase max_num_batched_tokens or decrease bucket minimum "
"Ignoring max_num_batched_tokens at risk of out-of-memory errors.")
logger.error(msg)
return list(sorted(buckets, key=lambda b:
(b[0] * b[1], b[1], b[0]))), []

captured_buckets = list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
omitted_buckets = list(
sorted([x for x in buckets if x not in filtered_buckets]))
return captured_buckets, omitted_buckets


def next_pow2(value: int):
Expand Down Expand Up @@ -525,8 +561,9 @@ def _setup_buckets(self) -> None:
f"bs:{self.prompt_bs_bucket_cfg}, "
f"seq:{self.prompt_seq_bucket_cfg}")
logger.info(msg)
self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg,
self.prompt_seq_bucket_cfg)
self.prompt_buckets, prompt_omitted_buckets = warmup_buckets(
self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)

if self.lora_config:
self.prompt_buckets[:] = [
Expand All @@ -538,12 +575,21 @@ def _setup_buckets(self) -> None:
f"prompt buckets: {list(sorted(self.prompt_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(prompt_omitted_buckets)} "
"prompt buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
logger.debug(msg)

msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{self.decode_bs_bucket_cfg}, "
f"seq:{self.decode_seq_bucket_cfg}")
logger.info(msg)
self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg,
self.decode_seq_bucket_cfg)
self.decode_buckets, decode_omitted_buckets = warmup_buckets(
self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg,
self.max_num_batched_tokens)
if self.lora_config:
self.decode_buckets[:] = [
bucket for bucket in self.decode_buckets
Expand All @@ -553,6 +599,14 @@ def _setup_buckets(self) -> None:
f"{list(sorted(self.decode_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(decode_omitted_buckets)} "
"decode buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted decode buckets: {list(sorted(decode_omitted_buckets))}"
logger.debug(msg)

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down Expand Up @@ -958,8 +1012,13 @@ def prepare_input_tensors(
paddings = [max_len - s for s in seq_lens]
paddings = [0] + paddings[:-1]
paddings = list(itertools.accumulate(paddings))
paddings_prompt_logprobs = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.sampling_params.prompt_logprobs is not None \
and seq_group_metadata.is_prompt:
paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i])
paddings = torch.tensor(
paddings,
paddings_prompt_logprobs if paddings_prompt_logprobs else paddings,
dtype=sampling_metadata.selected_token_indices.dtype,
device=sampling_metadata.selected_token_indices.device)
sampling_metadata.selected_token_indices.add_(paddings)
Expand Down Expand Up @@ -1441,6 +1500,15 @@ def get_counter_dict(self, cache_config, duration, seq_len,
return counters


def unwrap_model(model):
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
return unwrap_model(model._orig_mod)
else:
model = list(vars(model)['_modules'].values())[0]
modules = list(vars(model)['_modules'].values())
return modules


class HabanaModelRunner(
HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
"""
Expand Down Expand Up @@ -1558,13 +1626,10 @@ def execute_model(

if self.lora_config:
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
property = vars(self.model.model)
model = list(property['_modules'].values())[0]
property = vars(model)
modules = list(property['_modules'].values())
modules = unwrap_model(self.model.model)
for module in modules:
if isinstance(module, VocabParallelEmbeddingWithLoRA):
for i in range(0, 4):
for i in range(0, len(module.indices_len)):
module.indices_len[
i] = sampling_metadata.selected_token_indices.numel(
)
Expand Down

0 comments on commit f2710c9

Please sign in to comment.