From 8b2a85aec438ed822a5b4c782a53e312fbc27aca Mon Sep 17 00:00:00 2001 From: Kim Tran <17498395+kxtran@users.noreply.github.com> Date: Mon, 3 Jun 2024 12:58:24 -0400 Subject: [PATCH] ENG-821 Verify generated completions submitted to the platform (#172) * Add getting last_completion_id from httpx hook call and load async call * Add _LogAssertion utils class to input and response * Add _LogAssertion for anthropic tests * Add _LogAssertion to the rest of the tssts and adjust some assertions * Format styles * Increase timeout for litellm tests * Address feedback * Add checking for last_completion_url --------- Co-authored-by: Niklas Nielsen --- log10/_httpx_utils.py | 3 +- log10/litellm.py | 7 +++- log10/load.py | 16 ++++++-- tests/conftest.py | 9 +++++ tests/test_anthropic.py | 64 ++++++++++++++++++------------ tests/test_google.py | 9 +++-- tests/test_lamini.py | 4 +- tests/test_langchain.py | 8 ++-- tests/test_litellm.py | 31 +++++++++------ tests/test_magentic.py | 75 ++++++++++++++++++++++------------- tests/test_mistralai.py | 10 ++--- tests/test_openai.py | 69 ++++++++++++++++++++------------- tests/utils.py | 86 +++++++++++++++++++++++++++++++++++++++++ 13 files changed, 283 insertions(+), 108 deletions(-) create mode 100644 tests/utils.py diff --git a/log10/_httpx_utils.py b/log10/_httpx_utils.py index 6ca18024..cab1e1c5 100644 --- a/log10/_httpx_utils.py +++ b/log10/_httpx_utils.py @@ -9,7 +9,7 @@ from httpx import Request, Response from log10.llm import Log10Config -from log10.load import get_log10_session_tags, session_id_var +from log10.load import get_log10_session_tags, last_completion_response_var, session_id_var logger: logging.Logger = logging.getLogger("LOG10") @@ -140,6 +140,7 @@ async def log_request(request: Request): if not completion_id: return + last_completion_response_var.set({"completionID": completion_id}) orig_module = "" orig_qualname = "" request_content_decode = request.content.decode("utf-8") diff --git a/log10/litellm.py b/log10/litellm.py index f0439257..3eacbb80 100644 --- a/log10/litellm.py +++ b/log10/litellm.py @@ -2,6 +2,7 @@ from typing import List, Optional from log10.llm import LLM, Kind, Log10Config +from log10.load import last_completion_response_var try: @@ -56,6 +57,7 @@ def log_pre_api_call(self, model, messages, kwargs): request = kwargs.get("additional_args").get("complete_input_dict").copy() request["messages"] = messages.copy() completion_id = self.log_start(request, Kind.chat, self.tags) + last_completion_response_var.set({"completionID": completion_id}) litellm_call_id = kwargs.get("litellm_call_id") self.runs[litellm_call_id] = { "kind": Kind.chat, @@ -76,7 +78,10 @@ def log_success_event(self, kwargs, response_obj, start_time, end_time): litellm_call_id = kwargs.get("litellm_call_id") run = self.runs.get(litellm_call_id, None) duration = (end_time - start_time).total_seconds() - self.log_end(run["completion_id"], response_obj.dict(), duration) + + completion_id = run["completion_id"] + last_completion_response_var.set({"completionID": completion_id}) + self.log_end(completion_id, response_obj.dict(), duration) def log_failure_event(self, kwargs, response_obj, start_time, end_time): update_log_row = { diff --git a/log10/load.py b/log10/load.py index 86dcefdd..3d3e93b1 100644 --- a/log10/load.py +++ b/log10/load.py @@ -153,6 +153,10 @@ def last_completion_url(self): if last_completion_response_var.get() is None: return None response = last_completion_response_var.get() + + # organizationSlug will not be returned from httpx hook + if not response.get("organizationSlug", ""): + return None return f'{url}/app/{response["organizationSlug"]}/completions/{response["completionID"]}' def last_completion_id(self): @@ -186,8 +190,8 @@ async def log_async(completion_url, log_row): res = None try: res = post_request(completion_url) - last_completion_response_var.set(res.json()) completionID = res.json().get("completionID", None) + organizationSlug = res.json().get("organizationSlug", None) if completionID is None: logging.warn("LOG10: failed to get completionID from log10. Skipping log.") @@ -212,7 +216,7 @@ async def log_async(completion_url, log_row): logging.warn(f"LOG10: failed to log: {e}. Skipping") return None - return completionID + return {"completionID": completionID, "organizationSlug": organizationSlug} def run_async_in_thread(completion_url, log_row, result_queue): @@ -671,7 +675,9 @@ def wrapper(*args, **kwargs): with timed_block("extra time spent waiting for log10 call"): while result_queue.empty(): pass - completionID = result_queue.get() + result = result_queue.get() + completionID = result["completionID"] + last_completion_response_var.set(result) if completionID is None: logger.warning(f"LOG10: failed to get completionID from log10: {e}. Skipping log.") @@ -698,7 +704,9 @@ def wrapper(*args, **kwargs): with timed_block("extra time spent waiting for log10 call"): while result_queue.empty(): pass - completionID = result_queue.get() + result = result_queue.get() + completionID = result["completionID"] + last_completion_response_var.set(result) with timed_block("result call duration (sync)"): response = output diff --git a/tests/conftest.py b/tests/conftest.py index c2241638..46a787a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import pytest +from log10.load import log10_session + def pytest_addoption(parser): parser.addoption("--openai_model", action="store", help="Model name for OpenAI tests") @@ -59,3 +61,10 @@ def google_model(request): @pytest.fixture def magentic_model(request): return request.config.getoption("--magentic_model") + + +@pytest.fixture +def session(): + with log10_session() as session: + assert session.last_completion_id() is None, "No completion ID should be found." + yield session diff --git a/tests/test_anthropic.py b/tests/test_anthropic.py index 6727b95d..5970628a 100644 --- a/tests/test_anthropic.py +++ b/tests/test_anthropic.py @@ -1,4 +1,5 @@ import base64 +import json import anthropic import httpx @@ -8,13 +9,14 @@ from typing_extensions import override from log10.load import log10 +from tests.utils import _LogAssertion log10(anthropic) @pytest.mark.chat -def test_messages_create(anthropic_model): +def test_messages_create(session, anthropic_model): client = anthropic.Anthropic() message = client.messages.create( @@ -27,13 +29,13 @@ def test_messages_create(anthropic_model): text = message.content[0].text assert isinstance(text, str) - assert text, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() @pytest.mark.chat @pytest.mark.async_client @pytest.mark.asyncio -async def test_messages_create_async(anthropic_model): +async def test_messages_create_async(session, anthropic_model): client = anthropic.AsyncAnthropic() message = await client.messages.create( @@ -46,12 +48,12 @@ async def test_messages_create_async(anthropic_model): text = message.content[0].text assert isinstance(text, str) - assert text, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() @pytest.mark.chat @pytest.mark.stream -def test_messages_create_stream(anthropic_model): +def test_messages_create_stream(session, anthropic_model): client = anthropic.Anthropic() stream = client.messages.create( @@ -75,11 +77,11 @@ def test_messages_create_stream(anthropic_model): if text.isdigit(): assert int(text) <= 10 - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.vision -def test_messages_image(anthropic_model): +def test_messages_image(session, anthropic_model): client = anthropic.Anthropic() image1_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" @@ -108,12 +110,11 @@ def test_messages_image(anthropic_model): ) text = message.content[0].text - assert text, "No output from the model." - assert "ant" in text + _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() @pytest.mark.chat -def test_chat_not_given(anthropic_model): +def test_chat_not_given(session, anthropic_model): client = anthropic.Anthropic() message = client.beta.tools.messages.create( @@ -132,10 +133,11 @@ def test_chat_not_given(anthropic_model): content = message.content[0].text assert isinstance(content, str) assert content, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() @pytest.mark.chat -def test_beta_tools_messages_create(anthropic_model): +def test_beta_tools_messages_create(session, anthropic_model): client = anthropic.Anthropic() message = client.beta.tools.messages.create( @@ -145,13 +147,13 @@ def test_beta_tools_messages_create(anthropic_model): ) text = message.content[0].text - assert text, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() @pytest.mark.chat @pytest.mark.async_client @pytest.mark.asyncio -async def test_beta_tools_messages_create_async(anthropic_model): +async def test_beta_tools_messages_create_async(session, anthropic_model): client = anthropic.AsyncAnthropic() message = await client.beta.tools.messages.create( @@ -161,13 +163,13 @@ async def test_beta_tools_messages_create_async(anthropic_model): ) text = message.content[0].text - assert text, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() @pytest.mark.chat @pytest.mark.stream @pytest.mark.context_manager -def test_messages_stream_context_manager(anthropic_model): +def test_messages_stream_context_manager(session, anthropic_model): client = anthropic.Anthropic() output = "" @@ -187,7 +189,7 @@ def test_messages_stream_context_manager(anthropic_model): if hasattr(message.delta, "text"): output += message.delta.text - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.chat @@ -195,7 +197,7 @@ def test_messages_stream_context_manager(anthropic_model): @pytest.mark.context_manager @pytest.mark.async_client @pytest.mark.asyncio -async def test_messages_stream_context_manager_async(anthropic_model): +async def test_messages_stream_context_manager_async(session, anthropic_model): client = anthropic.AsyncAnthropic() output = "" @@ -212,13 +214,13 @@ async def test_messages_stream_context_manager_async(anthropic_model): async for text in stream.text_stream: output += text - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.tools @pytest.mark.stream @pytest.mark.context_manager -def test_tools_messages_stream_context_manager(anthropic_model): +def test_tools_messages_stream_context_manager(session, anthropic_model): client = anthropic.Anthropic() output = "" with client.beta.tools.messages.stream( @@ -252,7 +254,7 @@ def test_tools_messages_stream_context_manager(anthropic_model): if hasattr(message.delta, "partial_json"): output += message.delta.partial_json - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.tools @@ -260,15 +262,17 @@ def test_tools_messages_stream_context_manager(anthropic_model): @pytest.mark.context_manager @pytest.mark.async_client @pytest.mark.asyncio -async def test_tools_messages_stream_context_manager_async(anthropic_model): +async def test_tools_messages_stream_context_manager_async(session, anthropic_model): client = anthropic.AsyncAnthropic() - output = None + json_snapshot = None + final_message = None + output = "" class MyHandler(AsyncToolsBetaMessageStream): @override async def on_input_json(self, delta: str, snapshot: object) -> None: - nonlocal output - output = snapshot + nonlocal json_snapshot + json_snapshot = snapshot async with client.beta.tools.messages.stream( model=anthropic_model, @@ -294,6 +298,16 @@ async def on_input_json(self, delta: str, snapshot: object) -> None: max_tokens=1024, event_handler=MyHandler, ) as stream: - await stream.until_done() + final_message = await stream.get_final_message() + + content = final_message.content[0] + if hasattr(content, "text"): + output = content.text + + if json_snapshot: + output += json.dumps(json_snapshot) assert output, "No output from the model." + assert session.last_completion_id(), "No completion ID found." + ## TODO fix this test after the anthropic fixes for the tool_calls + # _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() diff --git a/tests/test_google.py b/tests/test_google.py index eb0d09a5..83cf2f0e 100644 --- a/tests/test_google.py +++ b/tests/test_google.py @@ -2,13 +2,14 @@ import pytest from log10.load import log10 +from tests.utils import _LogAssertion log10(genai) @pytest.mark.chat -def test_genai_chat(google_model): +def test_genai_chat(session, google_model): model = genai.GenerativeModel(google_model) chat = model.start_chat() @@ -21,11 +22,11 @@ def test_genai_chat(google_model): text = response.text assert isinstance(text, str) - assert text, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() @pytest.mark.chat -def test_genai_chat_w_history(google_model): +def test_genai_chat_w_history(session, google_model): model = genai.GenerativeModel(google_model, system_instruction="You are a cat. Your name is Neko.") chat = model.start_chat( history=[ @@ -39,4 +40,4 @@ def test_genai_chat_w_history(google_model): text = response.text assert isinstance(text, str) - assert text, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() diff --git a/tests/test_lamini.py b/tests/test_lamini.py index f67ac63c..18c2c187 100644 --- a/tests/test_lamini.py +++ b/tests/test_lamini.py @@ -2,15 +2,17 @@ import pytest from log10.load import log10 +from tests.utils import _LogAssertion log10(lamini) @pytest.mark.chat -def test_generate(lamini_model): +def test_generate(session, lamini_model): llm = lamini.Lamini(lamini_model) response = llm.generate("What's 2 + 9 * 3?") assert isinstance(response, str) assert "29" in response + _LogAssertion(completion_id=session.last_completion_id(), message_content=response).assert_chat_response() diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 663694cb..778323de 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -5,10 +5,11 @@ from langchain.schema import HumanMessage, SystemMessage from log10.load import log10 +from tests.utils import _LogAssertion @pytest.mark.chat -def test_chat_openai_messages(openai_model): +def test_chat_openai_messages(session, openai_model): log10(openai) llm = ChatOpenAI( model_name=openai_model, @@ -20,10 +21,11 @@ def test_chat_openai_messages(openai_model): content = completion.content assert isinstance(content, str) assert content, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() @pytest.mark.chat -def test_chat_anthropic_messages(anthropic_legacy_model): +def test_chat_anthropic_messages(session, anthropic_legacy_model): log10(anthropic) llm = ChatAnthropic(model=anthropic_legacy_model, temperature=0.7) messages = [SystemMessage(content="You are a ping pong machine"), HumanMessage(content="Ping?")] @@ -31,4 +33,4 @@ def test_chat_anthropic_messages(anthropic_legacy_model): content = completion.content assert isinstance(content, str) - assert content, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), text=content).assert_text_response() diff --git a/tests/test_litellm.py b/tests/test_litellm.py index e0a37ab2..fb36a192 100644 --- a/tests/test_litellm.py +++ b/tests/test_litellm.py @@ -1,10 +1,12 @@ import base64 +import time import httpx import litellm import pytest from log10.litellm import Log10LitellmLogger +from tests.utils import _LogAssertion ### litellm seems allowing to use multiple callbacks @@ -17,9 +19,9 @@ @pytest.mark.chat @pytest.mark.stream -def test_completion_stream(openai_model): +def test_completion_stream(session, openai_model): response = litellm.completion( - model=openai_model, messages=[{"role": "user", "content": "Count to 10."}], stream=True + model=openai_model, messages=[{"role": "user", "content": "Count to 6."}], stream=True ) output = "" @@ -27,7 +29,9 @@ def test_completion_stream(openai_model): if chunk.choices[0].delta.content: output += chunk.choices[0].delta.content - assert output, "No output from the model." + time.sleep(1) + + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.async_client @@ -36,19 +40,21 @@ def test_completion_stream(openai_model): @pytest.mark.asyncio async def test_completion_async_stream(anthropic_model): response = await litellm.acompletion( - model=anthropic_model, messages=[{"role": "user", "content": "count to 10"}], stream=True + model=anthropic_model, messages=[{"role": "user", "content": "count to 8"}], stream=True ) output = "" async for chunk in response: if chunk.choices[0].delta.content: - output += chunk.choices[0].delta.content.strip() + output += chunk.choices[0].delta.content + ## This test doesn't get completion_id from the session + ## and logged a couple times during debug mode, punt this for now assert output, "No output from the model." @pytest.mark.vision -def test_image(openai_vision_model): +def test_image(session, openai_vision_model): image_url = "https://upload.wikimedia.org/wikipedia/commons/e/e8/Log10.png" image_media_type = "image/png" image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") @@ -71,12 +77,13 @@ def test_image(openai_vision_model): content = resp.choices[0].message.content assert isinstance(content, str) - assert content, "No output from the model." + + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() @pytest.mark.stream @pytest.mark.vision -def test_image_stream(anthropic_model): +def test_image_stream(session, anthropic_model): image_url = "https://upload.wikimedia.org/wikipedia/commons/e/e8/Log10.png" image_media_type = "image/png" image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") @@ -103,14 +110,15 @@ def test_image_stream(anthropic_model): if chunk.choices[0].delta.content: output += chunk.choices[0].delta.content - assert output, "No output from the model." + time.sleep(3) + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.async_client @pytest.mark.stream @pytest.mark.vision @pytest.mark.asyncio -async def test_image_async_stream(anthropic_model): +async def test_image_async_stream(session, anthropic_model): image_url = "https://upload.wikimedia.org/wikipedia/commons/e/e8/Log10.png" image_media_type = "image/png" image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") @@ -140,4 +148,5 @@ async def test_image_async_stream(anthropic_model): if chunk.choices[0].delta.content: output += chunk.choices[0].delta.content - assert output, "No output from the model." + time.sleep(3) + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() diff --git a/tests/test_magentic.py b/tests/test_magentic.py index 792d0b48..f54ce747 100644 --- a/tests/test_magentic.py +++ b/tests/test_magentic.py @@ -6,24 +6,25 @@ from pydantic import BaseModel from log10.load import log10, log10_session +from tests.utils import _LogAssertion, format_magentic_function_args log10(openai) @pytest.mark.chat -def test_prompt(magentic_model): +def test_prompt(session, magentic_model): @prompt("Tell me a short joke", model=OpenaiChatModel(model=magentic_model)) def llm() -> str: ... output = llm() assert isinstance(output, str) - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.chat @pytest.mark.stream -def test_prompt_stream(magentic_model): +def test_prompt_stream(session, magentic_model): @prompt("Tell me a short joke", model=OpenaiChatModel(model=magentic_model)) def llm() -> StreamedStr: ... @@ -32,11 +33,11 @@ def llm() -> StreamedStr: ... for chunk in response: output += chunk - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.tools -def test_function_logging(magentic_model): +def test_function_logging(session, magentic_model): def activate_oven(temperature: int, mode: Literal["broil", "bake", "roast"]) -> str: """Turn the oven on with the provided settings.""" return f"Preheating to {temperature} F with mode {mode}" @@ -48,30 +49,32 @@ def configure_oven(food: str) -> FunctionCall[str]: # ruff: ignore ... output = configure_oven("cookies!") - assert output(), "No output from the model." + function_args = format_magentic_function_args([output]) + _LogAssertion( + completion_id=session.last_completion_id(), function_args=function_args + ).assert_function_call_response() @pytest.mark.async_client @pytest.mark.stream @pytest.mark.asyncio -async def test_async_stream_logging(magentic_model): +async def test_async_stream_logging(session, magentic_model): @prompt("Tell me a 50-word story about {topic}", model=OpenaiChatModel(model=magentic_model)) async def tell_story(topic: str) -> AsyncStreamedStr: # ruff: ignore ... - with log10_session(tags=["async_tag"]): - output = await tell_story("Europe.") - result = "" - async for chunk in output: - result += chunk + output = await tell_story("Europe.") + result = "" + async for chunk in output: + result += chunk - assert result, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=result).assert_chat_response() @pytest.mark.async_client @pytest.mark.tools @pytest.mark.asyncio -async def test_async_parallel_stream_logging(magentic_model): +async def test_async_parallel_stream_logging(session, magentic_model): def plus(a: int, b: int) -> int: return a + b @@ -85,9 +88,15 @@ async def minus(a: int, b: int) -> int: ) async def plus_and_minus(a: int, b: int) -> AsyncParallelFunctionCall[int]: ... + result = [] output = await plus_and_minus(2, 3) async for chunk in output: - assert isinstance(chunk, FunctionCall), "chunk is not an instance of FunctionCall" + result.append(chunk) + + function_args = format_magentic_function_args(result) + _LogAssertion( + completion_id=session.last_completion_id(), function_args=function_args + ).assert_function_call_response() @pytest.mark.async_client @@ -98,32 +107,39 @@ async def test_async_multi_session_tags(magentic_model): async def do_math_with_llm_async(a: int, b: int) -> AsyncStreamedStr: # ruff: ignore ... - output = "" + final_output = "" - with log10_session(tags=["test_tag_a"]): + with log10_session(tags=["test_tag_a"]) as session: result = await do_math_with_llm_async(2, 2) + output = "" async for chunk in result: output += chunk - result = await do_math_with_llm_async(2.5, 2.5) - async for chunk in result: - output += chunk + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() + + final_output += output + with log10_session(tags=["test_tag_b"]) as session: + output = "" + result = await do_math_with_llm_async(2.5, 2.5) + async for chunk in result: + output += chunk + + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() - with log10_session(tags=["test_tag_b"]): + final_output += output + with log10_session(tags=["test_tag_c"]) as session: + output = "" result = await do_math_with_llm_async(3, 3) async for chunk in result: output += chunk - assert output, "No output from the model." - assert "4" in output - assert "6.25" in output - assert "9" in output + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.async_client @pytest.mark.widget @pytest.mark.asyncio -async def test_async_widget(magentic_model): +async def test_async_widget(session, magentic_model): class WidgetInfo(BaseModel): title: str description: str @@ -146,3 +162,10 @@ async def _generate_title_and_description(query: str, widget_data: str) -> Widge assert isinstance(r.description, str) assert r.title, "No title generated." assert r.description, "No description generated." + + arguments = {"title": r.title, "description": r.description} + + function_args = [{"name": "return_widgetinfo", "arguments": str(arguments)}] + _LogAssertion( + completion_id=session.last_completion_id(), function_args=function_args + ).assert_function_call_response() diff --git a/tests/test_mistralai.py b/tests/test_mistralai.py index 001c79ba..72359dba 100644 --- a/tests/test_mistralai.py +++ b/tests/test_mistralai.py @@ -4,6 +4,7 @@ from mistralai.models.chat_completion import ChatMessage from log10.load import log10 +from tests.utils import _LogAssertion log10(mistralai) @@ -11,20 +12,19 @@ @pytest.mark.chat -def test_chat(mistralai_model): +def test_chat(session, mistralai_model): chat_response = client.chat( model=mistralai_model, messages=[ChatMessage(role="user", content="10 + 2 * 3=?")], ) content = chat_response.choices[0].message.content - assert content, "No output from the model." - assert "16" in content + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() @pytest.mark.chat @pytest.mark.stream -def test_chat_stream(mistralai_model): +def test_chat_stream(session, mistralai_model): response = client.chat_stream( model=mistralai_model, messages=[ChatMessage(role="user", content="Count the odd numbers from 1 to 20.")], @@ -36,4 +36,4 @@ def test_chat_stream(mistralai_model): if chunk.choices[0].delta.content is not None: output += content - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() diff --git a/tests/test_openai.py b/tests/test_openai.py index a3d91b8b..814dbe76 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -7,6 +7,7 @@ from openai import NOT_GIVEN, AsyncOpenAI from log10.load import log10 +from tests.utils import _LogAssertion, format_function_args log10(openai) @@ -14,7 +15,7 @@ @pytest.mark.chat -def test_chat(openai_model): +def test_chat(session, openai_model): completion = client.chat.completions.create( model=openai_model, messages=[ @@ -31,11 +32,12 @@ def test_chat(openai_model): content = completion.choices[0].message.content assert isinstance(content, str) - assert content, "No output from the model." + assert session.last_completion_url() is not None, "No completion URL found." + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() @pytest.mark.chat -def test_chat_not_given(openai_model): +def test_chat_not_given(session, openai_model): completion = client.chat.completions.create( model=openai_model, messages=[ @@ -50,15 +52,15 @@ def test_chat_not_given(openai_model): content = completion.choices[0].message.content assert isinstance(content, str) - assert content, "No output from the model." + assert session.last_completion_url() is not None, "No completion URL found." + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() @pytest.mark.chat @pytest.mark.async_client @pytest.mark.asyncio -async def test_chat_async(openai_model): +async def test_chat_async(session, openai_model): client = AsyncOpenAI() - completion = await client.chat.completions.create( model=openai_model, messages=[{"role": "user", "content": "Say this is a test"}], @@ -66,50 +68,46 @@ async def test_chat_async(openai_model): content = completion.choices[0].message.content assert isinstance(content, str) - assert content, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() @pytest.mark.chat @pytest.mark.stream -def test_chat_stream(openai_model): +def test_chat_stream(session, openai_model): response = client.chat.completions.create( model=openai_model, - messages=[{"role": "user", "content": "Count to 10"}], + messages=[{"role": "user", "content": "Count to 5"}], temperature=0, stream=True, ) output = "" for chunk in response: - content = chunk.choices[0].delta.content - if content: - output += content.strip() + output += chunk.choices[0].delta.content - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.async_client @pytest.mark.stream @pytest.mark.asyncio -async def test_chat_async_stream(openai_model): +async def test_chat_async_stream(session, openai_model): client = AsyncOpenAI() output = "" stream = await client.chat.completions.create( model=openai_model, - messages=[{"role": "user", "content": "Count to 10"}], + messages=[{"role": "user", "content": "Count to 8"}], stream=True, ) async for chunk in stream: - content = chunk.choices[0].delta.content - if content: - output += content.strip() + output += chunk.choices[0].delta.content or "" - assert output, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @pytest.mark.vision -def test_chat_image(openai_vision_model): +def test_chat_image(session, openai_vision_model): image1_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" image1_media_type = "image/jpeg" image1_data = base64.b64encode(httpx.get(image1_url).content).decode("utf-8") @@ -136,7 +134,7 @@ def test_chat_image(openai_vision_model): content = response.choices[0].message.content assert isinstance(content, str) - assert content, "No output from the model." + _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() def get_current_weather(location, unit="fahrenheit"): @@ -186,7 +184,7 @@ def setup_tools_messages() -> dict: @pytest.mark.tools -def test_tools(openai_model): +def test_tools(session, openai_model): # Step 1: send the conversation and available functions to the model result = setup_tools_messages() messages = result["messages"] @@ -198,8 +196,13 @@ def test_tools(openai_model): tools=tools, tool_choice="auto", # auto is default, but we'll be explicit ) + + first_completion_id = session.last_completion_id() response_message = response.choices[0].message tool_calls = response_message.tool_calls + + function_args = format_function_args(tool_calls) + _LogAssertion(completion_id=first_completion_id, function_args=function_args).assert_function_call_response() # Step 2: check if the model wanted to call a function if tool_calls: # Step 3: call the function @@ -232,12 +235,15 @@ def test_tools(openai_model): ) # get a new response from the model where it can see the function response content = second_response.choices[0].message.content assert isinstance(content, str) - assert content, "No output from the model." + + tool_call_completion_id = session.last_completion_id() + assert tool_call_completion_id != first_completion_id, "Completion IDs should be different." + _LogAssertion(completion_id=tool_call_completion_id, message_content=content).assert_chat_response() @pytest.mark.stream @pytest.mark.tools -def test_tools_stream(openai_model): +def test_tools_stream(session, openai_model): # Step 1: send the conversation and available functions to the model result = setup_tools_messages() messages = result["messages"] @@ -257,15 +263,20 @@ def test_tools_stream(openai_model): tool_calls.append(tc[0]) else: tool_calls[-1].function.arguments += tc[0].function.arguments - function_args = [{"function": t.function.name, "arguments": t.function.arguments} for t in tool_calls] + + function_args = format_function_args(tool_calls) assert len(function_args) == 3 + _LogAssertion( + completion_id=session.last_completion_id(), function_args=function_args + ).assert_function_call_response() + @pytest.mark.tools @pytest.mark.stream @pytest.mark.async_client @pytest.mark.asyncio -async def test_tools_stream_async(openai_model): +async def test_tools_stream_async(session, openai_model): client = AsyncOpenAI() # Step 1: send the conversation and available functions to the model result = setup_tools_messages() @@ -289,5 +300,9 @@ async def test_tools_stream_async(openai_model): else: tool_calls[-1].function.arguments += tc[0].function.arguments - function_args = [{"function": t.function.name, "arguments": t.function.arguments} for t in tool_calls] + function_args = format_function_args(tool_calls) assert len(function_args) == 3 + + _LogAssertion( + completion_id=session.last_completion_id(), function_args=function_args + ).assert_function_call_response() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..d443331d --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,86 @@ +import uuid + +from log10.completions.completions import _get_completion + + +def is_valid_uuid(val): + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + + +def format_magentic_function_args(outputs): + return [{"name": t._function.__name__, "arguments": str(t.arguments)} for t in outputs] + + +def format_function_args(tool_calls): + return [{"name": t.function.name, "arguments": t.function.arguments} for t in tool_calls] + + +class _LogAssertion: + def __init__(self, *args, **kwargs): + self._completion_id = kwargs.get("completion_id", "") + self._message_content = kwargs.get("message_content", "") + self._text = kwargs.get("text", "") + self._function_args = kwargs.get("function_args", []) + + assert self._completion_id, "No completion id provided." + assert is_valid_uuid(self._completion_id), "Completion ID should be found and valid uuid." + + def get_completion(self): + res = _get_completion(self._completion_id) + self.data = res.json()["data"] + assert self.data.get("response", {}), f"No response logged for completion {self._completion_id}." + self.response = self.data["response"] + + def assert_expected_response_fields(self): + assert self.data.get("status", ""), f"No status logged for completion {self._completion_id}." + assert self.response.get("choices", []), f"No choices logged for completion {self._completion_id}." + self.response_choices = self.response["choices"] + + def assert_text_response(self): + assert self._text, "No output generated from the model." + self.get_completion() + self.assert_expected_response_fields() + + choice = self.response_choices[0] + assert choice.get("text", {}), f"No text logged for completion {self._completion_id}." + text = choice["text"] + assert ( + self._text == text + ), f"Text does not match the generated completion for completion {self._completion_id}." + + def assert_chat_response(self): + assert self._message_content, "No output generated from the model." + self.get_completion() + self.assert_expected_response_fields() + + choice = self.response_choices[0] + assert choice.get("message", {}), f"No message logged for completion {self._completion_id}." + message = choice["message"] + assert message.get("content", ""), f"No message content logged for completion {self._completion_id}." + message_content = message["content"] + assert ( + message_content == self._message_content + ), f"Message content does not match the generated completion for completion {self._completion_id}." + + def assert_function_call_response(self): + assert self._function_args, "No function args generated from the model." + + self.get_completion() + self.assert_expected_response_fields() + choice = self.response_choices[0] + assert choice.get("message", {}), f"No message logged for completion {self._completion_id}." + message = choice["message"] + assert message.get("tool_calls", []), f"No function calls logged for completion {self._completion_id}." + response_tool_calls = message["tool_calls"] + response_function_args = [ + {"name": t.get("function", "").get("name", ""), "arguments": t.get("function", "").get("arguments", "")} + for t in response_tool_calls + ] + + assert len(response_function_args) == len( + self._function_args + ), f"Function calls do not match the generated completion for completion {self._completion_id}."