Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENG-784 Add anthropic async and tools stream api support #162

Merged
merged 18 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/logging/anthropic_tools_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import anthropic

from log10.load import log10


log10(anthropic)


client = anthropic.Anthropic()

with client.beta.tools.messages.stream(
model="claude-3-haiku-20240307",
tools=[
{
"name": "get_weather",
"description": "Get the weather at a specific location",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Unit for the output",
},
},
"required": ["location"],
},
}
],
messages=[{"role": "user", "content": "What is the weather in SF?"}],
max_tokens=1024,
) as stream:
for message in stream:
print(message)
69 changes: 64 additions & 5 deletions log10/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,14 @@ class AnthropicStreamingResponseWrapper:
Wraps a streaming response object to log the final result and duration to log10.
"""

def __enter__(self):
self.response = self.response.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
kxtran marked this conversation as resolved.
Show resolved Hide resolved
self.response.__exit__(exc_type, exc_value, traceback)
return

def __init__(self, completion_url, completionID, response, partial_log_row):
self.completionID = completionID
self.completion_url = completion_url
Expand All @@ -392,16 +400,49 @@ def __iter__(self):

def __next__(self):
chunk = next(self.response)
self._process_chunk(chunk)
return chunk

async def __aenter__(self):
self.response = await self.response.__aenter__()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.response.__aexit__(exc_type, exc_value, traceback)
return

def __aiter__(self):
return self

async def __anext__(self):
try:
chunk = await self.response.__anext__()
self._process_chunk(chunk)
return chunk
except StopAsyncIteration:
raise StopAsyncIteration from None

async def until_done(self):
async for _ in self:
pass

def _process_chunk(self, chunk):
if chunk.type == "message_start":
self.model = chunk.message.model
self.message_id = chunk.message.id
self.input_tokens = chunk.message.usage.input_tokens
if chunk.type == "content_block_start":
if hasattr(chunk.content_block, "text"):
self.final_result += chunk.content_block.text
elif chunk.type == "message_delta":
self.finish_reason = chunk.delta.stop_reason
self.output_tokens = chunk.usage.output_tokens
elif chunk.type == "content_block_delta":
self.final_result += chunk.delta.text
elif chunk.type == "message_stop":
if hasattr(chunk.delta, "text"):
self.final_result += chunk.delta.text
if hasattr(chunk.delta, "partial_json"):
self.final_result += chunk.delta.partial_json
elif chunk.type == "message_stop" or chunk.type == "content_block_stop":
response = {
"id": self.message_id,
"object": "chat",
Expand Down Expand Up @@ -429,8 +470,6 @@ def __next__(self):
if res.status_code != 200:
logger.error(f"Failed to insert in log10: {self.partial_log_row} with error {res.text}. Skipping")

return chunk


def flatten_messages(messages):
flat_messages = []
Expand Down Expand Up @@ -506,6 +545,17 @@ def _init_log_row(func, *args, **kwargs):
else:
new_content.append(c)
m["content"] = new_content
if "tools" in kwargs_copy:
for t in kwargs_copy["tools"]:
new_function = {
"name": t["name"],
kxtran marked this conversation as resolved.
Show resolved Hide resolved
"description": t["description"],
"parameters": {
"properties": t["input_schema"]["properties"],
},
}
t["function"] = new_function
t.pop("input_schema", None)
elif "vertexai" in func.__module__:
if func.__name__ == "_send_message":
# get model name save in ChatSession instance
Expand Down Expand Up @@ -639,7 +689,7 @@ def wrapper(*args, **kwargs):
response = output
# Adjust the Anthropic output to match OAI completion output
if "anthropic" in func.__module__:
if type(output).__name__ == "Stream":
if type(output).__name__ == "Stream" or "MessageStreamManager" in type(output).__name__:
log_row["response"] = response
log_row["status"] = "finished"
return AnthropicStreamingResponseWrapper(
Expand All @@ -648,6 +698,7 @@ def wrapper(*args, **kwargs):
response=response,
partial_log_row=log_row,
)

from log10.anthropic import Anthropic

response = Anthropic.prepare_response(output, input_prompt=kwargs.get("prompt", ""))
Expand Down Expand Up @@ -892,6 +943,14 @@ def log10(module, DEBUG_=False, USE_ASYNC_=True):
attr = module.resources.messages.Messages
method = getattr(attr, "create")
setattr(attr, "create", intercepting_decorator(method))

attr = module.resources.beta.tools.Messages
method = getattr(attr, "stream")
setattr(attr, "stream", intercepting_decorator(method))

attr = module.resources.beta.tools.AsyncMessages
method = getattr(attr, "stream")
setattr(attr, "stream", intercepting_decorator(method))
elif module.__name__ == "lamini":
attr = module.api.utils.completion.Completion
method = getattr(attr, "generate")
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading