Skip to content

Commit

Permalink
i enjoy typing
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 27, 2024
1 parent 14d7ef7 commit c113c13
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,34 @@ def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens,
seq_len_buckets) > seq_bucket_config[3]:
seq_len_buckets = warmup_range_with_limit(seq_bucket_config)

# NOTE(kzawora): When using sequence limits, we need to make sure we still
# cover the largest valid scenario (max_num_batched_tokens/bs). If we drop
# out-of-bound buckets, we might be left with last bucket that's too small
# and encounter recompilations when exceeding that bucket and still being
# in valid range

# Example:
# max_model_len = 768
# max_num_batched_tokens = 1536
# bs_buckets = [1, 2, 4]
# seq_len_buckets = [128, 256, 512, 1024]
# buckets = [
# (1,128), (1,256), (1,512), (1,1024), # last one is invalid, 1024 > 768
# (2,128), (2,256), (2,512), (2,1024) # last one is invalid, 2048 > 1536
# (4,128), (4,256), (4,512), (4,1024) # last two are invalid, 2048 > 1536
# ]
# filtered_buckets = [
# (1,128), (1,256), (1,512), # 512-768 is valid and not bucketed
# (2,128), (2,256), (2,512), # 512-768 is valid and not bucketed
# (4,128), (4,256), # 256-384 is valid and not bucketed
# ]
# boundary_buckets = [
# (1,768), # covers 512-768 range, constrained by max_model_len
# (2,768) # covers 512-768 range, constrained by budget and max_model_len
# (4,384) # covers 256-384 range, constrained by token budget
# ]

# Generate the largest valid bucket for all batch sizes.
boundary_buckets = [(bs,
min(
max_model_len,
Expand All @@ -207,6 +235,9 @@ def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens,
for bs in bs_buckets]
buckets = list(itertools.product(bs_buckets, seq_len_buckets))
buckets = list(set([*buckets, *boundary_buckets]))

# Remove buckets that either exceed token budget
# or are greater than max model length
filtered_buckets = filter(
lambda bucket: bucket[0] * bucket[1] <= math.ceil(
max_num_batched_tokens / block_size) * block_size and bucket[1] <=
Expand Down

0 comments on commit c113c13

Please sign in to comment.