Skip to content

Commit

Permalink
feat: improved logs
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Dec 27, 2024
1 parent 492e8a2 commit f0554f7
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 49 deletions.
34 changes: 20 additions & 14 deletions aidial_analytics_realtime/analytics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
from datetime import datetime
from decimal import Decimal
from enum import Enum
from logging import Logger
from typing import Awaitable, Callable
from uuid import uuid4

from influxdb_client import Point
Expand All @@ -14,8 +14,11 @@
get_chat_completion_response_contents,
get_embeddings_request_contents,
)
from aidial_analytics_realtime.influx_writer import InfluxWriterAsync
from aidial_analytics_realtime.rates import RatesCalculator
from aidial_analytics_realtime.topic_model import TopicModel
from aidial_analytics_realtime.utils.log_config import with_prefix
from aidial_analytics_realtime.utils.timer import Timer

identifier = LanguageIdentifier.from_modelstring(model, norm_probs=True)

Expand All @@ -42,20 +45,25 @@ def detect_lang(
case _:
assert_never(request_type)

return to_string(detect_lang_by_text(text))
return to_string(detect_lang_by_text(logger, text))


def detect_lang_by_text(text: str) -> str | None:
def detect_lang_by_text(logger: logging.Logger, text: str) -> str | None:
text = text.strip()

if not text:
return None

logger = with_prefix(logger, "[langid]")

try:
lang, prob = identifier.classify(text)
with Timer(logger.info):
lang, prob = identifier.classify(text)

if prob > 0.998:
return lang
except Exception:
except Exception as e:
logger.error(f"error: {str(e)}")
pass

return None
Expand Down Expand Up @@ -106,15 +114,15 @@ def make_point(

if chat_id:
topic = topic_model.get_topic_by_text(
"\n\n".join(request_contents + response_contents)
logger, "\n\n".join(request_contents + response_contents)
)
case RequestType.EMBEDDING:
request_contents = get_embeddings_request_contents(logger, request)

request_content = "\n".join(request_contents)
if chat_id:
topic = topic_model.get_topic_by_text(
"\n\n".join(request_contents)
logger, "\n\n".join(request_contents)
)
case _:
assert_never(request_type)
Expand Down Expand Up @@ -249,7 +257,7 @@ async def parse_usage_per_model(response: dict):

async def on_message(
logger: Logger,
influx_writer: Callable[[Point], Awaitable[None]],
influx_writer: InfluxWriterAsync,
deployment: str,
model: str,
project_id: str,
Expand All @@ -268,8 +276,6 @@ async def on_message(
trace: dict | None,
execution_path: list | None,
):
logger.info(f"Chat completion response length {len(response)}")

usage_per_model = await parse_usage_per_model(response)
if token_usage is not None:
point = make_point(
Expand All @@ -292,7 +298,7 @@ async def on_message(
trace,
execution_path,
)
await influx_writer(point)
await influx_writer(logger, point)
elif len(usage_per_model) == 0:
point = make_point(
logger,
Expand All @@ -314,7 +320,7 @@ async def on_message(
trace,
execution_path,
)
await influx_writer(point)
await influx_writer(logger, point)
else:
point = make_point(
logger,
Expand All @@ -336,7 +342,7 @@ async def on_message(
trace,
execution_path,
)
await influx_writer(point)
await influx_writer(logger, point)

for usage in usage_per_model:
point = make_point(
Expand All @@ -359,4 +365,4 @@ async def on_message(
trace,
execution_path,
)
await influx_writer(point)
await influx_writer(logger, point)
104 changes: 86 additions & 18 deletions aidial_analytics_realtime/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import contextlib
import json
import logging
import re
from datetime import datetime

import aiohttp
import starlette.requests
import uvicorn
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
Expand All @@ -21,7 +24,12 @@
from aidial_analytics_realtime.time import parse_time
from aidial_analytics_realtime.topic_model import TopicModel
from aidial_analytics_realtime.universal_api_utils import merge
from aidial_analytics_realtime.utils.log_config import configure_loggers, logger
from aidial_analytics_realtime.utils.log_config import (
app_logger,
configure_loggers,
with_prefix,
)
from aidial_analytics_realtime.utils.timer import Timer

RATE_PATTERN = r"/v1/(.+?)/rate"
CHAT_COMPLETION_PATTERN = r"/openai/deployments/(.+?)/chat/completions"
Expand Down Expand Up @@ -49,6 +57,7 @@ async def lifespan(app: FastAPI):


async def on_rate_message(
logger: logging.Logger,
deployment: str,
project_id: str,
chat_id: str,
Expand All @@ -59,7 +68,7 @@ async def on_rate_message(
response: dict,
influx_writer: InfluxWriterAsync,
):
logger.info(f"Rate message length {len(request) + len(response)}")
app_logger.info(f"Rate message length {len(request) + len(response)}")
request_body = json.loads(request["body"])
point = make_rate_point(
deployment,
Expand All @@ -70,10 +79,11 @@ async def on_rate_message(
timestamp,
request_body,
)
await influx_writer(point)
await influx_writer(logger, point)


async def on_chat_completion_message(
logger: logging.Logger,
deployment: str,
project_id: str,
chat_id: str,
Expand Down Expand Up @@ -149,6 +159,7 @@ async def on_chat_completion_message(


async def on_embedding_message(
logger: logging.Logger,
deployment: str,
project_id: str,
chat_id: str,
Expand Down Expand Up @@ -193,6 +204,7 @@ async def on_embedding_message(


async def on_log_message(
logger: logging.Logger,
message: dict,
influx_writer: InfluxWriterAsync,
topic_model: TopicModel,
Expand All @@ -219,6 +231,7 @@ async def on_log_message(

if re.search(RATE_PATTERN, uri):
await on_rate_message(
logger,
deployment,
project_id,
chat_id,
Expand All @@ -232,6 +245,7 @@ async def on_log_message(

elif re.search(CHAT_COMPLETION_PATTERN, uri):
await on_chat_completion_message(
logger,
deployment,
project_id,
chat_id,
Expand All @@ -252,6 +266,7 @@ async def on_log_message(

elif re.search(EMBEDDING_PATTERN, uri):
await on_embedding_message(
logger,
deployment,
project_id,
chat_id,
Expand All @@ -271,7 +286,7 @@ async def on_log_message(
)

else:
logger.warning(f"Unsupported message type: {uri!r}")
app_logger.warning(f"Unsupported message type: {uri!r}")


@app.post("/data")
Expand All @@ -281,28 +296,81 @@ async def on_log_messages(
topic_model: TopicModel = Depends(),
rates_calculator: RatesCalculator = Depends(),
):
request_logger = app_logger

data = await request.json()

statuses = []
for idx, item in enumerate(data):
try:
await on_log_message(
json.loads(item["message"]),
influx_writer,
topic_model,
rates_calculator,
)
except Exception as e:
logging.exception(f"Error processing message #{idx}")
statuses.append({"status": "error", "error": str(e)})
else:
statuses.append({"status": "success"})
n = len(data)
request_logger.info(f"number of messages: {n}")

statuses: list[dict] = []

async with Timer(request_logger.info):
for i, item in enumerate(data, start=1):
message_logger = with_prefix(request_logger, f"[{i}/{n}]")

async with Timer(message_logger.info):
status = await process_message(
message_logger,
json.loads(item["message"]),
influx_writer,
topic_model,
rates_calculator,
)

statuses.append(status)

if request_logger.isEnabledFor(logging.DEBUG):
request_logger.debug(f"response: {json.dumps(statuses)}")

# Returning 200 code even if processing of some messages has failed,
# since the log broker that sends the messages may decide to retry the failed requests.
return JSONResponse(content=statuses, status_code=200)


async def process_message(
logger: logging.Logger,
message: dict,
influx_writer: InfluxWriterAsync,
topic_model: TopicModel,
rates_calculator: RatesCalculator,
) -> dict:
try:
await on_log_message(
logger,
message,
influx_writer,
topic_model,
rates_calculator,
)
logger.info("success")
return {"status": "success"}
except starlette.requests.ClientDisconnect as e:
logger.error("client disconnect")
return {
"status": "error",
"error": str(e),
"reason": "client disconnect",
}
except aiohttp.ClientConnectionError as e:
logger.error("connection error")
return {
"status": "error",
"error": str(e),
"reason": "connection error",
}
except asyncio.TimeoutError as e:
logger.error("timeout")
return {
"status": "error",
"error": str(e),
"reason": "timeout",
}
except Exception as e:
logger.exception("caught exception")
return {"status": "error", "error": str(e)}


@app.get("/health")
def health():
return {"status": "ok"}
Expand Down
11 changes: 8 additions & 3 deletions aidial_analytics_realtime/influx_writer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import os
from logging import Logger
from typing import Awaitable, Callable, Tuple

from influxdb_client import Point
from influxdb_client.client.influxdb_client_async import InfluxDBClientAsync

InfluxWriterAsync = Callable[[Point], Awaitable[None]]
from aidial_analytics_realtime.utils.log_config import with_prefix
from aidial_analytics_realtime.utils.timer import Timer

InfluxWriterAsync = Callable[[Logger, Point], Awaitable[None]]


def create_influx_writer() -> Tuple[InfluxDBClientAsync, InfluxWriterAsync]:
Expand All @@ -18,7 +22,8 @@ def create_influx_writer() -> Tuple[InfluxDBClientAsync, InfluxWriterAsync]:
)
influx_write_api = client.write_api()

async def influx_writer_impl(record: Point):
await influx_write_api.write(bucket=influx_bucket, record=record)
async def influx_writer_impl(logger: Logger, record: Point):
with Timer(with_prefix(logger, "[influx]").info):
await influx_write_api.write(bucket=influx_bucket, record=record)

return client, influx_writer_impl
26 changes: 18 additions & 8 deletions aidial_analytics_realtime/topic_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import os

from bertopic import BERTopic

from aidial_analytics_realtime.utils.log_config import with_prefix
from aidial_analytics_realtime.utils.timer import Timer


class TopicModel:
def __init__(
Expand All @@ -18,14 +22,20 @@ def __init__(
self.model = BERTopic.load(
topic_model_name, topic_embeddings_model_name
)
self.model.transform(["test"]) # Make sure the model is loaded

def get_topic_by_text(self, text):
topics, _ = self.model.transform([text])
topic = self.model.get_topic_info(topics[0])
# Disable tqdm progress bars on batch encoding
self.model.verbose = False

# Make sure the model is loaded
self.model.transform(["test"])

def get_topic_by_text(self, logger: logging.Logger, text):
with Timer(with_prefix(logger, "[topic]").info):
topics, _ = self.model.transform([text])
topic = self.model.get_topic_info(topics[0])

if "GeneratedName" in topic:
# "GeneratedName" is an expected name for the human readable topic representation
return topic["GeneratedName"][0][0][0]
if "GeneratedName" in topic:
# "GeneratedName" is an expected name for the human readable topic representation
return topic["GeneratedName"][0][0][0]

return topic["Name"][0]
return topic["Name"][0]
Loading

0 comments on commit f0554f7

Please sign in to comment.