Skip to content

Commit

Permalink
Handle async anthropic function call via httpx hook
Browse files Browse the repository at this point in the history
  • Loading branch information
kxtran committed May 23, 2024
1 parent 59817f0 commit 9563ed4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
65 changes: 58 additions & 7 deletions log10/_httpx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ async def _try_post_request_async(url: str, payload: dict = {}) -> httpx.Respons
logger.error(f"Failed to insert in log10: {payload} with error {err}")


def format_anthropic_tools_request(request_content) -> str:
new_tools = []
for tool in request_content["tools"]:
new_tool = {
"type": "function",
"function": {"name": tool["name"], "description": tool["description"], "parameters": tool["input_schema"]},
}
new_tools.append(new_tool)
request_content["tools"] = new_tools
return json.dumps(request_content)


async def get_completion_id(request: Request):
host = request.headers.get("host")
if "anthropic" in host and "/v1/messages" not in str(request.url):
Expand All @@ -130,6 +142,7 @@ async def log_request(request: Request):

orig_module = ""
orig_qualname = ""
request_content_decode = request.content.decode("utf-8")
host = request.headers.get("host")
if "openai" in host:
if "chat" in str(request.url):
Expand All @@ -145,6 +158,14 @@ async def log_request(request: Request):
kind = "chat"
orig_module = "anthropic.resources.beta.tools"
orig_qualname = "Messages.stream"
request_content = json.loads(request_content_decode)
if "tools" in request_content:
orig_module = "anthropic.resources.beta.tools"
orig_qualname = "Messages.stream"
request_content_decode = format_anthropic_tools_request(request_content)
else:
orig_module = "anthropic.resources.messages"
orig_qualname = "Messages.stream"
else:
logger.warning("Currently logging is only available for anthropic async")
return
Expand All @@ -156,7 +177,7 @@ async def log_request(request: Request):
"kind": kind,
"orig_module": orig_module,
"orig_qualname": orig_qualname,
"request": request.content.decode("utf-8"),
"request": request_content_decode,
"session_id": sessionID,
}
if get_log10_session_tags():
Expand Down Expand Up @@ -198,7 +219,6 @@ async def aiter_bytes(self, *args, **kwargs):
r_json = self.parse_response_data(responses)

response_json = r_json.copy()
response_json["object"] = "chat.completion"
# r_json is the last response before "data: [DONE]"

if self.full_content:
Expand All @@ -216,13 +236,19 @@ async def aiter_bytes(self, *args, **kwargs):
"arguments": self.full_argument,
}

request_content_decode = self.request.content.decode("utf-8")
if "anthropic" in self.request.headers.get("host"):
request_content = json.loads(request_content_decode)
if "tools" in request_content:
request_content_decode = format_anthropic_tools_request(request_content)

log_row = {
"response": json.dumps(response_json),
"status": "finished",
"duration": duration,
"stacktrace": json.dumps(stacktrace),
"kind": "chat",
"request": self.request.content.decode("utf-8"),
"request": request_content_decode,
"session_id": sessionID,
}
if get_log10_session_tags():
Expand Down Expand Up @@ -252,9 +278,10 @@ def parse_anthropic_responses(self, responses: list[str]):
finish_reason = None
input_tokens = 0
output_tokens = 0

tool_call = {}
arguments = ""
for r in responses:
if self.is_anthropic_response_end_reached(r):
if not r:
break

data_index = r.find("data:")
Expand All @@ -268,12 +295,35 @@ def parse_anthropic_responses(self, responses: list[str]):
model = r_json["message"]["model"]
input_tokens = r_json["message"]["usage"]["input_tokens"]
elif type == "content_block_start":
self.full_content += r_json["content_block"]["text"]
content_block = r_json["content_block"]
type = content_block["type"]
if type == "tool_use":
id = content_block["id"]
tool_call = {
"id": id,
"type": "function",
"function": {"name": content_block["name"], "arguments": ""},
}
if "text" in content_block:
self.full_content += content_block["text"]
elif type == "content_block_delta":
self.full_content += r_json["delta"]["text"]
delta = r_json["delta"]
if "text" in delta:
self.full_content += delta["text"]
if "partial_json" in delta:
if self.full_content:
self.full_content += delta["partial_json"]
else:
arguments += delta["partial_json"]
elif type == "message_delta":
finish_reason = r_json["delta"]["stop_reason"]
output_tokens = r_json["usage"]["output_tokens"]
elif type == "content_block_end" or type == "message_end":
if tool_call:
tool_call["function"]["arguments"] = arguments
self.tool_calls.append(tool_call)
tool_call = {}
arguments = ""

return {
"id": message_id,
Expand Down Expand Up @@ -325,6 +375,7 @@ def parse_openai_responses(self, responses: list[str]):
idx = tc[0].get("index")
self.tool_calls[idx]["function"]["arguments"] += tc[0]["function"]["arguments"]

r_json["object"] = "chat.completion"
return r_json

def parse_response_data(self, responses: list[str]):
Expand Down
23 changes: 0 additions & 23 deletions log10/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,29 +403,6 @@ def __next__(self):
self._process_chunk(chunk)
return chunk

async def __aenter__(self):
self.response = await self.response.__aenter__()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.response.__aexit__(exc_type, exc_value, traceback)
return

def __aiter__(self):
return self

async def __anext__(self):
try:
chunk = await self.response.__anext__()
self._process_chunk(chunk)
return chunk
except StopAsyncIteration:
raise StopAsyncIteration from None

async def until_done(self):
async for _ in self:
pass

def _process_chunk(self, chunk):
if chunk.type == "message_start":
self.model = chunk.message.model
Expand Down

0 comments on commit 9563ed4

Please sign in to comment.