Skip to content

Commit

Permalink
use finalize in the test
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhe-log10 committed Jun 6, 2024
1 parent 3d744f2 commit 914565b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 22 deletions.
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import pytest_asyncio

from log10.load import log10_session

Expand Down Expand Up @@ -68,3 +69,10 @@ def session():
with log10_session() as session:
assert session.last_completion_id() is None, "No completion ID should be found."
yield session


@pytest_asyncio.fixture()
def async_session():
with log10_session() as session:
assert session.last_completion_id() is None, "No completion ID should be found."
yield session
22 changes: 14 additions & 8 deletions tests/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from anthropic.lib.streaming.beta import AsyncToolsBetaMessageStream
from typing_extensions import override

from log10._httpx_utils import finalize
from log10.load import log10
from tests.utils import _LogAssertion

Expand Down Expand Up @@ -35,7 +36,7 @@ def test_messages_create(session, anthropic_model):
@pytest.mark.chat
@pytest.mark.async_client
@pytest.mark.asyncio
async def test_messages_create_async(session, anthropic_model):
async def test_messages_create_async(async_session, anthropic_model):
client = anthropic.AsyncAnthropic()

message = await client.messages.create(
Expand All @@ -48,7 +49,9 @@ async def test_messages_create_async(session, anthropic_model):

text = message.content[0].text
assert isinstance(text, str)
_LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response()

await finalize()
_LogAssertion(completion_id=async_session.last_completion_id(), message_content=text).assert_chat_response()


@pytest.mark.chat
Expand Down Expand Up @@ -153,7 +156,7 @@ def test_beta_tools_messages_create(session, anthropic_model):
@pytest.mark.chat
@pytest.mark.async_client
@pytest.mark.asyncio
async def test_beta_tools_messages_create_async(session, anthropic_model):
async def test_beta_tools_messages_create_async(async_session, anthropic_model):
client = anthropic.AsyncAnthropic()

message = await client.beta.tools.messages.create(
Expand All @@ -163,7 +166,8 @@ async def test_beta_tools_messages_create_async(session, anthropic_model):
)

text = message.content[0].text
_LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response()
await finalize()
_LogAssertion(completion_id=async_session.last_completion_id(), message_content=text).assert_chat_response()


@pytest.mark.chat
Expand Down Expand Up @@ -197,7 +201,7 @@ def test_messages_stream_context_manager(session, anthropic_model):
@pytest.mark.context_manager
@pytest.mark.async_client
@pytest.mark.asyncio
async def test_messages_stream_context_manager_async(session, anthropic_model):
async def test_messages_stream_context_manager_async(async_session, anthropic_model):
client = anthropic.AsyncAnthropic()

output = ""
Expand All @@ -214,7 +218,8 @@ async def test_messages_stream_context_manager_async(session, anthropic_model):
async for text in stream.text_stream:
output += text

_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()
await finalize()
_LogAssertion(completion_id=async_session.last_completion_id(), message_content=output).assert_chat_response()


@pytest.mark.tools
Expand Down Expand Up @@ -262,7 +267,7 @@ def test_tools_messages_stream_context_manager(session, anthropic_model):
@pytest.mark.context_manager
@pytest.mark.async_client
@pytest.mark.asyncio
async def test_tools_messages_stream_context_manager_async(session, anthropic_model):
async def test_tools_messages_stream_context_manager_async(async_session, anthropic_model):
client = anthropic.AsyncAnthropic()
json_snapshot = None
final_message = None
Expand Down Expand Up @@ -307,7 +312,8 @@ async def on_input_json(self, delta: str, snapshot: object) -> None:
if json_snapshot:
output += json.dumps(json_snapshot)

await finalize()
assert output, "No output from the model."
assert session.last_completion_id(), "No completion ID found."
assert async_session.last_completion_id(), "No completion ID found."
## TODO fix this test after the anthropic fixes for the tool_calls
# _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()
9 changes: 7 additions & 2 deletions tests/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import litellm
import pytest

from log10._httpx_utils import finalize
from log10.litellm import Log10LitellmLogger
from tests.utils import _LogAssertion

Expand Down Expand Up @@ -50,6 +51,7 @@ async def test_completion_async_stream(anthropic_model):

## This test doesn't get completion_id from the session
## and logged a couple times during debug mode, punt this for now
await finalize()
assert output, "No output from the model."


Expand Down Expand Up @@ -78,6 +80,8 @@ def test_image(session, openai_vision_model):
content = resp.choices[0].message.content
assert isinstance(content, str)

# Wait for the completion to be logged
time.sleep(3)
_LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response()


Expand Down Expand Up @@ -118,7 +122,7 @@ def test_image_stream(session, anthropic_model):
@pytest.mark.stream
@pytest.mark.vision
@pytest.mark.asyncio
async def test_image_async_stream(session, anthropic_model):
async def test_image_async_stream(async_session, anthropic_model):
image_url = "https://upload.wikimedia.org/wikipedia/commons/e/e8/Log10.png"
image_media_type = "image/png"
image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
Expand Down Expand Up @@ -149,4 +153,5 @@ async def test_image_async_stream(session, anthropic_model):
output += chunk.choices[0].delta.content

time.sleep(3)
_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()
await finalize()
_LogAssertion(completion_id=async_session.last_completion_id(), message_content=output).assert_chat_response()
19 changes: 13 additions & 6 deletions tests/test_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from magentic import AsyncParallelFunctionCall, AsyncStreamedStr, FunctionCall, OpenaiChatModel, StreamedStr, prompt
from pydantic import BaseModel

from log10._httpx_utils import finalize
from log10.load import log10, log10_session
from tests.utils import _LogAssertion, format_magentic_function_args

Expand Down Expand Up @@ -58,7 +59,7 @@ def configure_oven(food: str) -> FunctionCall[str]: # ruff: ignore
@pytest.mark.async_client
@pytest.mark.stream
@pytest.mark.asyncio
async def test_async_stream_logging(session, magentic_model):
async def test_async_stream_logging(async_session, magentic_model):
@prompt("Tell me a 50-word story about {topic}", model=OpenaiChatModel(model=magentic_model))
async def tell_story(topic: str) -> AsyncStreamedStr: # ruff: ignore
...
Expand All @@ -68,13 +69,14 @@ async def tell_story(topic: str) -> AsyncStreamedStr: # ruff: ignore
async for chunk in output:
result += chunk

_LogAssertion(completion_id=session.last_completion_id(), message_content=result).assert_chat_response()
await finalize()
_LogAssertion(completion_id=async_session.last_completion_id(), message_content=result).assert_chat_response()


@pytest.mark.async_client
@pytest.mark.tools
@pytest.mark.asyncio
async def test_async_parallel_stream_logging(session, magentic_model):
async def test_async_parallel_stream_logging(async_session, magentic_model):
def plus(a: int, b: int) -> int:
return a + b

Expand All @@ -94,8 +96,9 @@ async def plus_and_minus(a: int, b: int) -> AsyncParallelFunctionCall[int]: ...
result.append(chunk)

function_args = format_magentic_function_args(result)
await finalize()
_LogAssertion(
completion_id=session.last_completion_id(), function_args=function_args
completion_id=async_session.last_completion_id(), function_args=function_args
).assert_function_call_response()


Expand All @@ -115,6 +118,7 @@ async def do_math_with_llm_async(a: int, b: int) -> AsyncStreamedStr: # ruff: i
async for chunk in result:
output += chunk

await finalize()
_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()

final_output += output
Expand All @@ -124,6 +128,7 @@ async def do_math_with_llm_async(a: int, b: int) -> AsyncStreamedStr: # ruff: i
async for chunk in result:
output += chunk

await finalize()
_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()

final_output += output
Expand All @@ -133,13 +138,14 @@ async def do_math_with_llm_async(a: int, b: int) -> AsyncStreamedStr: # ruff: i
async for chunk in result:
output += chunk

await finalize()
_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()


@pytest.mark.async_client
@pytest.mark.widget
@pytest.mark.asyncio
async def test_async_widget(session, magentic_model):
async def test_async_widget(async_session, magentic_model):
class WidgetInfo(BaseModel):
title: str
description: str
Expand All @@ -166,6 +172,7 @@ async def _generate_title_and_description(query: str, widget_data: str) -> Widge
arguments = {"title": r.title, "description": r.description}

function_args = [{"name": "return_widgetinfo", "arguments": str(arguments)}]
await finalize()
_LogAssertion(
completion_id=session.last_completion_id(), function_args=function_args
completion_id=async_session.last_completion_id(), function_args=function_args
).assert_function_call_response()
16 changes: 10 additions & 6 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from openai import NOT_GIVEN, AsyncOpenAI

from log10._httpx_utils import finalize
from log10.load import log10
from tests.utils import _LogAssertion, format_function_args

Expand Down Expand Up @@ -59,7 +60,7 @@ def test_chat_not_given(session, openai_model):
@pytest.mark.chat
@pytest.mark.async_client
@pytest.mark.asyncio
async def test_chat_async(session, openai_model):
async def test_chat_async(async_session, openai_model):
client = AsyncOpenAI()
completion = await client.chat.completions.create(
model=openai_model,
Expand All @@ -68,7 +69,8 @@ async def test_chat_async(session, openai_model):

content = completion.choices[0].message.content
assert isinstance(content, str)
_LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response()
await finalize()
_LogAssertion(completion_id=async_session.last_completion_id(), message_content=content).assert_chat_response()


@pytest.mark.chat
Expand All @@ -91,7 +93,7 @@ def test_chat_stream(session, openai_model):
@pytest.mark.async_client
@pytest.mark.stream
@pytest.mark.asyncio
async def test_chat_async_stream(session, openai_model):
async def test_chat_async_stream(async_session, openai_model):
client = AsyncOpenAI()

output = ""
Expand All @@ -103,7 +105,8 @@ async def test_chat_async_stream(session, openai_model):
async for chunk in stream:
output += chunk.choices[0].delta.content or ""

_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()
await finalize()
_LogAssertion(completion_id=async_session.last_completion_id(), message_content=output).assert_chat_response()


@pytest.mark.vision
Expand Down Expand Up @@ -276,7 +279,7 @@ def test_tools_stream(session, openai_model):
@pytest.mark.stream
@pytest.mark.async_client
@pytest.mark.asyncio
async def test_tools_stream_async(session, openai_model):
async def test_tools_stream_async(async_session, openai_model):
client = AsyncOpenAI()
# Step 1: send the conversation and available functions to the model
result = setup_tools_messages()
Expand All @@ -303,6 +306,7 @@ async def test_tools_stream_async(session, openai_model):
function_args = format_function_args(tool_calls)
assert len(function_args) == 3

await finalize()
_LogAssertion(
completion_id=session.last_completion_id(), function_args=function_args
completion_id=async_session.last_completion_id(), function_args=function_args
).assert_function_call_response()

0 comments on commit 914565b

Please sign in to comment.