From 2561ef416141daec3d63c5dd401c860d04eec9f6 Mon Sep 17 00:00:00 2001 From: Kim Tran Date: Tue, 2 Jul 2024 14:24:25 -0400 Subject: [PATCH] Fix missing logging system message in anthropic --- log10/_httpx_utils.py | 4 ++++ tests/test_anthropic.py | 10 +++++++--- tests/utils.py | 18 ++++++++++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/log10/_httpx_utils.py b/log10/_httpx_utils.py index 84a8062..3952472 100644 --- a/log10/_httpx_utils.py +++ b/log10/_httpx_utils.py @@ -153,6 +153,10 @@ async def _try_post_request_async(url: str, payload: dict = {}) -> httpx.Respons def format_anthropic_request(request_content) -> str: + system_message = request_content.get("system", "") + if system_message: + request_content["messages"].insert(0, {"role": "system", "content": system_message}) + for message in request_content.get("messages", []): new_content = [] message_content = message.get("content") diff --git a/tests/test_anthropic.py b/tests/test_anthropic.py index ec30111..015b646 100644 --- a/tests/test_anthropic.py +++ b/tests/test_anthropic.py @@ -31,18 +31,22 @@ def test_completions_create(session, anthropic_legacy_model): @pytest.mark.chat def test_messages_create(session, anthropic_model): client = Anthropic() - + system_message = "Respond only in Yoda-speak." message = client.messages.create( model=anthropic_model, max_tokens=1000, temperature=0.0, - system="Respond only in Yoda-speak.", + system=system_message, messages=[{"role": "user", "content": "How are you today?"}], ) text = message.content[0].text assert isinstance(text, str) - _LogAssertion(completion_id=session.last_completion_id(), message_content=text).assert_chat_response() + log_assertion = _LogAssertion( + completion_id=session.last_completion_id(), message_content=text, system_message=system_message + ) + log_assertion.assert_chat_response() + log_assertion.assert_system_message_request() @pytest.mark.chat diff --git a/tests/utils.py b/tests/utils.py index 534d855..b259015 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs): self._message_content = kwargs.get("message_content", "") self._text = kwargs.get("text", "") self._function_args = kwargs.get("function_args", []) + self._system_message = kwargs.get("system_message", "") assert self._completion_id, "No completion id provided." assert is_valid_uuid(self._completion_id), "Completion ID should be found and valid uuid." @@ -52,6 +53,23 @@ def assert_text_response(self): self._text == text ), f"Text does not match the generated completion for completion {self._completion_id}." + def assert_system_message_request(self): + if not self._system_message: + return + + self.get_completion() + assert self.data.get("request", {}), f"No request logged for completion {self._completion_id}." + request = self.data["request"] + assert request.get("messages", ""), f"No request message logged for completion {self._completion_id}." + system_message = request["messages"][0] + assert system_message.get( + "content", "" + ), f"No system message content logged for completion {self._completion_id}." + content = system_message["content"] + assert ( + self._system_message == content + ), f"System message content does not match the generated completion for completion {self._completion_id}." + def assert_chat_response(self): assert self._message_content, "No output generated from the model." self.get_completion()