Skip to content

Commit

Permalink
Warmup for multi-step scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
tzielinski-habana committed Nov 14, 2024
1 parent eca9a83 commit 64c7139
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,27 @@ def warmup_scenario(self,
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
self.execute_model(inputs, kv_caches, warmup_mode=True)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
is_last_step=False)
self.execute_model(inputs,
kv_caches,
warmup_mode=True,
num_steps=2,
seqs=seqs)
inputs = dataclasses.replace(inputs,
is_first_multi_step=False,
is_last_step=True)
self.execute_model(inputs,
kv_caches,
warmup_mode=True,
num_steps=2,
seqs=seqs)
torch.hpu.synchronize()
if profiler:
profiler.step()
Expand Down Expand Up @@ -2019,6 +2039,7 @@ def execute_model(
num_steps: int = 1,
warmup_mode=False,
previous_hidden_states: Optional[torch.Tensor] = None,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
Expand Down Expand Up @@ -2164,9 +2185,16 @@ def try_revert_dummy_output_tokens():
htorch.core.mark_step()
if i < num_steps - 1:
if i == 0:
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
seq_group_metadata_list = ctx.seq_group_metadata_list
if model_input.async_callback is not None:
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
seq_group_metadata_list = \
ctx.seq_group_metadata_list
elif seqs is not None:
seq_group_metadata_list = seqs
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
# Cache the original output token ids
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
Expand Down

0 comments on commit 64c7139

Please sign in to comment.