From 7d59fc117c19549590c1cdade0a7512a2468a90a Mon Sep 17 00:00:00 2001 From: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com> Date: Tue, 27 Aug 2024 17:58:11 +0200 Subject: [PATCH] Revert "Ensure buckets do not exceed the batch token limit (#206)" This reverts commit aefd336798248d519ddc4cc5662c9aa03a9dbfad. --- vllm/worker/habana_model_runner.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 62a9e814a5ac4..7f7f15bea86fa 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -94,16 +94,10 @@ def warmup_range(config: Tuple[int, int, int]): return list(ramp_up_tw) + list(stable) -def warmup_buckets(bs_bucket_config, seq_bucket_config, - max_num_batched_tokens): +def warmup_buckets(bs_bucket_config, seq_bucket_config): buckets = itertools.product(warmup_range(bs_bucket_config), warmup_range(seq_bucket_config)) - # Remove buckets exceeding batch token budget - filtered_buckets = filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets) - return list( - sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) def next_pow2(value: int): @@ -532,8 +526,7 @@ def _setup_buckets(self) -> None: f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) + self.prompt_seq_bucket_cfg) if self.lora_config: self.prompt_buckets[:] = [ @@ -550,8 +543,7 @@ def _setup_buckets(self) -> None: f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg, - self.max_num_batched_tokens) + self.decode_seq_bucket_cfg) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets