diff --git a/examples/logging/openai_async_stream_logging.py b/examples/logging/openai_async_stream_logging.py index 960dc784..526cbebf 100644 --- a/examples/logging/openai_async_stream_logging.py +++ b/examples/logging/openai_async_stream_logging.py @@ -3,6 +3,7 @@ import openai from openai import AsyncOpenAI +from log10._httpx_utils import gather_pending_async_tasks from log10.load import log10 @@ -14,11 +15,12 @@ async def main(): stream = await client.chat.completions.create( model="gpt-4", - messages=[{"role": "user", "content": "Count to 50."}], + messages=[{"role": "user", "content": "Count to 20."}], stream=True, ) async for chunk in stream: print(chunk.choices[0].delta.content or "", end="", flush=True) + await gather_pending_async_tasks() asyncio.run(main()) diff --git a/log10/_httpx_utils.py b/log10/_httpx_utils.py index 00402561..dc191d55 100644 --- a/log10/_httpx_utils.py +++ b/log10/_httpx_utils.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import time @@ -220,7 +221,7 @@ async def log_request(request: Request): } if get_log10_session_tags(): log_row["tags"] = get_log10_session_tags() - await _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row) + asyncio.create_task(_try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row)) class _LogResponse(Response): @@ -304,7 +305,10 @@ async def aiter_bytes(self, *args, **kwargs): } if get_log10_session_tags(): log_row["tags"] = get_log10_session_tags() - await _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row) + asyncio.create_task( + _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row) + ) + yield chunk def is_response_end_reached(self, text: str): @@ -502,7 +506,9 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: } if get_log10_session_tags(): log_row["tags"] = get_log10_session_tags() - await _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row) + asyncio.create_task( + _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=log_row) + ) return response elif response.headers.get("content-type").startswith("text/event-stream"): return _LogResponse( @@ -515,3 +521,9 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: # In case of an error, get out of the way return response + + +async def gather_pending_async_tasks(): + pending = asyncio.all_tasks() + pending.remove(asyncio.current_task()) + await asyncio.gather(*pending)