Skip to content

Commit

Permalink
Revert "Ensure buckets do not exceed the batch token limit (#206)"
Browse files Browse the repository at this point in the history
This reverts commit aefd336.
  • Loading branch information
szutenberg committed Aug 27, 2024
1 parent aefd336 commit 7d59fc1
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,10 @@ def warmup_range(config: Tuple[int, int, int]):
return list(ramp_up_tw) + list(stable)


def warmup_buckets(bs_bucket_config, seq_bucket_config,
max_num_batched_tokens):
def warmup_buckets(bs_bucket_config, seq_bucket_config):
buckets = itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config))
# Remove buckets exceeding batch token budget
filtered_buckets = filter(
lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
buckets)
return list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))


def next_pow2(value: int):
Expand Down Expand Up @@ -532,8 +526,7 @@ def _setup_buckets(self) -> None:
f"seq:{self.prompt_seq_bucket_cfg}")
logger.info(msg)
self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg,
self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)
self.prompt_seq_bucket_cfg)

if self.lora_config:
self.prompt_buckets[:] = [
Expand All @@ -550,8 +543,7 @@ def _setup_buckets(self) -> None:
f"seq:{self.decode_seq_bucket_cfg}")
logger.info(msg)
self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg,
self.decode_seq_bucket_cfg,
self.max_num_batched_tokens)
self.decode_seq_bucket_cfg)
if self.lora_config:
self.decode_buckets[:] = [
bucket for bucket in self.decode_buckets
Expand Down

0 comments on commit 7d59fc1

Please sign in to comment.