Skip to content

Commit

Permalink
create async task instead of await for log10 post calls, require to g…
Browse files Browse the repository at this point in the history
…ather all

pending tasks at the end
  • Loading branch information
wenzhe-log10 committed Jun 4, 2024
1 parent 59a22e5 commit 0931b01
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
4 changes: 3 additions & 1 deletion examples/logging/openai_async_stream_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import openai
from openai import AsyncOpenAI

from log10._httpx_utils import gather_pending_async_tasks
from log10.load import log10


Expand All @@ -14,11 +15,12 @@
async def main():
stream = await client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Count to 50."}],
messages=[{"role": "user", "content": "Count to 20."}],
stream=True,
)
async for chunk in stream:
print(chunk.choices[0].delta.content or "", end="", flush=True)
await gather_pending_async_tasks()


asyncio.run(main())
18 changes: 15 additions & 3 deletions log10/_httpx_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import time
Expand Down Expand Up @@ -220,7 +221,7 @@ async def log_request(request: Request):
}
if get_log10_session_tags():
log_row["tags"] = get_log10_session_tags()
await _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row)
asyncio.create_task(_try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row))


class _LogResponse(Response):
Expand Down Expand Up @@ -304,7 +305,10 @@ async def aiter_bytes(self, *args, **kwargs):
}
if get_log10_session_tags():
log_row["tags"] = get_log10_session_tags()
await _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row)
asyncio.create_task(
_try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row)
)

yield chunk

def is_response_end_reached(self, text: str):
Expand Down Expand Up @@ -502,7 +506,9 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
}
if get_log10_session_tags():
log_row["tags"] = get_log10_session_tags()
await _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row)
asyncio.create_task(
_try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row)
)
return response
elif response.headers.get("content-type").startswith("text/event-stream"):
return _LogResponse(
Expand All @@ -515,3 +521,9 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:

# In case of an error, get out of the way
return response


async def gather_pending_async_tasks():
pending = asyncio.all_tasks()
pending.remove(asyncio.current_task())
await asyncio.gather(*pending)

0 comments on commit 0931b01

Please sign in to comment.