From 76367b5ae769aa368f21d336afbb33709bbe8444 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 22 Oct 2024 16:58:08 +0000 Subject: [PATCH] wip --- aphrodite/engine/args_tools.py | 5 ++++ aphrodite/engine/async_aphrodite.py | 29 ++++++++++++++------ aphrodite/engine/metrics.py | 41 +++++++++++++++++++++++++++- aphrodite/engine/metrics_types.py | 42 +++++++++++++++++++++++++++-- 4 files changed, 106 insertions(+), 11 deletions(-) diff --git a/aphrodite/engine/args_tools.py b/aphrodite/engine/args_tools.py index fff6a822f..c25bcba7e 100644 --- a/aphrodite/engine/args_tools.py +++ b/aphrodite/engine/args_tools.py @@ -1093,6 +1093,7 @@ class AsyncEngineArgs(EngineArgs): engine_use_ray: bool = False disable_log_requests: bool = False + per_request_logging: bool = False uvloop: bool = False @staticmethod @@ -1107,6 +1108,10 @@ def add_cli_args(parser: FlexibleArgumentParser, parser.add_argument('--disable-log-requests', action='store_true', help='Disable logging requests.') + parser.add_argument('--per-request-logging', + action='store_true', + help='Switch to per-request logging instead of ' + 'global logging.') parser.add_argument( "--uvloop", action="store_true", diff --git a/aphrodite/engine/async_aphrodite.py b/aphrodite/engine/async_aphrodite.py index 7934a2548..63964e38c 100644 --- a/aphrodite/engine/async_aphrodite.py +++ b/aphrodite/engine/async_aphrodite.py @@ -498,6 +498,7 @@ class AsyncAphrodite: _engine_class: Type[_AsyncAphrodite] = _AsyncAphrodite def __init__(self, + per_request_logging: bool, worker_use_ray: bool, engine_use_ray: bool, *args, @@ -507,6 +508,7 @@ def __init__(self, self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests + self.per_request_logging = per_request_logging self.engine = self._init_engine(*args, **kwargs) self.background_loop: Optional[asyncio.Future] = None @@ -599,6 +601,7 @@ def from_engine_args( executor_class = cls._get_executor_cls(engine_config) # Create the async LLM engine. engine = cls( + engine_args.per_request_logging, executor_class.uses_ray, engine_args.engine_use_ray, **engine_config.to_dict(), @@ -914,14 +917,22 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - inputs, - sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ): - yield AphroditeEngine.validate_output(output, RequestOutput) + if self.per_request_logging: + request_id = f"{request_id}_{uuid4()}" + self.engine.stat_logger.start_request(request_id) + + try: + async for output in await self.add_request( + request_id, + inputs, + sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ): + yield AphroditeEngine.validate_output(output, RequestOutput) + finally: + if self.per_request_logging: + self.engine.stat_logger.end_request() async def encode( self, @@ -1066,6 +1077,8 @@ async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None) -> None: + if not self.per_request_logging: + if self.engine_use_ray: await self.engine.do_log_stats.remote( # type: ignore scheduler_outputs, model_output) diff --git a/aphrodite/engine/metrics.py b/aphrodite/engine/metrics.py index 8daecce32..69fb3389a 100644 --- a/aphrodite/engine/metrics.py +++ b/aphrodite/engine/metrics.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from typing import Counter as CollectionsCounter from typing import Dict, List, Optional, Union @@ -178,6 +178,35 @@ def __init__(self, labelnames: List[str], max_model_len: int): multiprocess_mode="sum", ) + if "request_id" not in labelnames: + labelnames.append("request_id") + + self.gauge_per_request_duration = self._gauge_cls( + name="aphrodite:per_request_duration_seconds", + documentation="Duration of each request in seconds.", + labelnames=labelnames, + multiprocess_mode="livesum") + self.gauge_per_request_prompt_throughput = self._gauge_cls( + name="aphrodite:per_request_prompt_throughput", + documentation="Prompt throughput for each request.", + labelnames=labelnames, + multiprocess_mode="livesum") + self.gauge_per_request_generation_throughput = self._gauge_cls( + name="aphrodite:per_request_generation_throughput", + documentation="Generation throughput for each request.", + labelnames=labelnames, + multiprocess_mode="livesum") + self.gauge_per_request_gpu_cache_usage = self._gauge_cls( + name="aphrodite:per_request_gpu_cache_usage", + documentation="GPU cache usage for each request.", + labelnames=labelnames, + multiprocess_mode="livesum") + self.gauge_per_request_cpu_cache_usage = self._gauge_cls( + name="aphrodite:per_request_cpu_cache_usage", + documentation="CPU cache usage for each request.", + labelnames=labelnames, + multiprocess_mode="livesum") + # end-metrics-definitions @@ -527,6 +556,16 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None: multiprocess_mode="mostrecent") info_gauge.labels(**metrics_info).set(1) + def log_per_request(self, request_id: str, + stats: Dict[str, Any]) -> None: + if self.per_request_logging: + labels = {**self.labels, "request_id": request_id} + self.metrics.gauge_per_request_duration.labels(**labels).set(stats["duration"]) + self.metrics.gauge_per_request_prompt_throughput.labels(**labels).set(stats["prompt_throughput"]) + self.metrics.gauge_per_request_generation_throughput.labels(**labels).set(stats["generation_throughput"]) + self.metrics.gauge_per_request_gpu_cache_usage.labels(**labels).set(stats["gpu_cache_usage"]) + self.metrics.gauge_per_request_cpu_cache_usage.labels(**labels).set(stats["cpu_cache_usage"]) + class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" diff --git a/aphrodite/engine/metrics_types.py b/aphrodite/engine/metrics_types.py index 5d337f2e1..8f58fdebc 100644 --- a/aphrodite/engine/metrics_types.py +++ b/aphrodite/engine/metrics_types.py @@ -11,7 +11,7 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -49,16 +49,54 @@ def metrics_info(self) -> Dict[str, str]: ... class StatLoggerBase(ABC): """Base class for StatLogger.""" - def __init__(self, local_interval: float) -> None: + def __init__(self, local_interval: float, + per_request_logging: bool = False) -> None: # Tracked stats over current local logging interval. self.num_prompt_tokens: List[int] = [] self.num_generation_tokens: List[int] = [] self.last_local_log = time.time() self.local_interval = local_interval + self.per_request_logging = per_request_logging + self.current_request_stats: Optional[Dict[str, Any]] = None self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + @abstractmethod def log(self, stats: Stats) -> None: raise NotImplementedError + + @abstractmethod + def log_per_request(self, request_id: str, + stats: Dict[str, Any]) -> None: + raise NotImplementedError + + def start_request(self, request_id: str) -> None: + if self.per_request_logging: + self.current_request_stats = { + "request_id": request_id, + "start_time": time.time(), + "prompt_tokens": 0.0, + "generation_tokens": 0.0, + "gpu_cache_usage": 0.0, + "cpu_cache_usage": 0.0 + } + + def end_request(self) -> None: + if self.per_request_logging and self.current_request_stats: + end_time = time.time() + duration = end_time - self.current_request_stats["start_time"] + prompt_throughput = self.current_request_stats["prompt_tokens"] + generation_throughput = self.current_request_stats["generation_tokens"] + + self.log_per_request(self.current_request_stats["request_id"], { + "duration": duration, + "prompt_throughput": prompt_throughput, + "generation_throughput": generation_throughput, + "gpu_cache_usage": self.current_request_stats["gpu_cache_usage"], + "cpu_cache_usage": self.current_request_stats["cpu_cache_usage"] + }) + + self.current_request_stats = None + @abstractmethod def info(self, type: str, obj: SupportsMetricsInfo) -> None: raise NotImplementedError