diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 7f7f15bea86fa..62a9e814a5ac4 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -94,10 +94,16 @@ def warmup_range(config: Tuple[int, int, int]): return list(ramp_up_tw) + list(stable) -def warmup_buckets(bs_bucket_config, seq_bucket_config): +def warmup_buckets(bs_bucket_config, seq_bucket_config, + max_num_batched_tokens): buckets = itertools.product(warmup_range(bs_bucket_config), warmup_range(seq_bucket_config)) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + # Remove buckets exceeding batch token budget + filtered_buckets = filter( + lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, + buckets) + return list( + sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) def next_pow2(value: int): @@ -526,7 +532,8 @@ def _setup_buckets(self) -> None: f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg) + self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ @@ -543,7 +550,8 @@ def _setup_buckets(self) -> None: f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) + self.decode_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets