Skip to content

Commit

Permalink
i ve made an oopsie woopsie
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 27, 2024
1 parent 8a1318f commit 14d7ef7
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], fill=True):

bmin, bstep, bmax, num_buckets = config
linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep))
print(len(linear_buckets))
assert num_buckets > 0, "num_buckets must be a positive integer"
if num_buckets == 1:
return [bmax]
Expand All @@ -191,25 +190,27 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], fill=True):


def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens,
block_size):
block_size, max_model_len):
bs_buckets = warmup_range(bs_bucket_config[:3])
seq_len_buckets = warmup_range(seq_bucket_config[:3])
if bs_bucket_config[3] != 0 and len(bs_buckets) > bs_bucket_config[3]:
bs_buckets = warmup_range_with_limit(bs_bucket_config)
if seq_bucket_config[3] != 0 and len(
seq_len_buckets) > seq_bucket_config[3]:
seq_len_buckets = warmup_range_with_limit(seq_bucket_config)
diagonal_buckets = [
(bs,
math.ceil(max_num_batched_tokens / (bs * block_size)) * block_size)
for bs in bs_buckets
]
print(diagonal_buckets)

boundary_buckets = [(bs,
min(
max_model_len,
math.ceil(max_num_batched_tokens /
(bs * block_size)) * block_size))
for bs in bs_buckets]
buckets = list(itertools.product(bs_buckets, seq_len_buckets))
buckets = list(set([*buckets, *diagonal_buckets]))
buckets = list(set([*buckets, *boundary_buckets]))
filtered_buckets = filter(
lambda bucket: bucket[0] * bucket[1] <= math.ceil(
max_num_batched_tokens / block_size) * block_size, buckets)
max_num_batched_tokens / block_size) * block_size and bucket[1] <=
max_model_len, buckets)
return list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))

Expand Down Expand Up @@ -639,7 +640,8 @@ def _setup_buckets(self) -> None:
self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg,
self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens,
self.block_size)
self.block_size,
self.max_model_len)

if self.lora_config:
self.prompt_buckets[:] = [
Expand All @@ -658,7 +660,8 @@ def _setup_buckets(self) -> None:
self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg,
self.decode_seq_bucket_cfg,
self.max_num_batched_tokens,
self.block_size)
self.block_size,
self.max_model_len)
if self.lora_config:
self.decode_buckets[:] = [
bucket for bucket in self.decode_buckets
Expand Down

0 comments on commit 14d7ef7

Please sign in to comment.