Skip to content

Commit

Permalink
Add padding-aware scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Oct 15, 2024
1 parent 9777c9f commit 38b044b
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 54 deletions.
18 changes: 17 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,9 @@ class SchedulerConfig:
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
max_num_prefill_seqs: Maximum number of prefill sequences to be
processed in a single iteration. Used only with padding-aware
scheduling.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
Expand All @@ -951,11 +954,14 @@ class SchedulerConfig:
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
policy: The scheduling policy to use. "fcfs" (default) or "priority".
use_padding_aware_scheduling: If True, scheduler will consider padded
tokens in prefill.
"""

def __init__(self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_num_prefill_seqs: Optional[int],
max_model_len: int,
use_v2_block_manager: bool = True,
num_lookahead_slots: int = 0,
Expand All @@ -967,7 +973,8 @@ def __init__(self,
num_scheduler_steps: int = 1,
multi_step_stream_outputs: bool = False,
send_delta_data: bool = False,
policy: str = "fcfs") -> None:
policy: str = "fcfs",
use_padding_aware_scheduling=False) -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
if num_scheduler_steps > 1:
Expand Down Expand Up @@ -1006,6 +1013,7 @@ def __init__(self,
self.max_num_batched_tokens)

self.max_num_seqs = max_num_seqs
self.max_num_prefill_seqs = max_num_prefill_seqs
self.max_model_len = max_model_len
self.use_v2_block_manager = use_v2_block_manager
self.num_lookahead_slots = num_lookahead_slots
Expand All @@ -1017,6 +1025,7 @@ def __init__(self,
self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data
self.policy = policy
self.use_padding_aware_scheduling = use_padding_aware_scheduling
self._verify_args()

def _verify_args(self) -> None:
Expand Down Expand Up @@ -1047,6 +1056,13 @@ def _verify_args(self) -> None:
"num_scheduler_steps "
f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.")
if self.max_num_prefill_seqs is not None \
and not self.use_padding_aware_scheduling:
raise ValueError("max_num_prefill_seqs can be only "
"used with padding-aware-scheduling. ")
if self.use_padding_aware_scheduling and self.chunked_prefill_enabled:
raise ValueError("Padding-aware scheduling currently "
"does not work with chunked prefill ")

@property
def is_multi_step(self) -> bool:
Expand Down
146 changes: 140 additions & 6 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
Expand Down Expand Up @@ -101,6 +102,118 @@ def num_curr_seqs(self):
return self._num_curr_seqs


@dataclass
class PaddingAwareSchedulingBudget(SchedulingBudget):
max_num_prefill_seqs: Optional[int] = None
_prefill_request_ids_max_seq_lens: Dict[str,
int] = field(default_factory=dict)
_max_seq_len: int = 0
_num_curr_prefill_seqs: int = 0

def _debug_validator(self):
assert self._num_curr_prefill_seqs == len(
self._prefill_request_ids_max_seq_lens)
if self._num_curr_prefill_seqs != 0:
assert self._max_seq_len == max(
self._prefill_request_ids_max_seq_lens.values())

def _generic_padding_fn(self, batch_size, max_seq_len) -> int:
return batch_size * max_seq_len

def _hpu_padding_fn(self, batch_size, max_seq_len):
from vllm.worker.hpu_model_runner import (find_bucket,
HPUBucketingGlobalState)
padded_bs = batch_size
padded_seq = max_seq_len

hpu_bucketing_global_state = HPUBucketingGlobalState()

bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg
if bs_cfg is not None:
padded_bs = find_bucket(batch_size, bs_cfg)
else:
logger.warning(
"prompt_bs_bucket_cfg was not set! Using unpadded batch size.")
seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg
if seq_cfg is not None:
padded_seq = find_bucket(max_seq_len, seq_cfg)
else:
logger.warning("prompt_seq_bucket_cfg was not set! "
"Using unpadded sequence length.")
return padded_bs * padded_seq

def _padding_fn_selector(self):
if current_platform.is_hpu():
return self._hpu_padding_fn
return self._generic_padding_fn

def _maybe_update_max_seq_len(self,
new_seq_max_seq_len: Optional[int] = None):
if new_seq_max_seq_len is not None \
and new_seq_max_seq_len > self._max_seq_len:
self._max_seq_len = new_seq_max_seq_len
return
self._max_seq_len = max(
self._prefill_request_ids_max_seq_lens.values())

def add_prefill_seqs(self, req_id, num_curr_prefill_seqs, max_seq_len):
self._prefill_request_ids_max_seq_lens[req_id] = max_seq_len
self._num_curr_prefill_seqs += num_curr_prefill_seqs
self._maybe_update_max_seq_len(max_seq_len)
self._debug_validator()

def subtract_prefill_seqs(self, req_id, num_curr_prefill_seqs):
if req_id in self._prefill_request_ids_max_seq_lens:
popped_seq_len = self._prefill_request_ids_max_seq_lens.pop(req_id)
self._num_curr_prefill_seqs -= num_curr_prefill_seqs
if popped_seq_len == self._max_seq_len:
self._maybe_update_max_seq_len()
self._debug_validator()

def can_schedule(self,
*args,
num_new_tokens: int,
num_new_seqs: int,
is_prefill: bool = False,
max_seq_len: int = 0):
can_parent_schedule = super().can_schedule(
*args, num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs)
if not can_parent_schedule or not is_prefill:
return can_parent_schedule
new_batch_size = self._num_curr_prefill_seqs + num_new_seqs
new_max_seq_len = max(max(self._max_seq_len, max_seq_len), 1)
padding_fn = self._padding_fn_selector()
num_new_padded_tokens = padding_fn(new_batch_size, new_max_seq_len)
self._debug_validator()
result = self.num_batched_tokens + num_new_padded_tokens \
<= self.token_budget
curr_padded_tokens = padding_fn(self._num_curr_prefill_seqs,
self._max_seq_len)
stats = f"curr_batch_size: {self._num_curr_prefill_seqs}, curr_max_seq_len: {self._max_seq_len}, curr_num_padded_tokens: {curr_padded_tokens} | new_batch_size: {new_batch_size}, new_max_seq_len: {new_max_seq_len}, new_num_padded_tokens: {num_new_padded_tokens}" # noqa: E501
if not result:
msg = f"[PaddingAwareScheduler DEBUG] CANNOT schedule prefill sequence. Reason: Exceeded token budget ({self.token_budget}) after padding. | {stats}" # noqa: E501
logger.info(msg)
if self.max_num_prefill_seqs is not None and result:
result = self._num_curr_prefill_seqs + num_new_seqs \
<= self.max_num_prefill_seqs
if not result:
msg = f"[PaddingAwareScheduler DEBUG] CANNOT schedule prefill sequence. Reason: Exceeded max_num_prefill_seqs ({self.max_num_prefill_seqs}) after padding. | {stats}" # noqa: E501
logger.info(msg)
#else:
# msg = f"[PaddingAwareScheduler DEBUG] CAN schedule sequence. | {stats}" # noqa: E501
# logger.info(msg)
self._debug_validator()
return result

@property
def max_seq_len(self):
return self._max_seq_len

@property
def num_curr_prefill_seqs(self):
return self._num_curr_prefill_seqs


@dataclass
class ScheduledSequenceGroup:
# A sequence group that's scheduled.
Expand Down Expand Up @@ -937,9 +1050,18 @@ def _schedule_prefills(
continue

num_new_seqs = seq_group.get_max_num_running_seqs()
max_prefill_seq_len = None
can_schedule_kwargs = {
'num_new_tokens': num_new_tokens,
'num_new_seqs': num_new_seqs
}
if self.scheduler_config.use_padding_aware_scheduling:
max_prefill_seq_len = max(
[seq.get_num_new_tokens() for seq in seq_group.get_seqs()])
can_schedule_kwargs['is_prefill'] = True
can_schedule_kwargs['max_seq_len'] = max_prefill_seq_len
if (num_new_tokens == 0
or not budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
or not budget.can_schedule(**can_schedule_kwargs)):
break

# Can schedule this request.
Expand Down Expand Up @@ -970,6 +1092,10 @@ def _schedule_prefills(
token_chunk_size=num_new_tokens))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
if self.scheduler_config.use_padding_aware_scheduling:
assert isinstance(budget, PaddingAwareSchedulingBudget)
budget.add_prefill_seqs(seq_group.request_id, num_new_seqs,
max_prefill_seq_len)

# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
Expand All @@ -991,10 +1117,18 @@ def _schedule_default(self) -> SchedulerOutputs:
be swapped or preempted.
"""
# Include running requests to the budget.
budget = SchedulingBudget(
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
budget: SchedulingBudget
if self.scheduler_config.use_padding_aware_scheduling:
budget = PaddingAwareSchedulingBudget(
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
max_num_prefill_seqs=self.scheduler_config.max_num_prefill_seqs
)
else:
budget = SchedulingBudget(
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
# Make sure we include num running seqs before scheduling prefill,
# so that we don't schedule beyond max_num_seqs for prefill.
for seq_group in self.running:
Expand Down
19 changes: 18 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,13 @@ class EngineArgs:
enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
use_padding_aware_scheduling: bool = False
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_num_prefill_seqs: Optional[int] = None
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
revision: Optional[str] = None
Expand Down Expand Up @@ -387,6 +389,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action='store_true',
help='Use BlockSpaceMangerV2. By default this is set to True. '
'Set to False to use BlockSpaceManagerV1')
parser.add_argument(
'--use-padding-aware-scheduling',
default=EngineArgs.use_padding_aware_scheduling,
action='store_true',
help=('Use padding-aware scheduling. If True, the scheduler '
'will consider padded tokens in prefill. '
'By default this is set to False. '))
parser.add_argument(
'--num-lookahead-slots',
type=int,
Expand Down Expand Up @@ -441,6 +450,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument(
'--max-num-prefill-seqs',
type=int,
default=EngineArgs.max_num_prefill_seqs,
help=('Maximum number of prefill sequences per '
'iteration. Can be used only with padding-aware '
'scheduling. Must be <= max_num_seqs.'))
parser.add_argument(
'--max-logprobs',
type=int,
Expand Down Expand Up @@ -1033,6 +1049,7 @@ def create_engine_config(self) -> EngineConfig:
scheduler_config = SchedulerConfig(
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_num_prefill_seqs=self.max_num_prefill_seqs,
max_model_len=model_config.max_model_len,
use_v2_block_manager=self.use_v2_block_manager,
num_lookahead_slots=num_lookahead_slots,
Expand All @@ -1046,7 +1063,7 @@ def create_engine_config(self) -> EngineConfig:
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy,
)
use_padding_aware_scheduling=self.use_padding_aware_scheduling)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
Expand Down
Loading

0 comments on commit 38b044b

Please sign in to comment.