Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to limit number of buckets #156

Closed
wants to merge 28 commits into from
Closed
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
73390bf
Remove redundant torch.device
kzawora-intel Jul 30, 2024
1546cd5
Add multiprocessing HPU executor
kzawora-intel Jul 31, 2024
5fe745d
Merge remote-tracking branch 'origin/habana_main' into private/kzawor…
kzawora-intel Jul 31, 2024
5b5dab1
typo
kzawora-intel Jul 31, 2024
b0c4d4c
Add Gaudi documentation for 1.17
kzawora-intel Aug 1, 2024
f532da1
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Aug 5, 2024
879dc38
WIP: Add bucket limit functionality
kzawora-intel Aug 5, 2024
77019dd
functional fixes
kzawora-intel Aug 5, 2024
0813521
Revert "Remove redundant torch.device"
kzawora-intel Aug 5, 2024
2826023
Revert "Add multiprocessing HPU executor"
kzawora-intel Aug 5, 2024
088f372
Revert "typo"
kzawora-intel Aug 5, 2024
5829434
Revert "Add Gaudi documentation for 1.17"
kzawora-intel Aug 5, 2024
4d1fc07
Revert "Add multiprocessing HPU executor"
kzawora-intel Aug 5, 2024
9321afe
Revert "Revert "Remove redundant torch.device""
kzawora-intel Aug 5, 2024
887fe50
format.sh
kzawora-intel Aug 5, 2024
4298c92
return list
kzawora-intel Aug 5, 2024
4e8a6aa
Update habana_model_runner.py
kzawora-intel Aug 14, 2024
b03ef68
Update habana_model_runner.py
kzawora-intel Aug 14, 2024
9ade5a7
Update habana_model_runner.py
kzawora-intel Aug 14, 2024
c4a2f2e
Update habana_model_runner.py
kzawora-intel Aug 14, 2024
f33a653
Update habana_model_runner.py
kzawora-intel Aug 14, 2024
5a7247c
Update habana_model_runner.py
kzawora-intel Aug 14, 2024
b84fa26
Update habana_model_runner.py
kzawora-intel Aug 14, 2024
3ed24b3
Merge branch 'habana_main' into private/kzawora/bucketing_limit
adobrzyniewicz-habana Aug 27, 2024
8a1318f
token budgeting integraation
kzawora-intel Aug 27, 2024
14d7ef7
i ve made an oopsie woopsie
kzawora-intel Aug 27, 2024
c113c13
i enjoy typing
kzawora-intel Aug 27, 2024
174b706
fix case where max_num_batched_tokens is not divisible by bucket batc…
kzawora-intel Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 229 additions & 49 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Optional, Set, Tuple, Type, TypeVar, Union)

import habana_frameworks.torch as htorch
import numpy as np
import torch

from vllm.attention import AttentionMetadata, get_attn_backend
Expand Down Expand Up @@ -60,7 +61,7 @@ def read_bucket_settings(phase: str, dim: str, **defaults):
param is either 'min', 'step' or 'max'
example env variable: VLLM_DECODE_BS_BUCKET_STEP=128
"""
params = ['min', 'step', 'max']
params = ['min', 'step', 'max', 'limit']
values = [
int(
os.environ.get(f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper(),
Expand Down Expand Up @@ -94,14 +95,169 @@ def warmup_range(config: Tuple[int, int, int]):
return list(ramp_up_tw) + list(stable)


def warmup_buckets(bs_bucket_config, seq_bucket_config,
max_num_batched_tokens):
buckets = itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config))
# Remove buckets exceeding batch token budget
filtered_buckets = filter(
lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
buckets)
def warmup_range_with_limit(config: Tuple[int, int, int, int], fill=True):
"""
NOTE(kzawora): we'll use exponential spacing for buckets in which scaled
power will return bmin for first bucket iteration, and bmax for last
iteration, with elements between determined by the exponent, and base being
unchanged. Note that after padding to bstep, duplicates may occur.
Handling of duplicates is configured by fill parameter.

If fill is False, duplicates are removed and less buckets are returned.

