Skip to content

Commit

Permalink
feat: Agent prompt completion
Browse files Browse the repository at this point in the history
  • Loading branch information
RWallan committed May 20, 2024
1 parent 257af03 commit 9f5f89f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
43 changes: 43 additions & 0 deletions openiziai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@

from openai import OpenAI
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic.dataclasses import dataclass

from openiziai.schemas import GPTModel
from openiziai.task import Task


@dataclass
class PromptResponse:
id: str
prompt: str
response: str | None
temperature: float
tokens: int | None
fine_tuned_model: str


class Agent(BaseModel):
client: OpenAI
model: Optional[GPTModel] = None
Expand Down Expand Up @@ -51,3 +62,35 @@ def _build_template(self) -> str:
"""

return template

def prompt(
self, prompt: str, temperature: float = 0.5, max_tokens: int = 1000
) -> PromptResponse:
messages = [
{
'role': 'system',
'content': self._template,
},
{
'role': 'user',
'content': prompt,
},
]

result = self.client.chat.completions.create(
messages=messages, # pyright: ignore
model=self._fine_tuned_model,
temperature=temperature,
max_tokens=max_tokens
)

response = PromptResponse(
id=result.id,
prompt=prompt,
response=result.choices[0].message.content,
temperature=temperature,
tokens = result.usage.total_tokens if result.usage else None,
fine_tuned_model=self._fine_tuned_model,
)

return response
24 changes: 23 additions & 1 deletion tests/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_initiate_with_model(client, valid_task):

agent = Agent(client=client, model=gpt_model)

assert agent.fine_tuned_model == 'model'
assert agent._fine_tuned_model == 'model'
assert valid_task.short_backstory in agent._template
assert valid_task.goal in agent._template
assert valid_task.role in agent._template
Expand All @@ -35,3 +35,25 @@ def test_agent_initialization_without_model_with_task(client, valid_task):
def test_agent_initialization_without_model_and_task_raises_error(client):
with pytest.raises(ValidationError):
Agent(client=client)


def test_agent_prompt(openai_chat, valid_task):
gpt_model = GPTModel(
name='model', base_model='gpt-3.5-turbo', task=valid_task
)
agent = Agent(
client=openai_chat,
model=gpt_model,
)

expected_token = 500
expected_temperature = 0.5

result = agent.prompt(prompt='teste')

assert result.id == '123'
assert result.prompt == 'teste'
assert result.temperature == expected_temperature
assert result.tokens == expected_token
assert result.response
assert result.fine_tuned_model == 'model'
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def openai_chat(client):
client.chat.completions = MagicMock()
client.chat.completions.create = MagicMock(
return_value=MagicMock(
id='123',
choices=[
MagicMock(
message=MagicMock(
Expand All @@ -31,7 +32,8 @@ def openai_chat(client):
})
)
)
]
],
usage=MagicMock(total_tokens=500),
)
)
return client
Expand Down

0 comments on commit 9f5f89f

Please sign in to comment.