diff --git a/src/model.py b/src/model.py index e3fc62d..6a6c0e4 100644 --- a/src/model.py +++ b/src/model.py @@ -174,9 +174,10 @@ def init_engine(self): } # Add vLLM custom metrics engine_config = self.llm_engine.engine.model_config - self.llm_engine.add_logger( - "triton", VllmStatLogger(labels, engine_config.max_model_len) + self.vllm_metrics = VllmStatLogger( + labels, engine_config.max_model_len, self.logger ) + self.llm_engine.add_logger("triton", self.vllm_metrics) except pb_utils.TritonModelException as e: if "metrics not supported" in str(e): # Metrics are disabled at the server @@ -572,6 +573,9 @@ def finalize(self): self._response_thread.join() self._response_thread = None + # Shutdown the logger thread. + self.vllm_metrics.finalize() + # When using parallel tensors, the stub process may not shutdown due to # unreleased references, so manually run the garbage collector once. self.logger.log_info("[vllm] Running Garbage Collector on finalize...") diff --git a/src/utils/metrics.py b/src/utils/metrics.py index b1471a1..3ac0c43 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -24,6 +24,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import queue +import threading from typing import Dict, List, Union import triton_python_backend_utils as pb_utils @@ -170,11 +172,18 @@ def __init__(self, labels: List[str], max_model_len: int): class VllmStatLogger(VllmStatLoggerBase): """StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider.""" - # local_interval not used here. It's for vLLM logs to stdout. - def __init__(self, labels: Dict, max_model_len: int) -> None: + def __init__(self, labels: Dict, max_model_len: int, logger) -> None: # Tracked stats over current local logging interval. + # local_interval not used here. It's for vLLM logs to stdout. super().__init__(local_interval=0) self.metrics = TritonMetrics(labels, max_model_len) + self.logger = logger + + # Starting the metrics thread. It allows vLLM to keep making progress + # while reporting metrics to triton metrics service. + self._logger_queue = queue.Queue() + self._logger_thread = threading.Thread(target=self.logger_loop) + self._logger_thread.start() def info(self, type: str, obj: SupportsMetricsInfo) -> None: pass @@ -190,7 +199,7 @@ def _log_counter(self, counter, data: Union[int, float]) -> None: None """ if data != 0: - counter.increment(data) + self._logger_queue.put_nowait((counter, "increment", data)) def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: """Convenience function for logging list to histogram. @@ -203,7 +212,7 @@ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None None """ for datum in data: - histogram.observe(datum) + self._logger_queue.put_nowait((histogram, "observe", datum)) def log(self, stats: VllmStats) -> None: """Report stats to Triton metrics server. @@ -246,3 +255,24 @@ def log(self, stats: VllmStats) -> None: self._log_counter(metric, data) for metric, data in histogram_metrics: self._log_histogram(metric, data) + + def logger_loop(self): + while True: + item = self._logger_queue.get() + # To signal shutdown a None item will be added to the queue. + if item is None: + break + metric, command, data = item + if command == "increment": + metric.increment(data) + elif command == "observe": + metric.observe(data) + else: + self.logger.log_error(f"Undefined command name: {command}") + + def finalize(self): + # Shutdown the logger thread. + self._logger_queue.put(None) + if self._logger_thread is not None: + self._logger_thread.join() + self._logger_thread = None