If fill is True, duplicates are resolved by selecting the closest (larger
or smaller) bucket. If duplicate resolution is not possible, less buckets
are returned. In that case, buckets are guaranteed to be linearly spaced.

Example (bmin=128, bstep=128, bmax=2048, num_buckets=10):

There are 16 possible buckets (2048/128), and we'll attempt to select 10 of
them with exponential spacing.
base = (bmax/bmin) ** (1/(num_buckets-1)); (2048/128) ** (1/9) = 1.36079
exponent = i
power = base ** exponent
scaled_power = b_min * power

For i == 0 (first bucket), power is 1.36079 ** 0 = 1;
scaled_power is 1 * 128 = 128 (==bmin)
For i == 9 (last bucket), power is 1.36079 ** 9 = 16;
scaled_power is 16 * 128 = 2048 (==bmax)

So, computing for all buckets:
scaled_powers_unpadded = [bmin*base^0(==bmin), bmin*base^1, bmin*base^2, ..., bmin*base^9(==bmax)]
scaled_powers_unpadded = [128.00, 174.18, 237.02, 322.54, 438.91, 597.26, 812.75, 1105.98, 1505.01, 2048.00]

if fill is False:
scaled_powers_padded = [ 128, 256, 256, 384, 512, 640, 896, 1152, 1536, 2048]
^_______^
duplicates

buckets = [ 128, 256, 384, 512, 640, 896, 1152, 1536, 2048]
^
duplicate bucket removed

len(buckets) = 9, num_buckets = 10

if fill is True:
buckets = [ 128, 256, 384, 512, 640, 768, 896, 1152, 1536, 2048]
^_______^_______^_______^
closest unused buckets selected
^_______^_______^
these become duplicates once previous duplicates are resolved

In this case we'll have four duplicated buckets:

174.18 -> 256, optimal bucket,
237.02 -> (256) -> 384, taking closest available bucket,
as optimal bucket 256 was already captured by 174.18,
322.54 -> (384) -> 512, taking closest available bucket,
as optimal bucket 384 was already captured by 237.02,
438.91 -> (512) -> 640, taking closest available bucket,
as optimal bucket 512 was already captured by 322.54,
597.26 -> (640) -> 768, taking closest available bucket,
as optimal bucket 640 was already captured by 438.91,
812.75 -> 896, optimal bucket

len(buckets) = 10, num_buckets = 10

In this case, the end result has the same buckets as fill=False,
but with additional bucket 768 added.
The difference is more pronounced for larger ranges and larger number
of buckets.

