Skip to content

Commit

Permalink
fix case where max_num_batched_tokens is not divisible by bucket batc…
Browse files Browse the repository at this point in the history
…h size
  • Loading branch information
kzawora-intel committed Aug 27, 2024
1 parent c113c13 commit 174b706
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], fill=True):

def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens,
block_size, max_model_len):
# NOTE(kzawora): If either max_model_len or max_num_batched_tokens are
# not divisible by block_size, we round them up here to block_size

bs_buckets = warmup_range(bs_bucket_config[:3])
seq_len_buckets = warmup_range(seq_bucket_config[:3])
if bs_bucket_config[3] != 0 and len(bs_buckets) > bs_bucket_config[3]:
Expand All @@ -199,49 +202,62 @@ def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens,
seq_len_buckets) > seq_bucket_config[3]:
seq_len_buckets = warmup_range_with_limit(seq_bucket_config)

# NOTE(kzawora): When using sequence limits, we need to make sure we still
# cover the largest valid scenario (max_num_batched_tokens/bs). If we drop
# out-of-bound buckets, we might be left with last bucket that's too small
# and encounter recompilations when exceeding that bucket and still being
# in valid range

buckets = list(itertools.product(bs_buckets, seq_len_buckets))

# Remove buckets that either exceed token budget
# or are greater than max model length
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * bucket[1] <= math.ceil(
max_num_batched_tokens / block_size) * block_size and bucket[1]
<= max_model_len, buckets))

# NOTE(kzawora): We need to make sure we cover the largest valid scenario
# (max_num_batched_tokens/bs). If we drop out-of-bound buckets, we might
# be left with last bucket that's too small and encounter recompilations
# when exceeding that bucket and still being in the valid range.

# Example:
# max_model_len = 768
# max_num_batched_tokens = 1536
# bs_buckets = [1, 2, 4]
# block_size = 128
# max_model_len = 768 (6 blocks)
# max_num_batched_tokens = 1536 (12 blocks)
# bs_buckets = [1, 2, 4, 5]
# seq_len_buckets = [128, 256, 512, 1024]
# buckets = [
# (1,128), (1,256), (1,512), (1,1024), # last one is invalid, 1024 > 768
# (2,128), (2,256), (2,512), (2,1024) # last one is invalid, 2048 > 1536
# (4,128), (4,256), (4,512), (4,1024) # last two are invalid, 2048 > 1536
# (5,128), (5,256), (5,512), (5,1024) # last two are invalid, 2560 > 1536
# ]
# filtered_buckets = [
# (1,128), (1,256), (1,512), # 512-768 is valid and not bucketed
# (2,128), (2,256), (2,512), # 512-768 is valid and not bucketed
# (4,128), (4,256), # 256-384 is valid and not bucketed
# (5,128), (5,256), # 256-307 is valid and not bucketed
# ]
# boundary_buckets = [
# (1,768), # covers 512-768 range, constrained by max_model_len
# (2,768) # covers 512-768 range, constrained by budget and max_model_len
# (4,384) # covers 256-384 range, constrained by token budget
# ]

# Generate the largest valid bucket for all batch sizes.
boundary_buckets = [(bs,
min(
max_model_len,
math.ceil(max_num_batched_tokens /
(bs * block_size)) * block_size))
for bs in bs_buckets]
buckets = list(itertools.product(bs_buckets, seq_len_buckets))
buckets = list(set([*buckets, *boundary_buckets]))
# (5,384) # covers 256-307 range, constrained by token budget,
# # rounded up to block_size(=128)
# ]
# While (5,384) bucket exceeds the token budget of 1536 (5*386=1920),
# it is the smallest bucket that can handle all valid use cases on bs=5,
# as all buckets for decode phase need to be aligned to block_size.

# Generate the largest valid bucket for all valid batch sizes.
# DO NOT use bs_buckets in here as it can include invalid batches
valid_bs_buckets = set(list(zip(*filtered_buckets))[0])
boundary_buckets = [
(bs,
min(
math.ceil(max_model_len / block_size) * block_size,
math.ceil(max_num_batched_tokens / (bs * block_size)) *
block_size)) for bs in valid_bs_buckets
]
filtered_buckets = list(set([*filtered_buckets, *boundary_buckets]))

# Remove buckets that either exceed token budget
# or are greater than max model length
filtered_buckets = filter(
lambda bucket: bucket[0] * bucket[1] <= math.ceil(
max_num_batched_tokens / block_size) * block_size and bucket[1] <=
max_model_len, buckets)
return list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))

Expand Down

0 comments on commit 174b706

Please sign in to comment.