diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py index d2059057..fbe6675f 100644 --- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -125,6 +125,17 @@ def test_vllm_metrics(self): # vllm:generation_tokens_total self.assertEqual(metrics_dict["vllm:generation_tokens_total"], 48) + # vllm:time_to_first_token_seconds + self.assertEqual(metrics_dict["vllm:time_to_first_token_seconds_count"], 3) + self.assertTrue( + 0 < metrics_dict["vllm:time_to_first_token_seconds_sum"] < 0.0005 + ) + # vllm:time_per_output_token_seconds + self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45) + self.assertTrue( + 0 <= metrics_dict["vllm:time_per_output_token_seconds_sum"] <= 0.005 + ) + def test_vllm_metrics_disabled(self): # Test vLLM metrics self.vllm_infer( diff --git a/src/utils/metrics.py b/src/utils/metrics.py index fc6e69bd..0374fa3b 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -24,7 +24,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Dict, Union +from typing import Dict, List, Union import triton_python_backend_utils as pb_utils from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase @@ -46,6 +46,16 @@ def __init__(self, labels): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.COUNTER, ) + self.histogram_time_to_first_token_family = pb_utils.MetricFamily( + name="vllm:time_to_first_token_seconds", + description="Histogram of time to first token in seconds.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) + self.histogram_time_per_output_token_family = pb_utils.MetricFamily( + name="vllm:time_per_output_token_seconds", + description="Histogram of time per output token in seconds.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) # Initialize metrics # Iteration stats @@ -55,6 +65,49 @@ def __init__(self, labels): self.counter_generation_tokens = self.counter_generation_tokens_family.Metric( labels=labels ) + self.histogram_time_to_first_token = ( + self.histogram_time_to_first_token_family.Metric( + labels=labels, + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ) + ) + self.histogram_time_per_output_token = ( + self.histogram_time_per_output_token_family.Metric( + labels=labels, + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + ) class VllmStatLogger(VllmStatLoggerBase): @@ -82,6 +135,19 @@ def _log_counter(self, counter, data: Union[int, float]) -> None: if data != 0: counter.increment(data) + def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: + """Convenience function for logging list to histogram. + + Args: + histogram: A histogram metric instance. + data: A list of int or float data to observe into the histogram metric. + + Returns: + None + """ + for datum in data: + histogram.observe(datum) + def log(self, stats: VllmStats) -> None: """Report stats to Triton metrics server. @@ -97,3 +163,10 @@ def log(self, stats: VllmStats) -> None: self._log_counter( self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter ) + self._log_histogram( + self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter + ) + self._log_histogram( + self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter, + )