""" # noqa: E501

bmin, bstep, bmax, num_buckets = config
linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep))
assert num_buckets > 0, "num_buckets must be a positive integer"
if num_buckets == 1:
return [bmax]
buckets: Set[Tuple[int, int]] = set()
for i in range(num_buckets):
power_unpadded = bmin * np.float_power(
bmax / bmin, (1. / float(num_buckets - 1)) * i)
bucket = math.ceil(power_unpadded / bstep) * bstep
if fill and bucket in buckets:
available_buckets = linear_buckets.difference(buckets)
if len(available_buckets) == 0:
break # there are no more unique buckets, let's exit now
new_bucket = min(available_buckets,
key=lambda x: abs(x - power_unpadded))
buckets.add(new_bucket)
else:
buckets.add(bucket)
return list(sorted(buckets))


def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens,
block_size, max_model_len):
# NOTE(kzawora): If either max_model_len or max_num_batched_tokens are
# not divisible by block_size, we round them up here to block_size

bs_buckets = warmup_range(bs_bucket_config[:3])
seq_len_buckets = warmup_range(seq_bucket_config[:3])
if bs_bucket_config[3] != 0 and len(bs_buckets) > bs_bucket_config[3]:
bs_buckets = warmup_range_with_limit(bs_bucket_config)
if seq_bucket_config[3] != 0 and len(
seq_len_buckets) > seq_bucket_config[3]:
seq_len_buckets = warmup_range_with_limit(seq_bucket_config)

buckets = list(itertools.product(bs_buckets, seq_len_buckets))

# Remove buckets that either exceed token budget
# or are greater than max model length
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * bucket[1] <= math.ceil(
max_num_batched_tokens / block_size) * block_size and bucket[1]
<= max_model_len, buckets))

# NOTE(kzawora): We need to make sure we cover the largest valid scenario
# (max_num_batched_tokens/bs). If we drop out-of-bound buckets, we might
# be left with last bucket that's too small and encounter recompilations
# when exceeding that bucket and still being in the valid range.

# Example:
# block_size = 128
# max_model_len = 768 (6 blocks)
# max_num_batched_tokens = 1536 (12 blocks)
# bs_buckets = [1, 2, 4, 5]
# seq_len_buckets = [128, 256, 512, 1024]
# buckets = [
# (1,128), (1,256), (1,512), (1,1024), # last one is invalid, 1024 > 768
# (2,128), (2,256), (2,512), (2,1024) # last one is invalid, 2048 > 1536
# (4,128), (4,256), (4,512), (4,1024) # last two are invalid, 2048 > 1536
# (5,128), (5,256), (5,512), (5,1024) # last two are invalid, 2560 > 1536
# ]
# filtered_buckets = [
# (1,128), (1,256), (1,512), # 512-768 is valid and not bucketed
# (2,128), (2,256), (2,512), # 512-768 is valid and not bucketed
# (4,128), (4,256), # 256-384 is valid and not bucketed
# (5,128), (5,256), # 256-307 is valid and not bucketed
# ]
# boundary_buckets = [
# (1,768), # covers 512-768 range, constrained by max_model_len
# (2,768) # covers 512-768 range, constrained by budget and max_model_len
# (4,384) # covers 256-384 range, constrained by token budget
# (5,384) # covers 256-307 range, constrained by token budget,
# # rounded up to block_size(=128)
# ]
# While (5,384) bucket exceeds the token budget of 1536 (5*386=1920),
# it is the smallest bucket that can handle all valid use cases on bs=5,
# as all buckets for decode phase need to be aligned to block_size.

# Generate the largest valid bucket for all valid batch sizes.
# DO NOT use bs_buckets in here as it can include invalid batches
valid_bs_buckets = set(list(zip(*filtered_buckets))[0])
boundary_buckets = [
(bs,
min(
math.ceil(max_model_len / block_size) * block_size,
math.ceil(max_num_batched_tokens / (bs * block_size)) *
block_size)) for bs in valid_bs_buckets
]
filtered_buckets = list(set([*filtered_buckets, *boundary_buckets]))

return list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))

Expand All @@ -118,15 +274,6 @@ def round_up(value: int, k: int):
return (value + k - 1) // k * k


def find_bucket(value: int, config: Tuple[int, int, int]):
bmin, bstep, bmax = config
if value < bstep:
result = min(next_pow2(value), bstep)
else:
result = round_up(value, bstep)
return result


def subtuple(obj: object,
typename: str,
to_copy: List[str],
Expand Down Expand Up @@ -509,31 +656,39 @@ def _setup_buckets(self) -> None:
step=32,
max=min(
self.max_num_seqs,
max_bucket_cfg))
max_bucket_cfg),
limit=0)
self.decode_bs_bucket_cfg = read_bucket_settings('decode',
'bs',
min=1,
step=128,
max=self.max_num_seqs)
self.prompt_seq_bucket_cfg = read_bucket_settings('prompt',
'seq',
min=self.block_size,
step=self.block_size,
max=1024)
self.decode_seq_bucket_cfg = read_bucket_settings('decode',
'seq',
min=self.block_size,
step=self.block_size,
max=2048)
max=self.max_num_seqs,
limit=0)
self.prompt_seq_bucket_cfg = read_bucket_settings(
'prompt',
'seq',
min=self.block_size,
step=self.block_size,
max=self.max_model_len,
limit=0)
self.decode_seq_bucket_cfg = read_bucket_settings(
'decode',
'seq',
min=self.block_size,
step=self.block_size,
max=self.max_model_len,
limit=0)
self.graphed_buckets: Set[Any] = set()

msg = ("Prompt bucket config (min, step, max_warmup) "
msg = ("Prompt bucket config (min, step, max_warmup, limit) "
f"bs:{self.prompt_bs_bucket_cfg}, "
f"seq:{self.prompt_seq_bucket_cfg}")
logger.info(msg)
self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg,
self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)
self.max_num_batched_tokens,
self.block_size,
self.max_model_len)

if self.lora_config:
self.prompt_buckets[:] = [
Expand All @@ -545,18 +700,32 @@ def _setup_buckets(self) -> None:
f"prompt buckets: {list(sorted(self.prompt_buckets))}")
logger.info(msg)

msg = ("Decode bucket config (min, step, max_warmup) "
msg = ("Decode bucket config (min, step, max_warmup, limit) "
f"bs:{self.decode_bs_bucket_cfg}, "
f"seq:{self.decode_seq_bucket_cfg}")
logger.info(msg)
self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg,
self.decode_seq_bucket_cfg,
self.max_num_batched_tokens)
self.max_num_batched_tokens,
self.block_size,
self.max_model_len)
if self.lora_config:
self.decode_buckets[:] = [
bucket for bucket in self.decode_buckets
if self._is_valid_bucket(bucket)
]

find_bucket = lambda bucket, x: next(p for p in sorted(bucket)
if p >= x)
get_bucket_dim = lambda bucket, dim: [b[dim] for b in bucket]
self.find_bs_bucket = lambda x, is_prompt: find_bucket(
get_bucket_dim(
self.prompt_buckets
if is_prompt else self.decode_buckets, 0), x)
self.find_seq_bucket = lambda x, is_prompt: find_bucket(
get_bucket_dim(
self.prompt_buckets
if is_prompt else self.decode_buckets, 1), x)
msg = (f"Generated {len(self.decode_buckets)} decode buckets: "
f"{list(sorted(self.decode_buckets))}")
logger.info(msg)
Expand Down Expand Up @@ -688,9 +857,8 @@ def _prepare_prompt(
multi_modal_input = None

max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
max_prompt_len = max(
find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg),
self.block_size)
max_prompt_len = max(self.find_seq_bucket(max(seq_lens), True),
self.block_size)

for seq_group_metadata, context_len in zip(seq_group_metadata_list,
context_lens):
Expand Down Expand Up @@ -896,9 +1064,7 @@ def prepare_input_tensors(
self.profiler.start('internal', base_event_name)

real_batch_size = len(seq_group_metadata_list)
bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else \
self.decode_bs_bucket_cfg
batch_size_padded = find_bucket(real_batch_size, bucket_cfg)
batch_size_padded = self.find_bs_bucket(real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
seq_group_metadata_list.extend(seq_group_metadata_list[0]
Expand Down Expand Up @@ -1087,16 +1253,30 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
max_batch_size = self.prompt_bs_bucket_cfg[-1]
max_seq_len = self.prompt_seq_bucket_cfg[-1]
if self.lora_config:
max_seq_len = self.max_num_batched_tokens // max_batch_size

self.warmup_scenario(max_batch_size,
max_seq_len,
True,
kv_caches,
is_profile_run=True)
# take two largest buckets in terms of token counts
# first with largest batch size, second with largest seq length
# if both are the same, only single warmup run is executed
warmup_scenarios = set()
warmup_scenarios.add(
max(self.prompt_buckets,
key=lambda item: (item[0], item[0] * item[1])))
warmup_scenarios.add(
max(self.prompt_buckets,
key=lambda item: (item[1], item[0] * item[1])))
for idx, (max_batch_size, max_seq_len) in enumerate(warmup_scenarios):
msg = (f'[{idx+1}/{len(warmup_scenarios)}] '
f'Executing profile run for bs={max_batch_size}, '
f'seq_len={max_seq_len}, '
f'num_batched_tokens={max_batch_size*max_seq_len}')
logger.info(msg)
if self.lora_config:
max_seq_len = self.max_num_batched_tokens // max_batch_size
self.warmup_scenario(max_batch_size,
max_seq_len,
True,
kv_caches,
is_profile_run=True)

def warmup_scenario(self,
batch_size,
Expand Down
Loading