From 9f5f89f9412783967137a4224dd10dd6cca93b8b Mon Sep 17 00:00:00 2001 From: RWallan <3am.richardwallan@gmail.com> Date: Mon, 20 May 2024 18:18:01 -0300 Subject: [PATCH] feat: Agent prompt completion --- openiziai/agents/agent.py | 43 ++++++++++++++++++++++++++++++++++++++ tests/agents/test_agent.py | 24 ++++++++++++++++++++- tests/conftest.py | 4 +++- 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/openiziai/agents/agent.py b/openiziai/agents/agent.py index 0e231ad..fd7c502 100644 --- a/openiziai/agents/agent.py +++ b/openiziai/agents/agent.py @@ -2,11 +2,22 @@ from openai import OpenAI from pydantic import BaseModel, ConfigDict, model_validator +from pydantic.dataclasses import dataclass from openiziai.schemas import GPTModel from openiziai.task import Task +@dataclass +class PromptResponse: + id: str + prompt: str + response: str | None + temperature: float + tokens: int | None + fine_tuned_model: str + + class Agent(BaseModel): client: OpenAI model: Optional[GPTModel] = None @@ -51,3 +62,35 @@ def _build_template(self) -> str: """ return template + + def prompt( + self, prompt: str, temperature: float = 0.5, max_tokens: int = 1000 + ) -> PromptResponse: + messages = [ + { + 'role': 'system', + 'content': self._template, + }, + { + 'role': 'user', + 'content': prompt, + }, + ] + + result = self.client.chat.completions.create( + messages=messages, # pyright: ignore + model=self._fine_tuned_model, + temperature=temperature, + max_tokens=max_tokens + ) + + response = PromptResponse( + id=result.id, + prompt=prompt, + response=result.choices[0].message.content, + temperature=temperature, + tokens = result.usage.total_tokens if result.usage else None, + fine_tuned_model=self._fine_tuned_model, + ) + + return response diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index 882049d..3c30a23 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -12,7 +12,7 @@ def test_initiate_with_model(client, valid_task): agent = Agent(client=client, model=gpt_model) - assert agent.fine_tuned_model == 'model' + assert agent._fine_tuned_model == 'model' assert valid_task.short_backstory in agent._template assert valid_task.goal in agent._template assert valid_task.role in agent._template @@ -35,3 +35,25 @@ def test_agent_initialization_without_model_with_task(client, valid_task): def test_agent_initialization_without_model_and_task_raises_error(client): with pytest.raises(ValidationError): Agent(client=client) + + +def test_agent_prompt(openai_chat, valid_task): + gpt_model = GPTModel( + name='model', base_model='gpt-3.5-turbo', task=valid_task + ) + agent = Agent( + client=openai_chat, + model=gpt_model, + ) + + expected_token = 500 + expected_temperature = 0.5 + + result = agent.prompt(prompt='teste') + + assert result.id == '123' + assert result.prompt == 'teste' + assert result.temperature == expected_temperature + assert result.tokens == expected_token + assert result.response + assert result.fine_tuned_model == 'model' diff --git a/tests/conftest.py b/tests/conftest.py index ae793c1..dcc7a21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ def openai_chat(client): client.chat.completions = MagicMock() client.chat.completions.create = MagicMock( return_value=MagicMock( + id='123', choices=[ MagicMock( message=MagicMock( @@ -31,7 +32,8 @@ def openai_chat(client): }) ) ) - ] + ], + usage=MagicMock(total_tokens=500), ) ) return client