Skip to content

Commit

Permalink
fix: support non-textual requests (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Dec 13, 2024
1 parent 1046b51 commit 492e8a2
Show file tree
Hide file tree
Showing 5 changed files with 437 additions and 74 deletions.
111 changes: 68 additions & 43 deletions aidial_analytics_realtime/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from influxdb_client import Point
from langid.langid import LanguageIdentifier, model
from typing_extensions import assert_never

from aidial_analytics_realtime.dial import (
get_chat_completion_request_contents,
get_chat_completion_response_contents,
get_embeddings_request_contents,
)
from aidial_analytics_realtime.rates import RatesCalculator
from aidial_analytics_realtime.topic_model import TopicModel

Expand All @@ -19,44 +25,52 @@ class RequestType(Enum):
EMBEDDING = 2


def detect_lang(request, response, request_type):
if request_type == RequestType.CHAT_COMPLETION:
text = (
request["messages"][-1]["content"]
+ "\n\n"
+ response["choices"][0]["message"]["content"]
)
else:
text = (
request["input"]
if isinstance(request["input"], str)
else "\n\n".join(request["input"])
)
def detect_lang(
logger: Logger, request: dict, response: dict, request_type: RequestType
) -> str:
match request_type:
case RequestType.CHAT_COMPLETION:
request_contents = get_chat_completion_request_contents(
logger, request
)
response_content = get_chat_completion_response_contents(
logger, response
)
text = "\n\n".join(request_contents[-1:] + response_content)
case RequestType.EMBEDDING:
text = "\n\n".join(get_embeddings_request_contents(logger, request))
case _:
assert_never(request_type)

return detect_lang_by_text(text)
return to_string(detect_lang_by_text(text))


def detect_lang_by_text(text):
def detect_lang_by_text(text: str) -> str | None:
text = text.strip()

if not text:
return None

try:
lang, prob = identifier.classify(text)

if prob > 0.998:
return lang

return "undefined"
except Exception:
return "undefined"
pass

return None


def to_string(obj: str | None):
return obj if obj else "undefined"
def to_string(obj: str | None) -> str:
return obj or "undefined"


def build_execution_path(path: list | None):
return "undefined" if not path else "/".join(map(to_string, path))


def make_point(
logger: Logger,
deployment: str,
model: str,
project_id: str,
Expand All @@ -78,26 +92,33 @@ def make_point(
topic = None
response_content = ""
request_content = ""
if request_type == RequestType.CHAT_COMPLETION:
response_content = response["choices"][0]["message"]["content"]
request_content = "\n".join(
[message["content"] for message in request["messages"]]
)
if chat_id:
topic = topic_model.get_topic(request["messages"], response_content)
else:
request_content = (
request["input"]
if isinstance(request["input"], str)
else "\n".join(request["input"])
)
if chat_id:
topic = topic_model.get_topic_by_text(
request["input"]
if isinstance(request["input"], str)
else "\n\n".join(request["input"])
match request_type:
case RequestType.CHAT_COMPLETION:
response_contents = get_chat_completion_response_contents(
logger, response
)
request_contents = get_chat_completion_request_contents(
logger, request
)

request_content = "\n".join(request_contents)
response_content = "\n".join(response_contents)

if chat_id:
topic = topic_model.get_topic_by_text(
"\n\n".join(request_contents + response_contents)
)
case RequestType.EMBEDDING:
request_contents = get_embeddings_request_contents(logger, request)

request_content = "\n".join(request_contents)
if chat_id:
topic = topic_model.get_topic_by_text(
"\n\n".join(request_contents)
)
case _:
assert_never(request_type)

price = Decimal(0)
deployment_price = Decimal(0)
if usage is not None and usage.get("price") is not None:
Expand Down Expand Up @@ -126,7 +147,7 @@ def make_point(
(
"undefined"
if not trace
else to_string(trace.get("core_parent_span_id", None))
else to_string(trace.get("core_parent_span_id"))
),
)
.tag("project_id", project_id)
Expand All @@ -135,7 +156,7 @@ def make_point(
(
"undefined"
if not chat_id
else detect_lang(request, response, request_type)
else detect_lang(logger, request, response, request_type)
),
)
.tag("upstream", to_string(upstream_url))
Expand Down Expand Up @@ -212,7 +233,7 @@ def make_rate_point(


async def parse_usage_per_model(response: dict):
statistics = response.get("statistics", None)
statistics = response.get("statistics")
if statistics is None:
return []

Expand Down Expand Up @@ -252,6 +273,7 @@ async def on_message(
usage_per_model = await parse_usage_per_model(response)
if token_usage is not None:
point = make_point(
logger,
deployment,
model,
project_id,
Expand All @@ -273,6 +295,7 @@ async def on_message(
await influx_writer(point)
elif len(usage_per_model) == 0:
point = make_point(
logger,
deployment,
model,
project_id,
Expand All @@ -284,7 +307,7 @@ async def on_message(
request,
response,
type,
response.get("usage", None),
response.get("usage"),
topic_model,
rates_calculator,
parent_deployment,
Expand All @@ -294,6 +317,7 @@ async def on_message(
await influx_writer(point)
else:
point = make_point(
logger,
deployment,
model,
project_id,
Expand All @@ -316,6 +340,7 @@ async def on_message(

for usage in usage_per_model:
point = make_point(
logger,
deployment,
usage["model"],
project_id,
Expand Down
25 changes: 17 additions & 8 deletions aidial_analytics_realtime/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import uvicorn
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse

from aidial_analytics_realtime.analytics import (
RequestType,
Expand Down Expand Up @@ -216,8 +217,7 @@ async def on_log_message(
execution_path = message.get("execution_path", None)
deployment = message.get("deployment", "")

match = re.search(RATE_PATTERN, uri)
if match:
if re.search(RATE_PATTERN, uri):
await on_rate_message(
deployment,
project_id,
Expand All @@ -230,8 +230,7 @@ async def on_log_message(
influx_writer,
)

match = re.search(CHAT_COMPLETION_PATTERN, uri)
if match:
elif re.search(CHAT_COMPLETION_PATTERN, uri):
await on_chat_completion_message(
deployment,
project_id,
Expand All @@ -251,8 +250,7 @@ async def on_log_message(
execution_path,
)

match = re.search(EMBEDDING_PATTERN, uri)
if match:
elif re.search(EMBEDDING_PATTERN, uri):
await on_embedding_message(
deployment,
project_id,
Expand All @@ -272,6 +270,9 @@ async def on_log_message(
execution_path,
)

else:
logger.warning(f"Unsupported message type: {uri!r}")


@app.post("/data")
async def on_log_messages(
Expand All @@ -282,7 +283,8 @@ async def on_log_messages(
):
data = await request.json()

for item in data:
statuses = []
for idx, item in enumerate(data):
try:
await on_log_message(
json.loads(item["message"]),
Expand All @@ -291,7 +293,14 @@ async def on_log_messages(
rates_calculator,
)
except Exception as e:
logging.exception(e)
logging.exception(f"Error processing message #{idx}")
statuses.append({"status": "error", "error": str(e)})
else:
statuses.append({"status": "success"})

# Returning 200 code even if processing of some messages has failed,
# since the log broker that sends the messages may decide to retry the failed requests.
return JSONResponse(content=statuses, status_code=200)


@app.get("/health")
Expand Down
53 changes: 53 additions & 0 deletions aidial_analytics_realtime/dial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from logging import Logger
from typing import List


def get_chat_completion_request_contents(
logger: Logger, request: dict
) -> List[str]:
return [
content
for message in request["messages"]
for content in _get_chat_completion_message_contents(logger, message)
]


def get_chat_completion_response_contents(
logger: Logger, response: dict
) -> List[str]:
message = response["choices"][0]["message"]
return _get_chat_completion_message_contents(logger, message)


def get_embeddings_request_contents(logger: Logger, request: dict) -> List[str]:
inp = request.get("input")

if isinstance(inp, str):
return [inp]
elif isinstance(inp, list):
return [i for i in inp if isinstance(i, str)]
else:
logger.warning(f"Unexpected type of embeddings input: {type(inp)}")
return []


def _get_chat_completion_message_contents(
logger: Logger, message: dict
) -> List[str]:
content = message.get("content")
if content is None:
return []
elif isinstance(content, str):
return [content]
elif isinstance(content, list):
ret: List[str] = []
for content_part in content:
if isinstance(content_part, dict):
if content_part.get("type") == "text" and (
text := content_part.get("content")
):
ret.extend(text)
return ret
else:
logger.warning(f"Unexpected message content type: {type(content)}")
return []
8 changes: 0 additions & 8 deletions aidial_analytics_realtime/topic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ def __init__(
)
self.model.transform(["test"]) # Make sure the model is loaded

def get_topic(self, request_messages, response_content):
text = "\n\n".join(
[message["content"] for message in request_messages]
+ [response_content]
)

return self.get_topic_by_text(text)

def get_topic_by_text(self, text):
topics, _ = self.model.transform([text])
topic = self.model.get_topic_info(topics[0])
Expand Down
Loading

0 comments on commit 492e8a2

Please sign in to comment.